Skip to content

Commit 346834e

Browse files
authored
fix: improve ForwardDiff tagging for HVP (#596)
* Improve ForwardDiff tagging * Remove tag unwrapping for FixTail * Cov * Bump DI
1 parent 16c0194 commit 346834e

6 files changed

Lines changed: 94 additions & 47 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.17"
4+
version = "0.6.18"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ using DifferentiationInterface:
77
BatchSizeSettings,
88
Cache,
99
Constant,
10+
PrepContext,
1011
Context,
12+
FixTail,
1113
DerivativePrep,
1214
DifferentiateWith,
1315
GradientPrep,
@@ -21,6 +23,7 @@ using DifferentiationInterface:
2123
SecondOrder,
2224
inner,
2325
outer,
26+
shuffled_gradient,
2427
unwrap,
2528
with_contexts
2629
import ForwardDiff.DiffResults as DR
Lines changed: 72 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,6 @@
1-
struct ForwardDiffOverSomethingHVPWrapper{F}
2-
f::F
3-
end
4-
5-
"""
6-
tag_backend_hvp(f, ::AutoForwardDiff, x)
7-
8-
Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like.
9-
"""
10-
tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend
11-
12-
function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize}
13-
tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x))
14-
return AutoForwardDiff{chunksize,typeof(tag)}(tag)
15-
end
16-
17-
struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep
18-
tagged_outer_backend::B
19-
inner_gradient::G
20-
outer_pushforward_prep::E
1+
struct ForwardDiffOverSomethingHVPPrep{E1<:GradientPrep,E2<:PushforwardPrep} <: HVPPrep
2+
inner_gradient_prep::E1
3+
outer_pushforward_prep::E2
214
end
225

