Skip to content

Commit 65e449d

Browse files
authored
Preparation of pushforward, pullback and hvp for same point x (#255)
1 parent 8843031 commit 65e449d

28 files changed

Lines changed: 580 additions & 529 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.3.4"
4+
version = "0.4.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2222
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2323

2424
[compat]
25-
DifferentiationInterface = "0.3"
2625
Documenter = "1"

DifferentiationInterface/docs/src/api.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ second_derivative!
5353

5454
```@docs
5555
prepare_hvp
56+
prepare_hvp_same_point
5657
hvp
5758
hvp!
5859
```
@@ -67,6 +68,7 @@ hessian!
6768

6869
```@docs
6970
prepare_pushforward
71+
prepare_pushforward_same_point
7072
pushforward
7173
pushforward!
7274
value_and_pushforward
@@ -75,12 +77,11 @@ value_and_pushforward!
7577

7678
```@docs
7779
prepare_pullback
80+
prepare_pullback_same_point
7881
pullback
7982
pullback!
8083
value_and_pullback
8184
value_and_pullback!
82-
value_and_pullback_split
83-
value_and_pullback!_split
8485
```
8586

8687
## Backend queries

DifferentiationInterface/docs/src/overview.md

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@ However they have different signatures:
6262
In many cases, AD can be accelerated if the function has been run at least once (e.g. to create a config or record a tape) and if some cache objects are provided.
6363
This is a backend-specific procedure, but we expose a common syntax to achieve it.
6464

65-
| operator | preparation function |
66-
| :------------------ | :---------------------------------- |
67-
| `derivative` | [`prepare_derivative`](@ref) |
68-
| `gradient` | [`prepare_gradient`](@ref) |
69-
| `jacobian` | [`prepare_jacobian`](@ref) |
70-
| `second_derivative` | [`prepare_second_derivative`](@ref) |
71-
| `hessian` | [`prepare_hessian`](@ref) |
72-
| `pushforward` | [`prepare_pushforward`](@ref) |
73-
| `pullback` | [`prepare_pullback`](@ref) |
74-
| `hvp` | [`prepare_hvp`](@ref) |
65+
| operator | preparation function | preparation function (same point) |
66+
| :------------------ | :---------------------------------- | ---------------------------------------- |
67+
| `derivative` | [`prepare_derivative`](@ref) | - |
68+
| `gradient` | [`prepare_gradient`](@ref) | - |
69+
| `jacobian` | [`prepare_jacobian`](@ref) | - |
70+
| `second_derivative` | [`prepare_second_derivative`](@ref) | - |
71+
| `hessian` | [`prepare_hessian`](@ref) | - |
72+
| `pushforward` | [`prepare_pushforward`](@ref) | [`prepare_pushforward_same_point`](@ref) |
73+
| `pullback` | [`prepare_pullback`](@ref) | [`prepare_pullback_same_point`](@ref) |
74+
| `hvp` | [`prepare_hvp`](@ref) | [`prepare_hvp_same_point`](@ref) |
7575

7676
Unsurprisingly, preparation syntax depends on the number of arguments:
7777

@@ -89,6 +89,9 @@ This is especially worth it if you plan to call `operator` several times in simi
8989
!!! warning
9090
The `extras` object is nearly always mutated when given to an operator, even when said operator does not have a bang `!` in its name.
9191

92+
With `pushforward`, `pullback` and `hvp`, you can also choose to prepare for the same point `x`, assuming only the seed `v` will change.
93+
Such is the purpose of `prepare_operator_same_point(f, backend, x, v)`, which is otherwise similar to standard preparation.
94+
9295
### Second order
9396

9497
We offer two ways to perform second-order differentiation (for [`second_derivative`](@ref), [`hvp`](@ref) and [`hessian`](@ref)):
@@ -115,15 +118,6 @@ We offer two ways to perform second-order differentiation (for [`second_derivati
115118
Just wrap it around any backend, with an appropriate choice of sparsity detector and coloring algorithm, and call `jacobian` or `hessian`: the result will be sparse.
116119
See the [tutorial section on sparsity](@ref sparsity-tutorial) for details.
117120

118-
### Split reverse mode
119-
120-
Some reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
121-
We make this available for all backends with the following operators:
122-
123-
| out-of-place | in-place |
124-
| :--------------------------------- | :---------------------------------- |
125-
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) |
126-
127121
### Translation
128122

