Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#825])

### Fixed

- Replace `one` with `oneunit` in basis computation ([#826])
Expand Down Expand Up @@ -67,6 +71,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53

[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826
[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825
Comment thread
gdalle marked this conversation as resolved.
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ MixedMode
DenseSparsityDetector
```

### From primitive

```@docs
DifferentiationInterface.AutoForwardFromPrimitive
DifferentiationInterface.AutoReverseFromPrimitive
```

## Internals

The following is not part of the public API.
Expand Down
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ export AutoSparse
## Public but not exported

@public inner, outer
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive

include("init.jl")

Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterface/src/first_order/mixed_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends.
"""
struct ForwardAndReverseMode <: ADTypes.AbstractMode end
ADTypes.mode(::MixedMode) = ForwardAndReverseMode()

function threshold_batchsize(backend::MixedMode, B::Integer)
return MixedMode(
threshold_batchsize(forward_backend(backend), B),
threshold_batchsize(reverse_backend(backend), B),
)
end
41 changes: 30 additions & 11 deletions DifferentiationInterface/src/misc/from_primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,45 @@ abstract type FromPrimitive{inplace} <: AbstractADType end
check_available(backend::FromPrimitive) = check_available(backend.backend)
inplace_support(::FromPrimitive{true}) = InPlaceSupported()
inplace_support(::FromPrimitive{false}) = InPlaceNotSupported()
function inner_preparation_behavior(backend::FromPrimitive)
return inner_preparation_behavior(backend.backend)

function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray)
return pick_batchsize(backend.backend, x_or_y)
end

function pick_batchsize(backend::FromPrimitive, N::Integer)
return pick_batchsize(backend.backend, N)
end

function inner_preparation_behavior(backend::FromPrimitive)
return inner_preparation_behavior(backend.backend)
end

function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple)
return overloaded_input(pushforward, f, backend.backend, x, tx)
end

function overloaded_input(
::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple
)
return overloaded_input(pushforward, f!, y, backend.backend, x, tx)
end

"""
AutoForwardFromPrimitive
AutoForwardFromPrimitive(backend::AbstractADType)

Wrapper which forces a given backend to act as a reverse-mode backend.
Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch.

Used in internal testing.
!!! tip
This can be useful to circumvent high-level operators when they have impractical limitations.
For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`.
"""
struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
backend::B
end

function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true)
function AutoForwardFromPrimitive(
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
)
return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend)
end

Expand Down Expand Up @@ -133,17 +152,17 @@ function value_and_pushforward!(
end

"""
AutoReverseFromPrimitive

Wrapper which forces a given backend to act as a reverse-mode backend.
AutoReverseFromPrimitive(backend::AbstractADType)

Used in internal testing.
Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch.
"""
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
backend::B
end

function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true)
function AutoReverseFromPrimitive(
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
)
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
end

Expand Down
7 changes: 0 additions & 7 deletions DifferentiationInterface/src/utils/batchsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer)
)
end

function threshold_batchsize(backend::MixedMode, B::Integer)
return MixedMode(
threshold_batchsize(forward_backend(backend), B),
threshold_batchsize(reverse_backend(backend), B),
)
end

"""
reasonable_batchsize(N::Integer, Bmax::Integer)

Expand Down
4 changes: 4 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using ForwardDiff: ForwardDiff
using StaticArrays: StaticArrays, @SVector
using JLArrays: JLArrays
using Test

using ExplicitImports
Expand Down Expand Up @@ -75,6 +76,9 @@ end
@testset "Weird" begin
test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING)
test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
test_differentiation(
DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING
)

@testset "Batch size" begin
@test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ end

@testset "Weird arrays" begin
test_differentiation(
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
[
AutoSimpleFiniteDiff(),
AutoForwardFromPrimitive(AutoSimpleFiniteDiff()),
AutoReverseFromPrimitive(AutoSimpleFiniteDiff()),
],
vcat(static_scenarios(), gpu_scenarios());
logging=LOGGING,
)
end;