14.1. Introduction to Adjoint Sensitivity Analysis

This section presents the SUNAdjointStepper and SUNAdjointCheckpointScheme classes. The SUNAdjointStepper represents a generic adjoint sensitivity analysis (ASA) procedure to obtain the adjoint sensitivities of an IVP of the form

(14.1)\[\dot{y}(t) = f(t, y, p), \qquad y(t_0) = y_0(p), \qquad y \in \mathbb{R}^N,\]

where \(p\) is some set of \(N_s\) problem parameters.

Note

The API itself does not implement ASA, but it provides a common interface for ASA capabilities implemented in the SUNDIALS packages. Right now it supports the ASA capabilities in ARKODE, while the ASA capabilities in CVODES and IDAS must be used directly.

Suppose we have a functional \(g(t_f, y(t_f), p)\) for which we would like to compute the gradients \(dg(t_f, y(t_f), p)/dy(t_0)\) and/or \(dg(t_f, y(t_f), p)/dp\). This most often arises in the form of an optimization problem such as

(14.2)\[\min_{y(t_0), p} g(t_f, y(t_f), p)\]

Warning

The CVODES documentation uses \(\lambda\) to represent the adjoint variables needed to obtain the gradient \(dG/dp\) where \(G\) is an integral of \(g\). Our use of \(\lambda\) in the following is akin to the use of \(\mu\) in the CVODES docs.

The adjoint method is one approach to obtaining the gradients that is particularly efficient when there are relatively few functionals and a large number of parameters. While CVODES and IDAS continuous adjoint methods (differentiate-then-discretize), ARKODE provides discrete adjoint methods (discretize-then-differentiate). For the continuous approach, we derive and solve the adjoint IVP backwards in time

(14.3)\[\dot{\lambda}(t) = -f_y^*(t, y, p) \lambda,\quad \lambda(t_F) = g_y^*(t_f, y(t_f), p)\]

where \(\lambda(t) \in \mathbb{R}^{N_s}\), \(f_y \equiv \partial f/\partial y \in \mathbb{R}^{N \times N}\) and \(g_y \equiv \partial g/\partial y \in \mathbb{R}^{N \times N}\), are the Jacobians with respect to the dependent variable, \(*\) denotes the Hermitian (conjugate) transpose, \(N\) is the size of the original IVP, and \(N_s\) is the number of parameters. When solved with a numerical time integration scheme, the solution to the continuous adjoint IVP is a numerical approximation of the continuous adjoint sensitivities,

(14.4)\[\lambda(t_n) \approx g_y(t_f, y(t_n), p), \quad \lambda(t_0) \approx g_y(t_f, y(t_0), p).\]

The gradients with respect to the parameters can then be obtained as

(14.5)\[\frac{d g(t_f, y(t_n), p)}{dp} = \lambda^*(t_n) y_p(t_n) + g_p(t_f, y(t_n), p) + \int_{t_n}^{t_f} \lambda^*(t) f_p(t, y(t_n), p)~ dt,\]

where y_p(t) equiv partial y(t)/partial p in mathbb{R}^{N times N_s}, and \(g_p \equiv \partial g/\partial p \in \mathbb{R}^{N \times N_s}\) and \(f_p \equiv \partial f/\partial p \in \mathbb{R}^{N \times N_s}\) are the Jacobians with respect to the parameters.

For the discrete adjoint approach, we first numerically discretize the original IVP (14.1) using a time integration scheme, \(\varphi\), so that

(14.6)\[y_0 = y(t_0),\quad y_n = \varphi(y_{n-k}, \cdots, y_{n-1}, p), \quad k = n, \cdots, 1.\]

For linear multistep methods \(k \geq 1\) and for one step methods \(k = 1\). Reformulating the optimization problem for the discrete case, we have

(14.7)\[\min_{y_0, p} g(t_f, y_n, p)\]

The gradients of (14.7) can be computed using the transposed chain rule backwards in time to obtain the discrete adjoint variables \(\lambda_n, \lambda_{n-1}, \cdots, \lambda_0\) and \(\mu_n, \mu_{n-1}, \cdots, \mu_0\). The discrete adjoint variables represent the gradients of the discrete cost function (14.7) with respect to changes in the discretized IVP (14.6),

(14.8)\[\frac{dg}{dy_n} = \lambda_n , \quad \frac{dg}{dp} = \mu_n + \lambda_n^* \left(\frac{\partial y_0}{\partial p} \right).\]

14.1.1. Discrete vs. Continuous Adjoint Method

