Skip to content

Commit 86f1e02

Browse files
authored
Batched ForwardDiff pushforward (#328)
* Batched ForwardDiff pushforward * Ambiguity * Typo * Add batched mode to AutoFromPrimitive * Remove type test * Fix type instability
1 parent 4390675 commit 86f1e02

7 files changed

Lines changed: 304 additions & 136 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module DifferentiationInterfaceForwardDiffExt
22

33
using ADTypes: AbstractADType, AutoForwardDiff
4-
using Base: Fix1
4+
using Base: Fix1, Fix2
55
using Compat
66
import DifferentiationInterface as DI
77
using DifferentiationInterface:
8+
Batch,
89
DerivativeExtras,
910
GradientExtras,
1011
HessianExtras,
@@ -33,6 +34,8 @@ using ForwardDiff:
3334
hessian!,
3435
jacobian,
3536
jacobian!,
37+
npartials,
38+
partials,
3639
value
3740
using LinearAlgebra: dot, mul!
3841

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@ end
66

77
function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F}
88
T = tag_type(f, backend, x)
9-
xdual_tmp = make_dual_similar(T, x)
9+
xdual_tmp = make_dual_similar(T, x, dx)
10+
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
11+
end
12+
13+
function DI.prepare_pushforward_batched(
14+
f::F, backend::AutoForwardDiff, x, dx::Batch
15+
) where {F}
16+
T = tag_type(f, backend, x)
17+
xdual_tmp = make_dual_similar(T, x, dx)
1018
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1119
end
1220

@@ -61,56 +69,25 @@ function DI.pushforward!(
6169
return dy
6270
end
6371

64-
## Second derivative
65-
66-
function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F}
67-
return NoSecondDerivativeExtras()
68-
end
69-
70-
function DI.second_derivative(
71-
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
72-
) where {F}
73-
T = tag_type(f, backend, x)
74-
xdual = make_dual(T, x, one(x))
75-
T2 = tag_type(f, backend, xdual)
76-
ydual = f(make_dual(T2, xdual, one(xdual)))
77-
return myderivative(T, myderivative(T2, ydual))
78-
end
79-
80-
function DI.second_derivative!(
81-
f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
82-
) where {F}
83-
T = tag_type(f, backend, x)
84-
xdual = make_dual(T, x, one(x))
85-
T2 = tag_type(f, backend, xdual)
86-
ydual = f(make_dual(T2, xdual, one(xdual)))
87-
return myderivative!(T, der2, myderivative(T2, ydual))
88-
end
89-
90-
function DI.value_derivative_and_second_derivative(
91-
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
92-
) where {F}
93-
T = tag_type(f, backend, x)
94-
xdual = make_dual(T, x, one(x))
95-
T2 = tag_type(f, backend, xdual)
96-
ydual = f(make_dual(T2, xdual, one(xdual)))
97-
y = myvalue(T, myvalue(T2, ydual))
98-
der = myderivative(T, myvalue(T2, ydual))
99-
der2 = myderivative(T, myderivative(T2, ydual))
100-
return y, der, der2
72+
function DI.pushforward_batched(
73+
f::F, ::AutoForwardDiff, x, dx::Batch{B}, extras::ForwardDiffOneArgPushforwardExtras{T}
74+
) where {F,T,B}
75+
ydual = compute_ydual_onearg(f, x, dx, extras)
76+
new_dy = mypartials(T, Val(B), ydual)
77+
return new_dy
10178
end
10279

103-
function DI.value_derivative_and_second_derivative!(
104-
f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
105-
) where {F}
106-
T = tag_type(f, backend, x)
107-
xdual = make_dual(T, x, one(x))
108-
T2 = tag_type(f, backend, xdual)
109-
ydual = f(make_dual(T2, xdual, one(xdual)))
110-
y = myvalue(T, myvalue(T2, ydual))
111-
myderivative!(T, der, myvalue(T2, ydual))
112-
myderivative!(T, der2, myderivative(T2, ydual))
113-
return y, der, der2
80+
function DI.pushforward_batched!(
81+
f::F,
82+
dy::Batch{B},
83+
::AutoForwardDiff,
84+
x,
85+
dx::Batch{B},
86+
extras::ForwardDiffOneArgPushforwardExtras{T},
87+
) where {F,T,B}
88+
ydual = compute_ydual_onearg(f, x, dx, extras)
89+
mypartials!(T, dy, ydual)
90+
return dy
11491
end
11592

11693
## Gradient
@@ -188,6 +165,58 @@ function DI.jacobian(
188165
return jacobian(f, x, extras.config)
189166
end
190167

168+
## Second derivative
169+
170+
function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F}
171+
return NoSecondDerivativeExtras()
172+
end
173+
174+
function DI.second_derivative(
175+
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
176+
) where {F}
177+
T = tag_type(f, backend, x)
178+
xdual = make_dual(T, x, one(x))
179+
T2 = tag_type(f, backend, xdual)
180+
ydual = f(make_dual(T2, xdual, one(xdual)))
181+
return myderivative(T, myderivative(T2, ydual))
182+
end
183+
184+
function DI.second_derivative!(
185+
f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
186+
) where {F}
187+
T = tag_type(f, backend, x)
188+
xdual = make_dual(T, x, one(x))
189+
T2 = tag_type(f, backend, xdual)
190+
ydual = f(make_dual(T2, xdual, one(xdual)))
191+
return myderivative!(T, der2, myderivative(T2, ydual))
192+
end
193+
194+
function DI.value_derivative_and_second_derivative(
195+
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
196+
) where {F}
197+
T = tag_type(f, backend, x)
198+
xdual = make_dual(T, x, one(x))
199+
T2 = tag_type(f, backend, xdual)
200+
ydual = f(make_dual(T2, xdual, one(xdual)))
201+
y = myvalue(T, myvalue(T2, ydual))
202+
der = myderivative(T, myvalue(T2, ydual))
203+
der2 = myderivative(T, myderivative(T2, ydual))
204+
return y, der, der2
205+
end
206+
207+
function DI.value_derivative_and_second_derivative!(
208+
f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
209+
) where {F}
210+
T = tag_type(f, backend, x)
211+
xdual = make_dual(T, x, one(x))
212+
T2 = tag_type(f, backend, xdual)
213+
ydual = f(make_dual(T2, xdual, one(xdual)))
214+
y = myvalue(T, myvalue(T2, ydual))
215+
myderivative!(T, der, myvalue(T2, ydual))
216+
myderivative!(T, der2, myderivative(T2, ydual))
217+
return y, der, der2
218+
end
219+
191220
## Hessian
192221

