Skip to content

Commit 10d63da

Browse files
authored
Fix ForwardDiff tags & improve StaticArrays support (#571)
* Improve StaticArrays support * Fix tags * Add an option to avoid testing when benchmarks fail * Fix
1 parent 9ee81a4 commit 10d63da

11 files changed

Lines changed: 332 additions & 186 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.10"
4+
version = "0.6.11"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 134 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,51 @@
11
## Pushforward
22

3+
### Unprepared (avoid working on `similar(x)`)
4+
5+
function DI.value_and_pushforward(
6+
f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}
7+
) where {F,B,C}
8+
T = tag_type(f, backend, x)
9+
xdual_tmp = make_dual(T, x, tx)
10+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
11+
y = myvalue(T, ydual)
12+
ty = mypartials(T, Val(B), ydual)
13+
return y, ty
14+
end
15+
16+
function DI.value_and_pushforward!(
17+
f::F, ty::NTuple, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
18+
) where {F,C}
19+
T = tag_type(f, backend, x)
20+
xdual_tmp = make_dual(T, x, tx)
21+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
22+
y = myvalue(T, ydual)
23+
mypartials!(T, ty, ydual)
24+
return y, ty
25+
end
26+
27+
function DI.pushforward(
28+
f::F, backend::AutoForwardDiff, x, tx::NTuple{B}, contexts::Vararg{Context,C}
29+
) where {F,B,C}
30+
T = tag_type(f, backend, x)
31+
xdual_tmp = make_dual(T, x, tx)
32+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
33+
ty = mypartials(T, Val(B), ydual)
34+
return ty
35+
end
36+
37+
function DI.pushforward!(
38+
f::F, ty::NTuple, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
39+
) where {F,C}
40+
T = tag_type(f, backend, x)
41+
xdual_tmp = make_dual(T, x, tx)
42+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
43+
mypartials!(T, ty, ydual)
44+
return ty
45+
end
46+
47+
### Prepared
48+
349
struct ForwardDiffOneArgPushforwardPrep{T,X} <: PushforwardPrep
450
xdual_tmp::X
551
end
@@ -159,12 +205,12 @@ end
159205

160206
## Gradient
161207

162-
### Unprepared, only when chunk size not specified
208+
### Unprepared, only when chunk size and tag are not specified
163209

164210
function DI.value_and_gradient!(
165-
f::F, grad, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
166-
) where {F,C,chunksize}
167-
if isnothing(chunksize)
211+
f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
212+
) where {F,C,chunksize,T}
213+
if isnothing(chunksize) && T === Nothing
168214
fc = with_contexts(f, contexts...)
169215
result = DiffResult(zero(eltype(x)), (grad,))
170216
result = gradient!(result, fc, x)
@@ -178,9 +224,9 @@ function DI.value_and_gradient!(
178224
end
179225

180226
function DI.value_and_gradient(
181-
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
182-
) where {F,C,chunksize}
183-
if isnothing(chunksize)
227+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
228+
) where {F,C,chunksize,T}
229+
if isnothing(chunksize) && T === Nothing
184230
fc = with_contexts(f, contexts...)
185231
result = GradientResult(x)
186232
result = gradient!(result, fc, x)
@@ -192,9 +238,9 @@ function DI.value_and_gradient(
192238
end
193239

194240
function DI.gradient!(
195-
f::F, grad, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
196-
) where {F,C,chunksize}
197-
if isnothing(chunksize)
241+
f::F, grad, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
242+
) where {F,C,chunksize,T}
243+
if isnothing(chunksize) && T === Nothing
198244
fc = with_contexts(f, contexts...)
199245
return gradient!(grad, fc, x)
200246
else
@@ -204,9 +250,9 @@ function DI.gradient!(
204250
end
205251

206252
function DI.gradient(
207-
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
208-
) where {F,C,chunksize}
209-
if isnothing(chunksize)
253+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
254+
) where {F,C,chunksize,T}
255+
if isnothing(chunksize) && T === Nothing
210256
fc = with_contexts(f, contexts...)
211257
return gradient(fc, x)
212258
else
@@ -225,7 +271,10 @@ function DI.prepare_gradient(
225271
f::F, backend::AutoForwardDiff, x::AbstractArray, contexts::Vararg{Context,C}
226272
) where {F,C}
227273
fc = with_contexts(f, contexts...)
228-
return ForwardDiffGradientPrep(GradientConfig(fc, x, choose_chunk(backend, x)))
274+
chunk = choose_chunk(backend, x)
275+
tag = get_tag(fc, backend, x)
276+
config = GradientConfig(fc, x, chunk, tag)
277+
return ForwardDiffGradientPrep(config)
229278
end
230279