129123
The wrapper [`DifferentiateWith`](@ref) allows you to translate between AD backends.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using ChainRulesCore:
1111
rrule_via_ad
1212
using Compat
1313
import DifferentiationInterface as DI
14-
using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras
14+
using DifferentiationInterface:
15+
DifferentiateWith, NoPullbackExtras, NoPushforwardExtras, PullbackExtras
1516

1617
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
1718

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ end
66

77
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
88
@compat (; f, backend) = dw
9-
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x)
10-
pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy))
11-
return y, pullbackfunc_adjusted
9+
y = f(x)
10+
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
11+
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, backend, x, dy, extras_same))
12+
return y, pullbackfunc
1213
end
Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,36 @@
11
## Pullback
22

3+
struct ChainRulesPullbackExtrasSamePoint{Y,PB} <: PullbackExtras
4+
y::Y
5+
pb::PB
6+
end
7+
38
DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()
49

5-
function DI.value_and_pullback_split(
6-
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
10+
function DI.prepare_pullback_same_point(
11+
f, backend::AutoReverseChainRules, x, dy, ::PullbackExtras=NoPullbackExtras()
712
)
813
rc = ruleconfig(backend)
9-
y, pullback = rrule_via_ad(rc, f, x)
10-
pullbackfunc(dy) = last(pullback(dy))
11-
return y, pullbackfunc
14+
y, pb = rrule_via_ad(rc, f, x)
15+
return ChainRulesPullbackExtrasSamePoint(y, pb)
1216
end
1317

14-
function DI.value_and_pullback!_split(
15-
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
16-
)
17-
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
18-
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
19-
return y, pullbackfunc!
18+
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras)
19+
rc = ruleconfig(backend)
20+
y, pb = rrule_via_ad(rc, f, x)
21+
return y, last(pb(dy))
2022
end
2123

2224
function DI.value_and_pullback(
23-
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
25+
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
26+
)
27+
@compat (; y, pb) = extras
28+
return copy(y), last(pb(dy))
29+
end
30+
31+
function DI.pullback(
32+
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
2433
)
25-
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
26-
return y, pullbackfunc(dy)
34+
@compat (; pb) = extras
35+
return last(pb(dy))
2736
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ end
66

77
function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F}
88
T = tag_type(f, backend, x)
9-
xdual_tmp = make_dual(T, x, dx)
9+
xdual_tmp = make_dual_similar(T, x)
1010
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1111
end
1212

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ end
77

88
function DI.prepare_pushforward(f!::F, y, backend::AutoForwardDiff, x, dx) where {F}
99
T = tag_type(f!, backend, x)
10-
xdual_tmp = make_dual(T, x, dx)
11-
ydual_tmp = make_dual(T, y, similar(y))
10+
xdual_tmp = make_dual_similar(T, x)
11+
ydual_tmp = make_dual_similar(T, y)
1212
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
1313
xdual_tmp, ydual_tmp
1414
)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)
77
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
88
make_dual(::Type{T}, x, dx) where {T} = Dual{T}.(x, dx) # TODO: map causes Enzyme to fail
99

10+
make_dual_similar(::Type{T}, x::Number) where {T} = Dual{T}(x, x)
11+
make_dual_similar(::Type{T}, x) where {T} = similar(x, Dual{T,eltype(x),1})
12+
1013
make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx)
1114

1215
myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)

0 commit comments

Comments
 (0)