Skip to content

Commit 40b629d

Browse files
authored
Implement gradient_and_hvp (#588)
* Implement `gradient_and_hvp` * Use inner * Da fixes * Add symbolic backends * Fix * Fix JuliaFormatter to v1
1 parent 495d988 commit 40b629d

17 files changed

Lines changed: 576 additions & 140 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ FastDifferentiation = "0.4.1"
5353
FiniteDiff = "2.23.1"
5454
FiniteDifferences = "0.12.31"
5555
ForwardDiff = "0.10.36"
56+
JuliaFormatter = "1"
5657
LinearAlgebra = "<0.0.1,1"
5758
Mooncake = "0.4.0"
5859
PolyesterForwardDiff = "0.1.2"

DifferentiationInterface/docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ prepare_hvp
9393
prepare_hvp_same_point
9494
hvp
9595
hvp!
96+
gradient_and_hvp
97+
gradient_and_hvp!
9698
```
9799

98100
### Hessian

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Several variants of each operator are defined:
5858
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
5959
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
6060
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
61-
| [`hvp`](@ref) | [`hvp!`](@ref) | - | - |
61+
| [`hvp`](@ref) | [`hvp!`](@ref) | [`gradient_and_hvp`](@ref) | [`gradient_and_hvp!`](@ref) |
6262

6363
## Mutation and signatures
6464

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,23 @@ end
396396

397397
## HVP
398398

399-
struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
399+
struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep
400400
hvp_exe::E2
401401
hvp_exe!::E2!
402+
gradient_prep::E1
402403
end
403404

404-
function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
405+
function DI.prepare_hvp(f, backend::AutoFastDifferentiation, x, tx::NTuple)
405406
x_var = make_variables(:x, size(x)...)
406407
y_var = f(x_var)
407408

408409
x_vec_var = vec(x_var)
409410
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
410411
hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
411412
hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
412-
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!)
413+
414+
gradient_prep = DI.prepare_gradient(f, backend, x)
415+
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep)
413416
end
414417

415418
function DI.hvp(
@@ -439,6 +442,28 @@ function DI.hvp!(
439442
return tg
440443
end
441444

445+
function DI.gradient_and_hvp(
446+
f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple
447+
)
448+
tg = DI.hvp(f, prep, backend, x, tx)
449+
grad = DI.gradient(f, prep.gradient_prep, backend, x)
450+
return grad, tg
451+
end
452+
453+
function DI.gradient_and_hvp!(
454+
f,
455+
grad,
456+
tg::NTuple,
457+
prep::FastDifferentiationHVPPrep,
458+
backend::AutoFastDifferentiation,
459+
x,
460+
tx::NTuple,
461+
)
462+
DI.hvp!(f, tg, prep, backend, x, tx)
463+
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
464+
return grad, tg
465+
end
466+
442467
## Hessian
443468

444469
struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,32 @@ function DI.hvp!(
541541
return DI.hvp!(f, tg, prep, SecondOrder(backend, backend), x, tx, contexts...)
542542
end
543543

544+
function DI.gradient_and_hvp(
545+
f::F,
546+
prep::HVPPrep,
547+
backend::AutoForwardDiff,
548+
x,
549+
tx::NTuple,
550+
contexts::Vararg{Context,C},
551+
) where {F,C}
552+
return DI.gradient_and_hvp(f, prep, SecondOrder(backend, backend), x, tx, contexts...)
553+
end
554+
555+
function DI.gradient_and_hvp!(
556+
f::F,
557+
grad,
558+
tg::NTuple,
559+
prep::HVPPrep,
560+
backend::AutoForwardDiff,
561+
x,
562+
tx::NTuple,
563+
contexts::Vararg{Context,C},
564+
) where {F,C}
565+
return DI.gradient_and_hvp!(
566+
f, grad, tg, prep, SecondOrder(backend, backend), x, tx, contexts...
567+
)
568+
end
569+
544570
## Hessian
545571

546572
### Unprepared, only when chunk size and tag are not specified

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ function DI.prepare_hvp(
3232
T = tag_type(f, tagged_outer_backend, x)
3333
xdual = make_dual(T, x, tx)
3434
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
35+
# TODO: get rid of closure?
3536
function inner_gradient(x, unannotated_contexts...)
3637
annotated_contexts = rewrap(unannotated_contexts...)
3738
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
@@ -73,3 +74,34 @@ function DI.hvp!(
7374
)
7475
return tg
7576
end
77+
78+
function DI.gradient_and_hvp(
79+
f::F,
80+
prep::ForwardDiffOverSomethingHVPPrep,
81+
::SecondOrder{<:AutoForwardDiff},
82+
x,
83+
tx::NTuple,
84+
contexts::Vararg{Context,C},
85+
) where {F,C}
86+
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
87+
return DI.value_and_pushforward(
88+
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
89+
)
90+
end
91+
92+
function DI.gradient_and_hvp!(
93+
f::F,
94+
grad,
95+
tg::NTuple,
96+
prep::ForwardDiffOverSomethingHVPPrep,
97+
::SecondOrder{<:AutoForwardDiff},
98+
x,
99+
tx::NTuple,
100+
contexts::Vararg{Context,C},
101+
) where {F,C}
102+
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
103+
new_grad, _ = DI.value_and_pushforward!(
104+
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
105+
)
106+
return copyto!(grad, new_grad), tg
107+
end

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,32 @@ function DI.hvp!(
309309
return DI.hvp!(f, tg, prep, single_threaded(backend), x, tx, contexts...)
310310
end
311311

312+
function DI.gradient_and_hvp(
313+
f,
314+
prep::HVPPrep,
315+
backend::AutoPolyesterForwardDiff,
316+
x,
317+
tx::NTuple,
318+
contexts::Vararg{Context,C},
319+
) where {C}
320+
return DI.gradient_and_hvp(f, prep, single_threaded(backend), x, tx, contexts...)
321+
end
322+
323+
function DI.gradient_and_hvp!(
324+
f,
325+
grad,
326+
tg::NTuple,
327+
prep::HVPPrep,
328+
backend::AutoPolyesterForwardDiff,
329+
x,
330+
tx::NTuple,
331+
contexts::Vararg{Context,C},
332+
) where {C}
333+
return DI.gradient_and_hvp!(
334+
f, grad, tg, prep, single_threaded(backend), x, tx, contexts...
335+
)
336+
end
337+
312338
## Second derivative
313339

314340
function DI.prepare_second_derivative(

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,22 @@ function DI.hvp!(
317317
return tg
318318
end
319319

320+
function DI.gradient_and_hvp(
321+
f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
322+
)
323+
tg = DI.hvp(f, prep, backend, x, tx)
324+
grad = DI.gradient(f, prep.gradient_prep, backend, x)
325+
return grad, tg
326+
end
327+
328+
function DI.gradient_and_hvp!(
329+
f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
330+
)
331+
DI.hvp!(f, tg, prep, backend, x, tx)
332+
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
333+
return grad, tg
334+
end
335+
320336
## Second derivative
321337

322338
struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,29 @@ function DI.hvp!(
171171
return DI.hvp!(f, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
172172
end
173173

174+
function DI.gradient_and_hvp(
175+
f, prep::HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{Constant,C}
176+
) where {C}
177+
return DI.gradient_and_hvp(
178+
f, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
179+
)
180+
end
181+
182+
function DI.gradient_and_hvp!(
183+
f,
184+
grad,
185+
tg::NTuple,
186+
prep::HVPPrep,
187+
backend::AutoZygote,
188+
x,
189+
tx::NTuple,
190+
contexts::Vararg{Constant,C},
191+
) where {C}
192+
return DI.gradient_and_hvp!(
193+
f, grad, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
194+
)
195+
end
196+
174197
## Hessian
175198

176199
function DI.prepare_hessian(f, ::AutoZygote, x, contexts::Vararg{Constant,C}) where {C}

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ export jacobian!, jacobian
8484

8585
export second_derivative!, second_derivative
8686
export value_derivative_and_second_derivative, value_derivative_and_second_derivative!
87-
export hvp!, hvp
87+
export hvp!, hvp, gradient_and_hvp, gradient_and_hvp!
8888
export hessian!, hessian
8989
export value_gradient_and_hessian, value_gradient_and_hessian!
9090

0 commit comments

Comments
 (0)