Skip to content

Commit 04b4097

Browse files
committed
feat: use Mooncake's copy utilities
1 parent d8905f5 commit 04b4097

4 files changed

Lines changed: 34 additions & 16 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -41,7 +41,9 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4141
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4242
DifferentiationInterfaceGTPSAExt = "GTPSA"
4343
DifferentiationInterfaceMooncakeExt = "Mooncake"
44-
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
44+
DifferentiationInterfacePolyesterForwardDiffExt = [
45+
"PolyesterForwardDiff", "ForwardDiff", "DiffResults"
46+
]
4547
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4648
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4749
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -53,7 +55,7 @@ DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
5355

5456
[compat]
5557
Aqua = "0.8.12"
56-
ADTypes = "1.13.0"
58+
ADTypes = "1.15.0"
5759
ChainRulesCore = "1.23.0"
5860
ComponentArrays = "0.15.27"
5961
DataFrames = "1.7.0"
@@ -72,7 +74,7 @@ JET = "0.9"
7274
JLArrays = "0.2.0"
7375
JuliaFormatter = "1,2"
7476
LinearAlgebra = "1"
75-
Mooncake = "0.4.88"
77+
Mooncake = "0.4.121"
7678
Pkg = "1"
7779
PolyesterForwardDiff = "0.1.2"
7880
Random = "1"
@@ -121,4 +123,21 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
121123
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
122124

123125
[targets]
124-
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
126+
test = [
127+
"ADTypes",
128+
"Aqua",
129+
"ComponentArrays",
130+
"DataFrames",
131+
"ExplicitImports",
132+
"JET",
133+
"JLArrays",
134+
"JuliaFormatter",
135+
"Pkg",
136+
"Random",
137+
"SparseArrays",
138+
"SparseConnectivityTracer",
139+
"SparseMatrixColorings",
140+
"StableRNGs",
141+
"StaticArrays",
142+
"Test",
143+
]

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@ using Mooncake:
1010
tangent_type,
1111
value_and_gradient!!,
1212
value_and_pullback!!,
13-
zero_tangent
13+
zero_tangent,
14+
_copy_output,
15+
_copy_to_output!
1416

1517
DI.check_available(::AutoMooncake) = true
1618

17-
copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
18-
copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src)
19-
2019
get_config(::AutoMooncake{Nothing}) = Config()
2120
get_config(backend::AutoMooncake{<:Config}) = backend.config
2221

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ function DI.value_and_pullback(
3030
) where {F,Y,C}
3131
DI.check_prep(f, prep, backend, x, ty, contexts...)
3232
dy = only(ty)
33-
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
33+
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy)
3434
new_y, (_, new_dx) = value_and_pullback!!(
3535
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
3636
)
37-
return new_y, (mycopy(new_dx),)
37+
return new_y, (_copy_output(new_dx),)
3838
end
3939

4040
function DI.value_and_pullback(
@@ -47,11 +47,11 @@ function DI.value_and_pullback(
4747
) where {F,Y,C}
4848
DI.check_prep(f, prep, backend, x, ty, contexts...)
4949
ys_and_tx = map(ty) do dy
50-
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
50+
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy)
5151
y, (_, new_dx) = value_and_pullback!!(
5252
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
5353
)
54-
y, mycopy(new_dx)
54+
y, _copy_output(new_dx)
5555
end
5656
y = first(ys_and_tx[1])
5757
tx = last.(ys_and_tx)
@@ -126,7 +126,7 @@ function DI.value_and_gradient(
126126
) where {F,C}
127127
DI.check_prep(f, prep, backend, x, contexts...)
128128
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
129-
return y, mycopy(new_grad)
129+
return y, _copy_output(new_grad)
130130
end
131131

132132
function DI.value_and_gradient!(

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function DI.value_and_pullback(
5858
map(DI.unwrap, contexts)...,
5959
)
6060
copyto!(y, y_after)
61-
return y, (mycopy(dx),)
61+
return y, (_copy_output(dx),)
6262
end
6363

6464
function DI.value_and_pullback(
@@ -83,7 +83,7 @@ function DI.value_and_pullback(
8383
map(DI.unwrap, contexts)...,
8484
)
8585
copyto!(y, y_after)
86-
mycopy(dx)
86+
_copy_output(dx)
8787
end
8888
return y, tx
8989
end

0 commit comments

Comments
 (0)