Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#825])

### Fixed

- Replace `one` with `oneunit` in basis computation ([#826])
Expand Down Expand Up @@ -66,7 +70,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53

[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826
[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825
Comment thread
gdalle marked this conversation as resolved.
[#824]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/824
Comment thread
gdalle marked this conversation as resolved.
Outdated
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ MixedMode
DenseSparsityDetector
```

### From primitive

```@docs
DifferentiationInterface.AutoForwardFromPrimitive
DifferentiationInterface.AutoReverseFromPrimitive
```

## Internals

The following is not part of the public API.
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ export AutoSparse
## Public but not exported

@public inner, outer
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive

include("init.jl")

Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterface/src/first_order/mixed_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends.
"""
struct ForwardAndReverseMode <: ADTypes.AbstractMode end
ADTypes.mode(::MixedMode) = ForwardAndReverseMode()

function threshold_batchsize(backend::MixedMode, B::Integer)
return MixedMode(
threshold_batchsize(forward_backend(backend), B),
threshold_batchsize(reverse_backend(backend), B),
)
end
41 changes: 30 additions & 11 deletions DifferentiationInterface/src/misc/from_primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,45 @@ abstract type FromPrimitive{inplace} <: AbstractADType end
check_available(backend::FromPrimitive) = check_available(backend.backend)
inplace_support(::FromPrimitive{true}) = InPlaceSupported()
inplace_support(::FromPrimitive{false}) = InPlaceNotSupported()
function inner_preparation_behavior(backend::FromPrimitive)
return inner_preparation_behavior(backend.backend)

function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray)
return pick_batchsize(backend.backend, x_or_y)
end

function pick_batchsize(backend::FromPrimitive, N::Integer)
return pick_batchsize(backend.backend, N)
end

function inner_preparation_behavior(backend::FromPrimitive)
return inner_preparation_behavior(backend.backend)
end

function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple)
return overloaded_input(pushforward, f, backend.backend, x, tx)
end

function overloaded_input(
::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple
)
return overloaded_input(pushforward, f!, y, backend.backend, x, tx)
end

"""
AutoForwardFromPrimitive
AutoForwardFromPrimitive(backend::AbstractADType)

Wrapper which forces a given backend to act as a reverse-mode backend.
Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch.

Used in internal testing.
!!! tip
This can be useful to circumvent high-level operators when they have impractical limitations.
For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`.
"""
struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
backend::B
end

function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true)
function AutoForwardFromPrimitive(
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
)
return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend)
end

Expand Down Expand Up @@ -133,17 +152,17 @@ function value_and_pushforward!(
end

"""
AutoReverseFromPrimitive

Wrapper which forces a given backend to act as a reverse-mode backend.
AutoReverseFromPrimitive(backend::AbstractADType)

Used in internal testing.
Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch.
"""
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
backend::B
end

function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true)
function AutoReverseFromPrimitive(
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
)
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
end

Expand Down
7 changes: 0 additions & 7 deletions DifferentiationInterface/src/utils/batchsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer)
)
end

function threshold_batchsize(backend::MixedMode, B::Integer)
return MixedMode(
threshold_batchsize(forward_backend(backend), B),
threshold_batchsize(reverse_backend(backend), B),
)
end

"""
reasonable_batchsize(N::Integer, Bmax::Integer)

Expand Down
4 changes: 4 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using ForwardDiff: ForwardDiff
using StaticArrays: StaticArrays, @SVector
using JLArrays: JLArrays
using Test

using ExplicitImports
Expand Down Expand Up @@ -75,6 +76,9 @@ end
@testset "Weird" begin
test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING)
test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
test_differentiation(
DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING
)

@testset "Batch size" begin
@test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ end

@testset "Weird arrays" begin
test_differentiation(
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
[
AutoSimpleFiniteDiff(),
AutoForwardFromPrimitive(AutoSimpleFiniteDiff()),
AutoReverseFromPrimitive(AutoSimpleFiniteDiff()),
],
vcat(static_scenarios(), gpu_scenarios());
logging=LOGGING,
)
end;
Loading