193222
struct ForwardDiffHessianExtras{C1,C2,C3} <: HessianExtras

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,19 @@ end
77

88
function DI.prepare_pushforward(f!::F, y, backend::AutoForwardDiff, x, dx) where {F}
99
T = tag_type(f!, backend, x)
10-
xdual_tmp = make_dual_similar(T, x)
11-
ydual_tmp = make_dual_similar(T, y)
10+
xdual_tmp = make_dual_similar(T, x, dx)
11+
ydual_tmp = make_dual_similar(T, y, dx) # dx only for batch size
12+
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
13+
xdual_tmp, ydual_tmp
14+
)
15+
end
16+
17+
function DI.prepare_pushforward_batched(
18+
f!::F, y, backend::AutoForwardDiff, x, dx::Batch
19+
) where {F}
20+
T = tag_type(f!, backend, x)
21+
xdual_tmp = make_dual_similar(T, x, dx)
22+
ydual_tmp = make_dual_similar(T, y, dx) # dx only for batch size
1223
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
1324
xdual_tmp, ydual_tmp
1425
)
@@ -66,6 +77,33 @@ function DI.pushforward!(
6677
return dy
6778
end
6879

80+
function DI.pushforward_batched(
81+
f!::F,
82+
y,
83+
::AutoForwardDiff,
84+
x,
85+
dx::Batch{B},
86+
extras::ForwardDiffTwoArgPushforwardExtras{T},
87+
) where {F,T,B}
88+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
89+
dy = mypartials(T, Val(B), ydual_tmp)
90+
return dy
91+
end
92+
93+
function DI.pushforward_batched!(
94+
f!::F,
95+
y,
96+
dy::Batch{B},
97+
::AutoForwardDiff,
98+
x,
99+
dx::Batch{B},
100+
extras::ForwardDiffTwoArgPushforwardExtras{T},
101+
) where {F,T,B}
102+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
103+
mypartials!(T, dy, ydual_tmp)
104+
return dy
105+
end
106+
69107
## Derivative
70108

71109
struct ForwardDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,47 @@
11
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
2-
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()
2+
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{min(length(x), C)}()
33

44
tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T
55
tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)}
66

7-
make_dual_similar(::Type{T}, x::Number) where {T} = Dual{T}(x, x)
8-
make_dual_similar(::Type{T}, x) where {T} = similar(x, Dual{T,eltype(x),1})
7+
make_dual_similar(::Type{T}, x::Number, dx::Number) where {T} = Dual{T}(x, dx)
8+
make_dual_similar(::Type{T}, x, dx) where {T} = similar(x, Dual{T,eltype(x),1})
99

10-
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
10+
function make_dual_similar(::Type{T}, x, dx::Batch{B}) where {T,B}
11+
return similar(x, Dual{T,eltype(x),B})
12+
end
13+
14+
make_dual(::Type{T}, x::Number, dx::Number) where {T} = Dual{T}(x, dx)
1115
make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx)
1216

13-
myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
14-
myvalue(::Type{T}, ydual) where {T} = map(Fix1(value, T), ydual)
17+
function make_dual!(::Type{T}, xdual, x, dx::Batch{B}) where {T,B}
18+
return map!(Dual{T}, xdual, x, dx.elements...)
19+
end
1520

16-
myvalue!(::Type{T}, y, ydual) where {T} = map!(Fix1(value, T), y, ydual)
21+
myvalue(::Type{T}, ydual::Dual{T}) where {T} = value(T, ydual)
22+
myvalue(::Type{T}, ydual) where {T} = map(Fix1(myvalue, T), ydual)
23+
24+
function myvalue!(::Type{T}, y, ydual) where {T}
25+
return map!(Fix1(myvalue, T), y, ydual)
26+
end
1727

18-
myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual)
19-
myderivative(::Type{T}, ydual) where {T} = map(Fix1(extract_derivative, T), ydual)
28+
myderivative(::Type{T}, ydual::Dual{T}) where {T} = extract_derivative(T, ydual)
29+
myderivative(::Type{T}, ydual) where {T} = map(Fix1(myderivative, T), ydual)
2030

2131
function myderivative!(::Type{T}, dy, ydual) where {T}
22-
return map!(Fix1(extract_derivative, T), dy, ydual)
32+
return map!(Fix1(myderivative, T), dy, ydual)
33+
end
34+
35+
function mypartials(::Type{T}, ::Val{B}, ydual) where {T,B}
36+
elements = ntuple(Val(B)) do b
37+
partials.(T, ydual, b)
38+
end
39+
return Batch(elements)
40+
end
41+
42+
function mypartials!(::Type{T}, dy::Batch{B}, ydual) where {T,B}
43+
for b in eachindex(dy.elements)
44+
dy.elements[b] .= partials.(T, ydual, b)
45+
end
46+
return dy
2347
end

0 commit comments

Comments
 (0)