It is understood that the continuous adjoint method can be problematic in the context of optimization problems because the continuous adjoint method provides an approximation to the gradient of a continuous cost function while the optimizer is expecting the gradient of the discrete cost function. The discrepancy means that the optimizer can fail to due to inconsistent gradients [63, 64]. On the other hand, the discrete adjoint method provides the exact gradient of the discrete cost function allowing the optimizer to fully converge. Consequently, the discrete adjoint method is often preferable in optimization despite its own drawbacks – such as its (relatively) increased memory usage and the possible introduction of unphysical computational modes [133]. This is not to say that the discrete adjoint approach is always the better choice over the continuous adjoint approach in optimization. Computational efficiency and stability of one approach over the other can be both problem and method dependent. Section 8 in the paper [110] discusses the tradeoffs further and provides numerous references that may help inform users in choosing between the discrete and continuous adjoint approaches.

14.2. The SUNAdjointStepper Class

Added in version 7.3.0.

type SUNAdjointStepper

The SUNAdjointStepper class provides a package-agnostic interface to SUNDIALS ASA capabilities. It currently only supports the discrete ASA capabilities in the ARKODE package, but in the future this support may be expanded.

14.2.1. Class Methods

The SUNAdjointStepper class has the following methods:

SUNErrCode SUNAdjointStepper_Create(SUNStepper fwd_sunstepper, sunbooleantype own_fwd, SUNStepper adj_sunstepper, sunbooleantype own_adj, suncountertype final_step_idx, sunrealtype tf, N_Vector sf, SUNAdjointCheckpointScheme checkpoint_scheme, SUNContext sunctx, SUNAdjointStepper *adj_stepper)

Creates the SUNAdjointStepper object needed to solve the adjoint problem.

Parameters:
  • fwd_sunstepper – The SUNStepper to be used for forward computations of the original ODE.

  • own_fwd – Should fwd_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.

  • adj_sunstepper – The SUNStepper to be used for the backward integration of the adjoint ODE.

  • own_adj – Should adj_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.

  • final_step_idx – The index (step number) of the step corresponding to t_f for the forward ODE.

  • tf – The terminal time for the forward ODE (the initial time for the adjoint ODE).

  • sf – The terminal condition for the adjoint ODE.

  • checkpoint_scheme – The SUNAdjointCheckpointScheme object that determines the checkpointing strategy to use. This should be the same object provided to the forward integrator/stepper.

  • sunctx – The SUNContext for the simulation.

  • adj_stepper – The SUNAdjointStepper to construct (will be NULL on failure).

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, sunrealtype t0, N_Vector y0, sunrealtype tf, N_Vector sf)

Reinitializes the adjoint stepper to solve a new problem of the same size.

Parameters:
  • adj_stepper – The adjoint solver object.

  • t0 – The new initial time.

  • y0 – The new initial condition.

  • tf – The time to start integrating the adjoint system from.

  • sf – The terminal condition vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_Evolve(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)

Integrates the adjoint system.

Parameters:
  • adj_stepper – The adjoint solver object.

  • tout – The time at which the adjoint solution is desired.

  • sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).

  • tret – On return, the time reached by the adjoint solver.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_OneStep(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)

Evolves the adjoint system backwards one step.

Parameters:
  • adj_stepper – The adjoint solver object.

  • tout – The time at which the adjoint solution is desired.

  • sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).

  • tret – On return, the time reached by the adjoint solver.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_RecomputeFwd(SUNAdjointStepper adj_stepper, suncountertype start_idx, sunrealtype t0, N_Vector y0, sunrealtype tf)

Evolves the forward system in time from (start_idx, t0) to (stop_idx, tf) with dense checkpointing.

Parameters:
  • adj_stepper – The SUNAdjointStepper object.

  • start_idx – the index of the step, w.r.t. the original forward integration, to begin forward integration from.

  • t0 – the initial time, w.r.t. the original forward integration, to start forward integration from.

  • y0 – the initial state, w.r.t. the original forward integration, to start forward integration from.

  • tf – the final time, w.r.t. the original forward integration, to stop forward integration at.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_SetUserData(SUNAdjointStepper adj_stepper, void *user_data)

Sets the user data pointer.

Parameters:
  • adj_stepper – The SUNAdjointStepper object.

  • user_data – the user data pointer that will be passed back to user-supplied callback functions.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_GetNumSteps(SUNAdjointStepper adj_stepper, suncountertype *num_steps)

Retrieves the number of steps taken by the adjoint stepper.

Parameters:
  • adj_stepper – The SUNAdjointStepper object.

  • num_steps – Pointer to store the number of steps.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_GetNumRecompute(SUNAdjointStepper adj_stepper, suncountertype *num_recompute)

