Skip to content

Commit 17863c6

Browse files
authored
Replace closures with argument shuffling in second order (#572)
* Replace closures with argument shuffling in second order * Fix symbolics * Remove mutating shuffled
1 parent 10d63da commit 17863c6

11 files changed

Lines changed: 236 additions & 87 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ In practice, many AD backends have custom implementations for high-level operato
4848
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
4949
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
5050
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
51-
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | | |
51+
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | | |
5252
| `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
5353
| `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ |
5454

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,96 @@ function DI.value_gradient_and_hessian!(
275275
DI.hessian!(f, hess, prep, backend, x)
276276
return y, grad, hess
277277
end
278+
279+
## HVP
280+
281+
struct SymbolicsOneArgHVPPrep{G,E2,E2!} <: HVPPrep
282+
gradient_prep::G
283+
hvp_exe::E2
284+
hvp_exe!::E2!
285+
end
286+
287+
function DI.prepare_hvp(f, backend::AutoSymbolics, x, tx::NTuple)
288+
x_var = variables(:x, axes(x)...)
289+
dx_var = variables(:dx, axes(x)...)
290+
# Symbolic.hessian only accepts vectors
291+
hess_var = hessian(f(x_var), vec(x_var))
292+
hvp_vec_var = hess_var * vec(dx_var)
293+
294+
res = build_function(hvp_vec_var, vcat(vec(x_var), vec(dx_var)); expression=Val(false))
295+
(hvp_exe, hvp_exe!) = res
296+
297+
gradient_prep = DI.prepare_gradient(f, backend, x)
298+
return SymbolicsOneArgHVPPrep(gradient_prep, hvp_exe, hvp_exe!)
299+
end
300+
301+
function DI.hvp(f, prep::SymbolicsOneArgHVPPrep, ::AutoSymbolics, x, tx::NTuple)
302+
return map(tx) do dx
303+
v_vec = vcat(vec(x), vec(dx))
304+
dg_vec = prep.hvp_exe(v_vec)
305+
reshape(dg_vec, size(x))
306+
end
307+
end
308+
309+
function DI.hvp!(
310+
f, tg::NTuple, prep::SymbolicsOneArgHVPPrep, ::AutoSymbolics, x, tx::NTuple
311+
)
312+
for b in eachindex(tx, tg)
313+
dx, dg = tx[b], tg[b]
314+
v_vec = vcat(vec(x), vec(dx))
315+
prep.hvp_exe!(vec(dg), v_vec)
316+
end
317+
return tg
318+
end
319+
320+
## Second derivative
321+
322+
struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep
323+
derivative_prep::D
324+
der2_exe::E1
325+
der2_exe!::E1!
326+
end
327+
328+
function DI.prepare_second_derivative(f, backend::AutoSymbolics, x)
329+
x_var = variable(:x)
330+
der_var = derivative(f(x_var), x_var)
331+
der2_var = derivative(der_var, x_var)
332+
333+
res = build_function(der2_var, x_var; expression=Val(false))
334+
(der2_exe, der2_exe!) = if res isa Tuple
335+
res
336+
elseif res isa RuntimeGeneratedFunction
337+
res, nothing
338+
end
339+
derivative_prep = DI.prepare_derivative(f, backend, x)
340+
return SymbolicsOneArgSecondDerivativePrep(derivative_prep, der2_exe, der2_exe!)
341+
end
342+
343+
function DI.second_derivative(
344+
f, prep::SymbolicsOneArgSecondDerivativePrep, ::AutoSymbolics, x
345+
)
346+
return prep.der2_exe(x)
347+
end
348+
349+
function DI.second_derivative!(
350+
f, der2, prep::SymbolicsOneArgSecondDerivativePrep, ::AutoSymbolics, x
351+
)
352+
prep.der2_exe!(der2, x)
353+
return der2
354+
end
355+
356+
function DI.value_derivative_and_second_derivative(
357+
f, prep::SymbolicsOneArgSecondDerivativePrep, backend::AutoSymbolics, x
358+
)
359+
y, der = DI.value_and_derivative(f, prep.derivative_prep, backend, x)
360+
der2 = DI.second_derivative(f, prep, backend, x)
361+
return y, der, der2
362+
end
363+
364+
function DI.value_derivative_and_second_derivative!(
365+
f, der, der2, prep::SymbolicsOneArgSecondDerivativePrep, backend::AutoSymbolics, x
366+
)
367+
y, _ = DI.value_and_derivative!(f, der, prep.derivative_prep, backend, x)
368+
DI.second_derivative!(f, der2, prep, backend, x)
369+
return y, der, der2
370+
end

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,11 @@ function derivative!(
192192
pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
193193
return der
194194
end
195+
196+
## Shuffled
197+
198+
function shuffled_derivative(
199+
x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any,C}
200+
) where {F,C}
201+
return derivative(f, backend, x, rewrap(unannotated_contexts...)...)
202+
end

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,11 @@ function gradient!(
120120
pullback!(f, (grad,), prep.pullback_prep, backend, x, (true,), contexts...)
121121
return grad
122122
end
123+
124+
## Shuffled
125+
126+
function shuffled_gradient(
127+
x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any,C}
128+
) where {F,C}
129+
return gradient(f, backend, x, rewrap(unannotated_contexts...)...)
130+
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,17 @@ function pushforward!(
326326
) where {F,C}
327327
return value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)[2]
328328
end
329+
330+
## Shuffled
331+
332+
function shuffled_single_pushforward(
333+
x,
334+
f::F,
335+
backend::AbstractADType,
336+
dx,
337+
rewrap::Rewrap{C},
338+
unannotated_contexts::Vararg{Any,C},
339+
) where {F,C}
340+
ty = pushforward(f, backend, x, (dx,), rewrap(unannotated_contexts...)...)
341+
return only(ty)
342+
end

0 commit comments

Comments
 (0)