Skip to content

Commit e1e9644

Browse files
authored
Add type annotations on f in DI (#231)
* Add type annotations on f in DI * Typo
1 parent e490e32 commit e1e9644

11 files changed

Lines changed: 240 additions & 208 deletions

File tree

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ struct PushforwardDerivativeExtras{E<:PushforwardExtras} <: DerivativeExtras
4949
pushforward_extras::E
5050
end
5151

52-
function prepare_derivative(f, backend::AbstractADType, x)
52+
function prepare_derivative(f::F, backend::AbstractADType, x) where {F}
5353
dx = one(x)
5454
return PushforwardDerivativeExtras(prepare_pushforward(f, backend, x, dx))
5555
end
5656

57-
function prepare_derivative(f!, y, backend::AbstractADType, x)
57+
function prepare_derivative(f!::F, y, backend::AbstractADType, x) where {F}
5858
dx = one(x)
5959
pushforward_extras = prepare_pushforward(f!, y, backend, x, dx)
6060
return PushforwardDerivativeExtras(pushforward_extras)
@@ -63,83 +63,83 @@ end
6363
## One argument
6464

6565
function value_and_derivative(
66-
f,
66+
f::F,
6767
backend::AbstractADType,
6868
x,
6969
extras::DerivativeExtras=prepare_derivative(f, backend, x),
70-
)
70+
) where {F}
7171
return value_and_pushforward(f, backend, x, one(x), extras.pushforward_extras)
7272
end
7373

7474
function value_and_derivative!(
75-
f,
75+
f::F,
7676
der,
7777
backend::AbstractADType,
7878
x,
7979
extras::DerivativeExtras=prepare_derivative(f, backend, x),
80-
)
80+
) where {F}
8181
return value_and_pushforward!(f, der, backend, x, one(x), extras.pushforward_extras)
8282
end
8383

8484
function derivative(
85-
f,
85+
f::F,
8686
backend::AbstractADType,
8787
x,
8888
extras::DerivativeExtras=prepare_derivative(f, backend, x),
89-
)
89+
) where {F}
9090
return pushforward(f, backend, x, one(x), extras.pushforward_extras)
9191
end
9292

9393
function derivative!(
94-
f,
94+
f::F,
9595
der,
9696
backend::AbstractADType,
9797
x,
9898
extras::DerivativeExtras=prepare_derivative(f, backend, x),
99-
)
99+
) where {F}
100100
return pushforward!(f, der, backend, x, one(x), extras.pushforward_extras)
101101
end
102102

103103
## Two arguments
104104

105105
function value_and_derivative(
106-
f!,
106+
f!::F,
107107
y,
108108
backend::AbstractADType,
109109
x,
110110
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
111-
)
111+
) where {F}
112112
return value_and_pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
113113
end
114114

115115
function value_and_derivative!(
116-
f!,
116+
f!::F,
117117
y,
118118
der,
119119
backend::AbstractADType,
120120
x,
121121
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
122-
)
122+
) where {F}
123123
return value_and_pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
124124
end
125125

126126
function derivative(
127-
f!,
127+
f!::F,
128128
y,
129129
backend::AbstractADType,
130130
x,
131131
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
132-
)
132+
) where {F}
133133
return pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
134134
end
135135

136136
function derivative!(
137-
f!,
137+
f!::F,
138138
y,
139139
der,
140140
backend::AbstractADType,
141141
x,
142142
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
143-
)
143+
) where {F}
144144
return pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
145145
end

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct PullbackGradientExtras{E<:PullbackExtras} <: GradientExtras
4242
pullback_extras::E
4343
end
4444

45-
function prepare_gradient(f, backend::AbstractADType, x)
45+
function prepare_gradient(f::F, backend::AbstractADType, x) where {F}
4646
y = f(x)
4747
dy = one(y)
4848
pullback_extras = prepare_pullback(f, backend, x, dy)
@@ -52,33 +52,33 @@ end
5252
## One argument
5353

5454
function value_and_gradient(
55-
f, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
56-
)
55+
f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
56+
) where {F}
5757
return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
5858
end
5959

6060
function value_and_gradient!(
61-
f,
61+
f::F,
6262
grad,
6363
backend::AbstractADType,
6464
x,
6565
extras::GradientExtras=prepare_gradient(f, backend, x),
66-
)
66+
) where {F}
6767
return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
6868
end
6969

7070
function gradient(
71-
f, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
72-
)
71+
f::F, backend::AbstractADType, x, extras::GradientExtras=prepare_gradient(f, backend, x)
72+
) where {F}
7373
return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
7474
end
7575

