Skip to content

Commit 1e30916

Browse files
authored
Alloc-free pushforward with ForwardDiff (#227)
1 parent 07b67e8 commit 1e30916

4 files changed

Lines changed: 113 additions & 15 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using DifferentiationInterface:
88
HessianExtras,
99
JacobianExtras,
1010
NoDerivativeExtras,
11-
NoPushforwardExtras
11+
PushforwardExtras
1212
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
1313
using ForwardDiff:
1414
Chunk,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,62 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx) = NoPushforwardExtras()
3+
struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
4+
xdual_tmp::X
5+
end
46

5-
function DI.value_and_pushforward(f, ::AutoForwardDiff, x, dx, ::NoPushforwardExtras)
7+
function DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx)
68
T = tag_type(f, x)
7-
xdual = make_dual(T, x, dx)
8-
ydual = f(xdual)
9+
xdual_tmp = make_dual(T, x, dx)
10+
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
11+
end
12+
13+
function compute_ydual_onearg(
14+
f, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
15+
) where {T}
16+
(; xdual_tmp) = extras
17+
xdual_tmp = if x isa Number
18+
make_dual(T, x, dx)
19+
else
20+
make_dual!(T, xdual_tmp, x, dx)
21+
end
22+
ydual = f(xdual_tmp)
23+
return ydual
24+
end
25+
26+
function DI.value_and_pushforward(
27+
f, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
28+
) where {T}
29+
ydual = compute_ydual_onearg(f, x, dx, extras)
930
y = myvalue(T, ydual)
1031
new_dy = myderivative(T, ydual)
1132
return y, new_dy
1233
end
1334

35+
function DI.pushforward(
36+
f, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
37+
) where {T}
38+
ydual = compute_ydual_onearg(f, x, dx, extras)
39+
new_dy = myderivative(T, ydual)
40+
return new_dy
41+
end
42+
43+
function DI.value_and_pushforward!(
44+
f, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
45+
) where {T}
46+
ydual = compute_ydual_onearg(f, x, dx, extras)
47+
y = myvalue(T, ydual)
48+
myderivative!(T, dy, ydual)
49+
return y, dy
50+
end
51+
52+
function DI.pushforward!(
53+
f, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
54+
) where {T}
55+
ydual = compute_ydual_onearg(f, x, dx, extras)
56+
myderivative!(T, dy, ydual)
57+
return dy
58+
end
59+
1460
## Gradient
1561

1662
struct ForwardDiffGradientExtras{C} <: GradientExtras

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,66 @@
1-
DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx) = NoPushforwardExtras()
1+
## Pushforward
22

3-
function DI.value_and_pushforward(f!, y, ::AutoForwardDiff, x, dx, ::NoPushforwardExtras)
3+
struct ForwardDiffTwoArgPushforwardExtras{T,X,Y} <: PushforwardExtras
4+
xdual_tmp::X
5+
ydual_tmp::Y
6+
end
7+
8+
function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx)
49
T = tag_type(f!, x)
5-
xdual = make_dual(T, x, dx)
6-
ydual = make_dual(T, y, similar(y))
7-
f!(ydual, xdual)
8-
y = myvalue!(T, y, ydual)
9-
dy = myderivative(T, ydual)
10+
xdual_tmp = make_dual(T, x, dx)
11+
ydual_tmp = make_dual(T, y, similar(y))
12+
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
13+
xdual_tmp, ydual_tmp
14+
)
15+
end
16+
17+
function compute_ydual_twoarg(
18+
f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
19+
) where {T}
20+
(; xdual_tmp, ydual_tmp) = extras
21+
xdual_tmp = if x isa Number
22+
make_dual(T, x, dx)
23+
else
24+
make_dual!(T, xdual_tmp, x, dx)
25+
end
26+
f!(ydual_tmp, xdual_tmp)
27+
return ydual_tmp
28+
end
29+
30+
function DI.value_and_pushforward(
31+
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
32+
) where {T}
33+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
34+
myvalue!(T, y, ydual_tmp)
35+
dy = myderivative(T, ydual_tmp)
1036
return y, dy
1137
end
1238

39+
function DI.pushforward(
40+
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
41+
) where {T}
42+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
43+
dy = myderivative(T, ydual_tmp)
44+
return dy
45+
end
46+
47+
function DI.value_and_pushforward!(
48+
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
49+
) where {T}
50+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
51+
myvalue!(T, y, ydual_tmp)
52+
myderivative!(T, dy, ydual_tmp)
53+
return y, dy
54+
end
55+
56+
function DI.pushforward!(
57+
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
58+
) where {T}
59+
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
60+
myderivative!(T, dy, ydual_tmp)
61+
return dy
62+
end
63+
1364
## Derivative
1465

1566
struct ForwardDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
21
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
32
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()
43

54
tag_type(::F, x::Number) where {F} = Tag{F,typeof(x)}
65
tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)}
76

8-
make_dual(::Type{T}, x::Number, dx::Number) where {T} = Dual{T}(x, dx)
7+
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
98
make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx)
109

10+
make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} = xdual .= Dual{T}.(x, dx)
11+
1112
myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
1213
myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual)
1314

@@ -16,6 +17,6 @@ myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} = y .= value.(T, ydual)
1617
myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual)
1718
myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual)
1819

19-
function myderivative!(::Type{T}, dy::AbstractArray, ydual::AbstractArray) where {T}
20+
function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T}
2021
return extract_derivative!(T, dy, ydual)
2122
end

0 commit comments

Comments
 (0)