Current Interface
The existing DifferentiateWith(f, backend) interface in DifferentiationInterface.jl presents a significant limitation: it inherently supports only single-argument functions. This design makes it cumbersome to:
- Differentiate functions with multiple arguments.
- Pass additional context or non-differentiable arguments (constants, pre-allocated caches) to the differentiation backend.
Proposed Interface
To address these limitations, we propose a more expressive interface for DifferentiateWith:
Tfunc_sig = Tuple{typeof(f), T_arg1, T_arg2, ..., T_argN}
DifferentiateWith(Tfunc_sig, backend_to_use::AbstractADType)
Where Tfunc_sig represents the function signature. The first element is the function f itself (or its type), and subsequent elements T_arg1, T_arg2, ..., T_argN represent the types of arguments to f.
Argument Type Wrappers:
To provide more context to the backend about how each argument should be treated, we can introduce wrapper types:
- Default: Arguments are assumed to be "active" (i.e., to be differentiated with respect to).
Constant{T}: Indicates that an argument of type T is a constant and should not be differentiated.
Cache{T}: Signals that an argument of type T is a pre-allocated cache that the backend can utilise.
Example Usage:
Consider a function f(x, y, z, c) where x and y are active arguments, z is a constant, and c is a cache. The func_sig would be constructed as:
Targtypes = (typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)})
Tfunc_sig = Tuple{typeof(f), Targtypes...}
# or more explicitly:
# Tfunc_sig = Tuple{typeof(f), typeof(x), typeof(y), Constant{typeof(z)}, Cache{typeof(c)}}
dw = DifferentiateWith(Tfunc_sig, backend)
Internal Handling:
With this richer Tfunc_sig, DifferentiateWith can internally manage functions with multiple arguments. For backends that fundamentally operate on single-argument functions (e.g., by packing arguments into a tuple), DifferentiateWith can perform this packing/unpacking automatically before invoking the backend's pushforward or pullback implementations. This keeps the backend APIs simpler while providing a user-friendly multi-argument interface.
Current Interface
The existing
DifferentiateWith(f, backend)interface inDifferentiationInterface.jlpresents a significant limitation: it inherently supports only single-argument functions. This design makes it cumbersome to:Proposed Interface
To address these limitations, we propose a more expressive interface for
DifferentiateWith:Where
Tfunc_sigrepresents the function signature. The first element is the functionfitself (or its type), and subsequent elementsT_arg1, T_arg2, ..., T_argNrepresent the types of arguments tof.Argument Type Wrappers:
To provide more context to the backend about how each argument should be treated, we can introduce wrapper types:
Constant{T}: Indicates that an argument of typeTis a constant and should not be differentiated.Cache{T}: Signals that an argument of typeTis a pre-allocated cache that the backend can utilise.Example Usage:
Consider a function
f(x, y, z, c)wherexandyare active arguments,zis a constant, andcis a cache. Thefunc_sigwould be constructed as:Internal Handling:
With this richer
Tfunc_sig,DifferentiateWithcan internally manage functions with multiple arguments. For backends that fundamentally operate on single-argument functions (e.g., by packing arguments into a tuple),DifferentiateWithcan perform this packing/unpacking automatically before invoking the backend's pushforward or pullback implementations. This keeps the backend APIs simpler while providing a user-friendly multi-argument interface.