Retrieves the number of recomputations performed by the adjoint stepper.

Parameters:
  • adj_stepper – The SUNAdjointStepper object.

  • num_recompute – Pointer to store the number of recomputations.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointStepper_PrintAllStats(SUNAdjointStepper adj_stepper, FILE *outfile, SUNOutputFormat fmt)

Prints the adjoint stepper statistics/counters in a human-readable table format or CSV format.

Parameters:
Returns:

A SUNErrCode indicating failure or success.

14.2.2. User-Supplied Functions

typedef int (*SUNAdjRhsFn)(sunrealtype t, N_Vector y, N_Vector sens, N_Vector sens_dot, void *user_data)

These functions compute the adjoint ODE right-hand side.

For ARKODE, this is

\[\begin{split}\Lambda &= f_y^*(t, y, p) \lambda, \quad \text{and if the systems has parameters}, \\ \nu &= f_p^*(t, y, p) \lambda.\end{split}\]

and corresponds to (2.74) for explicit Runge–Kutta methods.

Parameters:

  • t – the current value of the independent variable.

  • y – the current value of the forward solution vector.

  • sens – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\lambda\) and the second holds \(\mu\) and is unused in this function.

  • sens_dot – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\Lambda\) and the second holds \(\nu\).

  • user_data – the user_data pointer that was passed to SUNAdjointStepper_SetUserData().

Returns:

A SUNAdjRhsFn should return 0 if successful, a positive value if a recoverable error occurred (in which case the integrator may attempt to correct), or a negative value if it failed unrecoverably (in which case the integration is halted and an error is raised).

Note

Allocation of memory for y is handled within the integrator.

The vector sens_dot may be uninitialized on input; it is the user’s responsibility to fill this entire vector with meaningful values.

14.3. The SUNAdjointCheckpointScheme Class

Added in version 7.3.0.

As with other SUNDIALS classes, the SUNAdjointCheckpointScheme abstract base class is implemented using a C structure containing a content pointer to the derived class member data and a structure of function pointers to the derived class implementations of the virtual methods.

type SUNAdjointCheckpointScheme

A class that provides an interface for checkpointing states during forward integration and accessing them as needed during the backwards integration of the adjoint model.

enum SUNDataIOMode
enumerator SUNDATAIOMODE_INMEM

The IO mode for data that is stored in addressable random access memory. The location of the memory (e.g., CPU or GPU) is not specified by this mode.

14.3.1. Base Class Methods

SUNErrCode SUNAdjointCheckpointScheme_NewEmpty(SUNContext sunctx, SUNAdjointCheckpointScheme *cs_ptr)
Parameters:
Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointCheckpointScheme_NeedsSaving(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)

Determines if the (step_num, stage_num) should be checkpointed or not.

Parameters:
  • self – the SUNAdjointCheckpointScheme object

  • step_num – the step number of the checkpoint

  • stage_num – the stage number of the checkpoint

  • t – the time of the checkpoint

  • yes_or_no – boolean indicating if the checkpoint should be saved or not

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointCheckpointScheme_InsertVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)

Inserts the vector as the checkpoint for (step_num, stage_num).

Parameters:
  • self – the SUNAdjointCheckpointScheme object

  • step_num – the step number of the checkpoint

  • stage_num – the stage number of the checkpoint

  • t – the time of the checkpoint

  • y – the state vector to checkpoint

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointCheckpointScheme_LoadVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)

Loads the checkpointed vector for (step_num, stage_num).

Parameters:
  • self – the SUNAdjointCheckpointScheme object

  • step_num – the step number of the checkpoint

  • stage_num – the stage number of the checkpoint

  • t – the desired time of the checkpoint

  • peek – if true, then the checkpoint will be loaded but not deleted regardless of other implementation-specific settings. If false, then the checkpoint may be deleted depending on the implementation.

  • yout – the loaded state vector

  • tout – on output, the time of the checkpoint

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointCheckpointScheme_EnableDense(SUNAdjointCheckpointScheme self, sunbooleantype on_or_off)

Enables or disables dense checkpointing (checkpointing every step/stage). When dense checkpointing is disabled, the checkpointing interval that was set when the object was created is restored.

Parameters:
  • self – the SUNAdjointCheckpointScheme object

  • on_or_off – if true, dense checkpointing will be turned on, if false it will be turned off.

Returns:

A SUNErrCode indicating failure or success.

SUNErrCode SUNAdjointCheckpointScheme_Destroy(SUNAdjointCheckpointScheme *cs_ptr)

Destroys (deallocates) the SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating failure or success.