236
function DI.prepare_hvp(
@@ -27,65 +10,94 @@ function DI.prepare_hvp(
2710
tx::NTuple,
2811
contexts::Vararg{Context,C},
2912
) where {F,C}
30-
rewrap = Rewrap(contexts...)
31-
tagged_outer_backend = tag_backend_hvp(f, outer(backend), x)
32-
T = tag_type(f, tagged_outer_backend, x)
13+
T = tag_type(shuffled_gradient, outer(backend), x)
3314
xdual = make_dual(T, x, tx)
34-
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
35-
# TODO: get rid of closure?
36-
function inner_gradient(x, unannotated_contexts...)
37-
annotated_contexts = rewrap(unannotated_contexts...)
38-
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
39-
end
40-
outer_pushforward_prep = DI.prepare_pushforward(
41-
inner_gradient, tagged_outer_backend, x, tx, contexts...
15+
inner_gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
16+
rewrap = Rewrap(contexts...)
17+
new_contexts = (
18+
Constant(f),
19+
PrepContext(inner_gradient_prep),
20+
Constant(inner(backend)),
21+
Constant(rewrap),
22+
contexts...,
4223
)
43-
return ForwardDiffOverSomethingHVPPrep(
44-
tagged_outer_backend, inner_gradient, outer_pushforward_prep
24+
outer_pushforward_prep = DI.prepare_pushforward(
25+
shuffled_gradient, outer(backend), x, tx, new_contexts...
4526
)
27+
return ForwardDiffOverSomethingHVPPrep(inner_gradient_prep, outer_pushforward_prep)
4628
end
4729

4830
function DI.hvp(
4931
f::F,
5032
prep::ForwardDiffOverSomethingHVPPrep,
51-
::SecondOrder{<:AutoForwardDiff},
33+
backend::SecondOrder{<:AutoForwardDiff},
5234
x,
5335
tx::NTuple,
5436
contexts::Vararg{Context,C},
5537
) where {F,C}
56-
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
38+
(; inner_gradient_prep, outer_pushforward_prep) = prep
39+
rewrap = Rewrap(contexts...)
40+
new_contexts = (
41+
Constant(f),
42+
PrepContext(inner_gradient_prep),
43+
Constant(inner(backend)),
44+
Constant(rewrap),
45+
contexts...,
46+
)
5747
return DI.pushforward(
58-
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
48+
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
5949
)
6050
end
6151

6252
function DI.hvp!(
6353
f::F,
6454
tg::NTuple,
6555
prep::ForwardDiffOverSomethingHVPPrep,
66-
::SecondOrder{<:AutoForwardDiff},
56+
backend::SecondOrder{<:AutoForwardDiff},
6757
x,
6858
tx::NTuple,
6959
contexts::Vararg{Context,C},
7060
) where {F,C}
71-
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
72-
DI.pushforward!(
73-
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
61+
(; inner_gradient_prep, outer_pushforward_prep) = prep
62+
rewrap = Rewrap(contexts...)
63+
new_contexts = (
64+
Constant(f),
65+
PrepContext(inner_gradient_prep),
66+
Constant(inner(backend)),
67+
Constant(rewrap),
68+
contexts...,
69+
)
70+
return DI.pushforward!(
71+
shuffled_gradient,
72+
tg,
73+
outer_pushforward_prep,
74+
outer(backend),
75+
x,
76+
tx,
77+
new_contexts...,
7478
)
7579
return tg
7680
end
7781

7882
function DI.gradient_and_hvp(
7983
f::F,
8084
prep::ForwardDiffOverSomethingHVPPrep,
81-
::SecondOrder{<:AutoForwardDiff},
85+
backend::SecondOrder{<:AutoForwardDiff},
8286
x,
8387
tx::NTuple,
8488
contexts::Vararg{Context,C},
8589
) where {F,C}
86-
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
90+
(; inner_gradient_prep, outer_pushforward_prep) = prep
91+
rewrap = Rewrap(contexts...)
92+
new_contexts = (
93+
Constant(f),
94+
PrepContext(inner_gradient_prep),
95+
Constant(inner(backend)),
96+
Constant(rewrap),
97+
contexts...,
98+
)
8799
return DI.value_and_pushforward(
88-
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
100+
shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts...
89101
)
90102
end
91103

@@ -94,14 +106,28 @@ function DI.gradient_and_hvp!(
94106
grad,
95107
tg::NTuple,
96108
prep::ForwardDiffOverSomethingHVPPrep,
97-
::SecondOrder{<:AutoForwardDiff},
109+
backend::SecondOrder{<:AutoForwardDiff},
98110
x,
99111
tx::NTuple,
100112
contexts::Vararg{Context,C},
101113
) where {F,C}
102-
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
114+
(; inner_gradient_prep, outer_pushforward_prep) = prep
115+
rewrap = Rewrap(contexts...)
116+
new_contexts = (
117+
Constant(f),
118+
PrepContext(inner_gradient_prep),
119+
Constant(inner(backend)),
120+
Constant(rewrap),
121+
contexts...,
122+
)
103123
new_grad, _ = DI.value_and_pushforward!(
104-
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
124+
shuffled_gradient,
125+
tg,
126+
outer_pushforward_prep,
127+
outer(backend),
128+
x,
129+
tx,
130+
new_contexts...,
105131
)
106132
return copyto!(grad, new_grad), tg
107133
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
8585
end
8686

8787
_translate(::Type{T}, ::Val{B}, c::Constant) where {T,B} = unwrap(c)
88+
_translate(::Type{T}, ::Val{B}, c::PrepContext) where {T,B} = unwrap(c)
8889

8990
function _translate(::Type{T}, ::Val{B}, c::Cache) where {T,B}
9091
c0 = unwrap(c)

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,14 @@ function shuffled_gradient(
128128
) where {F,C}
129129
return gradient(f, backend, x, rewrap(unannotated_contexts...)...)
130130
end
131+
132+
function shuffled_gradient(
133+
x,
134+
f::F,
135+
prep::GradientPrep,
136+
backend::AbstractADType,
137+
rewrap::Rewrap{C},
138+
unannotated_contexts::Vararg{Any,C},
139+
) where {F,C}
140+
return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...)
141+
end

DifferentiationInterface/src/utils/context.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ unwrap(c::Cache) = c.data
7474

7575
Base.:(==)(c1::Cache, c2::Cache) = c1.data == c2.data
7676

77+
struct PrepContext{T<:Prep} <: Context
78+
data::T
79+
end
80+
81+
unwrap(c::PrepContext) = c.data
82+
7783
struct Rewrap{C,T}
7884
context_makers::T
7985
function Rewrap(contexts::Vararg{Context,C}) where {C}

0 commit comments

Comments
 (0)