Skip to content

Commit bad9cb5

Browse files
committed
perf: compute in-place HVP from in-place gradient
1 parent 6ae9532 commit bad9cb5

4 files changed

Lines changed: 318 additions & 13 deletions

File tree

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,16 @@ function shuffled_gradient(
132132
) where {F,C}
133133
return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...)
134134
end
135+
136+
function shuffled_gradient!(
137+
grad,
138+
x,
139+
f::F,
140+
prep::GradientPrep,
141+
backend::AbstractADType,
142+
rewrap::Rewrap{C},
143+
unannotated_contexts::Vararg{Any,C},
144+
) where {F,C}
145+
gradient!(f, grad, prep, backend, x, rewrap(unannotated_contexts...)...)
146+
return nothing
147+
end

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@ Wrapper which forces a given backend to act as a reverse-mode backend.
1414
1515
Used in internal testing.
1616
"""
17-
struct AutoReverseFromPrimitive{B} <: FromPrimitive
17+
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive
1818
backend::B
1919
end
2020

21+
function AutoReverseFromPrimitive(backend::AbstractADType; inplace=false)
22+
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
23+
end
24+
25+
inplace_support(::AutoReverseFromPrimitive{true}) = InPlaceSupported()
26+
inplace_support(::AutoReverseFromPrimitive{false}) = InPlaceNotSupported()
2127
ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode()
2228

2329
function threshold_batchsize(fromprim::AutoReverseFromPrimitive, dimension::Integer)

0 commit comments

Comments
 (0)