Differentiable Optimization in JAX
Table of Contents
1. Introduction
Optimization problems are very common in machine learning. In this context, we often need to learn the parameters of an optimization problem, which requires computing the gradients of the optimal solution with respect to those parameters. In this tutorial, we will explore how to compute these gradients in JAX. Let’s start by importing the necessary packages:
from typing import Callable, Optional, Tuple
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from cyipopt import minimize_ipopt
For the purpose of this tutorial, we will only focus on convex optimization problems. However, the same techniques can be extended to nonconvex problem.
2. General Convex Optimization Problem
A convex optimization problem1 has the following structure
\[\begin{equation} \begin{aligned} & \min _{x} \ \ f(x, p) \\ \text{subject to} & \\ & g(x, p) = 0 \\ & h(x, p) \leq 0 \end{aligned} \end{equation}\]where $x \in \mathbb{R}^n$ is the decision variable, $p \in \mathbb{R}^p$ are some learnable parameters, $f : \mathbb{R}^n \times \mathbb{R}^p \to \mathbb{R}$ is the objective function, $g : \mathbb{R}^n \times \mathbb{R}^n \to \mathbb{R}^g$ are the equality constraints and $h : \mathbb{R}^n \times \mathbb{R}^n \to \mathbb{R}^h$ are the inequality constraints. For the optimization problem to be convex in $x$, the objective function $f$ and the inequality constraints $h$ have to be convex in $x$, while the equality constraints $g$ must be affine in $x$. Let $x^{\ast} $ denote the primal optimal solution, and let $[\lambda ^{\ast} \in \mathbb{R}^g, \mu ^{\ast} \in \mathbb{R}^h]$ denote the dual optimal solution, corresponding to the equality and inequality constraint respectively. We define the Lagrangian of this optimization problem as
\[\begin{equation} L(x, \lambda, \mu) = f(x, p) + \lambda ^T g(x, p) + u^T h(x, p) \end{equation}\]Our goal here is to compute the gradient of the optimal primal and dual variables with respect to the parameters $p$. Before we compute these gradients, lets look at how you can use IPOPT to solve this optimization problem. Note that if the problem is convex, an off the self convex optimization solver also suffices. However, we would like to keep the code to be application to more broader class of optimization problems. We will create a function called _constraint_differentiable_optimization
that takes as input arguments the objective function, equality and inequality constraints as Callables
, the values of parameters, the inital guess of decision variables and any additional arguments that these function might take. Now, because we will be using an external function (not compatible with JAX), we will have to wrap it with jax.pure_callback
2.
def _constraint_differentiable_optimization(f : Callable, g : Callable, h : Callable, p : jnp.ndarray, x_guess : jnp.ndarray, args : Tuple[jnp.ndarray]) -> Tuple[jnp.ndarray] :
# f = objective function arguments (p, x) -> scalar output (assumed as only reverse mode compatible) shape = ()
# g = equality constraints arguments (p, x) -> vector output (should be forward and reverse mode compatible) shape = (ng, )
# h = inequality constraints (h >= 0) arguments (p, x) -> vector output (should be forward and reverse mode compatible) shape = (h, )
# p = parameters
# x_guess = decision variables
x_flat, unravel = flatten_util.ravel_pytree(x_guess)
eps = 10. * max(*x_flat.shape, 1) * jnp.array(jnp.finfo(x_flat.dtype).eps)
ng = len(g(p, x_guess)) # number of equality constraints
nh = len(h(p, x_guess)) # number of inequality constraints
def _minimize(x_flat, p, args):
# External function to solve the optimization problem
obj = jax.jit(lambda x : f(p, unravel(x), *args))
jac = jax.jit(jax.grad(obj))
hess = jax.jit(jax.hessian(obj))
_g = jax.jit(lambda x : g(p, unravel(x)))
_h = jax.jit(lambda x : h(p, unravel(x)))
res = minimize_ipopt(
obj,
jac = jac,
hess = hess,
x0 = x_flat,
constraints = [
{"type" : "eq", "fun" : _g, "jac" : jax.jacobian(_g), "hess" : lambda x, lam : jax.hessian(lambda _x : lam @ _g(_x))(x)},
{"type" : "ineq", "fun" : _h, "jac" : jax.jacobian(_h), "hess" : lambda x, lam : jax.hessian(lambda _x : lam @ _h(_x))(x)}
],
tol = 1e-5,
options = {"maxiter" : 1000, "disp" : 0, "sb" : "yes"}
)
if not res.success : print("Inner optimization problem failed with message : ", res.message)
x = jnp.array(res.x) # optimal primal variables
v, m = map(jnp.array, (res.info["mult_g"][:ng], res.info["mult_g"][ng:])) # optimal dual variables corresponding to equality constraints and inequality constraints respectively
return x, v, m
x_opt, v_opt, m_opt = jax.pure_callback(
_minimize,
(jax.ShapeDtypeStruct(x_flat.shape, x_flat.dtype), jax.ShapeDtypeStruct((ng, ), x_flat.dtype), jax.ShapeDtypeStruct((nh, ), x_flat.dtype)),
x_flat, p, args
)
def L(p, x, v, m) : return f(p, unravel(x), *args) + v @ g(p, unravel(x)) + m @ h(p, unravel(x)) # Lagrange function
gx_jvp = jax.linearize(lambda x : g(p, unravel(x)), x_opt)[-1]
hx_vjp = lambda ct, p : jax.vjp(lambda x : h(p, unravel(x)), x_opt)[-1](ct)[0]
B_xx = jax.hessian(L, argnums = 1)(p, x_opt, v_opt, m_opt) - jax.vmap(hx_vjp, in_axes = (0, None))(jax.vmap(hx_vjp, in_axes = (0, None))(jnp.diag(m_opt / h(p, unravel(x_opt))), p).T, p)
(u, s, vh) = jnp.linalg.svd(B_xx, hermitian = True, full_matrices = False) # shape = (nx, nx), (nx, nx), (nx, nx)
sinv = jnp.where(s <= eps, 0., 1/s) # works for stiff problems
(gu, gs, gvh) = jnp.linalg.svd(jax.vmap(gx_jvp)(vh).T @ jnp.diag(sinv) @ jax.vmap(gx_jvp)(u.T), hermitian = True, full_matrices = False)
gsinv = jnp.where(gs <= eps * jnp.max(gs, initial = -jnp.inf), 0., 1/gs) # initial value is provided to deal with zero dimensional arrays
return (unravel(x_opt), v_opt, m_opt), (u, sinv, vh, gu, gsinv, gvh) # (optimal primal variables, optimal dual variables)
In addition to returning the optimal primal and dual solutions, the function also returns certain matrices, which we will discuss in the next section. Furthermore, it is possible to precompute the sparsity structure and supply it to the optimization problem. However, we will defer that discussion to a future tutorial.
3. Computing Gradients
To compute the gradients of the optimal solution with respect to its parameters, we will use the implicit function theorem3. It states that for a function $R : \mathbb{R}^n \times \mathbb{R}^p \to \mathbb{R}^n $, such that
\[\begin{equation} \begin{aligned} R(z(p), p) & = 0 \\ \frac{\partial R}{\partial z} \frac{dz}{dp} + \frac{\partial R}{\partial p} & = 0 \\ \frac{dz}{dp} & = - \left[ \frac{\partial R}{\partial z} \right]^{-1} \frac{\partial R}{\partial p} \end{aligned} \end{equation}\]here we assume $R$ is continuously differentiable with non-singular Jacobian $ \frac{\partial R}{\partial z} $. Then the Jacobian $\frac{dz}{dp}$ can be calculated using the above formula. In the context of optimization, one can view the problem as finding the root of its first-order KKT conditions, which naturally fit into this framework. So if
\[\begin{equation} R(z(p), p) = \begin{bmatrix} L_{x} \\[3pt] g \\[3pt] \mu ^Th \end{bmatrix} \end{equation}\]where $ L_x$ is the gradient of the objective function with respect to $x$. Then the optimal solution $z^{\ast} = [x^{\ast}, \lambda ^{\ast}, \mu ^{\ast}]$ is such that $R(z^{\ast}, p) = 0$, $\ h(x^{\ast}, p) \leq 0 $, and $\mu ^{\ast} \geq 0 $. We will also assume that the optimal solution is regular, i.e. a small perturbation is the optimal solution, does not change the active set ($ \left[ i \mid h_i(x^{\ast}, p) = 0 \right] $). Consequently, using the implicit function theorem gives
\[\begin{equation} \begin{aligned} \begin{bmatrix} L_{xx} & g_x^T & h_x^T\\[3pt] g_{x} & 0 & 0\\[3pt] \text{diag}(\mu) h_p & 0 & \text{diag}(h) \end{bmatrix}_{x^*, \lambda ^*, \mu ^*} \begin{bmatrix} \frac{dx^*}{dp} \\[5pt] \frac{d\lambda ^*}{dp} \\[5pt] \frac{d\mu ^*}{dp} \end{bmatrix} & = - \begin{bmatrix} L_{xp} \\[3pt] g_p \\[3pt] \mu ^Th_p \end{bmatrix}_{x^*, \lambda ^*, \mu ^*} \\ \begin{bmatrix} \frac{dx^*}{dp} \\[5pt] \frac{d\lambda ^*}{dp} \\[5pt] \frac{d\mu ^*}{dp} \end{bmatrix} & = - \begin{bmatrix} L_{xx} & g_x^T & h_x^T\\[3pt] g_{x} & 0 & 0\\[3pt] \text{diag}(\mu) h_p & 0 & \text{diag}(h) \end{bmatrix} ^{-1} \begin{bmatrix} L_{xp} \\[3pt] g_p \\[3pt] \mu ^Th_p \end{bmatrix} \end{aligned} \end{equation}\]We will now derive the Jacobian-vector-product (jvp) rule that is required to define a custom gradient rule in JAX. The jvp rule maps the primal vector $z$ and the tangent vector $v$ such that $(z, v) \to (f(z), \partial f(z) v)$. Extending it for the above equation in our case
\[\begin{equation} \begin{aligned} \underbrace{ \begin{bmatrix} \frac{dx^*}{dp} \\[5pt] \frac{d\lambda ^*}{dp} \\[5pt] \frac{d\mu ^*}{dp} \end{bmatrix} v}_{\text{To find}} & = - \begin{bmatrix} L_{xx} & g_x^T & h_x^T\\[3pt] g_{x} & 0 & 0\\[3pt] \text{diag}(\mu) h_p & 0 & \text{diag}(h) \end{bmatrix} ^{-1} \begin{bmatrix} L_{xp} \\[3pt] g_p \\[3pt] \mu ^Th_p \end{bmatrix} v \end{aligned} \end{equation}\]Instead of directly inverting the entire matrix, we will decompose it so that only certain subcomponents need to be inverted. One might recognize this matrix as the KKT matrix whose decomposition can often be obtained directly from the optimization solver. However, in our case, we will assume that this decomposition is unavailable. To begin, we rewrite the right-hand side of the equation as follows
\[\begin{equation} \begin{aligned} \begin{bmatrix} L_{x} & g_x^T & h_x^T\\[3pt] g_{x} & 0 & 0\\[3pt] \text{diag}(\mu) h_x & 0 & \text{diag}(h) \end{bmatrix} \begin{bmatrix} w_1 \\[3pt] w_2 \\[3pt] w_3 \end{bmatrix} & = \begin{bmatrix} v_1 \\[3pt] v_2 \\[3pt] v_3 \end{bmatrix} = - \begin{bmatrix} L_{xp} v \\[3pt] g_{p} v \\[3pt] \text{diag}(\mu ^*)h_{p} v \end{bmatrix} \end{aligned} \end{equation}\]where we need to find $ w_1 = \frac{dx}{dp}v, \ w_2 = \frac{d\lambda}{dp}v, \ w_3 = \frac{d\mu}{dp}v $. On the other hand, $v_1 = L_{xp}v, \ v_2 = g_{p}v, \ v_3 = \text{diag}(\mu ^{\ast})h_{p}v$ are already expressed in a form that can be efficiently computed using jvp in JAX. The corresponding system of equations can be explicitly written as
\[\begin{equation} \begin{aligned} L_{xx} w_1 + g_x^T w_2 + h_x^T w_3 & = v_1 \\ g_x w_1 & = v_2 \\ w_3 & = [\text{diag}(h)]^{-1} \left[ v_3 - \text{diag}(\mu) h_x w_1 \right] \end{aligned} \end{equation}\]Let $H = \text{diag}(h)$ and $M = \text{diag}(\mu)$ then, substituting for $w_3^T$ gives
\[\begin{equation} \begin{aligned} \begin{bmatrix} L_{xx} - h_x^T M H^{-1} h_x & g_x^T \\[3pt] g_{x} & 0 \\ \end{bmatrix} \begin{bmatrix} w_1 \\[3pt] w_2 \\ \end{bmatrix} & = - \begin{bmatrix} v_1 - h_x^T H^{-1} v_3 \\[3pt] v_2 \end{bmatrix} \end{aligned} \end{equation}\]Let \(\hat{L}_{xx} = L_{xx} - h_x^T MH^{-1} h_x, \quad \hat{v}_1 = v_1 - h_x^T H^{-1}v_3\) Note that for equality constraint optimization problem, we get \(\hat{L}_{xx} = L_{xx}, \quad \hat{v}_1 = v_1\) We get the remaining vectors of $w$ as follows
\[\begin{equation} \begin{aligned} w_2 & = \left[ g_x^T \hat{L}_{xx}^{-1}g_x \right]^{-1} [- v_2 + g_x \hat{L}_{xx}^{-1} \hat{v}_1] \\ w_1 & = \hat{L}_{xx}^{-1} \left[ \hat{v}_1 - g_x^T w_2 \right] \end{aligned} \end{equation}\]Finally, we return the vector $w$. When we defined the function _constraint_differentiable_optimization
, in addition to the optimal solution, we also returned some auxilary data. This data is essentially the SVD decomposition of the matrix in the above equations. Lets look at the rest of the code.
functools.partial(jax.custom_jvp, nondiff_argnums = (0, 1, 2))
def constraint_differentiable_optimization(f : Callable, g : Callable, h : Callable, p : jnp.ndarray, x_guess : jnp.ndarray, args : Tuple[jnp.ndarray]) -> Tuple[jnp.ndarray] :
return _constraint_differentiable_optimization(f, g, h, p, x_guess, args)
@constraint_differentiable_optimization.defjvp
def constraint_differentiable_optimization_fwd(f : Callable, g : Callable, h : Callable, primals : Tuple[jnp.ndarray], tangents : Tuple[jnp.ndarray]) -> Tuple[jnp.ndarray] :
p, x_guess, args = primals
p_dot, _, _ = tangents
_, aux = (x_opt, v_opt, m_opt), (u, sinv, vh, gu, gsinv, gvh) = constraint_differentiable_optimization(f, g, h, p, x_guess, args)
x_opt, unravel = flatten_util.ravel_pytree(x_opt)
_f = lambda p, x : f(p, unravel(x), *args)
_g = lambda p, x : g(p, unravel(x))
_h = lambda p, x : h(p, unravel(x))
gx_jvp = jax.linearize(lambda x : _g(p, x), x_opt)[-1]
gx_vjp = lambda ct : jax.vjp(lambda x : _g(p, x), x_opt)[-1](ct)[0]
hx_vjp = lambda ct, p : jax.vjp(lambda x : _h(p, x), x_opt)[-1](ct)[0]
def L(p, x, v, m) : return _f(p, x) + v @ _g(p, x) + m @ _h(p, x) # Lagrange function
v = jax.tree_util.tree_map(
jnp.negative,
[
jax.jvp(lambda _p : jax.grad(L, argnums = 1)(_p, x_opt, v_opt, m_opt), (p, ), (p_dot, ))[-1],
jax.jvp(lambda _p : _g(_p, x_opt), (p, ), (p_dot, ))[-1],
jax.jvp(lambda _p : m_opt * _h(_p, x_opt), (p, ), (p_dot, ))[-1]
])
v_hat = v[0] - hx_vjp(v[2] / _h(p, x_opt), p)
mu_v = inv_vp(gu, gsinv, gvh, gx_jvp(inv_vp(u, sinv, vh, v_hat)) - v[1]) # shape = (g, )
mu_x = inv_vp(u, sinv, vh, v_hat - gx_vjp(mu_v)) # shape = (nx, )
mu_m = (v[2] - m_opt * jax.jvp(lambda x : _h(p, x), (x_opt, ), (mu_x, ))[-1]) / _h(p, x_opt) # shape = (h, )
return ((unravel(x_opt), v_opt, m_opt), aux), ((unravel(mu_x), mu_v, mu_m), tree_util.tree_map(jnp.zeros_like, aux))
4. Testing
We test the above implementation and compare the gradients using finite difference.
eps = 1e-5
p = jnp.array([2, 0.1, 5])
def f(p, x) : return jnp.mean((jnp.array([p[1], p[0] * p[1], p[0] * p[1]**2 * p[2]]) * x - jnp.array([1.3, 1.2, 1.05]))**2)
def g(p, x) : return jnp.log(p) * (jnp.array([1, 2, 3, 0, 1, 4, 5, 6, 0]).reshape(3, -1) @ x) - jnp.array([0.6931472 * 10, -2.3025851 * 6, 1.609438 * 27]) # Linear constraints to be convex
def h(p, x) : return p * x**2 # geq ; Convex constraints
def objcon(p) :
(x_opt, v_opt, m_opt), _ = constraint_differentiable_optimization(f, g, h, p, jnp.zeros(3), ())
return f(p, x_opt) + v_opt @ g(p, x_opt) + m_opt @ h(p, x_opt)
value, gradient = jax.value_and_grad(objcon_fwd)(p)
gradient_fd = jax.vmap(lambda v : (objcon_fwd(p + v * eps) - value) / eps)(jnp.eye(len(p)))
print("Gradients (autodiff)", gradient)
print("Gradients (fd)", gradient_fd)
# Hessian calculations
hessian = jax.hessian(objcon_fwd)(p)
_grad = jax.grad(objcon_fwd)
hessian_fd = jax.vmap(lambda v : (_grad(p + v * eps) - _grad(p - v * eps)) / eps / 2)(jnp.eye(len(p)))
assert jnp.allclose(hessian_fd, hessian, atol = 100 * eps), "Hessian does not match"
5. Caveat
The custom JVP rule defined above is specifically tailored for bilevel optimization problems. In this setting, the second-order derivatives of the optimal solution with respect to the parameters—though implicitly computed—are never actually required, since they cancel out to zero. When we want to compute the second order derivatives, we also want to differentiate the KKT matrix twice as per the formula below. When computing second-order derivatives in general, one would also need to differentiate the KKT matrix ($\frac{\partial ^2R}{\partial z^2}$) twice, as indicated by the formula below. However, in our case, by saving the decomposition of the KKT matrix during the forward pass and reusing it in the reverse pass, prevents second-order derivatives from being propagated through the KKT system. While this means incorrect second-order derivatives are computed, it leads to significantly faster computation of first-order derivatives. A quick fix to obtain accurate second order derivatives is to transfer the KKT matrix calculation from the function _constraint_differentiable_optimization
to the function constraint_differentiable_optimization_fwd
.
6. References
-
Stephen Boyd. Convex optimization–boyd and vandenberghe. 2004. ↩
-
James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake VanderPlas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+NumPy programs, 2018. ↩
-
Steven George Krantz and Harold R Parks. The implicit function theorem: history, theory, and applications. Springer Science & Business Media, 2002. ↩