231280
function DI.value_and_gradient!(
@@ -274,12 +323,12 @@ end
274323

275324
## Jacobian
276325

277-
### Unprepared, only when chunk size not specified
326+
### Unprepared, only when chunk size and tag are not specified
278327

279328
function DI.value_and_jacobian!(
280-
f::F, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
281-
) where {F,C,chunksize}
282-
if isnothing(chunksize)
329+
f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
330+
) where {F,C,chunksize,T}
331+
if isnothing(chunksize) && T === Nothing
283332
fc = with_contexts(f, contexts...)
284333
y = fc(x)
285334
result = DiffResult(y, (jac,))
@@ -294,9 +343,9 @@ function DI.value_and_jacobian!(
294343
end
295344

296345
function DI.value_and_jacobian(
297-
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
298-
) where {F,C,chunksize}
299-
if isnothing(chunksize)
346+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
347+
) where {F,C,chunksize,T}
348+
if isnothing(chunksize) && T === Nothing
300349
fc = with_contexts(f, contexts...)
301350
return fc(x), jacobian(fc, x)
302351
else
@@ -306,9 +355,9 @@ function DI.value_and_jacobian(
306355
end
307356

308357
function DI.jacobian!(
309-
f::F, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
310-
) where {F,C,chunksize}
311-
if isnothing(chunksize)
358+
f::F, jac, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
359+
) where {F,C,chunksize,T}
360+
if isnothing(chunksize) && T === Nothing
312361
fc = with_contexts(f, contexts...)
313362
return jacobian!(jac, fc, x)
314363
else
@@ -318,9 +367,9 @@ function DI.jacobian!(
318367
end
319368

320369
function DI.jacobian(
321-
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
322-
) where {F,C,chunksize}
323-
if isnothing(chunksize)
370+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
371+
) where {F,C,chunksize,T}
372+
if isnothing(chunksize) && T === Nothing
324373
fc = with_contexts(f, contexts...)
325374
return jacobian(fc, x)
326375
else
@@ -339,7 +388,10 @@ function DI.prepare_jacobian(
339388
f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C}
340389
) where {F,C}
341390
fc = with_contexts(f, contexts...)
342-
return ForwardDiffOneArgJacobianPrep(JacobianConfig(fc, x, choose_chunk(backend, x)))
391+
chunk = choose_chunk(backend, x)
392+
tag = get_tag(fc, backend, x)
393+
config = JacobianConfig(fc, x, chunk, tag)
394+
return ForwardDiffOneArgJacobianPrep(config)
343395
end
344396

345397
function DI.value_and_jacobian!(
@@ -491,62 +543,80 @@ end
491543

492544
## Hessian
493545

494-
### Unprepared
546+
### Unprepared, only when chunk size and tag are not specified
495547

496548
function DI.hessian!(
497-
f::F, hess, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
498-
) where {F,C}
499-
fc = with_contexts(f, contexts...)
500-
return hessian!(hess, fc, x)
549+
f::F, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
550+
) where {F,C,chunksize,T}
551+
if isnothing(chunksize) && T === Nothing
552+
fc = with_contexts(f, contexts...)
553+
return hessian!(hess, fc, x)
554+
else
555+
prep = DI.prepare_hessian(f, backend, x, contexts...)
556+
return DI.hessian!(f, hess, prep, backend, x, contexts...)
557+
end
501558
end
502559

503-
function DI.hessian(f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}) where {F,C}
504-
fc = with_contexts(f, contexts...)
505-
return hessian(fc, x)
560+
function DI.hessian(
561+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
562+
) where {F,C,chunksize,T}
563+
if isnothing(chunksize) && T === Nothing
564+
fc = with_contexts(f, contexts...)
565+
return hessian(fc, x)
566+
else
567+
prep = DI.prepare_hessian(f, backend, x, contexts...)
568+
return DI.hessian(f, prep, backend, x, contexts...)
569+
end
506570
end
507571

