Skip to content

Commit 17cdd2a

Browse files
authored
Better benchmarking + type annotations for ForwardDiff (#232)
1 parent e1e9644 commit 17cdd2a

11 files changed

Lines changed: 460 additions & 402 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,58 +4,58 @@ struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
44
xdual_tmp::X
55
end
66

7-
function DI.prepare_pushforward(f, backend::AutoForwardDiff, x, dx)
7+
function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F}
88
T = tag_type(f, backend, x)
99
xdual_tmp = make_dual(T, x, dx)
1010
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1111
end
1212

1313
function compute_ydual_onearg(
14-
f, x::Number, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
15-
) where {T}
14+
f::F, x::Number, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
15+
) where {F,T}
1616
xdual_tmp = make_dual(T, x, dx)
1717
ydual = f(xdual_tmp)
1818
return ydual
1919
end
2020

2121
function compute_ydual_onearg(
22-
f, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
23-
) where {T}
22+
f::F, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
23+
) where {F,T}
2424
(; xdual_tmp) = extras
2525
make_dual!(T, xdual_tmp, x, dx)
2626
ydual = f(xdual_tmp)
2727
return ydual
2828
end
2929

3030
function DI.value_and_pushforward(
31-
f, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
32-
) where {T}
31+
f::F, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
32+
) where {F,T}
3333
ydual = compute_ydual_onearg(f, x, dx, extras)
3434
y = myvalue(T, ydual)
3535
new_dy = myderivative(T, ydual)
3636
return y, new_dy
3737
end
3838

3939
function DI.pushforward(
40-
f, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
41-
) where {T}
40+
f::F, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
41+
) where {F,T}
4242
ydual = compute_ydual_onearg(f, x, dx, extras)
4343
new_dy = myderivative(T, ydual)
4444
return new_dy
4545
end
4646

4747
function DI.value_and_pushforward!(
48-
f, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
49-
) where {T}
48+
f::F, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
49+
) where {F,T}
5050
ydual = compute_ydual_onearg(f, x, dx, extras)
5151
y = myvalue(T, ydual)
5252
myderivative!(T, dy, ydual)
5353
return y, dy
5454
end
5555

5656
function DI.pushforward!(
57-
f, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
58-
) where {T}
57+
f::F, dy, ::AutoForwardDiff, x, dx, extras::ForwardDiffOneArgPushforwardExtras{T}
58+
) where {F,T}
5959
ydual = compute_ydual_onearg(f, x, dx, extras)
6060
myderivative!(T, dy, ydual)
6161
return dy
@@ -67,42 +67,34 @@ struct ForwardDiffGradientExtras{C} <: GradientExtras
6767
config::C
6868
end
6969

70-
function DI.prepare_gradient(f, backend::AutoForwardDiff, x::AbstractArray)
70+
function DI.prepare_gradient(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
7171
return ForwardDiffGradientExtras(GradientConfig(f, x, choose_chunk(backend, x)))
7272
end
7373

7474
function DI.value_and_gradient!(
75-
f,
76-
grad::AbstractArray,
77-
::AutoForwardDiff,
78-
x::AbstractArray,
79-
extras::ForwardDiffGradientExtras,
80-
)
75+
f::F, grad, ::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
76+
) where {F}
8177
result = MutableDiffResult(zero(eltype(x)), (grad,))
8278
result = gradient!(result, f, x, extras.config)
8379
return DiffResults.value(result), DiffResults.gradient(result)
8480
end
8581

8682
function DI.value_and_gradient(
87-
f, backend::AutoForwardDiff, x::AbstractArray, extras::ForwardDiffGradientExtras
88-
)
83+
f::F, backend::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
84+
) where {F}
8985
grad = similar(x)
9086
return DI.value_and_gradient!(f, grad, backend, x, extras)
9187
end
9288

