Skip to content

Commit 604bd8e

Browse files
committed
Fixes
1 parent 4bf5dc6 commit 604bd8e

5 files changed

Lines changed: 19 additions & 18 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function DI.value_and_pushforward(
4646
dy = _copy_output(last(y_and_dy))
4747
return y, dy
4848
end
49-
y = first(ys_and_ty[1])
49+
y = _copy_output(first(ys_and_ty[1]))
5050
ty = map(last, ys_and_ty)
5151
return y, ty
5252
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ function DI.prepare_pushforward_nokwarg(
1919
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
2020
config = get_config(backend)
2121
cache = prepare_derivative_cache(
22+
call_and_return,
2223
f!,
2324
y,
2425
x,
@@ -43,14 +44,15 @@ function DI.value_and_pushforward(
4344
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4445
ty = map(tx) do dx
4546
dy = zero_tangent(y) # TODO: remove allocation?
46-
value_and_derivative!!(
47+
_, new_dy = value_and_derivative!!(
4748
prep.cache,
49+
(call_and_return, zero_tangent(call_and_return)),
4850
(f!, prep.df!),
4951
(y, dy),
5052
(x, dx),
5153
map(first_unwrap, contexts, prep.context_tangents)...,
5254
)
53-
return dy
55+
return _copy_output(new_dy)
5456
end
5557
return y, ty
5658
end
@@ -80,13 +82,15 @@ function DI.value_and_pushforward!(
8082
) where {F, C, X, Y}
8183
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8284
foreach(tx, ty) do dx, dy
83-
value_and_derivative!!(
85+
_, new_dy = value_and_derivative!!(
8486
prep.cache,
87+
(call_and_return, zero_tangent(call_and_return)),
8588
(f!, prep.df!),
8689
(y, dy),
8790
(x, dx),
8891
map(first_unwrap, contexts, prep.context_tangents)...,
8992
)
93+
copyto!(dy, new_dy)
9094
end
9195
return y, ty
9296
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
1+
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
22
_sig::Val{SIG}
33
cache::Tcache
44
dy_backup::DY
5-
target_function::F
65
args_to_zero::NTuple{N, Bool}
76
end
87

@@ -16,13 +15,9 @@ function DI.prepare_pullback_nokwarg(
1615
contexts::Vararg{DI.Context, C}
1716
) where {F, C}
1817
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
19-
target_function = function (f!, y, x, contexts...)
20-
f!(y, x, contexts...)
21-
return y
22-
end
2318
config = get_config(backend)
2419
cache = prepare_pullback_cache(
25-
target_function,
20+
call_and_return,
2621
f!,
2722
y,
2823
x,
@@ -32,14 +27,14 @@ function DI.prepare_pullback_nokwarg(
3227
dy_backup_after = zero_tangent(y)
3328
contexts_tup_false = map(_ -> false, contexts)
3429
args_to_zero = (
35-
false, # target_function
30+
false, # call_and_return
3631
false, # f!
3732
false, # y
3833
true, # x
3934
contexts_tup_false...,
4035
)
4136
prep = MooncakeTwoArgPullbackPrep(
42-
_sig, cache, dy_backup_after, target_function, args_to_zero
37+
_sig, cache, dy_backup_after, args_to_zero
4338
)
4439
return prep
4540
end
@@ -61,7 +56,7 @@ function DI.value_and_pullback(
6156
y_after, (_, _, _, dx) = value_and_pullback!!(
6257
prep.cache,
6358
dy_backup_after,
64-
prep.target_function,
59+
call_and_return,
6560
f!,
6661
y,
6762
x,
@@ -87,7 +82,7 @@ function DI.value_and_pullback(
8782
y_after, (_, _, _, dx) = value_and_pullback!!(
8883
prep.cache,
8984
dy_backup_after,
90-
prep.target_function,
85+
call_and_return,
9186
f!,
9287
y,
9388
x,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,8 @@ get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
33

44
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
55
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)
6+
7+
function call_and_return(f!::F, y, x, contexts...) where {F}
8+
f!(y, x, contexts...)
9+
return y
10+
end

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,3 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
77
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
88
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10-
11-
[sources]
12-
Mooncake = { url = "https://github.com/gdalle/Mooncake.jl", rev = "gd/forwardcache" }

0 commit comments

Comments
 (0)