Skip to content

Commit cc73842

Browse files
authored
Less allocs in ForwardDiff (#228)
1 parent 1e30916 commit cc73842

3 files changed

Lines changed: 46 additions & 18 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@ function DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx)
1010
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1111
end
1212

13+
function compute_ydual_onearg(
14+
f, x::Number, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
15+
) where {T}
16+
xdual_tmp = make_dual(T, x, dx)
17+
ydual = f(xdual_tmp)
18+
return ydual
19+
end
20+
1321
function compute_ydual_onearg(
1422
f, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
1523
) where {T}
1624
(; xdual_tmp) = extras
17-
xdual_tmp = if x isa Number
18-
make_dual(T, x, dx)
19-
else
20-
make_dual!(T, xdual_tmp, x, dx)
21-
end
25+
make_dual!(T, xdual_tmp, x, dx)
2226
ydual = f(xdual_tmp)
2327
return ydual
2428
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,27 @@ function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx)
1515
end
1616

1717
function compute_ydual_twoarg(
18-
f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
18+
::Type{T}, f!, y, x::Number, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
19+
) where {T}
20+
(; ydual_tmp) = extras
21+
xdual_tmp = make_dual(T, x, dx)
22+
f!(ydual_tmp, xdual_tmp)
23+
return ydual_tmp
24+
end
25+
26+
function compute_ydual_twoarg(
27+
::Type{T}, f!, y, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
1928
) where {T}
2029
(; xdual_tmp, ydual_tmp) = extras
21-
xdual_tmp = if x isa Number
22-
make_dual(T, x, dx)
23-
else
24-
make_dual!(T, xdual_tmp, x, dx)
25-
end
30+
make_dual!(T, xdual_tmp, x, dx)
2631
f!(ydual_tmp, xdual_tmp)
2732
return ydual_tmp
2833
end
2934

3035
function DI.value_and_pushforward(
3136
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
3237
) where {T}
33-
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
38+
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
3439
myvalue!(T, y, ydual_tmp)
3540
dy = myderivative(T, ydual_tmp)
3641
return y, dy
@@ -39,15 +44,15 @@ end
3944
function DI.pushforward(
4045
f!, y, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
4146
) where {T}
42-
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
47+
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
4348
dy = myderivative(T, ydual_tmp)
4449
return dy
4550
end
4651

4752
function DI.value_and_pushforward!(
4853
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
4954
) where {T}
50-
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
55+
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
5156
myvalue!(T, y, ydual_tmp)
5257
myderivative!(T, dy, ydual_tmp)
5358
return y, dy
@@ -56,7 +61,7 @@ end
5661
function DI.pushforward!(
5762
f!, y, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffTwoArgPushforwardExtras{T}
5863
) where {T}
59-
ydual_tmp = compute_ydual_twoarg(f!, y, x, dx, extras)
64+
ydual_tmp = compute_ydual_twoarg(T, f!, y, x, dx, extras)
6065
myderivative!(T, dy, ydual_tmp)
6166
return dy
6267
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,35 @@ tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)}
77
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
88
make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx)
99

10-
make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} = xdual .= Dual{T}.(x, dx)
10+
function make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T}
11+
for i in eachindex(xdual, x, dx)
12+
xdual[i] = Dual{T}(x[i], dx[i])
13+
end
14+
return nothing
15+
end
1116

1217
myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
1318
myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual)
1419

15-
myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} = y .= value.(T, ydual)
20+
function myvalue!(::Type{T}, y::AbstractArray, ydual) where {T}
21+
for i in eachindex(y, ydual)
22+
y[i] = value(T, ydual[i])
23+
end
24+
return nothing
25+
end
1626

1727
myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual)
1828
myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual)
1929

2030
function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T}
21-
return extract_derivative!(T, dy, ydual)
31+
extract_derivative!(T, dy, ydual)
32+
return nothing
33+
end
34+
35+
function myvalueandderivative!(::Type{T}, y, dy, ydual::AbstractArray) where {T}
36+
for i in eachindex(y, dy, ydual)
37+
y[i] = value(T, ydual[i])
38+
dy[i] = extract_derivative(T, ydual[i])
39+
end
40+
return nothing
2241
end

0 commit comments

Comments
 (0)