9389
function DI.gradient!(
94-
f,
95-
grad::AbstractArray,
96-
::AutoForwardDiff,
97-
x::AbstractArray,
98-
extras::ForwardDiffGradientExtras,
99-
)
90+
f::F, grad, ::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
91+
) where {F}
10092
return gradient!(grad, f, x, extras.config)
10193
end
10294

10395
function DI.gradient(
104-
f, ::AutoForwardDiff, x::AbstractArray, extras::ForwardDiffGradientExtras
105-
)
96+
f::F, ::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
97+
) where {F}
10698
return gradient(f, x, extras.config)
10799
end
108100

@@ -112,42 +104,34 @@ struct ForwardDiffOneArgJacobianExtras{C} <: JacobianExtras
112104
config::C
113105
end
114106

115-
function DI.prepare_jacobian(f, backend::AutoForwardDiff, x::AbstractArray)
107+
function DI.prepare_jacobian(f, backend::AutoForwardDiff, x)
116108
return ForwardDiffOneArgJacobianExtras(JacobianConfig(f, x, choose_chunk(backend, x)))
117109
end
118110

119111
function DI.value_and_jacobian!(
120-
f,
121-
jac::AbstractMatrix,
122-
::AutoForwardDiff,
123-
x::AbstractArray,
124-
extras::ForwardDiffOneArgJacobianExtras,
125-
)
112+
f::F, jac, ::AutoForwardDiff, x, extras::ForwardDiffOneArgJacobianExtras
113+
) where {F}
126114
y = f(x)
127115
result = MutableDiffResult(y, (jac,))
128116
result = jacobian!(result, f, x, extras.config)
129117
return DiffResults.value(result), DiffResults.jacobian(result)
130118
end
131119

132120
function DI.value_and_jacobian(
133-
f, ::AutoForwardDiff, x::AbstractArray, extras::ForwardDiffOneArgJacobianExtras
134-
)
121+
f::F, ::AutoForwardDiff, x, extras::ForwardDiffOneArgJacobianExtras
122+
) where {F}
135123
return f(x), jacobian(f, x, extras.config)
136124
end
137125

138126
function DI.jacobian!(
139-
f,
140-
jac::AbstractMatrix,
141-
::AutoForwardDiff,
142-
x::AbstractArray,
143-
extras::ForwardDiffOneArgJacobianExtras,
144-
)
127+
f::F, jac, ::AutoForwardDiff, x, extras::ForwardDiffOneArgJacobianExtras
128+
) where {F}
145129
return jacobian!(jac, f, x, extras.config)
146130
end
147131

148132
function DI.jacobian(
149-
f, ::AutoForwardDiff, x::AbstractArray, extras::ForwardDiffOneArgJacobianExtras
150-
)
133+
f::F, ::AutoForwardDiff, x, extras::ForwardDiffOneArgJacobianExtras
134+
) where {F}
151135
return jacobian(f, x, extras.config)
152136
end
153137

@@ -157,22 +141,16 @@ struct ForwardDiffHessianExtras{C} <: HessianExtras
157141
config::C
158142
end
159143

160-
function DI.prepare_hessian(f, backend::AutoForwardDiff, x::AbstractArray)
144+
function DI.prepare_hessian(f, backend::AutoForwardDiff, x)
161145
return ForwardDiffHessianExtras(HessianConfig(f, x, choose_chunk(backend, x)))
162146
end
163147

164148
function DI.hessian!(
165-
f,
166-
hess::AbstractMatrix,
167-
::AutoForwardDiff,
168-
x::AbstractArray,
169-
extras::ForwardDiffHessianExtras,
170-
)
149+
f::F, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras
150+
) where {F}
171151
return hessian!(hess, f, x, extras.config)
172152
end
173153

174-
function DI.hessian(
175-
f, ::AutoForwardDiff, x::AbstractArray, extras::ForwardDiffHessianExtras
176-
)
154+
function DI.hessian(f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras) where {F}
177155
return hessian(f, x, extras.config)
178156
end

0 commit comments

Comments
 (0)