6565
6666# # Forward over anything
6767
68- struct ForwardOverAnythingHVPPrep{G,PO<: PushforwardPrep ,PI<: PushforwardPrep } <: HVPPrep
68+ struct ForwardOverAnythingHVPPrep{
69+ G,PO<: PushforwardPrep ,PI<: Union{Nothing,PushforwardPrep}
70+ } <: HVPPrep
6971 # pushforward of many pushforwards in theory, but pushforward of gradient in practice
7072 grad_buffer:: G
7173 outer_pushforward_prep:: PO
@@ -88,9 +90,13 @@ function _prepare_hvp_aux(
8890 outer_pushforward_prep = prepare_pushforward (
8991 shuffled_gradient, outer (backend), x, tx, new_contexts...
9092 )
91- outer_pushforward_in_prep = prepare_pushforward (
92- shuffled_gradient!, grad_buffer, outer (backend), x, tx, new_contexts...
93- )
93+ outer_pushforward_in_prep = if inplace_support (outer (backend)) isa InPlaceSupported
94+ prepare_pushforward (
95+ shuffled_gradient!, grad_buffer, outer (backend), x, tx, new_contexts...
96+ )
97+ else
98+ nothing
99+ end
94100 return ForwardOverAnythingHVPPrep (
95101 grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep
96102 )
386392
387393# # Reverse over reverse
388394
389- struct ReverseOverReverseHVPPrep{G,PO<: PullbackPrep ,PI<: PullbackPrep } <: HVPPrep
395+ struct ReverseOverReverseHVPPrep{G,PO<: PullbackPrep ,PI<: Union{Nothing,PullbackPrep} } < :
396+ HVPPrep
390397 # pullback of gradient
391398 grad_buffer:: G
392399 outer_pullback_prep:: PO
@@ -409,9 +416,13 @@ function _prepare_hvp_aux(
409416 outer_pullback_prep = prepare_pullback (
410417 shuffled_gradient, outer (backend), x, tx, new_contexts...
411418 )
412- outer_pullback_in_prep = prepare_pullback (
413- shuffled_gradient!, grad_buffer, outer (backend), x, tx, new_contexts...
414- )
419+ outer_pullback_in_prep = if inplace_support (outer (backend)) isa InPlaceSupported
420+ prepare_pullback (
421+ shuffled_gradient!, grad_buffer, outer (backend), x, tx, new_contexts...
422+ )
423+ else
424+ nothing
425+ end
415426 return ReverseOverReverseHVPPrep (
416427 grad_buffer, outer_pullback_prep, outer_pullback_in_prep
417428 )
0 commit comments