14.3.2. Implementation Specific Methods

This section describes the virtual methods defined by the SUNAdjointCheckpointScheme abstract base class.

typedef SUNErrCode (*SUNAdjointCheckpointSchemeNeedsSavingFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)

This type represents a function with the signature of SUNAdjointCheckpointScheme_NeedsSaving().

typedef SUNErrCode (*SUNAdjointCheckpointSchemeInsertVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)

This type represents a function with the signature of SUNAdjointCheckpointScheme_InsertVector().

typedef SUNErrCode (*SUNAdjointCheckpointSchemeLoadVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)

This type represents a function with the signature of SUNAdjointCheckpointScheme_LoadVector().

typedef SUNErrCode (*SUNAdjointCheckpointSchemeEnableDenseFn)(SUNAdjointCheckpointScheme check_scheme, sunbooleantype on_or_off)

This type represents a function with the signature of SUNAdjointCheckpointScheme_EnableDense().

typedef SUNErrCode (*SUNAdjointCheckpointSchemeDestroyFn)(SUNAdjointCheckpointScheme *check_scheme_ptr)

This type represents a function with the signature of SUNAdjointCheckpointScheme_Destroy().

14.3.3. Setting Content and Member Functions

These functions can be used to set the content pointer or virtual method pointers as needed when implementing the abstract base class.

SUNErrCode SUNAdjointCheckpointScheme_SetNeedsSavingFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeNeedsSavingFn fn)

This function attaches a SUNAdjointCheckpointSchemeNeedsSavingFn function to a SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_SetInsertVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeInsertVectorFn fn)

This function attaches a SUNAdjointCheckpointSchemeInsertVectorFn function to a SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_SetLoadVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeLoadVectorFn fn)

This function attaches a SUNAdjointCheckpointSchemeLoadVectorFn function to a SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_SetDestroyFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeDestroyFn fn)

This function attaches a SUNAdjointCheckpointSchemeDestroyFn function to a SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_SetEnableDenseFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeEnableDenseFn fn)

This function attaches a SUNAdjointCheckpointSchemeEnableDenseFn function to a SUNAdjointCheckpointScheme object.

Parameters:
Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_SetContent(SUNAdjointCheckpointScheme self, void *content)

This function attaches a member data (content) pointer to a SUNAdjointCheckpointScheme object.

Parameters:
  • self – a checkpoint scheme object.

  • content – a pointer to the checkpoint scheme member data.

Returns:

A SUNErrCode indicating success or failure.

SUNErrCode SUNAdjointCheckpointScheme_GetContent(SUNAdjointCheckpointScheme self, void **content)

This function retrieves the member data (content) pointer from a SUNAdjointCheckpointScheme object.

Parameters:
  • self – a checkpoint scheme object.

  • content – a pointer to set to the checkpoint scheme member data pointer.

Returns:

A SUNErrCode indicating success or failure.

14.4. The SUNAdjointCheckpointScheme_Fixed Module

The SUNAdjointCheckpointScheme_Fixed module implements a scheme where a checkpoint is saved at some fixed interval (in time steps). The module supports checkpointing of time step states only, or time step states with intermediate stage states as well (for multistage methods). When used with a fixed time step size then the number of checkpoints that will be saved is fixed. However, with adaptive time steps the number of checkpoints stored with this scheme is unbounded.

The diagram below illustrates how checkpoints are stored with this scheme:

../_images/sunadjoint_ckpt_fixed.png

14.4.1. Base-class Method Overrides

The SUNAdjointCheckpointScheme_Fixed module implements the following SUNAdjointCheckpointScheme functions:

14.4.2. Implementation Specific Methods

The SUNAdjointCheckpointScheme_Fixed module also implements the following module-specific functions:

SUNErrCode SUNAdjointCheckpointScheme_Create_Fixed(SUNDataIOMode io_mode, SUNMemoryHelper mem_helper, suncountertype interval, suncountertype estimate, sunbooleantype keep, SUNContext sunctx, SUNAdjointCheckpointScheme *check_scheme_ptr)

Creates a new SUNAdjointCheckpointScheme object that checkpoints at a fixed interval.

Parameters:
  • io_mode – The IO mode used for storing the checkpoints.

  • mem_helper – Memory helper for managing memory.

  • interval – The interval (in steps) between checkpoints.

  • estimate – An estimate of the total number of checkpoints needed.

  • keep – Keep data stored even after it is not needed anymore.

  • sunctx – The SUNContext for the simulation.

  • check_scheme_ptr – Pointer to the newly constructed object.

Returns:

A SUNErrCode indicating success or failure.