Skip to content

Commit 3a76d3a

Browse files
committed
feat: make FromPrimitive wrappers public
1 parent ea73473 commit 3a76d3a

5 files changed

Lines changed: 33 additions & 10 deletions

File tree

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#824])
13+
1014
## [0.7.3]
1115

1216
### Fixed
@@ -62,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6266
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
6367
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
6468

69+
[#824]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/824
6570
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
6671
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
6772
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812

DifferentiationInterface/docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ MixedMode
132132
DenseSparsityDetector
133133
```
134134

135+
### From primitive
136+
137+
```@docs
138+
AutoForwardFromPrimitive
139+
AutoReverseFromPrimitive
140+
```
141+
135142
## Internals
136143

137144
The following is not part of the public API.

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ export AutoSparse
126126
## Public but not exported
127127

128128
@public inner, outer
129+
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
129130

130131
include("init.jl")
131132

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ function pick_batchsize(backend::FromPrimitive, N::Integer)
1212
end
1313

1414
"""
15-
AutoForwardFromPrimitive
15+
AutoForwardFromPrimitive(backend::AbstractADType)
1616
17-
Wrapper which forces a given backend to act as a reverse-mode backend.
17+
Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch.
1818
19-
Used in internal testing.
19+
!!! tip
20+
This can be useful to circumvent high-level operators when they have impractical limitations.
21+
For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`.
2022
"""
2123
struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
2224
backend::B
2325
end
2426

25-
function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true)
27+
function AutoForwardFromPrimitive(
28+
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
29+
)
2630
return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend)
2731
end
2832

@@ -133,17 +137,17 @@ function value_and_pushforward!(
133137
end
134138

135139
"""
136-
AutoReverseFromPrimitive
140+
AutoReverseFromPrimitive(backend::AbstractADType)
137141
138-
Wrapper which forces a given backend to act as a reverse-mode backend.
139-
140-
Used in internal testing.
142+
Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch.
141143
"""
142144
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
143145
backend::B
144146
end
145147

146-
function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true)
148+
function AutoReverseFromPrimitive(
149+
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
150+
)
147151
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
148152
end
149153

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ end
141141

142142
@testset "Weird arrays" begin
143143
test_differentiation(
144-
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
144+
[
145+
AutoSimpleFiniteDiff(),
146+
AutoForwardFromPrimitive(AutoSimpleFiniteDiff()),
147+
AutoReverseFromPrimitive(AutoSimpleFiniteDiff()),
148+
],
149+
vcat(static_scenarios(), gpu_scenarios());
150+
logging=LOGGING,
145151
)
146152
end;

0 commit comments

Comments
 (0)