Skip to content

Commit 6639952

Browse files
committed
Retoggle tests
1 parent 00b7f09 commit 6639952

2 files changed

Lines changed: 10 additions & 10 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
## Pushforward
22

3-
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6+
dcall::FT0
67
df!::FT
8+
dy::YT
79
context_tangents::CT
810
end
911

@@ -26,9 +28,11 @@ function DI.prepare_pushforward_nokwarg(
2628
map(DI.unwrap, contexts)...;
2729
config
2830
)
31+
dcall = zero_tangent_or_primal(call_and_return, backend)
2932
df! = zero_tangent_or_primal(f!, backend)
33+
dy = zero_tangent_or_primal(y, backend)
3034
context_tangents = map(zero_tangent_unwrap, contexts)
31-
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents)
35+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dcall, df!, dy, context_tangents)
3236
return prep
3337
end
3438

@@ -43,13 +47,11 @@ function DI.value_and_pushforward(
4347
) where {F, C, X}
4448
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4549
ty = map(tx) do dx
46-
dy = zero_tangent_or_primal(y, backend) # TODO: remove allocation?
47-
dcall = zero_tangent_or_primal(call_and_return, backend)
4850
_, new_dy = value_and_derivative!!(
4951
prep.cache,
50-
(call_and_return, dcall),
52+
(call_and_return, prep.dcall),
5153
(f!, prep.df!),
52-
(y, dy),
54+
(y, prep.dy),
5355
(x, dx),
5456
map(first_unwrap, contexts, prep.context_tangents)...,
5557
)
@@ -83,10 +85,9 @@ function DI.value_and_pushforward!(
8385
) where {F, C, X, Y}
8486
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8587
foreach(tx, ty) do dx, dy
86-
dcall = zero_tangent_or_primal(call_and_return, backend)
8788
_, new_dy = value_and_derivative!!(
8889
prep.cache,
89-
(call_and_return, dcall),
90+
(call_and_return, prep.dcall),
9091
(f!, prep.df!),
9192
(y, dy),
9293
(x, dx),

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ end
2222
test_differentiation(
2323
backends,
2424
default_scenarios(;
25-
include_batchified = false,
26-
include_constantified = false, include_cachified = false, use_tuples = true
25+
include_constantified = true, include_cachified = true, use_tuples = true
2726
);
2827
excluded = SECOND_ORDER,
2928
logging = LOGGING,

0 commit comments

Comments
 (0)