Skip to content

Commit 267023a

Browse files
authored
Multi-argument support: basic infrastructure (#461)
1 parent 41bade9 commit 267023a

31 files changed

Lines changed: 1825 additions & 731 deletions

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Base: Fix1, Fix2
55
using Compat
66
import DifferentiationInterface as DI
77
using DifferentiationInterface:
8+
Context,
89
DerivativeExtras,
910
GradientExtras,
1011
HessianExtras,
@@ -16,7 +17,8 @@ using DifferentiationInterface:
1617
SecondOrder,
1718
Tangents,
1819
inner,
19-
outer
20+
outer,
21+
unwrap
2022
using ForwardDiff.DiffResults:
2123
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
2224
using ForwardDiff:

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,36 @@ struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
44
xdual_tmp::X
55
end
66

7-
function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, tx::Tangents) where {F}
7+
function DI.prepare_pushforward(
8+
f::F, backend::AutoForwardDiff, x, tx::Tangents, contexts::Vararg{Context,C}
9+
) where {F,C}
810
T = tag_type(f, backend, x)
911
xdual_tmp = make_dual_similar(T, x, tx)
1012
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1113
end
1214

1315
function compute_ydual_onearg(
14-
f::F, extras::ForwardDiffOneArgPushforwardExtras{T}, x::Number, tx::Tangents
15-
) where {F,T}
16+
f::F,
17+
extras::ForwardDiffOneArgPushforwardExtras{T},
18+
x::Number,
19+
tx::Tangents,
20+
contexts::Vararg{Context,C},
21+
) where {F,T,C}
1622
xdual_tmp = make_dual(T, x, tx)
17-
ydual = f(xdual_tmp)
23+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
1824
return ydual
1925
end
2026

2127
function compute_ydual_onearg(
22-
f::F, extras::ForwardDiffOneArgPushforwardExtras{T}, x, tx::Tangents
23-
) where {F,T}
28+
f::F,
29+
extras::ForwardDiffOneArgPushforwardExtras{T},
30+
x,
31+
tx::Tangents,
32+
contexts::Vararg{Context,C},
33+
) where {F,T,C}
2434
@compat (; xdual_tmp) = extras
2535
make_dual!(T, xdual_tmp, x, tx)
26-
ydual = f(xdual_tmp)
36+
ydual = f(xdual_tmp, map(unwrap, contexts)...)
2737
return ydual
2838
end
2939

@@ -33,8 +43,9 @@ function DI.value_and_pushforward(
3343
::AutoForwardDiff,
3444
x,
3545
tx::Tangents{B},
36-
) where {F,T,B}
37-
ydual = compute_ydual_onearg(f, extras, x, tx)
46+
contexts::Vararg{Context,C},
47+
) where {F,T,B,C}
48+
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
3849
y = myvalue(T, ydual)
3950
ty = mypartials(T, Val(B), ydual)
4051
return y, ty
@@ -47,8 +58,9 @@ function DI.value_and_pushforward!(
4758
::AutoForwardDiff,
4859
x,
4960
tx::Tangents,
50-
) where {F,T}
51-
ydual = compute_ydual_onearg(f, extras, x, tx)
61+
contexts::Vararg{Context,C},
62+
) where {F,T,C}
63+
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
5264
y = myvalue(T, ydual)
5365
mypartials!(T, ty, ydual)
5466
return y, ty
@@ -60,8 +72,9 @@ function DI.pushforward(
6072
::AutoForwardDiff,
6173
x,
6274
tx::Tangents{B},
63-
) where {F,T,B}
64-
ydual = compute_ydual_onearg(f, extras, x, tx)
75+
contexts::Vararg{Context,C},
76+
) where {F,T,B,C}
77+
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
6578
ty = mypartials(T, Val(B), ydual)
6679
return ty
6780
end
@@ -73,8 +86,9 @@ function DI.pushforward!(
7386
::AutoForwardDiff,
7487
x,
7588
tx::Tangents,
76-
) where {F,T}
77-
ydual = compute_ydual_onearg(f, extras, x, tx)
89+
contexts::Vararg{Context,C},
90+
) where {F,T,C}
91+
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
7892
mypartials!(T, ty, ydual)
7993
return ty
8094
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,60 @@ struct ForwardDiffOverSomethingHVPExtras{B<:AutoForwardDiff,G,E<:PushforwardExtr
2626
end
2727

2828
function DI.prepare_hvp(
29-
f::F, backend::SecondOrder{<:AutoForwardDiff}, x, tx::Tangents
30-
) where {F}
29+
f::F,
30+
backend::SecondOrder{<:AutoForwardDiff},
31+
x,
32+
tx::Tangents,
33+
contexts::Vararg{Context,C},
34+
) where {F,C}
3135
tagged_outer_backend = tag_backend_hvp(f, outer(backend), x)
3236
T = tag_type(f, tagged_outer_backend, x)
3337
xdual = make_dual(T, x, tx)
34-
gradient_extras = DI.prepare_gradient(f, inner(backend), xdual)
35-
inner_gradient(x) = DI.gradient(f, gradient_extras, inner(backend), x)
38+
gradient_extras = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
39+
function inner_gradient(x, unannotated_contexts...)
40+
annotated_contexts = map.(typeof.(contexts), unannotated_contexts)
41+
return DI.gradient(f, gradient_extras, inner(backend), x, unannotated_contexts...)
42+
end
3643
outer_pushforward_extras = DI.prepare_pushforward(
37-
inner_gradient, tagged_outer_backend, x, tx
44+
inner_gradient, tagged_outer_backend, x, tx, contexts...
3845
)
3946
return ForwardDiffOverSomethingHVPExtras(
4047
tagged_outer_backend, inner_gradient, outer_pushforward_extras
4148
)
4249
end
4350

