Skip to content

Commit fb16fab

Browse files
authored
fix: use PolyesterForwardDiff's inner gradient in its HVP (#610)
* fix: use PolyesterForwardDiff's inner gradient in its HVP * SecondOrder
1 parent 3b978b8 commit fb16fab

3 files changed

Lines changed: 22 additions & 6 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.19"
4+
version = "0.6.20"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using DifferentiationInterface:
1515
NoJacobianPrep,
1616
PushforwardPrep,
1717
SecondDerivativePrep,
18+
SecondOrder,
1819
unwrap,
1920
with_contexts
2021
using LinearAlgebra: mul!

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,9 @@ end
283283
function DI.prepare_hvp(
284284
f, backend::AutoPolyesterForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
285285
) where {C}
286-
return DI.prepare_hvp(f, single_threaded(backend), x, tx, contexts...)
286+
return DI.prepare_hvp(
287+
f, SecondOrder(single_threaded(backend), backend), x, tx, contexts...
288+
)
287289
end
288290

289291
function DI.hvp(
@@ -294,7 +296,9 @@ function DI.hvp(
294296
tx::NTuple,
295297
contexts::Vararg{Context,C},
296298
) where {C}
297-
return DI.hvp(f, prep, single_threaded(backend), x, tx, contexts...)
299+
return DI.hvp(
300+
f, prep, SecondOrder(single_threaded(backend), backend), x, tx, contexts...
301+
)
298302
end
299303

300304
function DI.hvp!(
@@ -306,7 +310,9 @@ function DI.hvp!(
306310
tx::NTuple,
307311
contexts::Vararg{Context,C},
308312
) where {C}
309-
return DI.hvp!(f, tg, prep, single_threaded(backend), x, tx, contexts...)
313+
return DI.hvp!(
314+
f, tg, prep, SecondOrder(single_threaded(backend), backend), x, tx, contexts...
315+
)
310316
end
311317

312318
function DI.gradient_and_hvp(
@@ -317,7 +323,9 @@ function DI.gradient_and_hvp(
317323
tx::NTuple,
318324
contexts::Vararg{Context,C},
319325
) where {C}
320-
return DI.gradient_and_hvp(f, prep, single_threaded(backend), x, tx, contexts...)
326+
return DI.gradient_and_hvp(
327+
f, prep, SecondOrder(single_threaded(backend), backend), x, tx, contexts...
328+
)
321329
end
322330

323331
function DI.gradient_and_hvp!(
@@ -331,7 +339,14 @@ function DI.gradient_and_hvp!(
331339
contexts::Vararg{Context,C},
332340
) where {C}
333341
return DI.gradient_and_hvp!(
334-
f, grad, tg, prep, single_threaded(backend), x, tx, contexts...
342+
f,
343+
grad,
344+
tg,
345+
prep,
346+
SecondOrder(single_threaded(backend), backend),
347+
x,
348+
tx,
349+
contexts...,
335350
)
336351
end
337352

0 commit comments

Comments
 (0)