508572
function DI.value_gradient_and_hessian!(
509-
f::F, grad, hess, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
510-
) where {F,C}
511-
fc = with_contexts(f, contexts...)
512-
result = DiffResult(one(eltype(x)), (grad, hess))
513-
result = hessian!(result, fc, x)
514-
y = DR.value(result)
515-
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
516-
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
517-
return (y, grad, hess)
573+
f::F, grad, hess, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
574+
) where {F,C,chunksize,T}
575+
if isnothing(chunksize) && T === Nothing
576+
fc = with_contexts(f, contexts...)
577+
result = DiffResult(one(eltype(x)), (grad, hess))
578+
result = hessian!(result, fc, x)
579+
y = DR.value(result)
580+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
581+
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
582+
return (y, grad, hess)
583+
else
584+
prep = DI.prepare_hessian(f, backend, x, contexts...)
585+
return DI.value_gradient_and_hessian!(f, grad, hess, prep, backend, x, contexts...)
586+
end
518587
end
519588

520589
function DI.value_gradient_and_hessian(
521-
f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
522-
) where {F,C}
523-
fc = with_contexts(f, contexts...)
524-
result = HessianResult(x)
525-
result = hessian!(result, fc, x)
526-
return (DR.value(result), DR.gradient(result), DR.hessian(result))
590+
f::F, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{Context,C}
591+
) where {F,C,chunksize,T}
592+
if isnothing(chunksize) && T === Nothing
593+
fc = with_contexts(f, contexts...)
594+
result = HessianResult(x)
595+
result = hessian!(result, fc, x)
596+
return (DR.value(result), DR.gradient(result), DR.hessian(result))
597+
else
598+
prep = DI.prepare_hessian(f, backend, x, contexts...)
599+
return DI.value_gradient_and_hessian(f, prep, backend, x, contexts...)
600+
end
527601
end
528602

529603
### Prepared
530604

531-
struct ForwardDiffHessianPrep{C1,C2,C3} <: HessianPrep
605+
struct ForwardDiffHessianPrep{C1,C2} <: HessianPrep
532606
array_config::C1
533-
manual_result_config::C2
534-
auto_result_config::C3
607+
result_config::C2
535608
end
536609

537610
function DI.prepare_hessian(
538611
f::F, backend::AutoForwardDiff, x, contexts::Vararg{Context,C}
539612
) where {F,C}
540613
fc = with_contexts(f, contexts...)
541-
manual_result = MutableDiffResult(
542-
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
543-
)
544-
auto_result = HessianResult(x)
545614
chunk = choose_chunk(backend, x)
546-
array_config = HessianConfig(fc, x, chunk)
547-
manual_result_config = HessianConfig(fc, manual_result, x, chunk)
548-
auto_result_config = HessianConfig(fc, auto_result, x, chunk)
549-
return ForwardDiffHessianPrep(array_config, manual_result_config, auto_result_config)
615+
tag = get_tag(fc, backend, x)
616+
result = HessianResult(x)
617+
array_config = HessianConfig(fc, x, chunk, tag)
618+
result_config = HessianConfig(fc, result, x, chunk, tag)
619+
return ForwardDiffHessianPrep(array_config, result_config)
550620
end
551621

552622
function DI.hessian!(
@@ -579,7 +649,7 @@ function DI.value_gradient_and_hessian!(
579649
) where {F,C}
580650
fc = with_contexts(f, contexts...)
581651
result = DiffResult(one(eltype(x)), (grad, hess))
582-
result = hessian!(result, fc, x, prep.manual_result_config)
652+
result = hessian!(result, fc, x, prep.result_config)
583653
y = DR.value(result)
584654
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
585655
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
@@ -591,6 +661,6 @@ function DI.value_gradient_and_hessian(
591661
) where {F,C}
592662
fc = with_contexts(f, contexts...)
593663
result = HessianResult(x)
594-
result = hessian!(result, fc, x, prep.auto_result_config)
664+
result = hessian!(result, fc, x, prep.result_config)
595665
return (DR.value(result), DR.gradient(result), DR.hessian(result))
596666
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@ end
77
88
Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like.
99
"""
10-
function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize}
11-
return AutoForwardDiff(;
12-
chunksize=chunksize,
13-
tag=ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x)),
14-
)
15-
end
10+
tag_backend_hvp(f, backend::AutoForwardDiff, x) = backend
1611

17-
function tag_backend_hvp(f, backend::AutoForwardDiff, x)
18-
return backend
12+
function tag_backend_hvp(f::F, ::AutoForwardDiff{chunksize,Nothing}, x) where {F,chunksize}
13+
tag = ForwardDiff.Tag(ForwardDiffOverSomethingHVPWrapper(f), eltype(x))
14+
return AutoForwardDiff{chunksize,typeof(tag)}(tag)
1915
end
2016

2117
struct ForwardDiffOverSomethingHVPPrep{B<:AutoForwardDiff,G,E<:PushforwardPrep} <: HVPPrep

0 commit comments

Comments
 (0)