Sparse Identification of Nonlinear Dynamics with Integral Terms

Table of Contents

  1. Introduction
  2. SINDy with Integral Terms
  3. Implementation in JAX
  4. References

1. Introduction

Sparse Identification of Nonlinear Dynamics (SINDy) 1 is a popular data-driven system identification technique that discovers dynamic models using measurements. Since the underlying mathematical expression (i.e., the dynamic model) is unknown, SINDy assumes that it can be constructed by selecting relevant terms from a large library of candidate basis functions (polynomial combinations of the measurements). Once the active terms are identified, a linear regression problem is solved to estimate their corresponding coefficients. This method has atleast two problems. First, it can only represent systems as additive combinations of the basis functions included in the library, making it incapable of discovering rational functions or parameterized nonlinear terms. Second, because SINDy relies on finite-difference approximations to estimate derivatives from data, it is extremely sensitive to measurement noise. The first challenge is addressed, to some extent, in this post, while in the present tutorial we focus on addressing the second challenge by using integral formulations instead of directly computing derivatives. This approach is based on our work on Derivative-Free SINDy (DF-SINDy) 2.

2. SINDy with Integral Terms

Given a set of ordinary differential equations (ODE) as follows

\[\begin{equation} \frac{d}{dt}X(t) = f(X, p) \end{equation}\]

where the state vector $ X(t) = [x_1(t), \ x_2(t), \ \ldots , \ x_n(t)]^T \in \mathbb{R}^{n \times 1}$ represents the measurements in the system ($n$ in total) at time $t$. $p$ are the parameters of the differential equations. The functions $f: \mathbb{R}^n \times \mathbb{R}^n \rightarrow \mathbb{R}^n $ describe the dynamics of the system that we wish to learn from the measurements. We assume that the parameters $p_k$ appear linearly in $f$ and are constant i.e they do not vary with time. $f$ can therefore be decomposed as a dot product ($ \cdot $) between the features $g(X)$ and the parameters $p$.

\[\begin{equation} \begin{aligned} x_k(t) & = \ x_k(t = 0) + \int_0^t g_k(X(t)) \cdot p_k \ dt \quad \forall \ k \in \{1, 2, \ldots , n \} \\ & = \ x_k(t = 0) + p_k \cdot \int_0^t g_k(X(t)) \ dt \quad \forall \ k \in \{1, 2, \ldots , n \} \end{aligned} \end{equation}\]

To find the parameters $p_k$, one can formulate a least squares problem, dropping the subscript $k$ and the explicit notation of time dependence of $X(t)$ for convenience, as follows

\[\begin{equation} p^{*} = \ arg\min _p \Big| \Big| (\hat{x} - \hat{x}(t = 0)) - p \int_0^t g(X) \ dt \Big| \Big|_2^2 \end{equation}\]

Because parameters $p$ appear linearly, the problem is convex and therefore its analytical solution can be written as

\[\begin{equation} p^{*} = \left [ \left ( \int _0^t g(X) \ dt \right) ^T \left ( \int _0^t g(X) \ dt \right ) \right ]^{-1} \left( \int_0^t g(X) \ dt \right) ^T (\hat{x}(t) - \hat{x}(t = 0)) \end{equation}\]

Since we do not know $ g(X)$ a priori, we consider a functional library vector $ \Theta (X) $ of $b$ terms that contains all possible polynomial combinations and assume that $g(X) \in \Theta (X)$.

\[\begin{equation} \Theta (X) = [x_1 \quad x_2 \ \cdots \ x_n \quad x_1^2 \quad x_1x_2 \ \cdots \ x_n^2 \quad x_1^3 \quad x_1 x_2 x_n \ \cdots] \end{equation}\]

