diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index dc80a5b6d..341caa0ed 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#825]) + ### Fixed - Replace `one` with `oneunit` in basis computation ([#826]) @@ -67,6 +71,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 [#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826 +[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825 [#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823 [#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818 [#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812 diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 486a5830f..ac50ab1ab 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -132,6 +132,13 @@ MixedMode DenseSparsityDetector ``` +### From primitive + +```@docs +DifferentiationInterface.AutoForwardFromPrimitive +DifferentiationInterface.AutoReverseFromPrimitive +``` + ## Internals The following is not part of the public API. diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 32e699572..929abff94 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -126,6 +126,7 @@ export AutoSparse ## Public but not exported @public inner, outer +@public AutoForwardFromPrimitive, AutoReverseFromPrimitive include("init.jl") diff --git a/DifferentiationInterface/src/first_order/mixed_mode.jl b/DifferentiationInterface/src/first_order/mixed_mode.jl index 839cb941f..5951b456c 100644 --- a/DifferentiationInterface/src/first_order/mixed_mode.jl +++ b/DifferentiationInterface/src/first_order/mixed_mode.jl @@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends. """ struct ForwardAndReverseMode <: ADTypes.AbstractMode end ADTypes.mode(::MixedMode) = ForwardAndReverseMode() + +function threshold_batchsize(backend::MixedMode, B::Integer) + return MixedMode( + threshold_batchsize(forward_backend(backend), B), + threshold_batchsize(reverse_backend(backend), B), + ) +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl index 00dd6b445..81f7ea0b6 100644 --- a/DifferentiationInterface/src/misc/from_primitive.jl +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -3,26 +3,45 @@ abstract type FromPrimitive{inplace} <: AbstractADType end check_available(backend::FromPrimitive) = check_available(backend.backend) inplace_support(::FromPrimitive{true}) = InPlaceSupported() inplace_support(::FromPrimitive{false}) = InPlaceNotSupported() -function inner_preparation_behavior(backend::FromPrimitive) - return inner_preparation_behavior(backend.backend) + +function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray) + return pick_batchsize(backend.backend, x_or_y) end function pick_batchsize(backend::FromPrimitive, N::Integer) return pick_batchsize(backend.backend, N) end +function inner_preparation_behavior(backend::FromPrimitive) + return inner_preparation_behavior(backend.backend) +end + +function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple) + return overloaded_input(pushforward, f, backend.backend, x, tx) +end + +function overloaded_input( + ::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple +) + return overloaded_input(pushforward, f!, y, backend.backend, x, tx) +end + """ - AutoForwardFromPrimitive + AutoForwardFromPrimitive(backend::AbstractADType) -Wrapper which forces a given backend to act as a reverse-mode backend. +Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch. -Used in internal testing. +!!! tip + This can be useful to circumvent high-level operators when they have impractical limitations. + For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`. """ struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} backend::B end -function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true) +function AutoForwardFromPrimitive( + backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) +) return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend) end @@ -133,17 +152,17 @@ function value_and_pushforward!( end """ - AutoReverseFromPrimitive - -Wrapper which forces a given backend to act as a reverse-mode backend. + AutoReverseFromPrimitive(backend::AbstractADType) -Used in internal testing. +Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch. """ struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace} backend::B end -function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true) +function AutoReverseFromPrimitive( + backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend)) +) return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend) end diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index b3d370d05..054d5c9b9 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer) ) end -function threshold_batchsize(backend::MixedMode, B::Integer) - return MixedMode( - threshold_batchsize(forward_backend(backend), B), - threshold_batchsize(reverse_backend(backend), B), - ) -end - """ reasonable_batchsize(N::Integer, Bmax::Integer) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 8a99c08f4..d840f0024 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -8,6 +8,7 @@ import DifferentiationInterface as DI import DifferentiationInterfaceTest as DIT using ForwardDiff: ForwardDiff using StaticArrays: StaticArrays, @SVector +using JLArrays: JLArrays using Test using ExplicitImports @@ -75,6 +76,9 @@ end @testset "Weird" begin test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING) test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING) + test_differentiation( + DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING + ) @testset "Batch size" begin @test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7} diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index beec9841f..573aae0af 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -141,6 +141,12 @@ end @testset "Weird arrays" begin test_differentiation( - AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING + [ + AutoSimpleFiniteDiff(), + AutoForwardFromPrimitive(AutoSimpleFiniteDiff()), + AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), + ], + vcat(static_scenarios(), gpu_scenarios()); + logging=LOGGING, ) end;