diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index f651c2d05..4c64b45bc 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -372,6 +372,35 @@ function _jacobian_aux( end end +function _jacobian_aux( + f_or_f!y::FY, + prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, C} + (; batched_seeds, seed_example, pushforward_prep) = prep + + pushforward_prep_same = prepare_pushforward_same_point( + f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts... + ) + + jac = stack(eachindex(batched_seeds); dims = 2) do a + dy = only( + pushforward( + f_or_f!y..., + pushforward_prep_same, + backend, + x, + batched_seeds[a], + contexts..., + ) + ) + return vec(dy) + end + return jac +end + function _jacobian_aux( f_or_f!y::FY, prep::PushforwardJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, @@ -428,6 +457,34 @@ function _jacobian_aux( end end +function _jacobian_aux( + f_or_f!y::FY, + prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{1, false, true}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {FY, SIG, C} + (; batched_seeds, seed_example, pullback_prep) = prep + + pullback_prep_same = prepare_pullback_same_point( + f_or_f!y..., pullback_prep, backend, x, seed_example, contexts... + ) + + jac = stack(eachindex(batched_seeds); dims = 1) do a + dx = only( + pullback( + f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... + ) + ) + if eltype(x) <: Complex + return map(conj, vec(dx)) + else + return vec(dx) + end + end + return jac +end + function _jacobian_aux( f_or_f!y::FY, prep::PullbackJacobianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index f8d74b4a4..5067f78c2 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -151,6 +151,27 @@ function hessian( return block end +function hessian( + f::F, + prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{1, false, true}}, + backend::AbstractADType, + x, + contexts::Vararg{Context, C}, + ) where {F, SIG, C} + check_prep(f, prep, backend, x, contexts...) + (; batched_seeds, seed_example, hvp_prep) = prep + + hvp_prep_same = prepare_hvp_same_point( + f, hvp_prep, backend, x, seed_example, contexts... + ) + + hess = mapreduce(hcat, eachindex(batched_seeds)) do a + dg = only(hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)) + return vec(dg) + end + return hess +end + function hessian( f::F, prep::HVPGradientHessianPrep{SIG, <:BatchSizeSettings{B, false, aligned}}, diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 8ebf4bf95..0d8d0f703 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -74,7 +74,14 @@ end logging = LOGGING, ) - test_differentiation(backends, complex_scenarios(); logging = LOGGING) + test_differentiation( + vcat( + backends[2:3], + AutoReverseFromPrimitive(AutoSimpleFiniteDiff(; chunksize = 1)) + ), + complex_scenarios(); + logging = LOGGING + ) end @testset "Sparse" begin