Skip to content

Commit c05dfb6

Browse files
committed
Adapt to latest Mooncake
1 parent 972a526 commit c05dfb6

4 files changed

Lines changed: 10 additions & 4 deletions

File tree

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [0.7.1]
1111

12+
### Feat
13+
14+
- Use Mooncake's internal copy utilities ([#809])
15+
1216
### Fixed
1317

1418
- Make basis work for `CuArray` ([#810])
@@ -40,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4044
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
4145

4246
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
47+
[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809
4348
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
4449
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
4550
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ JET = "0.9"
7474
JLArrays = "0.2.0"
7575
JuliaFormatter = "1,2"
7676
LinearAlgebra = "1"
77-
Mooncake = "0.4.121"
77+
Mooncake = "0.4.122"
7878
Pkg = "1"
7979
PolyesterForwardDiff = "0.1.2"
8080
Random = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using Mooncake:
1212
value_and_pullback!!,
1313
zero_tangent,
1414
_copy_output,
15-
_copy_to_output!
15+
_copy_to_output!!
1616

1717
DI.check_available(::AutoMooncake) = true
1818

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function DI.value_and_pullback(
3030
) where {F,Y,C}
3131
DI.check_prep(f, prep, backend, x, ty, contexts...)
3232
dy = only(ty)
33-
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy)
33+
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
3434
new_y, (_, new_dx) = value_and_pullback!!(
3535
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
3636
)
@@ -47,7 +47,8 @@ function DI.value_and_pullback(
4747
) where {F,Y,C}
4848
DI.check_prep(f, prep, backend, x, ty, contexts...)
4949
ys_and_tx = map(ty) do dy
50-
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy)
50+
dy_righttype =
51+
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
5152
y, (_, new_dx) = value_and_pullback!!(
5253
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
5354
)

0 commit comments

Comments
 (0)