diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index c14bca747..f6527dd12 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.50" +version = "0.6.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 2e46bc7c4..d637a01bd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -17,7 +17,6 @@ function DI.prepare_pullback_nokwarg( y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype) - DI.value_and_pullback(f, prep, backend, x, ty, contexts...) return prep end @@ -111,11 +110,10 @@ function DI.prepare_gradient_nokwarg( ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) - cache = prepare_pullback_cache( + cache = prepare_gradient_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) prep = MooncakeGradientPrep(_sig, cache) - DI.value_and_gradient(f, prep, backend, x, contexts...) return prep end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index e89fbc37e..24cb39e41 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -31,7 +31,6 @@ function DI.prepare_pullback_nokwarg( ) dy_righttype_after = zero_tangent(y) prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) - DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) return prep end