Skip to content

Commit 2cc261f

Browse files
committed
Avoid prep
1 parent 6ab9b93 commit 2cc261f

1 file changed

Lines changed: 19 additions & 8 deletions

File tree

  • DifferentiationInterface/src/second_order

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ end
6565

6666
## Forward over anything
6767

68-
struct ForwardOverAnythingHVPPrep{G,PO<:PushforwardPrep,PI<:PushforwardPrep} <: HVPPrep
68+
struct ForwardOverAnythingHVPPrep{
69+
G,PO<:PushforwardPrep,PI<:Union{Nothing,PushforwardPrep}
70+
} <: HVPPrep
6971
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
7072
grad_buffer::G
7173
outer_pushforward_prep::PO
@@ -88,9 +90,13 @@ function _prepare_hvp_aux(
8890
outer_pushforward_prep = prepare_pushforward(
8991
shuffled_gradient, outer(backend), x, tx, new_contexts...
9092
)
91-
outer_pushforward_in_prep = prepare_pushforward(
92-
shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...
93-
)
93+
outer_pushforward_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported
94+
prepare_pushforward(
95+
shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...
96+
)
97+
else
98+
nothing
99+
end
94100
return ForwardOverAnythingHVPPrep(
95101
grad_buffer, outer_pushforward_prep, outer_pushforward_in_prep
96102
)
@@ -386,7 +392,8 @@ end
386392

387393
## Reverse over reverse
388394

389-
struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:PullbackPrep} <: HVPPrep
395+
struct ReverseOverReverseHVPPrep{G,PO<:PullbackPrep,PI<:Union{Nothing,PullbackPrep}} <:
396+
HVPPrep
390397
# pullback of gradient
391398
grad_buffer::G
392399
outer_pullback_prep::PO
@@ -409,9 +416,13 @@ function _prepare_hvp_aux(
409416
outer_pullback_prep = prepare_pullback(
410417
shuffled_gradient, outer(backend), x, tx, new_contexts...
411418
)
412-
outer_pullback_in_prep = prepare_pullback(
413-
shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...
414-
)
419+
outer_pullback_in_prep = if inplace_support(outer(backend)) isa InPlaceSupported
420+
prepare_pullback(
421+
shuffled_gradient!, grad_buffer, outer(backend), x, tx, new_contexts...
422+
)
423+
else
424+
nothing
425+
end
415426
return ReverseOverReverseHVPPrep(
416427
grad_buffer, outer_pullback_prep, outer_pullback_in_prep
417428
)

0 commit comments

Comments
 (0)