14.1. Introduction to Adjoint Sensitivity Analysis
This section presents the SUNAdjointStepper
and
SUNAdjointCheckpointScheme
classes. The SUNAdjointStepper
represents a generic adjoint sensitivity analysis (ASA) procedure to obtain the adjoint
sensitivities of an IVP of the form
where \(p\) is some set of \(N_s\) problem parameters.
Note
The API itself does not implement ASA, but it provides a common interface for ASA capabilities implemented in the SUNDIALS packages. Right now it supports the ASA capabilities in ARKODE, while the ASA capabilities in CVODES and IDAS must be used directly.
Suppose we have a functional \(g(t_f, y(t_f), p)\) for which we would like to compute the gradients \(dg(t_f, y(t_f), p)/dy(t_0)\) and/or \(dg(t_f, y(t_f), p)/dp\). This most often arises in the form of an optimization problem such as
Warning
The CVODES documentation uses \(\lambda\) to represent the adjoint variables needed to obtain the gradient \(dG/dp\) where \(G\) is an integral of \(g\). Our use of \(\lambda\) in the following is akin to the use of \(\mu\) in the CVODES docs.
The adjoint method is one approach to obtaining the gradients that is particularly efficient when there are relatively few functionals and a large number of parameters. While CVODES and IDAS continuous adjoint methods (differentiate-then-discretize), ARKODE provides discrete adjoint methods (discretize-then-differentiate). For the continuous approach, we derive and solve the adjoint IVP backwards in time
where \(\lambda(t) \in \mathbb{R}^{N_s}\), \(f_y \equiv \partial f/\partial y \in \mathbb{R}^{N \times N}\) and \(g_y \equiv \partial g/\partial y \in \mathbb{R}^{N \times N}\), are the Jacobians with respect to the dependent variable, \(*\) denotes the Hermitian (conjugate) transpose, \(N\) is the size of the original IVP, and \(N_s\) is the number of parameters. When solved with a numerical time integration scheme, the solution to the continuous adjoint IVP is a numerical approximation of the continuous adjoint sensitivities,
The gradients with respect to the parameters can then be obtained as
where y_p(t) equiv partial y(t)/partial p in mathbb{R}^{N times N_s}, and \(g_p \equiv \partial g/\partial p \in \mathbb{R}^{N \times N_s}\) and \(f_p \equiv \partial f/\partial p \in \mathbb{R}^{N \times N_s}\) are the Jacobians with respect to the parameters.
For the discrete adjoint approach, we first numerically discretize the original IVP (14.1) using a time integration scheme, \(\varphi\), so that
For linear multistep methods \(k \geq 1\) and for one step methods \(k = 1\). Reformulating the optimization problem for the discrete case, we have
The gradients of (14.7) can be computed using the transposed chain rule backwards in time to obtain the discrete adjoint variables \(\lambda_n, \lambda_{n-1}, \cdots, \lambda_0\) and \(\mu_n, \mu_{n-1}, \cdots, \mu_0\). The discrete adjoint variables represent the gradients of the discrete cost function (14.7) with respect to changes in the discretized IVP (14.6),
14.1.1. Discrete vs. Continuous Adjoint Method
It is understood that the continuous adjoint method can be problematic in the context of optimization problems because the continuous adjoint method provides an approximation to the gradient of a continuous cost function while the optimizer is expecting the gradient of the discrete cost function. The discrepancy means that the optimizer can fail to due to inconsistent gradients [63, 64]. On the other hand, the discrete adjoint method provides the exact gradient of the discrete cost function allowing the optimizer to fully converge. Consequently, the discrete adjoint method is often preferable in optimization despite its own drawbacks – such as its (relatively) increased memory usage and the possible introduction of unphysical computational modes [133]. This is not to say that the discrete adjoint approach is always the better choice over the continuous adjoint approach in optimization. Computational efficiency and stability of one approach over the other can be both problem and method dependent. Section 8 in the paper [110] discusses the tradeoffs further and provides numerous references that may help inform users in choosing between the discrete and continuous adjoint approaches.
14.2. The SUNAdjointStepper Class
Added in version 7.3.0.
-
type SUNAdjointStepper
The
SUNAdjointStepper
class provides a package-agnostic interface to SUNDIALS ASA capabilities. It currently only supports the discrete ASA capabilities in the ARKODE package, but in the future this support may be expanded.
14.2.1. Class Methods
The SUNAdjointStepper
class has the following methods:
-
SUNErrCode SUNAdjointStepper_Create(SUNStepper fwd_sunstepper, sunbooleantype own_fwd, SUNStepper adj_sunstepper, sunbooleantype own_adj, suncountertype final_step_idx, sunrealtype tf, N_Vector sf, SUNAdjointCheckpointScheme checkpoint_scheme, SUNContext sunctx, SUNAdjointStepper *adj_stepper)
Creates the
SUNAdjointStepper
object needed to solve the adjoint problem.- Parameters:
fwd_sunstepper – The
SUNStepper
to be used for forward computations of the original ODE.own_fwd – Should fwd_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.
adj_sunstepper – The
SUNStepper
to be used for the backward integration of the adjoint ODE.own_adj – Should adj_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.
final_step_idx – The index (step number) of the step corresponding to
t_f
for the forward ODE.tf – The terminal time for the forward ODE (the initial time for the adjoint ODE).
sf – The terminal condition for the adjoint ODE.
checkpoint_scheme – The
SUNAdjointCheckpointScheme
object that determines the checkpointing strategy to use. This should be the same object provided to the forward integrator/stepper.sunctx – The
SUNContext
for the simulation.adj_stepper – The
SUNAdjointStepper
to construct (will beNULL
on failure).
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, sunrealtype t0, N_Vector y0, sunrealtype tf, N_Vector sf)
Reinitializes the adjoint stepper to solve a new problem of the same size.
- Parameters:
adj_stepper – The adjoint solver object.
t0 – The new initial time.
y0 – The new initial condition.
tf – The time to start integrating the adjoint system from.
sf – The terminal condition vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_Evolve(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)
Integrates the adjoint system.
- Parameters:
adj_stepper – The adjoint solver object.
tout – The time at which the adjoint solution is desired.
sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
tret – On return, the time reached by the adjoint solver.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_OneStep(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)
Evolves the adjoint system backwards one step.
- Parameters:
adj_stepper – The adjoint solver object.
tout – The time at which the adjoint solution is desired.
sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
tret – On return, the time reached by the adjoint solver.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_RecomputeFwd(SUNAdjointStepper adj_stepper, suncountertype start_idx, sunrealtype t0, N_Vector y0, sunrealtype tf)
Evolves the forward system in time from (
start_idx
,t0
) to (stop_idx
,tf
) with dense checkpointing.- Parameters:
adj_stepper – The SUNAdjointStepper object.
start_idx – the index of the step, w.r.t. the original forward integration, to begin forward integration from.
t0 – the initial time, w.r.t. the original forward integration, to start forward integration from.
y0 – the initial state, w.r.t. the original forward integration, to start forward integration from.
tf – the final time, w.r.t. the original forward integration, to stop forward integration at.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_SetUserData(SUNAdjointStepper adj_stepper, void *user_data)
Sets the user data pointer.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
user_data – the user data pointer that will be passed back to user-supplied callback functions.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_GetNumSteps(SUNAdjointStepper adj_stepper, suncountertype *num_steps)
Retrieves the number of steps taken by the adjoint stepper.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
num_steps – Pointer to store the number of steps.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_GetNumRecompute(SUNAdjointStepper adj_stepper, suncountertype *num_recompute)
Retrieves the number of recomputations performed by the adjoint stepper.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
num_recompute – Pointer to store the number of recomputations.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointStepper_PrintAllStats(SUNAdjointStepper adj_stepper, FILE *outfile, SUNOutputFormat fmt)
Prints the adjoint stepper statistics/counters in a human-readable table format or CSV format.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
outfile – A file to write the output to.
fmt – the format to write in (
SUN_OUTPUTFORMAT_TABLE
orSUN_OUTPUTFORMAT_CSV
).
- Returns:
A
SUNErrCode
indicating failure or success.
14.2.2. User-Supplied Functions
-
typedef int (*SUNAdjRhsFn)(sunrealtype t, N_Vector y, N_Vector sens, N_Vector sens_dot, void *user_data)
These functions compute the adjoint ODE right-hand side.
For ARKODE, this is
\[\begin{split}\Lambda &= f_y^*(t, y, p) \lambda, \quad \text{and if the systems has parameters}, \\ \nu &= f_p^*(t, y, p) \lambda.\end{split}\]and corresponds to (2.74) for explicit Runge–Kutta methods.
Parameters:
t – the current value of the independent variable.
y – the current value of the forward solution vector.
sens – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\lambda\) and the second holds \(\mu\) and is unused in this function.
sens_dot – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\Lambda\) and the second holds \(\nu\).
user_data – the user_data pointer that was passed to
SUNAdjointStepper_SetUserData()
.
Returns:
A
SUNAdjRhsFn
should return 0 if successful, a positive value if a recoverable error occurred (in which case the integrator may attempt to correct), or a negative value if it failed unrecoverably (in which case the integration is halted and an error is raised).Note
Allocation of memory for
y
is handled within the integrator.The vector
sens_dot
may be uninitialized on input; it is the user’s responsibility to fill this entire vector with meaningful values.
14.3. The SUNAdjointCheckpointScheme Class
Added in version 7.3.0.
As with other SUNDIALS classes, the SUNAdjointCheckpointScheme
abstract base class is
implemented using a C structure containing a content
pointer to the derived class member data
and a structure of function pointers to the derived class implementations of the virtual methods.
-
type SUNAdjointCheckpointScheme
A class that provides an interface for checkpointing states during forward integration and accessing them as needed during the backwards integration of the adjoint model.
-
enum SUNDataIOMode
-
enumerator SUNDATAIOMODE_INMEM
The IO mode for data that is stored in addressable random access memory. The location of the memory (e.g., CPU or GPU) is not specified by this mode.
-
enumerator SUNDATAIOMODE_INMEM
14.3.1. Base Class Methods
-
SUNErrCode SUNAdjointCheckpointScheme_NewEmpty(SUNContext sunctx, SUNAdjointCheckpointScheme *cs_ptr)
- Parameters:
sunctx – The SUNDIALS simulation context
cs_ptr – on output, a pointer to a new
SUNAdjointCheckpointScheme
object
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_NeedsSaving(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)
Determines if the (step_num, stage_num) should be checkpointed or not.
- Parameters:
self – the
SUNAdjointCheckpointScheme
objectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the time of the checkpoint
yes_or_no – boolean indicating if the checkpoint should be saved or not
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_InsertVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)
Inserts the vector as the checkpoint for (step_num, stage_num).
- Parameters:
self – the
SUNAdjointCheckpointScheme
objectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the time of the checkpoint
y – the state vector to checkpoint
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_LoadVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)
Loads the checkpointed vector for (step_num, stage_num).
- Parameters:
self – the
SUNAdjointCheckpointScheme
objectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the desired time of the checkpoint
peek – if true, then the checkpoint will be loaded but not deleted regardless of other implementation-specific settings. If false, then the checkpoint may be deleted depending on the implementation.
yout – the loaded state vector
tout – on output, the time of the checkpoint
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_EnableDense(SUNAdjointCheckpointScheme self, sunbooleantype on_or_off)
Enables or disables dense checkpointing (checkpointing every step/stage). When dense checkpointing is disabled, the checkpointing interval that was set when the object was created is restored.
- Parameters:
self – the
SUNAdjointCheckpointScheme
objecton_or_off – if true, dense checkpointing will be turned on, if false it will be turned off.
- Returns:
A
SUNErrCode
indicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_Destroy(SUNAdjointCheckpointScheme *cs_ptr)
Destroys (deallocates) the SUNAdjointCheckpointScheme object.
- Parameters:
cs_ptr – pointer to a
SUNAdjointCheckpointScheme
object
- Returns:
A
SUNErrCode
indicating failure or success.
14.3.2. Implementation Specific Methods
This section describes the virtual methods defined by the SUNAdjointCheckpointScheme
abstract base class.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeNeedsSavingFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_NeedsSaving()
.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeInsertVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_InsertVector()
.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeLoadVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_LoadVector()
.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeEnableDenseFn)(SUNAdjointCheckpointScheme check_scheme, sunbooleantype on_or_off)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_EnableDense()
.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeDestroyFn)(SUNAdjointCheckpointScheme *check_scheme_ptr)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_Destroy()
.
14.3.3. Setting Content and Member Functions
These functions can be used to set the content pointer or virtual method pointers as needed when implementing the abstract base class.
-
SUNErrCode SUNAdjointCheckpointScheme_SetNeedsSavingFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeNeedsSavingFn fn)
This function attaches a
SUNAdjointCheckpointSchemeNeedsSavingFn
function to aSUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeNeedsSavingFn
function to attach.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetInsertVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeInsertVectorFn fn)
This function attaches a
SUNAdjointCheckpointSchemeInsertVectorFn
function to aSUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeInsertVectorFn
function to attach.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetLoadVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeLoadVectorFn fn)
This function attaches a
SUNAdjointCheckpointSchemeLoadVectorFn
function to aSUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeLoadVectorFn
function to attach.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetDestroyFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeDestroyFn fn)
This function attaches a
SUNAdjointCheckpointSchemeDestroyFn
function to aSUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeDestroyFn
function to attach.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetEnableDenseFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeEnableDenseFn fn)
This function attaches a
SUNAdjointCheckpointSchemeEnableDenseFn
function to aSUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeEnableDenseFn
function to attach.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetContent(SUNAdjointCheckpointScheme self, void *content)
This function attaches a member data (content) pointer to a
SUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
content – a pointer to the checkpoint scheme member data.
- Returns:
A
SUNErrCode
indicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_GetContent(SUNAdjointCheckpointScheme self, void **content)
This function retrieves the member data (content) pointer from a
SUNAdjointCheckpointScheme
object.- Parameters:
self – a checkpoint scheme object.
content – a pointer to set to the checkpoint scheme member data pointer.
- Returns:
A
SUNErrCode
indicating success or failure.
14.4. The SUNAdjointCheckpointScheme_Fixed Module
The SUNAdjointCheckpointScheme_Fixed
module implements a scheme where a checkpoint is saved at some
fixed interval (in time steps). The module supports checkpointing of time step states only, or time step
states with intermediate stage states as well (for multistage methods). When used with a
fixed time step size then the number of checkpoints that will be saved is fixed. However, with
adaptive time steps the number of checkpoints stored with this scheme is unbounded.
The diagram below illustrates how checkpoints are stored with this scheme:

14.4.1. Base-class Method Overrides
The SUNAdjointCheckpointScheme_Fixed
module implements the following SUNAdjointCheckpointScheme
functions:
14.4.2. Implementation Specific Methods
The SUNAdjointCheckpointScheme_Fixed
module also implements the following module-specific functions:
-
SUNErrCode SUNAdjointCheckpointScheme_Create_Fixed(SUNDataIOMode io_mode, SUNMemoryHelper mem_helper, suncountertype interval, suncountertype estimate, sunbooleantype keep, SUNContext sunctx, SUNAdjointCheckpointScheme *check_scheme_ptr)
Creates a new
SUNAdjointCheckpointScheme
object that checkpoints at a fixed interval.- Parameters:
io_mode – The IO mode used for storing the checkpoints.
mem_helper – Memory helper for managing memory.
interval – The interval (in steps) between checkpoints.
estimate – An estimate of the total number of checkpoints needed.
keep – Keep data stored even after it is not needed anymore.
sunctx – The
SUNContext
for the simulation.check_scheme_ptr – Pointer to the newly constructed object.
- Returns:
A
SUNErrCode
indicating success or failure.