Skip to content

Commit 031fb52

Browse files
committed
Fixes
1 parent be6caff commit 031fb52

9 files changed

Lines changed: 189 additions & 81 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ function DI.jacobian(
224224
x,
225225
contexts::Vararg{DI.Constant,C},
226226
) where {F,B,C}
227-
DI.check_prep(f, prep, backend, contexts...)
227+
DI.check_prep(f, prep, backend, x, contexts...)
228228
mode = forward_noprimal(backend)
229229
f_and_df = get_f_and_df(f, backend, mode)
230230
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -242,7 +242,7 @@ function DI.value_and_jacobian(
242242
x,
243243
contexts::Vararg{DI.Constant,C},
244244
) where {F,B,C}
245-
DI.check_prep(f, prep, backend, contexts...)
245+
DI.check_prep(f, prep, backend, x, contexts...)
246246
mode = forward_withprimal(backend)
247247
f_and_df = get_f_and_df(f, backend, mode)
248248
annotated_contexts = translate(backend, mode, Val(B), contexts...)
@@ -261,7 +261,7 @@ function DI.jacobian!(
261261
x,
262262
contexts::Vararg{DI.Constant,C},
263263
) where {F,C}
264-
DI.check_prep(f, prep, backend, contexts...)
264+
DI.check_prep(f, prep, backend, x, contexts...)
265265
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
266266
end
267267

@@ -273,7 +273,7 @@ function DI.value_and_jacobian!(
273273
x,
274274
contexts::Vararg{DI.Constant,C},
275275
) where {F,C}
276-
DI.check_prep(f, prep, backend, contexts...)
276+
DI.check_prep(f, prep, backend, x, contexts...)
277277
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
278278
return y, copyto!(jac, new_jac)
279279
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@ function compute_ydual_twoarg(
4848
tx::NTuple{B},
4949
contexts::Vararg{DI.Context,C},
5050
) where {F,SIG,T,B,C}
51-
(; xdual_tmp, ydual_tmp) = prep
52-
make_dual!(T, xdual_tmp, x, tx)
51+
(; ydual_tmp) = prep
52+
if DI.ismutable_array(x)
53+
make_dual!(T, prep.xdual_tmp, x, tx)
54+
xdual_tmp = prep.xdual_tmp
55+
else
56+
xdual_tmp = make_dual(T, x, tx)
57+
end
5358
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
5459
f!(ydual_tmp, xdual_tmp, contexts_dual...)
5560
return ydual_tmp

DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function DI.prepare_pushforward(
2525
end
2626
if x isa Number
2727
xt = TPS{promote_type(typeof(first(tx)), typeof(x), Float64)}(; use=d)
28-
return GTPSAOneArgPushforwardPrep(xt)
28+
return GTPSAOneArgPushforwardPrep(_sig, xt)
2929
else
3030
xt = similar(x, TPS{promote_type(eltype(first(tx)), eltype(x), Float64)})
3131
for i in eachindex(xt)

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,33 +43,21 @@ function DI.value_and_pullback(
4343
return new_y, (mycopy(new_dx),)
4444
end
4545

46-
function DI.value_and_pullback!(
47-
f::F,
48-
tx::NTuple{1},
49-
prep::MooncakeOneArgPullbackPrep{Y},
50-
backend::AutoMooncake,
51-
x,
52-
ty::NTuple{1},
53-
contexts::Vararg{DI.Context,C},
54-
) where {F,Y,C}
55-
DI.check_prep(f, prep, backend, x, ty, contexts...)
56-
y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
57-
copyto!(only(tx), new_dx)
58-
return y, tx
59-
end
60-
6146
function DI.value_and_pullback(
6247
f::F,
63-
prep::MooncakeOneArgPullbackPrep,
48+
prep::MooncakeOneArgPullbackPrep{Y},
6449
backend::AutoMooncake,
6550
x,
6651
ty::NTuple,
6752
contexts::Vararg{DI.Context,C},
68-
) where {F,C}
53+
) where {F,Y,C}
6954
DI.check_prep(f, prep, backend, x, ty, contexts...)
7055
ys_and_tx = map(ty) do dy
71-
y, tx = DI.value_and_pullback(f, prep, backend, x, (dy,), contexts...)
72-
y, only(tx)
56+
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
57+
y, (_, new_dx) = value_and_pullback!!(
58+
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
59+
)
60+
y, mycopy(new_dx)
7361
end
7462
y = first(ys_and_tx[1])
7563
tx = last.(ys_and_tx)
@@ -86,11 +74,8 @@ function DI.value_and_pullback!(
8674
contexts::Vararg{DI.Context,C},
8775
) where {F,C}
8876
DI.check_prep(f, prep, backend, x, ty, contexts...)
89-
ys = map(tx, ty) do dx, dy
90-
y, _ = DI.value_and_pullback!(f, (dx,), prep, backend, x, (dy,), contexts...)
91-
y
92-
end
93-
y = ys[1]
77+
y, new_tx = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
78+
foreach(copyto!, tx, new_tx)
9479
return y, tx
9580
end
9681

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,18 @@ function DI.value_and_pullback(
4545
contexts::Vararg{DI.Context,C},
4646
) where {F,C}
4747
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
48-
# Prepare cotangent to add after the forward pass.
4948
dy = only(ty)
49+
# Prepare cotangent to add after the forward pass.
5050
dy_righttype_after = copyto!(prep.dy_righttype, dy)
51-
5251
# Run the reverse-pass and return the results.
53-
contexts = map(DI.unwrap, contexts)
5452
y_after, (_, _, _, dx) = value_and_pullback!!(
55-
prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts...
53+
prep.cache,
54+
dy_righttype_after,
55+
prep.target_function,
56+
f!,
57+
y,
58+
x,
59+
map(DI.unwrap, contexts)...,
5660
)
5761
copyto!(y, y_after)
5862
return y, (mycopy(dx),)
@@ -69,8 +73,18 @@ function DI.value_and_pullback(
6973
) where {F,C}
7074
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
7175
tx = map(ty) do dy
72-
_, tx = DI.value_and_pullback(f!, y, prep, backend, x, (dy,), contexts...)
73-
only(tx)
76+
dy_righttype_after = copyto!(prep.dy_righttype, dy)
77+
y_after, (_, _, _, dx) = value_and_pullback!!(
78+
prep.cache,
79+
dy_righttype_after,
80+
prep.target_function,
81+
f!,
82+
y,
83+
x,
84+
map(DI.unwrap, contexts)...,
85+
)
86+
copyto!(y, y_after)
87+
mycopy(dx)
7488
end
7589
return y, tx
7690
end

0 commit comments

Comments
 (0)