Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.52"
version = "0.6.53"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,91 +1,104 @@
## Pushforward

struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
df::DF
context_shadows::DC
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f::F,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
tx::NTuple{B},
contexts::Vararg{DI.Context,C};
) where {F,C}
) where {F,C,B}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
return DI.NoPushforwardPrep(_sig)
df = function_shadow(f, backend, Val(B))
mode = forward_withprimal(backend)
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows)

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L18-L21

Added lines #L18 - L21 were not covered by tests
end

function DI.value_and_pushforward(
f::F,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep

Check warning on line 33 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L33

Added line #L33 was not covered by tests
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1))

Check warning on line 35 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L35

Added line #L35 was not covered by tests
dx = only(tx)
x_and_dx = Duplicated(x, dx)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))

Check warning on line 38 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L38

Added line #L38 was not covered by tests
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
return y, (dy,)
end

function DI.value_and_pushforward(
f::F,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep

Check warning on line 52 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L52

Added line #L52 was not covered by tests
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))

Check warning on line 54 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L54

Added line #L54 was not covered by tests
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 56 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L56

Added line #L56 was not covered by tests
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
return y, values(ty)
end

function DI.pushforward(
f::F,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep

Check warning on line 70 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L70

Added line #L70 was not covered by tests
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1))

Check warning on line 72 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L72

Added line #L72 was not covered by tests
dx = only(tx)
x_and_dx = Duplicated(x, dx)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))

Check warning on line 75 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L75

Added line #L75 was not covered by tests
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
return (dy,)
end

function DI.pushforward(
f::F,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; df, context_shadows) = prep

Check warning on line 89 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L89

Added line #L89 was not covered by tests
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))

Check warning on line 91 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L91

Added line #L91 was not covered by tests
x_and_tx = BatchDuplicated(x, tx)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 93 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L93

Added line #L93 was not covered by tests
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
return values(ty)
end

function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
Expand All @@ -101,7 +114,7 @@
function DI.pushforward!(
f::F,
ty::NTuple,
prep::DI.NoPushforwardPrep,
prep::EnzymeOneArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
Expand All @@ -116,10 +129,12 @@

## Gradient

struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
struct EnzymeForwardGradientPrep{SIG,B,DF,DC,O} <: DI.GradientPrep{SIG}
_sig::Val{SIG}
_valB::Val{B}
shadows::O
df::DF
context_shadows::DC
basis_shadows::O
end

function DI.prepare_gradient_nokwarg(
Expand All @@ -131,8 +146,11 @@
) where {F,C}
_sig = DI.signature(f, backend, x, contexts...; strict)
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(_sig, valB, shadows)
df = function_shadow(f, backend, valB)
mode = forward_withprimal(backend)
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
basis_shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows)

Check warning on line 153 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L149-L153

Added lines #L149 - L153 were not covered by tests
end

function DI.gradient(
Expand All @@ -143,11 +161,12 @@
contexts::Vararg{DI.Constant,C},
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; df, context_shadows, basis_shadows) = prep

Check warning on line 164 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L164

Added line #L164 was not covered by tests
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 167 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
derivs = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
return first(derivs)
end
Expand All @@ -160,11 +179,12 @@
contexts::Vararg{DI.Constant,C},
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; df, context_shadows, basis_shadows) = prep

Check warning on line 182 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L182

Added line #L182 was not covered by tests
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 185 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L184-L185

Added lines #L184 - L185 were not covered by tests
(; derivs, val) = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
return val, first(derivs)
end
Expand Down Expand Up @@ -196,10 +216,12 @@

## Jacobian

struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC,O} <: DI.JacobianPrep{SIG}
_sig::Val{SIG}
_valB::Val{B}
shadows::O
df::DF
context_shadows::DC
basis_shadows::O
output_length::Int
end

Expand All @@ -213,8 +235,13 @@
_sig = DI.signature(f, backend, x, contexts...; strict)
y = f(x, map(DI.unwrap, contexts)...)
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y))
mode = forward_withprimal(backend)
df = function_shadow(f, backend, valB)
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
basis_shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(

Check warning on line 242 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L238-L242

Added lines #L238 - L242 were not covered by tests
_sig, valB, df, context_shadows, basis_shadows, length(y)
)
end

function DI.jacobian(
Expand All @@ -225,14 +252,15 @@
contexts::Vararg{DI.Constant,C},
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; df, context_shadows, basis_shadows, output_length) = prep

Check warning on line 255 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L255

Added line #L255 was not covered by tests
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 258 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L257-L258

Added lines #L257 - L258 were not covered by tests
derivs = jacobian(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
jac_tensor = first(derivs)
return maybe_reshape(jac_tensor, prep.output_length, length(x))
return maybe_reshape(jac_tensor, output_length, length(x))

Check warning on line 263 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L263

Added line #L263 was not covered by tests
end

function DI.value_and_jacobian(
Expand All @@ -243,14 +271,15 @@
contexts::Vararg{DI.Constant,C},
) where {F,SIG,B,C}
DI.check_prep(f, prep, backend, x, contexts...)
(; df, context_shadows, basis_shadows, output_length) = prep

Check warning on line 274 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L274

Added line #L274 was not covered by tests
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 277 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L276-L277

Added lines #L276 - L277 were not covered by tests
(; derivs, val) = jacobian(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
)
jac_tensor = first(derivs)
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
return val, maybe_reshape(jac_tensor, output_length, length(x))

Check warning on line 282 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl#L282

Added line #L282 was not covered by tests
end

function DI.jacobian!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,74 @@
## Pushforward

struct EnzymeTwoArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
df!::DF
context_shadows::DC
end

function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
tx::NTuple{B},
contexts::Vararg{DI.Context,C};
) where {F,C}
) where {F,B,C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
return DI.NoPushforwardPrep(_sig)
df! = function_shadow(f!, backend, Val(B))
mode = forward_noprimal(backend)
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)

Check warning on line 22 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L19-L22

Added lines #L19 - L22 were not covered by tests
end

function DI.value_and_pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep

Check warning on line 35 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L35

Added line #L35 was not covered by tests
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode)
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1))

Check warning on line 37 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L37

Added line #L37 was not covered by tests
dx = only(tx)
dy = make_zero(y)
x_and_dx = Duplicated(x, dx)
y_and_dy = Duplicated(y, dy)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))

Check warning on line 42 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L42

Added line #L42 was not covered by tests
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
return y, (dy,)
end

function DI.value_and_pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep

Check warning on line 57 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L57

Added line #L57 was not covered by tests
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))

Check warning on line 59 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L59

Added line #L59 was not covered by tests
ty = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 63 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L63

Added line #L63 was not covered by tests
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end

function DI.pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
Expand All @@ -72,18 +83,19 @@
f!::F,
y,
ty::NTuple{B},
prep::DI.NoPushforwardPrep,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep

Check warning on line 93 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L93

Added line #L93 was not covered by tests
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))

Check warning on line 95 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L95

Added line #L95 was not covered by tests
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))

Check warning on line 98 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl#L98

Added line #L98 was not covered by tests
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end
Expand All @@ -92,7 +104,7 @@
f!::F,
y,
ty::NTuple,
prep::DI.NoPushforwardPrep,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
Expand Down
Loading
Loading