Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions DifferentiationInterface/docs/src/dev_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ Most operators have 4 variants, which look like this in the first order: `operat
To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above).
For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).

The method `prepare_operator` must output a `prep` object of the correct type.
For instance, `prepare_gradient(f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep`.
Otherwise, define a custom struct like `MyGradientPrep <: DifferentiationInterface.GradientPrep` and put the necessary storage in there.
The method `prepare_operator_nokwarg` must output a `prep` object of the correct type.
For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there.

## New backend

Expand Down Expand Up @@ -75,4 +75,4 @@ GROUP = get(ENV, "JULIA_DI_TEST_GROUP", "Back/SuperDiff")

but don't forget to switch it back before pushing.

Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`).
Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`).
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point(Val(true), f, backend, x, (y,))
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
pb::PB
end

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(

Check warning on line 9 in DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl#L9

Added line #L9 was not covered by tests
strict::Val,
f,
backend::AutoReverseChainRules,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

## Pushforward

function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
function DI.prepare_pushforward_nokwarg(

Check warning on line 13 in DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl#L13

Added line #L13 was not covered by tests
strict::Val, f, backend::AutoDiffractor, x, tx::NTuple
)
_sig = DI.signature(f, backend, x, tx; strict)
return DI.NoPushforwardPrep(_sig)
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Pushforward

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
Expand Down Expand Up @@ -122,7 +122,7 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
shadows::O
end

function DI.prepare_gradient(
function DI.prepare_gradient_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
Expand Down Expand Up @@ -203,7 +203,7 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
output_length::Int
end

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Pushforward

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG}
y_example::Y # useful to create return activity
end

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
Expand Down Expand Up @@ -191,7 +191,7 @@ end

## Gradient

function DI.prepare_gradient(
function DI.prepare_gradient_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG}
ty_copy::TY
end

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(
strict::Val,
f!::F,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP
jvp_exe!::E1!
end

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f,
backend::AutoFastDifferentiation,
Expand Down Expand Up @@ -105,7 +105,7 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
vjp_exe!::E1!
end

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(
strict::Val,
f,
backend::AutoFastDifferentiation,
Expand Down Expand Up @@ -204,7 +204,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre
der_exe!::E1!
end

function DI.prepare_derivative(
function DI.prepare_derivative_nokwarg(
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -284,7 +284,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG}
jac_exe!::E1!
end

function DI.prepare_gradient(
function DI.prepare_gradient_nokwarg(
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -360,7 +360,7 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI
jac_exe!::E1!
end

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val,
f,
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
Expand Down Expand Up @@ -445,7 +445,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <:
der2_exe!::E2!
end

function DI.prepare_second_derivative(
function DI.prepare_second_derivative_nokwarg(
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand All @@ -462,7 +462,7 @@ function DI.prepare_second_derivative(
der2_exe = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=false)
der2_exe! = make_function(der2_vec_var, x_vec_var, context_vec_vars...; in_place=true)

derivative_prep = DI.prepare_derivative(f, backend, x, contexts...)
derivative_prep = DI.prepare_derivative_nokwarg(strict, f, backend, x, contexts...)
return FastDifferentiationAllocatingSecondDerivativePrep(
_sig, y_prototype, derivative_prep, der2_exe, der2_exe!
)
Expand Down Expand Up @@ -534,7 +534,7 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG}
gradient_prep::E1
end

function DI.prepare_hvp(
function DI.prepare_hvp_nokwarg(
strict::Val,
f,
backend::AutoFastDifferentiation,
Expand All @@ -557,7 +557,7 @@ function DI.prepare_hvp(
hv_vec_var, x_vec_var, v_vec_var, context_vec_vars...; in_place=true
)

gradient_prep = DI.prepare_gradient(f, backend, x, contexts...)
gradient_prep = DI.prepare_gradient_nokwarg(strict, f, backend, x, contexts...)
return FastDifferentiationHVPPrep(_sig, hvp_exe, hvp_exe!, gradient_prep)
end

Expand Down Expand Up @@ -633,7 +633,7 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
hess_exe!::E2!
end

function DI.prepare_hessian(
function DI.prepare_hessian_nokwarg(
strict::Val,
f,
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
Expand All @@ -656,7 +656,9 @@ function DI.prepare_hessian(
hess_exe = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=false)
hess_exe! = make_function(hess_var, x_vec_var, context_vec_vars...; in_place=true)

gradient_prep = DI.prepare_gradient(f, dense_ad(backend), x, contexts...)
gradient_prep = DI.prepare_gradient_nokwarg(
strict, f, dense_ad(backend), x, contexts...
)
return FastDifferentiationHessianPrep(_sig, gradient_prep, hess_exe, hess_exe!)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre
jvp_exe!::E1!
end

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f!,
y,
Expand Down Expand Up @@ -107,7 +107,7 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
vjp_exe!::E1!
end

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(
strict::Val,
f!,
y,
Expand Down Expand Up @@ -213,7 +213,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{
der_exe!::E1!
end

function DI.prepare_derivative(
function DI.prepare_derivative_nokwarg(
strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
Expand Down Expand Up @@ -295,7 +295,7 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
jac_exe!::E1!
end

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val,
f!,
y,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG}
dir::D
end

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
Expand Down Expand Up @@ -124,7 +124,7 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG}
dir::D
end

function DI.prepare_derivative(
function DI.prepare_derivative_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -253,7 +253,7 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG}
dir::D
end

function DI.prepare_gradient(
function DI.prepare_gradient_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -341,7 +341,7 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG}
dir::D
end

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -446,7 +446,7 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG}
absstep_h::AH
end

function DI.prepare_hessian(
function DI.prepare_hessian_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct FiniteDiffTwoArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG}
dir::D
end

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f!,
y,
Expand Down Expand Up @@ -161,7 +161,7 @@ struct FiniteDiffTwoArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG}
dir::D
end

function DI.prepare_derivative(
function DI.prepare_derivative_nokwarg(
strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
Expand Down Expand Up @@ -198,7 +198,9 @@ function DI.prepare!_derivative(
cache.c3 isa Union{Number,Nothing} || resize!(cache.c3, length(y))
return old_prep
else
return DI.prepare_derivative(DI.is_strict(old_prep), f!, y, backend, x, contexts...)
return DI.prepare_derivative_nokwarg(
DI.is_strict(old_prep), f!, y, backend, x, contexts...
)
end
end

Expand Down Expand Up @@ -277,7 +279,7 @@ struct FiniteDiffTwoArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG}
dir::D
end

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val, f!, y, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
Expand Down Expand Up @@ -318,7 +320,9 @@ function DI.prepare!_jacobian(
cache.sparsity = nothing
return old_prep
else
return DI.prepare_jacobian(DI.is_strict(old_prep), f!, y, backend, x, contexts...)
return DI.prepare_jacobian_nokwarg(
DI.is_strict(old_prep), f!, y, backend, x, contexts...
)
end
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple()

## Pushforward

function DI.prepare_pushforward(
function DI.prepare_pushforward_nokwarg(
strict::Val,
f,
backend::AutoFiniteDifferences,
Expand Down Expand Up @@ -54,7 +54,7 @@ end

## Pullback

function DI.prepare_pullback(
function DI.prepare_pullback_nokwarg(
strict::Val,
f,
backend::AutoFiniteDifferences,
Expand Down Expand Up @@ -97,7 +97,7 @@ end

## Gradient

function DI.prepare_gradient(
function DI.prepare_gradient_nokwarg(
strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down Expand Up @@ -154,7 +154,7 @@ end

## Jacobian

function DI.prepare_jacobian(
function DI.prepare_jacobian_nokwarg(
strict::Val, f, backend::AutoFiniteDifferences, x, contexts::Vararg{DI.Context,C};
) where {C}
_sig = DI.signature(f, backend, x, contexts...; strict)
Expand Down
Loading
Loading