diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5f2a9e96a..852718e42 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.52" +version = "0.6.53" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 9ce785d99..3dd70bbc7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -1,83 +1,96 @@ ## Pushforward +struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + df::DF + context_shadows::DC +end + function DI.prepare_pushforward_nokwarg( strict::Val, f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{DI.Context,C}; -) where {F,C} +) where {F,C,B} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - return DI.NoPushforwardPrep(_sig) + df = function_shadow(f, backend, Val(B)) + mode = forward_withprimal(backend) + context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) + return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows) end function DI.value_and_pushforward( f::F, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, tx, contexts...) + (; df, context_shadows) = prep mode = forward_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) dx = only(tx) x_and_dx = Duplicated(x, dx) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...) return y, (dy,) end function DI.value_and_pushforward( f::F, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f, prep, backend, x, tx, contexts...) + (; df, context_shadows) = prep mode = forward_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode, Val(B)) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) x_and_tx = BatchDuplicated(x, tx) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...) return y, values(ty) end function DI.pushforward( f::F, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, tx, contexts...) + (; df, context_shadows) = prep mode = forward_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) dx = only(tx) x_and_dx = Duplicated(x, dx) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)) return (dy,) end function DI.pushforward( f::F, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f, prep, backend, x, tx, contexts...) + (; df, context_shadows) = prep mode = forward_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode, Val(B)) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) x_and_tx = BatchDuplicated(x, tx) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)) return values(ty) end @@ -85,7 +98,7 @@ end function DI.value_and_pushforward!( f::F, ty::NTuple, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, @@ -101,7 +114,7 @@ end function DI.pushforward!( f::F, ty::NTuple, - prep::DI.NoPushforwardPrep, + prep::EnzymeOneArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, @@ -116,10 +129,12 @@ end ## Gradient -struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG} +struct EnzymeForwardGradientPrep{SIG,B,DF,DC,O} <: DI.GradientPrep{SIG} _sig::Val{SIG} _valB::Val{B} - shadows::O + df::DF + context_shadows::DC + basis_shadows::O end function DI.prepare_gradient_nokwarg( @@ -131,8 +146,11 @@ function DI.prepare_gradient_nokwarg( ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) - shadows = create_shadows(valB, x) - return EnzymeForwardGradientPrep(_sig, valB, shadows) + df = function_shadow(f, backend, valB) + mode = forward_withprimal(backend) + context_shadows = make_context_shadows(backend, mode, valB, contexts...) + basis_shadows = create_shadows(valB, x) + return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows) end function DI.gradient( @@ -143,11 +161,12 @@ function DI.gradient( contexts::Vararg{DI.Constant,C}, ) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows, basis_shadows) = prep mode = forward_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) derivs = gradient( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) return first(derivs) end @@ -160,11 +179,12 @@ function DI.value_and_gradient( contexts::Vararg{DI.Constant,C}, ) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows, basis_shadows) = prep mode = forward_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) (; derivs, val) = gradient( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) return val, first(derivs) end @@ -196,10 +216,12 @@ end ## Jacobian -struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG} +struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC,O} <: DI.JacobianPrep{SIG} _sig::Val{SIG} _valB::Val{B} - shadows::O + df::DF + context_shadows::DC + basis_shadows::O output_length::Int end @@ -213,8 +235,13 @@ function DI.prepare_jacobian_nokwarg( _sig = DI.signature(f, backend, x, contexts...; strict) y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) - shadows = create_shadows(valB, x) - return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y)) + mode = forward_withprimal(backend) + df = function_shadow(f, backend, valB) + context_shadows = make_context_shadows(backend, mode, valB, contexts...) + basis_shadows = create_shadows(valB, x) + return EnzymeForwardOneArgJacobianPrep( + _sig, valB, df, context_shadows, basis_shadows, length(y) + ) end function DI.jacobian( @@ -225,14 +252,15 @@ function DI.jacobian( contexts::Vararg{DI.Constant,C}, ) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows, basis_shadows, output_length) = prep mode = forward_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) derivs = jacobian( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) jac_tensor = first(derivs) - return maybe_reshape(jac_tensor, prep.output_length, length(x)) + return maybe_reshape(jac_tensor, output_length, length(x)) end function DI.value_and_jacobian( @@ -243,14 +271,15 @@ function DI.value_and_jacobian( contexts::Vararg{DI.Constant,C}, ) where {F,SIG,B,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows, basis_shadows, output_length) = prep mode = forward_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) (; derivs, val) = jacobian( - mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows + mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows ) jac_tensor = first(derivs) - return val, maybe_reshape(jac_tensor, prep.output_length, length(x)) + return val, maybe_reshape(jac_tensor, output_length, length(x)) end function DI.jacobian!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 4d8328c3e..f0d2a2d91 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -1,35 +1,45 @@ ## Pushforward +struct EnzymeTwoArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG} + _sig::Val{SIG} + df!::DF + context_shadows::DC +end + function DI.prepare_pushforward_nokwarg( strict::Val, f!::F, y, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, - tx::NTuple, + tx::NTuple{B}, contexts::Vararg{DI.Context,C}; -) where {F,C} +) where {F,B,C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) - return DI.NoPushforwardPrep(_sig) + df! = function_shadow(f!, backend, Val(B)) + mode = forward_noprimal(backend) + context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) + return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows) end function DI.value_and_pushforward( f!::F, y, - prep::DI.NoPushforwardPrep, + prep::EnzymeTwoArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{1}, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + (; df!, context_shadows) = prep mode = forward_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1)) dx = only(tx) dy = make_zero(y) x_and_dx = Duplicated(x, dx) y_and_dy = Duplicated(y, dy) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dy,) end @@ -37,19 +47,20 @@ end function DI.value_and_pushforward( f!::F, y, - prep::DI.NoPushforwardPrep, + prep::EnzymeTwoArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + (; df!, context_shadows) = prep mode = forward_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B)) ty = ntuple(_ -> make_zero(y), Val(B)) x_and_tx = BatchDuplicated(x, tx) y_and_ty = BatchDuplicated(y, ty) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, ty end @@ -57,7 +68,7 @@ end function DI.pushforward( f!::F, y, - prep::DI.NoPushforwardPrep, + prep::EnzymeTwoArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, @@ -72,18 +83,19 @@ function DI.value_and_pushforward!( f!::F, y, ty::NTuple{B}, - prep::DI.NoPushforwardPrep, + prep::EnzymeTwoArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple{B}, contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) + (; df!, context_shadows) = prep mode = forward_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B)) x_and_tx = BatchDuplicated(x, tx) y_and_ty = BatchDuplicated(y, ty) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, ty end @@ -92,7 +104,7 @@ function DI.pushforward!( f!::F, y, ty::NTuple, - prep::DI.NoPushforwardPrep, + prep::EnzymeTwoArgPushforwardPrep, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, tx::NTuple, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index b0a52fb92..67b3989f0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -47,8 +47,10 @@ end ## Pullback -struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG} +struct EnzymeReverseOneArgPullbackPrep{SIG,DF,DC,Y} <: DI.PullbackPrep{SIG} _sig::Val{SIG} + df::DF + context_shadows::DC y_example::Y # useful to create return activity end @@ -57,12 +59,15 @@ function DI.prepare_pullback_nokwarg( f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, - ty::NTuple, + ty::NTuple{B}, contexts::Vararg{DI.Context,C}; -) where {F,C} +) where {F,B,C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) + df = function_shadow(f, backend, Val(B)) + mode = reverse_split_withprimal(backend) + context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) y = f(x, map(DI.unwrap, contexts)...) - return EnzymeReverseOneArgPullbackPrep(_sig, y) + return EnzymeReverseOneArgPullbackPrep(_sig, df, context_shadows, y) end ### Out-of-place @@ -76,12 +81,13 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, ty, contexts...) + (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) - f_and_df = force_annotation(get_f_and_df(f, backend, mode)) + f_and_df = force_annotation(get_f_and_df_prepared!(df, f, backend, Val(1))) IA = guess_activity(typeof(x), mode) - RA = guess_activity(typeof(prep.y_example), mode) + RA = guess_activity(typeof(y_example), mode) dx = make_zero(x) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dinputs, result = seeded_autodiff_thunk( mode, only(ty), f_and_df, RA, annotate(IA, x, dx), annotated_contexts... ) @@ -102,12 +108,13 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f, prep, backend, x, ty, contexts...) + (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) - f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) + f_and_df = force_annotation(get_f_and_df_prepared!(df, f, backend, Val(B))) IA = batchify_activity(guess_activity(typeof(x), mode), Val(B)) - RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B)) + RA = batchify_activity(guess_activity(typeof(y_example), mode), Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) dinputs, result = batch_seeded_autodiff_thunk( mode, ty, f_and_df, RA, annotate(IA, x, tx), annotated_contexts... ) @@ -143,12 +150,13 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, ty, contexts...) + (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) - f_and_df = force_annotation(get_f_and_df(f, backend, mode)) - RA = guess_activity(typeof(prep.y_example), mode) + f_and_df = force_annotation(get_f_and_df_prepared!(df, f, backend, Val(1))) + RA = guess_activity(typeof(y_example), mode) dx = only(tx) make_zero!(dx) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) _, result = seeded_autodiff_thunk( mode, only(ty), f_and_df, RA, Duplicated(x, dx), annotated_contexts... ) @@ -165,11 +173,12 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f, prep, backend, x, ty, contexts...) + (; df, context_shadows, y_example) = prep mode = reverse_split_withprimal(backend) - f_and_df = force_annotation(get_f_and_df(f, backend, mode, Val(B))) - RA = batchify_activity(guess_activity(typeof(prep.y_example), mode), Val(B)) + f_and_df = force_annotation(get_f_and_df_prepared!(df, f, backend, Val(B))) + RA = batchify_activity(guess_activity(typeof(y_example), mode), Val(B)) make_zero!(tx) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) _, result = batch_seeded_autodiff_thunk( mode, ty, f_and_df, RA, BatchDuplicated(x, tx), annotated_contexts... ) @@ -191,6 +200,12 @@ end ## Gradient +struct EnzymeGradientPrep{SIG,DF,DC} <: DI.GradientPrep{SIG} + _sig::Val{SIG} + df::DF + context_shadows::DC +end + function DI.prepare_gradient_nokwarg( strict::Val, f::F, @@ -199,37 +214,42 @@ function DI.prepare_gradient_nokwarg( contexts::Vararg{DI.Context,C}; ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) - return DI.NoGradientPrep(_sig) + df = function_shadow(f, backend, Val(1)) + mode = reverse_withprimal(backend) + context_shadows = make_context_shadows(backend, mode, Val(1), contexts...) + return EnzymeGradientPrep(_sig, df, context_shadows) end ### Enzyme gradient API (only constants) function DI.gradient( f::F, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) grads = gradient(mode, f_and_df, x, annotated_contexts...) return first(grads) end function DI.value_and_gradient( f::F, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, contexts::Vararg{DI.Constant,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) grads, result = gradient(mode, f_and_df, x, annotated_contexts...) return result, first(grads) end @@ -237,13 +257,14 @@ end function DI.gradient!( f::F, grad, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} DI.check_prep(f, prep, backend, x) + (; df) = prep mode = reverse_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) gradient!(mode, grad, f_and_df, x) return grad end @@ -251,13 +272,14 @@ end function DI.value_and_gradient!( f::F, grad, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}}, x, ) where {F} DI.check_prep(f, prep, backend, x) + (; df) = prep mode = reverse_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) _, result = gradient!(mode, grad, f_and_df, x) return result, grad end @@ -266,17 +288,18 @@ end function DI.gradient( f::F, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) IA = guess_activity(typeof(x), mode) grad = make_zero(x) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dinputs = only( autodiff(mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts...) ) @@ -290,17 +313,18 @@ end function DI.value_and_gradient( f::F, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) IA = guess_activity(typeof(x), mode) grad = make_zero(x) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dinputs, result = autodiff( mode, f_and_df, Active, annotate(IA, x, grad), annotated_contexts... ) @@ -315,15 +339,16 @@ end function DI.gradient!( f::F, grad, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_noprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) make_zero!(grad) autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...) return grad @@ -332,15 +357,16 @@ end function DI.value_and_gradient!( f::F, grad, - prep::DI.NoGradientPrep, + prep::EnzymeGradientPrep, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) + (; df, context_shadows) = prep mode = reverse_withprimal(backend) - f_and_df = get_f_and_df(f, backend, mode) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1)) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) make_zero!(grad) _, y = autodiff(mode, f_and_df, Active, Duplicated(x, grad), annotated_contexts...) return y, grad diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index ae2b33923..18c4d9c68 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -1,7 +1,9 @@ ## Pullback -struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG} +struct EnzymeReverseTwoArgPullbackPrep{SIG,DF,DC,TY} <: DI.PullbackPrep{SIG} _sig::Val{SIG} + df!::DF + context_shadows::DC ty_copy::TY end @@ -11,12 +13,15 @@ function DI.prepare_pullback_nokwarg( y, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, - ty::NTuple, + ty::NTuple{B}, contexts::Vararg{DI.Context,C}; -) where {F,C} +) where {F,B,C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) + df! = function_shadow(f!, backend, Val(B)) + mode = reverse_noprimal(backend) + context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) ty_copy = map(copy, ty) - return EnzymeReverseTwoArgPullbackPrep(_sig, ty_copy) + return EnzymeReverseTwoArgPullbackPrep(_sig, df!, context_shadows, ty_copy) end function DI.value_and_pullback( @@ -29,12 +34,13 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - copyto!(only(prep.ty_copy), only(ty)) + (; df!, context_shadows, ty_copy) = prep + copyto!(only(ty_copy), only(ty)) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode) - dy = only(prep.ty_copy) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1)) + dy = only(ty_copy) y_and_dy = Duplicated(y, dy) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) dinputs = only( autodiff(mode, f!_and_df!, Const, y_and_dy, Active(x), annotated_contexts...) ) @@ -52,12 +58,13 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - foreach(copyto!, prep.ty_copy, ty) + (; df!, context_shadows, ty_copy) = prep + foreach(copyto!, ty_copy, ty) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) - ty = prep.ty_copy + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B)) + ty = ty_copy y_and_ty = BatchDuplicated(y, ty) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) dinputs = only( autodiff(mode, f!_and_df!, Const, y_and_ty, Active(x), annotated_contexts...) ) @@ -75,14 +82,15 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - copyto!(only(prep.ty_copy), only(ty)) + (; df!, context_shadows, ty_copy) = prep + copyto!(only(ty_copy), only(ty)) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1)) dx = make_zero(x) # allocates - dy = only(prep.ty_copy) + dy = only(ty_copy) x_and_dx = Duplicated(x, dx) y_and_dy = Duplicated(y, dy) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dx,) end @@ -97,14 +105,15 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - foreach(copyto!, prep.ty_copy, ty) + (; df!, context_shadows, ty_copy) = prep + foreach(copyto!, ty_copy, ty) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B)) tx = ntuple(_ -> make_zero(x), Val(B)) # allocates - ty = prep.ty_copy + ty = ty_copy x_and_tx = BatchDuplicated(x, tx) y_and_ty = BatchDuplicated(y, ty) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, tx end @@ -120,15 +129,16 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - copyto!(only(prep.ty_copy), only(ty)) + (; df!, context_shadows, ty_copy) = prep + copyto!(only(ty_copy), only(ty)) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1)) dx = only(tx) make_zero!(dx) - dy = only(prep.ty_copy) + dy = only(ty_copy) x_and_dx = Duplicated(x, dx) y_and_dy = Duplicated(y, dy) - annotated_contexts = translate(backend, mode, Val(1), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1)) autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...) return y, (dx,) end @@ -144,14 +154,15 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {F,B,C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - foreach(copyto!, prep.ty_copy, ty) + (; df!, context_shadows, ty_copy) = prep + foreach(copyto!, ty_copy, ty) mode = reverse_noprimal(backend) - f!_and_df! = get_f_and_df(f!, backend, mode, Val(B)) + f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B)) make_zero!(tx) - ty = prep.ty_copy + ty = ty_copy x_and_tx = BatchDuplicated(x, tx) y_and_ty = BatchDuplicated(y, ty) - annotated_contexts = translate(backend, mode, Val(B), contexts...) + annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B)) autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...) return y, tx end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 3c3180c79..991796bb1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -1,3 +1,12 @@ +const AnyDuplicated = Union{ + Duplicated, + MixedDuplicated, + BatchDuplicated, + BatchMixedDuplicated, + DuplicatedNoNeed, + BatchDuplicatedNoNeed, +} + # until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged function DI.pick_batchsize(::AutoEnzyme, N::Integer) B = DI.reasonable_batchsize(N, 16) @@ -8,84 +17,133 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B) ## Annotations -@inline function get_f_and_df( - f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, (::Val{B})=Val(1) -) where {F,M,B} +function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}) where {F,M,B} return f end -@inline function get_f_and_df( - f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, (::Val{B})=Val(1) -) where {F,M,B} +function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}) where {F,M,B} return Const(f) end -@inline function get_f_and_df( - f::F, - backend::AutoEnzyme{ - M, - <:Union{ - Duplicated, - MixedDuplicated, - BatchDuplicated, - BatchMixedDuplicated, - DuplicatedNoNeed, - BatchDuplicatedNoNeed, - }, - }, - mode::Mode, - (::Val{B})=Val(1), +function get_f_and_df_prepared!( + df, f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B} ) where {F,M,B} - # TODO: needs more sophistication for mixed activities + #= + It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`. + - In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`. + - In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`. + =# if B == 1 - return Duplicated(f, make_zero(f)) + return Duplicated(f, df) else - return BatchDuplicated(f, ntuple(_ -> make_zero(f), Val(B))) + return BatchDuplicated(f, df) + end +end + +function function_shadow( + ::F, ::AutoEnzyme{M,<:Union{Const,Nothing}}, ::Val{B} +) where {M,B,F} + return nothing +end + +function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where {F,M,B} + if B == 1 + return make_zero(f) + else + return ntuple(_ -> make_zero(f), Val(B)) end end force_annotation(f::F) where {F<:Annotation} = f force_annotation(f::F) where {F} = Const(f) -@inline function _translate( - backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant -) where {B} - return Const(DI.unwrap(c)) +function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Constant) where {B} + return nothing end -@inline function _translate( - backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache -) where {B} - # important to keep make_zero here for ConstantOrCache instead of similar +function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Cache) where {B} + c = DI.unwrap(c_wrapped) if B == 1 - return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) + return make_zero(c) else - return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B))) + return ntuple(_ -> make_zero(c), Val(B)) end end -@inline function _translate( - backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.ConstantOrCache +function _shadow( + ::AutoEnzyme, mode::Mode, valB::Val{B}, c_wrapped::DI.ConstantOrCache ) where {B} - IA = guess_activity(typeof(DI.unwrap(c)), mode) + c = DI.unwrap(c_wrapped) + IA = guess_activity(typeof(c), mode) if IA <: Const - return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c))) + nothing else - return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c))) + if B == 1 + return make_zero(c) + else + return ntuple(_ -> make_zero(c), Val(B)) + end end end -@inline function _translate( - backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext -) where {B} - return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) +function _shadow( + backend::AutoEnzyme{M,<:Union{Const,Nothing}}, + ::Mode, + ::Val{B}, + c_wrapped::DI.FunctionContext, +) where {M,B} + f = DI.unwrap(c_wrapped) + return function_shadow(f, backend, Val(B)) end -@inline function translate( +function make_context_shadows( backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C} ) where {B,C} - new_contexts = map(contexts) do c - _translate(backend, mode, Val(B), c) + context_shadows = map(contexts) do c_wrapped + _shadow(backend, mode, Val(B), c_wrapped) + end + return context_shadows +end + +function _translate_prepared!(dc, c_wrapped::DI.Constant, ::Val{B}) where {B} + c = DI.unwrap(c_wrapped) + return Const(c) +end + +function _translate_prepared!(dc, c_wrapped::DI.Cache, ::Val{B}) where {B} + c = DI.unwrap(c_wrapped) + if B == 1 + return Duplicated(c, dc) + else + return BatchDuplicated(c, dc) + end +end + +function _translate_prepared!( + dc, c_wrapped::Union{DI.ConstantOrCache,DI.FunctionContext}, ::Val{B} +) where {B} + #= + It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts. + - In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`. + - In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`. + =# + c = DI.unwrap(c_wrapped) + if isnothing(dc) + return Const(c) + else + if B == 1 + return Duplicated(c, dc) + else + return BatchDuplicated(c, dc) + end + end +end + +function translate_prepared!( + context_shadows::NTuple{C,Any}, contexts::NTuple{C,DI.Context}, ::Val{B} +) where {B,C} + new_contexts = map(context_shadows, contexts) do dc, c_wrapped + _translate_prepared!(dc, c_wrapped, Val(B)) end return new_contexts end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index cd1fb521e..ee794892d 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -104,26 +104,31 @@ struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier f::F x_buffer::Vector{X} y_buffer::Vector{Y} + a::Float64 + b::Vector{Float64} end function WritableClosure{pl_fun}( - f::F, x_buffer::Vector{X}, y_buffer::Vector{Y} + f::F, x_buffer::Vector{X}, y_buffer::Vector{Y}, a, b ) where {pl_fun,F,X,Y} - return WritableClosure{pl_fun,F,X,Y}(f, x_buffer, y_buffer) + return WritableClosure{pl_fun,F,X,Y}(f, x_buffer, y_buffer, a, b) end Base.show(io::IO, f::WritableClosure) = print(io, "WritableClosure($(f.f))") function (mc::WritableClosure{:out})(x) - mc.x_buffer[1] = x - mc.y_buffer[1] = mc.f(x) - return copy(mc.y_buffer[1]) + (; f, x_buffer, y_buffer, a, b) = mc + x_buffer[1] = copy(x) + y_buffer[1] = (a + only(b)) * f(x) + return copy(y_buffer[1]) end function (mc::WritableClosure{:in})(y, x) - mc.x_buffer[1] = x - mc.f(mc.y_buffer[1], mc.x_buffer[1]) - copyto!(y, mc.y_buffer[1]) + (; f, x_buffer, y_buffer, a, b) = mc + x_buffer[1] = copy(x) + f(y_buffer[1], x_buffer[1]) + y_buffer[1] .*= (a + only(b)) + copyto!(y, y_buffer[1]) return nothing end @@ -132,13 +137,25 @@ end Return a new `Scenario` identical to `scen` except for the function `f` which is made to close over differentiable data. """ -function closurify(scen::Scenario) +function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} (; f, x, y) = scen @assert isempty(scen.contexts) x_buffer = [zero(x)] y_buffer = [zero(y)] - closure_f = WritableClosure{function_place(scen)}(f, x_buffer, y_buffer) - return change_function(scen, closure_f; keep_smaller=false) + a = 3.0 + b = [4.0] + closure_f = WritableClosure{pl_fun}(f, x_buffer, y_buffer, a, b) + return Scenario{op,pl_op,pl_fun}( + closure_f; + x = scen.x, + y = mymultiply(scen.y, a + only(b)), + tang = scen.tang, + contexts = scen.contexts, + res1 = mymultiply(scen.res1, a + only(b)), + res2 = mymultiply(scen.res2, a + only(b)), + smaller = nothing, + name = isnothing(scen.name) ? nothing : scen.name * " [closurified]", + ) end struct MultiplyByConstant{pl_fun,F} <: FunctionModifier @@ -267,7 +284,8 @@ end function (sc::MultiplyByConstantAndStoreInCache{:out})(x, constantorcache) (; constant, cache) = constantorcache - y = constant * sc.f(x) + (; a, b) = constant + y = (a + only(b)) * sc.f(x) if eltype(y) == eltype(cache) newcache = cache else @@ -285,6 +303,7 @@ end function (sc::MultiplyByConstantAndStoreInCache{:in})(y, x, constantorcache) (; constant, cache) = constantorcache + (; a, b) = constant if eltype(y) == eltype(cache) newcache = cache else @@ -292,7 +311,7 @@ function (sc::MultiplyByConstantAndStoreInCache{:in})(y, x, constantorcache) newcache = similar(cache, eltype(y)) end sc.f(newcache, x) - newcache .*= constant + newcache .*= (a + only(b)) copyto!(y, newcache) return nothing end @@ -307,19 +326,20 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f @assert isempty(scen.contexts) constantorcache_f = MultiplyByConstantAndStoreInCache{pl_fun}(f) a = 3.0 + b = [4.0] constantorcache = if scen.y isa Number - (; cache=[myzero(scen.y)], constant=a) + (; cache=[myzero(scen.y)], constant=(; a, b)) else - (; cache=mysimilar(scen.y), constant=a) + (; cache=mysimilar(scen.y), constant=(; a, b)) end return Scenario{op,pl_op,pl_fun}( constantorcache_f; x=scen.x, - y=mymultiply(scen.y, a), + y=mymultiply(scen.y, a + only(b)), tang=scen.tang, contexts=(ConstantOrCache(constantorcache),), - res1=mymultiply(scen.res1, a), - res2=mymultiply(scen.res2, a), + res1=mymultiply(scen.res1, a + only(b)), + res2=mymultiply(scen.res2, a + only(b)), smaller=isnothing(scen.smaller) ? nothing : constantorcachify(scen.smaller), name=isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]", )