7676
function gradient!(
77-
f,
77+
f::F,
7878
grad,
7979
backend::AbstractADType,
8080
x,
8181
extras::GradientExtras=prepare_gradient(f, backend, x),
82-
)
82+
) where {F}
8383
return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
8484
end

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,34 +53,34 @@ struct PullbackJacobianExtras{E<:PullbackExtras} <: JacobianExtras
5353
pullback_extras::E
5454
end
5555

56-
function prepare_jacobian(f, backend::AbstractADType, x)
56+
function prepare_jacobian(f::F, backend::AbstractADType, x) where {F}
5757
return prepare_jacobian_aux(f, backend, x, pushforward_performance(backend))
5858
end
5959

60-
function prepare_jacobian(f!, y, backend::AbstractADType, x)
60+
function prepare_jacobian(f!::F, y, backend::AbstractADType, x) where {F}
6161
return prepare_jacobian_aux(f!, y, backend, x, pushforward_performance(backend))
6262
end
6363

64-
function prepare_jacobian_aux(f, backend, x, ::PushforwardFast)
64+
function prepare_jacobian_aux(f::F, backend, x, ::PushforwardFast) where {F}
6565
dx = basis(backend, x, first(CartesianIndices(x)))
6666
pushforward_extras = prepare_pushforward(f, backend, x, dx)
6767
return PushforwardJacobianExtras(pushforward_extras)
6868
end
6969

70-
function prepare_jacobian_aux(f!, y, backend, x, ::PushforwardFast)
70+
function prepare_jacobian_aux(f!::F, y, backend, x, ::PushforwardFast) where {F}
7171
dx = basis(backend, x, first(CartesianIndices(x)))
7272
pushforward_extras = prepare_pushforward(f!, y, backend, x, dx)
7373
return PushforwardJacobianExtras(pushforward_extras)
7474
end
7575

76-
function prepare_jacobian_aux(f, backend, x, ::PushforwardSlow)
76+
function prepare_jacobian_aux(f::F, backend, x, ::PushforwardSlow) where {F}
7777
y = f(x)
7878
dy = basis(backend, y, first(CartesianIndices(y)))
7979
pullback_extras = prepare_pullback(f, backend, x, dy)
8080
return PullbackJacobianExtras(pullback_extras)
8181
end
8282

83-
function prepare_jacobian_aux(f!, y, backend, x, ::PushforwardSlow)
83+
function prepare_jacobian_aux(f!::F, y, backend, x, ::PushforwardSlow) where {F}
8484
dy = basis(backend, y, first(CartesianIndices(y)))
8585
pullback_extras = prepare_pullback(f!, y, backend, x, dy)
8686
return PullbackJacobianExtras(pullback_extras)
@@ -89,14 +89,14 @@ end
8989
## One argument
9090

9191
function value_and_jacobian(
92-
f, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x)
93-
)
92+
f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x)
93+
) where {F}
9494
return value_and_jacobian_onearg_aux(f, backend, x, extras)
9595
end
9696

9797
function value_and_jacobian_onearg_aux(
98-
f, backend, x::AbstractArray, extras::PushforwardJacobianExtras
99-
)
98+
f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras
99+
) where {F}
100100
y = f(x)
101101
jac = stack(CartesianIndices(x); dims=2) do j
102102
dx_j = basis(backend, x, j)
@@ -107,8 +107,8 @@ function value_and_jacobian_onearg_aux(
107107
end
108108

109109
function value_and_jacobian_onearg_aux(
110-
f, backend, x::AbstractArray, extras::PullbackJacobianExtras
111-
)
110+
f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras
111+
) where {F}
112112
y, pullbackfunc = value_and_pullback_split(f, backend, x, extras.pullback_extras)
113113
jac = stack(CartesianIndices(y); dims=1) do i
114114
dy_i = basis(backend, y, i)
@@ -119,18 +119,18 @@ function value_and_jacobian_onearg_aux(
119119
end
120120

121121
function value_and_jacobian!(
122-
f,
122+
f::F,
123123
jac,
124124
backend::AbstractADType,
125125
x,
126126
extras::JacobianExtras=prepare_jacobian(f, backend, x),
127-
)
127+
) where {F}
128128
return value_and_jacobian_onearg_aux!(f, jac, backend, x, extras)
129129
end
130130

