diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index ec6be730f..2cbd312ff 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -36,6 +36,7 @@ using ADTypes: using LinearAlgebra: dot include("compat.jl") +include("docstrings.jl") include("first_order/mixed_mode.jl") include("second_order/second_order.jl") @@ -45,7 +46,7 @@ include("utils/traits.jl") include("utils/basis.jl") include("utils/batchsize.jl") include("utils/check.jl") -include("utils/printing.jl") +include("utils/errors.jl") include("utils/context.jl") include("utils/linalg.jl") include("utils/sparse.jl") diff --git a/DifferentiationInterface/src/docstrings.jl b/DifferentiationInterface/src/docstrings.jl new file mode 100644 index 000000000..66834d1e2 --- /dev/null +++ b/DifferentiationInterface/src/docstrings.jl @@ -0,0 +1,51 @@ +function docstring_preptype(preptype::AbstractString, operator::AbstractString) + return """ + $(preptype) + + Abstract type for additional information needed by [`$(operator)`](@ref) and its variants. + """ +end + +function samepoint_warning(samepoint::Bool) + if samepoint + ", _if they are applied at the same point `x` and with the same `contexts`_" + else + "" + end +end + +function docstring_prepare(operator; samepoint=false, inplace=false) + return """ + Create a `prep` object that can be given to [`$(operator)`](@ref) and its variants to speed them up$(samepoint_warning(samepoint)). + + Depending on the backend, this can have several effects (preallocating memory, recording an execution trace) which are transparent to the user. + + !!! warning + The preparation result is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified. + Otherwise, preparation will be invalidated and you will need to run it again. + $(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "") + """ +end + +function docstring_prepare!(operator) + return """ + Same behavior as [`prepare_$(operator)`](@ref) but can resize the contents of an existing `prep` object to avoid some allocations. + + There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. + + !!! danger + Compared to when `prep` was first created, the only authorized modification is a size change for input `x` or output `y`. + Any other modification (like a change of type for the input) is not supported and will give erroneous results. + + !!! danger + For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. + """ +end + +function docstring_preparation_hint(operator::AbstractString; same_point=false) + if same_point + return "To improve performance via operator preparation, refer to [`prepare_$(operator)`](@ref) and [`prepare_$(operator)_same_point`](@ref)." + else + return "To improve performance via operator preparation, refer to [`prepare_$(operator)`](@ref)." + end +end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 8ede1cf9c..84b080ae5 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -4,11 +4,7 @@ prepare_derivative(f, backend, x, [contexts...]) -> prep prepare_derivative(f!, y, backend, x, [contexts...]) -> prep -Create a `prep` object that can be given to [`derivative`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("derivative"; inplace=true)) """ function prepare_derivative end @@ -16,12 +12,7 @@ function prepare_derivative end prepare!_derivative(f, prep, backend, x, [contexts...]) -> new_prep prepare!_derivative(f!, y, prep, backend, x, [contexts...]) -> new_prep -Same behavior as [`prepare_derivative`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("derivative")) """ function prepare!_derivative end @@ -31,7 +22,7 @@ function prepare!_derivative end Compute the value and the derivative of the function `f` at point `x`. -$(document_preparation("derivative")) +$(docstring_preparation_hint("derivative")) """ function value_and_derivative end @@ -41,7 +32,7 @@ function value_and_derivative end Compute the value and the derivative of the function `f` at point `x`, overwriting `der`. -$(document_preparation("derivative")) +$(docstring_preparation_hint("derivative")) """ function value_and_derivative! end @@ -51,7 +42,7 @@ function value_and_derivative! end Compute the derivative of the function `f` at point `x`. -$(document_preparation("derivative")) +$(docstring_preparation_hint("derivative")) """ function derivative end @@ -61,7 +52,7 @@ function derivative end Compute the derivative of the function `f` at point `x`, overwriting `der`. -$(document_preparation("derivative")) +$(docstring_preparation_hint("derivative")) """ function derivative! end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 9b72e561b..3b6041f8e 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -3,22 +3,14 @@ """ prepare_gradient(f, backend, x, [contexts...]) -> prep -Create a `prep` object that can be given to [`gradient`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. +$(docstring_prepare("gradient")) """ function prepare_gradient end """ prepare!_gradient(f, prep, backend, x, [contexts...]) -> new_prep -Same behavior as [`prepare_gradient`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("gradient")) """ function prepare!_gradient end @@ -27,7 +19,7 @@ function prepare!_gradient end Compute the value and the gradient of the function `f` at point `x`. -$(document_preparation("gradient")) +$(docstring_preparation_hint("gradient")) """ function value_and_gradient end @@ -36,7 +28,7 @@ function value_and_gradient end Compute the value and the gradient of the function `f` at point `x`, overwriting `grad`. -$(document_preparation("gradient")) +$(docstring_preparation_hint("gradient")) """ function value_and_gradient! end @@ -45,7 +37,7 @@ function value_and_gradient! end Compute the gradient of the function `f` at point `x`. -$(document_preparation("gradient")) +$(docstring_preparation_hint("gradient")) """ function gradient end @@ -54,7 +46,7 @@ function gradient end Compute the gradient of the function `f` at point `x`, overwriting `grad`. -$(document_preparation("gradient")) +$(docstring_preparation_hint("gradient")) """ function gradient! end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 6d3193a64..546c2b80f 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -4,11 +4,7 @@ prepare_jacobian(f, backend, x, [contexts...]) -> prep prepare_jacobian(f!, y, backend, x, [contexts...]) -> prep -Create a `prep` object that can be given to [`jacobian`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("jacobian"; inplace=true)) """ function prepare_jacobian end @@ -16,12 +12,7 @@ function prepare_jacobian end prepare!_jacobian(f, prep, backend, x, [contexts...]) -> new_prep prepare!_jacobian(f!, y, prep, backend, x, [contexts...]) -> new_prep -Same behavior as [`prepare_jacobian`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("jacobian")) """ function prepare!_jacobian end @@ -31,7 +22,7 @@ function prepare!_jacobian end Compute the value and the Jacobian matrix of the function `f` at point `x`. -$(document_preparation("jacobian")) +$(docstring_preparation_hint("jacobian")) """ function value_and_jacobian end @@ -41,7 +32,7 @@ function value_and_jacobian end Compute the value and the Jacobian matrix of the function `f` at point `x`, overwriting `jac`. -$(document_preparation("jacobian")) +$(docstring_preparation_hint("jacobian")) """ function value_and_jacobian! end @@ -51,7 +42,7 @@ function value_and_jacobian! end Compute the Jacobian matrix of the function `f` at point `x`. -$(document_preparation("jacobian")) +$(docstring_preparation_hint("jacobian")) """ function jacobian end @@ -61,7 +52,7 @@ function jacobian end Compute the Jacobian matrix of the function `f` at point `x`, overwriting `jac`. -$(document_preparation("jacobian")) +$(docstring_preparation_hint("jacobian")) """ function jacobian! end diff --git a/DifferentiationInterface/src/first_order/mixed_mode.jl b/DifferentiationInterface/src/first_order/mixed_mode.jl index 20dcf3504..839cb941f 100644 --- a/DifferentiationInterface/src/first_order/mixed_mode.jl +++ b/DifferentiationInterface/src/first_order/mixed_mode.jl @@ -1,10 +1,10 @@ """ MixedMode -Combination of a forward and a reverse mode backend for mixed-mode Jacobian computation. +Combination of a forward and a reverse mode backend for mixed-mode sparse Jacobian computation. !!! danger - `MixedMode` backends only support [`jacobian`](@ref) and its variants. + `MixedMode` backends only support [`jacobian`](@ref) and its variants, and it should be used inside an [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper. # Constructor @@ -20,8 +20,24 @@ struct MixedMode{F<:AbstractADType,R<:AbstractADType} <: AbstractADType end end +""" + forward_backend(m::MixedMode) + +Return the forward-mode part of a `MixedMode` backend. +""" forward_backend(m::MixedMode) = m.forward + +""" + reverse_backend(m::MixedMode) + +Return the reverse-mode part of a `MixedMode` backend. +""" reverse_backend(m::MixedMode) = m.reverse +""" + ForwardAndReverseMode <: ADTypes.AbstractMode + +Appropriate mode type for `MixedMode` backends. +""" struct ForwardAndReverseMode <: ADTypes.AbstractMode end ADTypes.mode(::MixedMode) = ForwardAndReverseMode() diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index e281dee64..d80a4ff7e 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -4,11 +4,7 @@ prepare_pullback(f, backend, x, ty, [contexts...]) -> prep prepare_pullback(f!, y, backend, x, ty, [contexts...]) -> prep -Create a `prep` object that can be given to [`pullback`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("pullback"; inplace=true)) """ function prepare_pullback end @@ -16,12 +12,7 @@ function prepare_pullback end prepare!_pullback(f, prep, backend, x, ty, [contexts...]) -> new_prep prepare!_pullback(f!, y, prep, backend, x, ty, [contexts...]) -> new_prep -Same behavior as [`prepare_pullback`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("pullback")) """ function prepare!_pullback end @@ -29,11 +20,7 @@ function prepare!_pullback end prepare_pullback_same_point(f, backend, x, ty, [contexts...]) -> prep_same prepare_pullback_same_point(f!, y, backend, x, ty, [contexts...]) -> prep_same -Create an `prep_same` object that can be given to [`pullback`](@ref) and its variants _if they are applied at the same point `x` and with the same `contexts`_. - -!!! warning - If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("pullback"; samepoint=true, inplace=true)) """ function prepare_pullback_same_point end @@ -43,7 +30,7 @@ function prepare_pullback_same_point end Compute the value and the pullback of the function `f` at point `x` with a tuple of tangents `ty`. -$(document_preparation("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point=true)) !!! tip Pullbacks are also commonly called vector-Jacobian products or VJPs. @@ -60,7 +47,7 @@ function value_and_pullback end Compute the value and the pullback of the function `f` at point `x` with a tuple of tangents `ty`, overwriting `dx`. -$(document_preparation("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point=true)) !!! tip Pullbacks are also commonly called vector-Jacobian products or VJPs. @@ -74,7 +61,7 @@ function value_and_pullback! end Compute the pullback of the function `f` at point `x` with a tuple of tangents `ty`. -$(document_preparation("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point=true)) !!! tip Pullbacks are also commonly called vector-Jacobian products or VJPs. @@ -88,7 +75,7 @@ function pullback end Compute the pullback of the function `f` at point `x` with a tuple of tangents `ty`, overwriting `dx`. -$(document_preparation("pullback"; same_point=true)) +$(docstring_preparation_hint("pullback"; same_point=true)) !!! tip Pullbacks are also commonly called vector-Jacobian products or VJPs. diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index e80f733a4..91a967ad6 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -4,11 +4,7 @@ prepare_pushforward(f, backend, x, tx, [contexts...]) -> prep prepare_pushforward(f!, y, backend, x, tx, [contexts...]) -> prep -Create a `prep` object that can be given to [`pushforward`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("pushforward"; inplace=true)) """ function prepare_pushforward end @@ -16,12 +12,7 @@ function prepare_pushforward end prepare!_pushforward(f, prep, backend, x, tx, [contexts...]) -> new_prep prepare!_pushforward(f!, y, prep, backend, x, tx, [contexts...]) -> new_prep -Same behavior as [`prepare_pushforward`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("pushforward")) """ function prepare!_pushforward end @@ -29,11 +20,7 @@ function prepare!_pushforward end prepare_pushforward_same_point(f, backend, x, tx, [contexts...]) -> prep_same prepare_pushforward_same_point(f!, y, backend, x, tx, [contexts...]) -> prep_same -Create an `prep_same` object that can be given to [`pushforward`](@ref) and its variants _if they are applied at the same point `x` and with the same `contexts`_. - -!!! warning - If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. - For in-place functions, `y` is mutated by `f!` during preparation. +$(docstring_prepare("pushforward"; samepoint=true, inplace=true)) """ function prepare_pushforward_same_point end @@ -43,7 +30,7 @@ function prepare_pushforward_same_point end Compute the value and the pushforward of the function `f` at point `x` with a tuple of tangents `tx`. -$(document_preparation("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point=true)) !!! tip Pushforwards are also commonly called Jacobian-vector products or JVPs. @@ -60,7 +47,7 @@ function value_and_pushforward end Compute the value and the pushforward of the function `f` at point `x` with a tuple of tangents `tx`, overwriting `ty`. -$(document_preparation("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point=true)) !!! tip Pushforwards are also commonly called Jacobian-vector products or JVPs. @@ -74,7 +61,7 @@ function value_and_pushforward! end Compute the pushforward of the function `f` at point `x` with a tuple of tangents `tx`. -$(document_preparation("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point=true)) !!! tip Pushforwards are also commonly called Jacobian-vector products or JVPs. @@ -88,7 +75,7 @@ function pushforward end Compute the pushforward of the function `f` at point `x` with a tuple of tangents `tx`, overwriting `ty`. -$(document_preparation("pushforward"; same_point=true)) +$(docstring_preparation_hint("pushforward"; same_point=true)) !!! tip Pushforwards are also commonly called Jacobian-vector products or JVPs. diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 59c2663c1..ecdc88445 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -7,6 +7,13 @@ function pick_batchsize(fromprim::FromPrimitive, N::Integer) return pick_batchsize(fromprim.backend, N) end +""" + AutoReverseFromPrimitive + +Wrapper which forces a given backend to act as a reverse-mode backend. + +Used in internal testing. +""" struct AutoReverseFromPrimitive{B} <: FromPrimitive backend::B end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 80b204633..3dc8b7916 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -3,22 +3,14 @@ """ prepare_hessian(f, backend, x, [contexts...]) -> prep -Create a `prep` object that can be given to [`hessian`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. +$(docstring_prepare("hessian")) """ function prepare_hessian end """ prepare!_hessian(f, backend, x, [contexts...]) -> new_prep -Same behavior as [`prepare_hessian`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare!("hessian")) """ function prepare!_hessian end @@ -27,7 +19,7 @@ function prepare!_hessian end Compute the Hessian matrix of the function `f` at point `x`. -$(document_preparation("hessian")) +$(docstring_preparation_hint("hessian")) """ function hessian end @@ -36,7 +28,7 @@ function hessian end Compute the Hessian matrix of the function `f` at point `x`, overwriting `hess`. -$(document_preparation("hessian")) +$(docstring_preparation_hint("hessian")) """ function hessian! end @@ -45,7 +37,7 @@ function hessian! end Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`. -$(document_preparation("hessian")) +$(docstring_preparation_hint("hessian")) """ function value_gradient_and_hessian end @@ -54,7 +46,7 @@ function value_gradient_and_hessian end Compute the value, gradient vector and Hessian matrix of the function `f` at point `x`, overwriting `grad` and `hess`. -$(document_preparation("hessian")) +$(docstring_preparation_hint("hessian")) """ function value_gradient_and_hessian! end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index cea59e7ef..3fd696daa 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -3,32 +3,21 @@ """ prepare_hvp(f, backend, x, tx, [contexts...]) -> prep -Create a `prep` object that can be given to [`hvp`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. +$(docstring_prepare("hvp")) """ function prepare_hvp end """ prepare!_hvp(f, backend, x, tx, [contexts...]) -> new_prep -Same behavior as [`prepare_hvp`](@ref) but can modify an existing `prep` object to avoid some allocations. - -There is no guarantee that `prep` will be mutated, or that performance will be improved compared to preparation from scratch. - -!!! danger - For efficiency, this function needs to rely on backend package internals, therefore it not protected by semantic versioning. +$(docstring_prepare("hvp")) """ function prepare!_hvp end """ prepare_hvp_same_point(f, backend, x, tx, [contexts...]) -> prep_same -Create an `prep_same` object that can be given to [`hvp`](@ref) and its variants _if they are applied at the same point `x` and with the same `contexts`_. - -!!! warning - If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. +$(docstring_prepare("hvp"; samepoint=true)) """ function prepare_hvp_same_point end @@ -37,7 +26,7 @@ function prepare_hvp_same_point end Compute the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`. -$(document_preparation("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point=true)) """ function hvp end @@ -46,7 +35,7 @@ function hvp end Compute the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`, overwriting `tg`. -$(document_preparation("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point=true)) """ function hvp! end @@ -55,7 +44,7 @@ function hvp! end Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`. -$(document_preparation("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point=true)) """ function gradient_and_hvp end @@ -64,7 +53,7 @@ function gradient_and_hvp end Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`, overwriting `grad` and `tg`. -$(document_preparation("hvp"; same_point=true)) +$(docstring_preparation_hint("hvp"; same_point=true)) """ function gradient_and_hvp! end diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index c39c65cf4..52f07d74b 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -3,19 +3,23 @@ """ prepare_second_derivative(f, backend, x, [contexts...]) -> prep -Create a `prep` object that can be given to [`second_derivative`](@ref) and its variants. - -!!! warning - If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. +$(docstring_prepare("second_derivative")) """ function prepare_second_derivative end +""" + prepare!_second_derivative(f, prep, backend, x, [contexts...]) -> new_prep + +$(docstring_prepare!("second_derivative")) +""" +function prepare!_second_derivative end + """ second_derivative(f, [prep,] backend, x, [contexts...]) -> der2 Compute the second derivative of the function `f` at point `x`. -$(document_preparation("second_derivative")) +$(docstring_preparation_hint("second_derivative")) """ function second_derivative end @@ -24,7 +28,7 @@ function second_derivative end Compute the second derivative of the function `f` at point `x`, overwriting `der2`. -$(document_preparation("second_derivative")) +$(docstring_preparation_hint("second_derivative")) """ function second_derivative! end @@ -33,7 +37,7 @@ function second_derivative! end Compute the value, first derivative and second derivative of the function `f` at point `x`. -$(document_preparation("second_derivative")) +$(docstring_preparation_hint("second_derivative")) """ function value_derivative_and_second_derivative end @@ -42,7 +46,7 @@ function value_derivative_and_second_derivative end Compute the value, first derivative and second derivative of the function `f` at point `x`, overwriting `der` and `der2`. -$(document_preparation("second_derivative")) +$(docstring_preparation_hint("second_derivative")) """ function value_derivative_and_second_derivative! end diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index a28ebf384..a6ed3f420 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -1,3 +1,8 @@ +""" + OneElement + +Efficient storage for a one-hot array, aka an array in the standard Euclidean basis. +""" struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N} ind::I val::T diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index b93bae3f6..b3d370d05 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -88,6 +88,14 @@ function check_batchsize_pickable(backend::AbstractADType) end end +""" + threshold_batchsize(backend::AbstractADType, B::Integer) + +If the backend object has a fixed batch size `B0`, return a new backend where the fixed batch size is `min(B0, B)`. +Otherwise, act as the identity. +""" +function threshold_batchsize end + threshold_batchsize(backend::AbstractADType, ::Integer) = backend function threshold_batchsize(backend::AutoSparse, B::Integer) @@ -111,8 +119,17 @@ function threshold_batchsize(backend::MixedMode, B::Integer) ) end +""" + reasonable_batchsize(N::Integer, Bmax::Integer) + +Reproduces the heuristic from ForwardDiff to minimize + +1. the number of batches necessary to cover an array of length `N` +2. the number of leftover indices in the last partial batch + +Source: https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29 +""" function reasonable_batchsize(N::Integer, Bmax::Integer) - # borrowed from https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29 if N <= Bmax return N else diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 6e2d947a3..9e0510533 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -1,7 +1,7 @@ """ check_available(backend) -Check whether `backend` is available (i.e. whether the extension is loaded). +Check whether `backend` is available (i.e. whether the extension is loaded) and return a `Bool`. """ check_available(backend::AbstractADType) = false @@ -19,6 +19,6 @@ end """ check_inplace(backend) -Check whether `backend` supports differentiation of in-place functions. +Check whether `backend` supports differentiation of in-place functions and return a `Bool`. """ check_inplace(backend::AbstractADType) = Bool(inplace_support(backend)) diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index af8d1f622..0b1f8f591 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -15,7 +15,7 @@ end Abstract supertype for additional context arguments, which can be passed to differentiation operators after the active input `x` but are not differentiated. -# See also +# Subtypes - [`Constant`](@ref) - [`Cache`](@ref) @@ -72,8 +72,27 @@ Concrete type of [`Context`](@ref) argument which can be mutated with active val The initial values present inside the cache do not matter. +For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl). + !!! warning Most backends require any `Cache` context to be an `AbstractArray`. + +# Example + +```jldoctest +julia> using DifferentiationInterface + +julia> import ForwardDiff + +julia> f(x, c) = sum(copyto!(c, x)); + +julia> prep = prepare_gradient(f, AutoForwardDiff(), [1.0, 2.0], Cache(zeros(2))); + +julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2))) +2-element Vector{Float64}: + 1.0 + 1.0 +```` """ struct Cache{T} <: Context data::T diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/errors.jl similarity index 67% rename from DifferentiationInterface/src/utils/printing.jl rename to DifferentiationInterface/src/utils/errors.jl index 6960ce0e1..64308ab22 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/errors.jl @@ -27,11 +27,3 @@ end function required_packages(::Type{<:AutoSparse{D}}) where {D} return unique(vcat(required_packages(D), "SparseMatrixColorings")) end - -function document_preparation(operator_name::AbstractString; same_point=false) - if same_point - return "To improve performance via operator preparation, refer to [`prepare_$(operator_name)`](@ref) and [`prepare_$(operator_name)_same_point`](@ref)." - else - return "To improve performance via operator preparation, refer to [`prepare_$(operator_name)`](@ref)." - end -end diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 21f05823e..fcabdb143 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -1,5 +1,12 @@ stack_vec_col(t::NTuple) = stack(vec, t; dims=2) stack_vec_row(t::NTuple) = stack(vec, t; dims=1) +""" + ismutable_array(x) + +Check whether `x` is a mutable array and return a `Bool`. + +At the moment, this only returns `false` for `StaticArrays.SArray`. +""" ismutable_array(::Type) = true ismutable_array(x) = ismutable_array(typeof(x)) diff --git a/DifferentiationInterface/src/utils/prep.jl b/DifferentiationInterface/src/utils/prep.jl index cc7dc376d..0fb7af0df 100644 --- a/DifferentiationInterface/src/utils/prep.jl +++ b/DifferentiationInterface/src/utils/prep.jl @@ -1,65 +1,49 @@ abstract type Prep end """ - PushforwardPrep - -Abstract type for additional information needed by [`pushforward`](@ref) and its variants. +$(docstring_preptype("PushforwardPrep", "pushforward")) """ abstract type PushforwardPrep <: Prep end struct NoPushforwardPrep <: PushforwardPrep end """ - PullbackPrep - -Abstract type for additional information needed by [`pullback`](@ref) and its variants. +$(docstring_preptype("PullbackPrep", "pullback")) """ abstract type PullbackPrep <: Prep end struct NoPullbackPrep <: PullbackPrep end """ - DerivativePrep - -Abstract type for additional information needed by [`derivative`](@ref) and its variants. +$(docstring_preptype("DerivativePrep", "derivative")) """ abstract type DerivativePrep <: Prep end struct NoDerivativePrep <: DerivativePrep end """ - GradientPrep - -Abstract type for additional information needed by [`gradient`](@ref) and its variants. +$(docstring_preptype("GradientPrep", "gradient")) """ abstract type GradientPrep <: Prep end struct NoGradientPrep <: GradientPrep end """ - JacobianPrep - -Abstract type for additional information needed by [`jacobian`](@ref) and its variants. +$(docstring_preptype("JacobianPrep", "jacobian")) """ abstract type JacobianPrep <: Prep end struct NoJacobianPrep <: JacobianPrep end """ - HVPPrep - -Abstract type for additional information needed by [`hvp`](@ref) and its variants. +$(docstring_preptype("HVPPrep", "hvp")) """ abstract type HVPPrep <: Prep end struct NoHVPPrep <: HVPPrep end """ - HessianPrep - -Abstract type for additional information needed by [`hessian`](@ref) and its variants. +$(docstring_preptype("HessianPrep", "hessian")) """ abstract type HessianPrep <: Prep end struct NoHessianPrep <: HessianPrep end """ - SecondDerivativePrep - -Abstract type for additional information needed by [`second_derivative`](@ref) and its variants. +$(docstring_preptype("SecondDerivativePrep", "second_derivative")) """ abstract type SecondDerivativePrep <: Prep end struct NoSecondDerivativePrep <: SecondDerivativePrep end diff --git a/DifferentiationInterface/src/utils/sparse.jl b/DifferentiationInterface/src/utils/sparse.jl index dc5749138..68fe591e1 100644 --- a/DifferentiationInterface/src/utils/sparse.jl +++ b/DifferentiationInterface/src/utils/sparse.jl @@ -1,3 +1,9 @@ +""" + jacobian_sparsity_with_contexts(f, detector, x, contexts...) + jacobian_sparsity_with_contexts(f!, y, detector, x, contexts...) + +Wrapper around [`ADTypes.jacobian_sparsity`](@extref ADTypes.jacobian_sparsity) enabling the allocation of caches with proper element types. +""" function jacobian_sparsity_with_contexts( f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} ) where {F,C} @@ -10,6 +16,11 @@ function jacobian_sparsity_with_contexts( return jacobian_sparsity(with_contexts(f!, contexts...), y, x, detector) end +""" + hessian_sparsity_with_contexts(f, detector, x, contexts...) + +Wrapper around [`ADTypes.hessian_sparsity`](@extref ADTypes.hessian_sparsity) enabling the allocation of caches with proper element types. +""" function hessian_sparsity_with_contexts( f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} ) where {F,C} diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index 055d32fd8..f684aea36 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -141,6 +141,16 @@ Traits identifying second-order backends that compute HVPs in forward over forwa """ struct ForwardOverForward <: HVPMode end +""" + hvp_mode(backend) + +Return the best combination of modes for [`hvp`](@ref) and its variants, among the following options: + +- [`ForwardOverForward`](@ref) +- [`ForwardOverReverse`](@ref) +- [`ReverseOverForward`](@ref) +- [`ReverseOverReverse`](@ref) +""" hvp_mode(backend::AbstractADType) = hvp_mode(SecondOrder(backend, backend)) function hvp_mode(ba::SecondOrder)