Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.7.1]

### Feat

- Use Mooncake's internal copy utilities ([#809])

### Fixed

- Take `absstep` into account for FiniteDiff ([#812])
Expand Down Expand Up @@ -42,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ JET = "0.9"
JLArrays = "0.2.0"
JuliaFormatter = "1,2"
LinearAlgebra = "1"
Mooncake = "0.4.88"
Mooncake = "0.4.122"
Pkg = "1"
PolyesterForwardDiff = "0.1.2"
Random = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ using Mooncake:
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
_copy_output,
_copy_to_output!!

DI.check_available(::AutoMooncake) = true

copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src)

get_config(::AutoMooncake{Nothing}) = Config()
get_config(backend::AutoMooncake{<:Config}) = backend.config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
) where {F,Y,C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)

Check warning on line 33 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L33

Added line #L33 was not covered by tests
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
)
return new_y, (mycopy(new_dx),)
return new_y, (_copy_output(new_dx),)

Check warning on line 37 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L37

Added line #L37 was not covered by tests
end

function DI.value_and_pullback(
Expand All @@ -47,11 +47,12 @@
) where {F,Y,C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
ys_and_tx = map(ty) do dy
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
dy_righttype =

Check warning on line 50 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L50

Added line #L50 was not covered by tests
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
)
y, mycopy(new_dx)
y, _copy_output(new_dx)

Check warning on line 55 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L55

Added line #L55 was not covered by tests
end
y = first(ys_and_tx[1])
tx = last.(ys_and_tx)
Expand Down Expand Up @@ -126,7 +127,7 @@
) where {F,C}
DI.check_prep(f, prep, backend, x, contexts...)
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
return y, mycopy(new_grad)
return y, _copy_output(new_grad)

Check warning on line 130 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl#L130

Added line #L130 was not covered by tests
end

function DI.value_and_gradient!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
map(DI.unwrap, contexts)...,
)
copyto!(y, y_after)
return y, (mycopy(dx),)
return y, (_copy_output(dx),)

Check warning on line 61 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl#L61

Added line #L61 was not covered by tests
end

function DI.value_and_pullback(
Expand All @@ -83,7 +83,7 @@
map(DI.unwrap, contexts)...,
)
copyto!(y, y_after)
mycopy(dx)
_copy_output(dx)

Check warning on line 86 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl#L86

Added line #L86 was not covered by tests
end
return y, tx
end
Expand Down
Loading