These polynomial terms are user-defined and akin to the original SINDy method. Since $ g(X) \in \Theta (X)$, to calculate $ \int _0^t g(X) \ dt $, we need $x_k \in X$ as an explicit function of time. We approximate this explicit function of all the states $x_k \in X$ using cubic spline interpolation on the measurements $\hat{x}_k$. Let $\Psi (\hat{X}) = { \psi _1 (\hat{x}_1) \ \psi _2 (\hat{x}_2) \ \cdots \ \psi _n (\hat{x}_n) } $ be a set of interpolating functions obtained from measurements $\hat{X}$, then integrating the library matrix results in

\[\begin{equation} \begin{aligned} \int _{t_1}^{t_2} \Theta (X) \ dt & \approx \int _{t_1}^{t_2} \Theta (\Psi (\hat{X})) \ dt \\ & = \left [\int _{t_1}^{t_2} \psi _1 \ dt \quad \int _{t_1}^{t_2} \psi _2 \ dt \ \cdots \ \int _{t_1}^{t_2} \psi _n \ dt \quad \int _{t_1}^{t_2}\psi _1^2 \ dt \quad \int _{t_1}^{t_2}\psi _1 \psi _2 \ dt \ \cdots \ \int _{t_1}^{t_2}\psi _n^2 \ dt \quad \int _{t_1}^{t_2}\psi _1^3 \ dt \ \cdots \right] \end{aligned} \end{equation}\]

Consider the coefficient matrix $ \Xi = [\xi _1, \ \xi _2, \ \ldots , \ \xi _n] \in \mathbb{R}^{b \times n}$ where each column $\xi _{i}$ represents the coefficients corresponding to the terms in the functional library matrix $ \Theta (\Psi (\hat{X} ))$, then the linear system identification problem is formulated as

\[\begin{equation} \begin{aligned} \begin{bmatrix} \hat{x}_1(t) - \hat{x}_1(t_0) \\ \hat{x}_2(t) - \hat{x}_2(t_0) \\ \vdots \\ \hat{x}_{n - 1}(t) - \hat{x}_{n - 1}(t_0) \\ \hat{x}_n(t) - \hat{x}_{n}(t_0) \\ \end{bmatrix} _{n \times 1} = \begin{bmatrix} \int _{0}^{t} \Theta (\Psi) \ dt \cdot \xi _{1} \\ \vdots \\ \int _{0}^{t} \Theta (\Psi) \ dt \cdot \xi _{n} \\ \end{bmatrix} _{n \times 1} \end{aligned} \end{equation}\]

To find the optimal value of the coefficient matrix $ \Xi $, the optimization problem can be written, where the summation is over all the measurements obtained between the start time $t = t_0$ and the end time $t = t_f$.

\[\begin{equation} \Xi ^* = \text{arg}\min _{\Xi} \ \frac{1}{2} \ \sum _{t = t_1}^{t_f} \ \text{MSE(Equation 7)} \\ \end{equation}\]

Since $g(X) \in \Theta (X)$, we need to find a sparse solution to the above optimization problem. We use sequential threshold least square (STLSQ) algorithm that minimizes a least square objective function with ridge penalty $(\lambda)$. It sets the coefficients that are less than the thresholding parameter $(\epsilon)$ to zero and solves the optimization problem again. Note that other variable selection algorithms such as LASSO, Elastic Net, SR3, etc can also be used. The modified optimization problem is given as

\[\begin{equation} \Xi ^* = \text{arg}\min _{\Xi} \ \frac{1}{2} \left[ \sum _{t = t_1}^{t_f} \ \text{MSE(Equation 7)} + \lambda \sum _{i = 1}^n \big|\big| \xi _i \big|\big|_2 ^2 \right] \\ \end{equation}\]

There are two possible terminating conditions 1) all the coefficients in the matrix $ \Xi $ are eliminated, in which case the thresholding is too large to consider all the coefficients and no solution is obtained 2) there are no more coefficients to be eliminated, in which case the optimal solution with the remaining coefficients has been found. If the algorithm eliminated all the coefficients then consider either lowering the thresholding parameter or changing the terms in the polynomial library. Once the algorithm is terminated successfully, the values of the coefficients are returned.

Now, lets implement this in Python.

3. Implementation in JAX