4451
function DI.hvp(
45-
f,
52+
f::F,
4653
extras::ForwardDiffOverSomethingHVPExtras,
4754
::SecondOrder{<:AutoForwardDiff},
4855
x,
4956
tx::Tangents,
50-
)
57+
contexts::Vararg{Context,C},
58+
) where {F,C}
5159
@compat (; tagged_outer_backend, inner_gradient, outer_pushforward_extras) = extras
5260
return DI.pushforward(
53-
inner_gradient, outer_pushforward_extras, tagged_outer_backend, x, tx
61+
inner_gradient, outer_pushforward_extras, tagged_outer_backend, x, tx, contexts...
5462
)
5563
end
5664

5765
function DI.hvp!(
58-
f,
66+
f::F,
5967
tg::Tangents,
6068
extras::ForwardDiffOverSomethingHVPExtras,
6169
::SecondOrder{<:AutoForwardDiff},
6270
x,
6371
tx::Tangents,
64-
)
72+
contexts::Vararg{Context,C},
73+
) where {F,C}
6574
@compat (; tagged_outer_backend, inner_gradient, outer_pushforward_extras) = extras
6675
DI.pushforward!(
67-
inner_gradient, tg, outer_pushforward_extras, tagged_outer_backend, x, tx
76+
inner_gradient,
77+
tg,
78+
outer_pushforward_extras,
79+
tagged_outer_backend,
80+
x,
81+
tx,
82+
contexts...,
6883
)
6984
return tg
7085
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ struct ForwardDiffTwoArgPushforwardExtras{T,X,Y} <: PushforwardExtras
66
end
77

88
function DI.prepare_pushforward(
9-
f!::F, y, backend::AutoForwardDiff, x, tx::Tangents
10-
) where {F}
9+
f!::F, y, backend::AutoForwardDiff, x, tx::Tangents, contexts::Vararg{Context,C}
10+
) where {F,C}
1111
T = tag_type(f!, backend, x)
1212
xdual_tmp = make_dual_similar(T, x, tx)
1313
ydual_tmp = make_dual_similar(T, y, tx) # dx only for batch size
@@ -17,20 +17,30 @@ function DI.prepare_pushforward(
1717
end
1818

1919
function compute_ydual_twoarg(
20-
f!::F, y, extras::ForwardDiffTwoArgPushforwardExtras{T}, x::Number, tx::Tangents
21-
) where {F,T}
20+
f!::F,
21+
y,
22+
extras::ForwardDiffTwoArgPushforwardExtras{T},
23+
x::Number,
24+
tx::Tangents,
25+
contexts::Vararg{Context,C},
26+
) where {F,T,C}
2227
@compat (; ydual_tmp) = extras
2328
xdual_tmp = make_dual(T, x, tx)
24-
f!(ydual_tmp, xdual_tmp)
29+
f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...)
2530
return ydual_tmp
2631
end
2732

2833
function compute_ydual_twoarg(
29-
f!::F, y, extras::ForwardDiffTwoArgPushforwardExtras{T}, x, tx::Tangents
30-
) where {F,T}
34+
f!::F,
35+
y,
36+
extras::ForwardDiffTwoArgPushforwardExtras{T},
37+
x,
38+
tx::Tangents,
39+
contexts::Vararg{Context,C},
40+
) where {F,T,C}
3141
@compat (; xdual_tmp, ydual_tmp) = extras
3242
make_dual!(T, xdual_tmp, x, tx)
33-
f!(ydual_tmp, xdual_tmp)
43+
f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...)
3444
return ydual_tmp
3545
end
3646

@@ -41,8 +51,9 @@ function DI.value_and_pushforward(
4151
::AutoForwardDiff,
4252
x,
4353
tx::Tangents{B},
44-
) where {F,T,B}
45-
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
54+
contexts::Vararg{Context,C},
55+
) where {F,T,B,C}
56+
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
4657
myvalue!(T, y, ydual_tmp)
4758
ty = mypartials(T, Val(B), ydual_tmp)
4859
return y, ty
@@ -56,8 +67,9 @@ function DI.value_and_pushforward!(
5667
::AutoForwardDiff,
5768
x,
5869
tx::Tangents,
59-
) where {F,T}
60-
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
70+
contexts::Vararg{Context,C},
71+
) where {F,T,C}
72+
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
6173
myvalue!(T, y, ydual_tmp)
6274
mypartials!(T, ty, ydual_tmp)
6375
return y, ty
@@ -70,8 +82,9 @@ function DI.pushforward(
7082
::AutoForwardDiff,
7183
x,
7284
tx::Tangents{B},
73-
) where {F,T,B}
74-
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
85+
contexts::Vararg{Context,C},
86+
) where {F,T,B,C}
87+
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
7588
ty = mypartials(T, Val(B), ydual_tmp)
7689
return ty
7790
end
@@ -84,8 +97,9 @@ function DI.pushforward!(
8497
::AutoForwardDiff,
8598
x,
8699
tx::Tangents,
87-
) where {F,T}
88-
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
100+
contexts::Vararg{Context,C},
101+
) where {F,T,C}
102+
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
89103
mypartials!(T, ty, ydual_tmp)
90104
return ty
91105
end

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ using DifferentiationInterface:
2626
maybe_outer,
2727
multibasis,
2828
pick_batchsize,
29-
pushforward_performance
29+
pushforward_performance,
30+
unwrap
3031
import DifferentiationInterface as DI
3132
using SparseMatrixColorings:
3233
AbstractColoringResult,

0 commit comments

Comments
 (0)