Skip to content

Commit 932913a

Browse files
authored
Better HVP in reverse over forward (#494)
1 parent 472c027 commit 932913a

2 files changed

Lines changed: 48 additions & 25 deletions

File tree

  • DifferentiationInterface

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ struct ForwardOverReverseHVPPrep{G,E<:PushforwardPrep} <: HVPPrep
5050
outer_pushforward_prep::E
5151
end
5252

53-
struct ReverseOverForwardHVPPrep <: HVPPrep end
53+
struct ReverseOverForwardHVPPrep{P,E} <: HVPPrep
54+
inner_pushforward::P
55+
outer_gradient_prep::E
56+
end
5457

5558
struct ReverseOverReverseHVPPrep{G,E<:PullbackPrep} <: HVPPrep
5659
inner_gradient::G
@@ -111,9 +114,19 @@ function _prepare_hvp_aux(
111114
tx::Tangents,
112115
contexts::Vararg{Context,C},
113116
) where {F,C}
117+
rewrap = Rewrap(contexts...)
114118
# gradient of pushforward
115-
# uses dx in the closure so it can't be prepared
116-
return ReverseOverForwardHVPPrep()
119+
function inner_pushforward(_x, _dx, unannotated_contexts...)
120+
annotated_contexts = rewrap(unannotated_contexts...)
121+
ty = pushforward(
122+
f, nested(inner(backend)), _x, Tangents(_dx), annotated_contexts...
123+
)
124+
return only(ty)
125+
end
126+
outer_gradient_prep = prepare_gradient(
127+
inner_pushforward, outer(backend), x, contexts...
128+
)
129+
return ReverseOverForwardHVPPrep(inner_pushforward, outer_gradient_prep)
117130
end
118131

119132
function _prepare_hvp_aux(
@@ -168,23 +181,15 @@ end
168181

169182
function hvp(
170183
f::F,
171-
::ReverseOverForwardHVPPrep,
184+
prep::ReverseOverForwardHVPPrep,
172185
backend::AbstractADType,
173186
x,
174187
tx::Tangents,
175188
contexts::Vararg{Context,C},
176189
) where {F,C}
177-
rewrap = Rewrap(contexts...)
190+
@compat (; inner_pushforward, outer_gradient_prep) = prep
178191
tg = map(tx) do dx
179-
function inner_pushforward(_x, unannotated_contexts...)
180-
annotated_contexts = rewrap(unannotated_contexts...)
181-
return only(
182-
pushforward(
183-
f, nested(inner(backend)), _x, Tangents(dx), annotated_contexts...
184-
),
185-
)
186-
end
187-
gradient(only inner_pushforward, outer(backend), x, contexts...)
192+
gradient(inner_pushforward, outer(backend), x, Constant(dx), contexts...)
188193
end
189194
return tg
190195
end
@@ -234,23 +239,23 @@ end
234239
function hvp!(
235240
f::F,
236241
tg::Tangents,
237-
::ReverseOverForwardHVPPrep,
242+
prep::ReverseOverForwardHVPPrep,
238243
backend::AbstractADType,
239244
x,
240245
tx::Tangents,
241246
contexts::Vararg{Context,C},
242247
) where {F,C}
243-
rewrap = Rewrap(contexts...)
248+
@compat (; inner_pushforward, outer_gradient_prep) = prep
244249
for b in eachindex(tx.d, tg.d)
245-
function inner_pushforward(_x, unannotated_contexts...)
246-
annotated_contexts = rewrap(unannotated_contexts...)
247-
return only(
248-
pushforward(
249-
f, nested(inner(backend)), _x, Tangents(tx.d[b]), annotated_contexts...
250-
),
251-
)
252-
end
253-
gradient!(only inner_pushforward, tg.d[b], outer(backend), x, contexts...)
250+
gradient!(
251+
inner_pushforward,
252+
tg.d[b],
253+
outer_gradient_prep,
254+
outer(backend),
255+
x,
256+
Constant(tx.d[b]),
257+
contexts...,
258+
)
254259
end
255260
return tg
256261
end

DifferentiationInterface/test/Misc/FromPrimitive/test.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,17 @@ fromprimitive_backends = [ #
1111
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
1212
]
1313

14+
fromprimitive_secondorder_backends = [ #
15+
SecondOrder(
16+
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
17+
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
18+
),
19+
SecondOrder(
20+
AutoReverseFromPrimitive(AutoForwardDiff(; chunksize=5)),
21+
AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5)),
22+
),
23+
]
24+
1425
for backend in vcat(fromprimitive_backends)
1526
@test check_available(backend)
1627
@test check_inplace(backend)
@@ -19,3 +30,10 @@ end
1930
test_differentiation(
2031
fromprimitive_backends, default_scenarios(; include_constantified=true); logging=LOGGING
2132
);
33+
34+
test_differentiation(
35+
fromprimitive_secondorder_backends,
36+
default_scenarios(; include_constantified=true);
37+
first_order=false,
38+
logging=LOGGING,
39+
);

0 commit comments

Comments
 (0)