Differentiable SciPy ODE solvers in JAX
This tutorial demonstrates how to incorporate ordinary differential equation (ODE) solvers from SciPy into JAX and make them compatible with reverse-mode automatic differentiation.
This tutorial introduces a bilevel optimization framework for parameter estimation in ordinary differential equations using JAX. We will walk through two examples: one where all states are measured, and another where only a few states are observed.
This tutorial derives the reverse- and forward-mode sensitivities of hybrid dynamical systems. We will then implement a custom ODE solver in JAX that can handle events in parallel.
This tutorial introduces how to compute gradients across an optimization problem using the implicit function theorem in JAX.
This tutorial introduces differentiable cubic spline interpolation in JAX. We will begin with a basic implementation and then enhance its computational performance by leveraging sparsity.
This tutorial explains three different parameter estimation methods: single shooting, multiple shooting, and orthogonal collocation. We will walk through a simple example and implement it in CasADi.