131131
function value_and_jacobian_onearg_aux!(
132-
f, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras
133-
)
132+
f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras
133+
) where {F}
134134
y = f(x)
135135
for (k, j) in enumerate(CartesianIndices(x))
136136
dx_j = basis(backend, x, j)
@@ -141,8 +141,8 @@ function value_and_jacobian_onearg_aux!(
141141
end
142142

143143
function value_and_jacobian_onearg_aux!(
144-
f, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
145-
)
144+
f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
145+
) where {F}
146146
y, pullbackfunc! = value_and_pullback!_split(f, backend, x, extras.pullback_extras)
147147
for (k, i) in enumerate(CartesianIndices(y))
148148
dy_i = basis(backend, y, i)
@@ -153,36 +153,36 @@ function value_and_jacobian_onearg_aux!(
153153
end
154154

155155
function jacobian(
156-
f, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x)
157-
)
156+
f::F, backend::AbstractADType, x, extras::JacobianExtras=prepare_jacobian(f, backend, x)
157+
) where {F}
158158
return value_and_jacobian(f, backend, x, extras)[2]
159159
end
160160

161161
function jacobian!(
162-
f,
162+
f::F,
163163
jac,
164164
backend::AbstractADType,
165165
x,
166166
extras::JacobianExtras=prepare_jacobian(f, backend, x),
167-
)
167+
) where {F}
168168
return value_and_jacobian!(f, jac, backend, x, extras)[2]
169169
end
170170

171171
## Two arguments
172172

173173
function value_and_jacobian(
174-
f!,
174+
f!::F,
175175
y,
176176
backend::AbstractADType,
177177
x,
178178
extras::JacobianExtras=prepare_jacobian(f!, y, backend, x),
179-
)
179+
) where {F}
180180
return value_and_jacobian_twoarg_aux(f!, y, backend, x, extras)
181181
end
182182

183183
function value_and_jacobian_twoarg_aux(
184-
f!, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras
185-
)
184+
f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras
185+
) where {F}
186186
jac = stack(CartesianIndices(x); dims=2) do j
187187
dx_j = basis(backend, x, j)
188188
jac_col_j = pushforward(f!, y, backend, x, dx_j, extras.pushforward_extras)
@@ -193,8 +193,8 @@ function value_and_jacobian_twoarg_aux(
193193
end
194194

195195
function value_and_jacobian_twoarg_aux(
196-
f!, y, backend, x::AbstractArray, extras::PullbackJacobianExtras
197-
)
196+
f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras
197+
) where {F}
198198
y, pullbackfunc = value_and_pullback_split(f!, y, backend, x, extras.pullback_extras)
199199
jac = stack(CartesianIndices(y); dims=1) do i
200200
dy_i = basis(backend, y, i)
@@ -206,19 +206,24 @@ function value_and_jacobian_twoarg_aux(
206206
end
207207

208208
function value_and_jacobian!(
209-
f!,
209+
f!::F,
210210
y,
211211
jac,
212212
backend::AbstractADType,
213213
x,
214214
extras::JacobianExtras=prepare_jacobian(f!, y, backend, x),
215-
)
215+
) where {F}
216216
return value_and_jacobian_twoarg_aux!(f!, y, jac, backend, x, extras)
217217
end
218218

219219
function value_and_jacobian_twoarg_aux!(
220-
f!, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras
221-
)
220+
f!::F,
221+
y,
222+
jac::AbstractMatrix,
223+
backend,
224+
x::AbstractArray,
225+
extras::PushforwardJacobianExtras,
226+
) where {F}
222227
for (k, j) in enumerate(CartesianIndices(x))
223228
dx_j = basis(backend, x, j)
224229
jac_col_j = reshape(view(jac, :, k), size(y))
@@ -229,8 +234,8 @@ function value_and_jacobian_twoarg_aux!(
229234
end
230235

231236
function value_and_jacobian_twoarg_aux!(
232-
f!, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
233-
)
237+
f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras
238+
) where {F}
234239
y, pullbackfunc! = value_and_pullback!_split(f!, y, backend, x, extras.pullback_extras)
235240
for (k, i) in enumerate(CartesianIndices(y))
236241
dy_i = basis(backend, y, i)
@@ -242,22 +247,22 @@ function value_and_jacobian_twoarg_aux!(
242247
end
243248

244249
function jacobian(
245-
f!,
250+
f!::F,
246251
y,
247252
backend::AbstractADType,
248253
x,
249254
extras::JacobianExtras=prepare_jacobian(f!, y, backend, x),
250-
)
255+
) where {F}
251256
return value_and_jacobian(f!, y, backend, x, extras)[2]
252257
end
253258

254259
function jacobian!(
255-
f!,
260+
f!::F,
256261
y,
257262
jac,
258263
backend::AbstractADType,
259264
x,
260265
extras::JacobianExtras=prepare_jacobian(f!, y, backend, x),
261-
)
266+
) where {F}
262267
return value_and_jacobian!(f!, y, jac, backend, x, extras)[2]
263268
end

0 commit comments

Comments
 (0)