We begin by importing the necessary packages.

from functools import partial

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from cyipopt import minimize_ipopt
from scipy.interpolate import CubicSpline
from scipy.integrate import odeint

We consider the classic Lotka-Volterra system and generate data by simulating its dynamics.

def LotkaVolterra(x, t, p) : 
    # Lotka-Volterra dynamics
    return jnp.array([
        p[0] * x[0] + p[1] * x[0] * x[1], 
        p[2] * x[1] + p[3] * x[0] * x[1]
    ])

t = jnp.arange(0, 20., 0.1) # Time span
x = jnp.array([0.1, 0.2]) # Initial conditions
p = jnp.array([2/3, -4/3, -1, 1]) # parameters
solution = odeint(LotkaVolterra, x, t, args = (p, )) # Integrate
interp = CubicSpline(t, solution) # Cubic spline interpolation

We will now construct a library of basis functions consisting of polynomial combinations of the states up to third order. After that, we will integrate all of these basis functions. Note that these computations are precomputed and are not part of the optimization procedure.

def Library(x, t, _interp):
    # Library of basis functions
    x = _interp(t)
    return jnp.array([
        x[0], x[1], # degree 1 polynomials
        x[0]**2, x[0] * x[1], x[1]**2, # degree 2 polynomials
        x[0]**3, x[0]**2 * x[1], x[0] * x[1]**2, x[1]**3 # degree 3 polynomials
    ])

ThetaInitial = Library(x, t[0], interp) # Initial Conditions
Theta = odeint(Library, ThetaInitial, t, args = (interp, )) - ThetaInitial

We will now define the objective function (Equation 9). The optimization problem will be solved using IPOPT, while JAX will be employed to compute the gradients. Although this particular implementation may not be the fastest or most efficient, it is the most intuitive and straightforward to implement for the purpose of this tutorial.

def objective(p, include, target, reg):
    # Loss function 
    p = p.reshape(2, -1) # shape = (nstates, nterms in library)
    include = include.reshape(2, -1)
    mse = jnp.mean((target - jnp.einsum("ij,kj,kj->ik", Theta, p, include))**2) # regression loss
    penalty = reg * jnp.sum(p**2) # L2 penalty
    return mse + penalty


def stlsq(thresholding, regularization, maxiters):
    # Sequential Threshold Least Square Algorithm

    iteration = 0 # Iteration count
    big = jnp.ones(2 * 9) # Initially, include all parameters 
    previous_big = big # Previous interations big parameters
    p_guess = jnp.zeros(2 * 9) # Initial guess
    
    # Jit compiled objective, Jacobian and Hessian functions
    _obj = jax.jit(partial(objective, target = solution - x, reg = regularization))
    _obj_grad = jax.jit(jax.grad(_obj))
    _obj_hess = jax.jit(jax.jacfwd(_obj_grad))

    while iteration < maxiters : 

        if iteration > 0 and jnp.allclose(previous_big, big) : 
            print("Optimal solution found")
            break
            
        # Solve the optimization problem using IPOPT
        solution_object = minimize_ipopt(
            partial(_obj, include = big), 
            x0 = p_guess, # Restart from intial guess 
            jac = partial(_obj_grad, include = big),
            hess = partial(_obj_hess, include = big),  
            tol = 1e-5, 
            options = {"maxiter" : 1000, "disp" : 5}
            )
        
        p = jnp.array(solution_object.x)
        prev_big = big
        big = jnp.minimum(
            jnp.where(jnp.abs(p) < thresholding, 0, 1), # new indices
            prev_big # previous indices
        )
        iteration += 1
    
    return p # return optimal solution


optimal = stlsq(thresholding = 0.1, regularization = 0., maxiters = 10)

The following results show that the algorithm identified the correct terms along with their correct corresponding coefficients.

>>> optimal
Array([ 0.66666611,  0.        ,  0.        , -1.33332863,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
       -0.99999867,  0.        ,  0.99999614,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ], dtype=float64)

4. References