Skip to content

Commit 84378d7

Browse files
authored
Type-stable jacobian & gradient for unprepared ForwardDiff (#541)
* Make ForwardDiff's unprepared jacobian and gradient type-stable with fixed chunk size * Avoid ambiguity
1 parent c47e26c commit 84378d7

3 files changed

Lines changed: 130 additions & 66 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.5"
4+
version = "0.6.6"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 87 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -159,38 +159,60 @@ end
159159

160160
## Gradient
161161

162-
### Unprepared
162+
### Unprepared, only when chunk size not specified
163163

164164
function DI.value_and_gradient!(
165-
f::F, grad, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
166-
) where {F,C}
167-
fc = with_contexts(f, contexts...)
168-
result = DiffResult(zero(eltype(x)), (grad,))
169-
result = gradient!(result, fc, x)
170-
y = DR.value(result)
171-
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
172-
return y, grad
165+
f::F, grad, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
166+
) where {F,C,chunksize}
167+
if isnothing(chunksize)
168+
fc = with_contexts(f, contexts...)
169+
result = DiffResult(zero(eltype(x)), (grad,))
170+
result = gradient!(result, fc, x)
171+
y = DR.value(result)
172+
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
173+
return y, grad
174+
else
175+
prep = DI.prepare_gradient(f, backend, x, contexts...)
176+
return DI.value_and_gradient!(f, grad, prep, backend, x, contexts...)
177+
end
173178
end
174179

175180
function DI.value_and_gradient(
176-
f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
177-
) where {F,C}
178-
fc = with_contexts(f, contexts...)
179-
result = GradientResult(x)
180-
result = gradient!(result, fc, x)
181-
return DR.value(result), DR.gradient(result)
181+
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
182+
) where {F,C,chunksize}
183+
if isnothing(chunksize)
184+
fc = with_contexts(f, contexts...)
185+
result = GradientResult(x)
186+
result = gradient!(result, fc, x)
187+
return DR.value(result), DR.gradient(result)
188+
else
189+
prep = DI.prepare_gradient(f, backend, x, contexts...)
190+
return DI.value_and_gradient(f, prep, backend, x, contexts...)
191+
end
182192
end
183193

184194
function DI.gradient!(
185-
f::F, grad, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
186-
) where {F,C}
187-
fc = with_contexts(f, contexts...)
188-
return gradient!(grad, fc, x)
195+
f::F, grad, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
196+
) where {F,C,chunksize}
197+
if isnothing(chunksize)
198+
fc = with_contexts(f, contexts...)
199+
return gradient!(grad, fc, x)
200+
else
201+
prep = DI.prepare_gradient(f, backend, x, contexts...)
202+
return DI.gradient!(f, grad, prep, backend, x, contexts...)
203+
end
189204
end
190205

191-
function DI.gradient(f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}) where {F,C}
192-
fc = with_contexts(f, contexts...)
193-
return gradient(fc, x)
206+
function DI.gradient(
207+
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
208+
) where {F,C,chunksize}
209+
if isnothing(chunksize)
210+
fc = with_contexts(f, contexts...)
211+
return gradient(fc, x)
212+
else
213+
prep = DI.prepare_gradient(f, backend, x, contexts...)
214+
return DI.gradient(f, prep, backend, x, contexts...)
215+
end
194216
end
195217

196218
### Prepared
@@ -252,37 +274,59 @@ end
252274

253275
## Jacobian
254276

255-
### Unprepared
277+
### Unprepared, only when chunk size not specified
256278

257279
function DI.value_and_jacobian!(
258-
f::F, jac, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
259-
) where {F,C}
260-
fc = with_contexts(f, contexts...)
261-
y = fc(x)
262-
result = DiffResult(y, (jac,))
263-
result = jacobian!(result, fc, x)
264-
y = DR.value(result)
265-
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
266-
return y, jac
280+
f::F, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
281+
) where {F,C,chunksize}
282+
if isnothing(chunksize)
283+
fc = with_contexts(f, contexts...)
284+
y = fc(x)
285+
result = DiffResult(y, (jac,))
286+
result = jacobian!(result, fc, x)
287+
y = DR.value(result)
288+
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
289+
return y, jac
290+
else
291+
prep = DI.prepare_jacobian(f, backend, x, contexts...)
292+
return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...)
293+
end
267294
end
268295

269296
function DI.value_and_jacobian(
270-
f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
271-
) where {F,C}
272-
fc = with_contexts(f, contexts...)
273-
return fc(x), jacobian(fc, x)
297+
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
298+
) where {F,C,chunksize}
299+
if isnothing(chunksize)
300+
fc = with_contexts(f, contexts...)
301+
return fc(x), jacobian(fc, x)
302+
else
303+
prep = DI.prepare_jacobian(f, backend, x, contexts...)
304+
return DI.value_and_jacobian(f, prep, backend, x, contexts...)
305+
end
274306
end
275307

276308
function DI.jacobian!(
277-
f::F, jac, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
278-
) where {F,C}
279-
fc = with_contexts(f, contexts...)
280-
return jacobian!(jac, fc, x)
309+
f::F, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
310+
) where {F,C,chunksize}
311+
if isnothing(chunksize)
312+
fc = with_contexts(f, contexts...)
313+
return jacobian!(jac, fc, x)
314+
else
315+
prep = DI.prepare_jacobian(f, backend, x, contexts...)
316+
return DI.jacobian!(f, jac, prep, backend, x, contexts...)
317+
end
281318
end
282319

283-
function DI.jacobian(f::F, ::AutoForwardDiff, x, contexts::Vararg{Context,C}) where {F,C}
284-
fc = with_contexts(f, contexts...)
285-
return jacobian(fc, x)
320+
function DI.jacobian(
321+
f::F, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
322+
) where {F,C,chunksize}
323+
if isnothing(chunksize)
324+
fc = with_contexts(f, contexts...)
325+
return jacobian(fc, x)
326+
else
327+
prep = DI.prepare_jacobian(f, backend, x, contexts...)
328+
return DI.jacobian(f, prep, backend, x, contexts...)
329+
end
286330
end
287331

288332
### Prepared

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -209,39 +209,59 @@ end
209209

210210
## Jacobian
211211

212-
### Unprepared
212+
### Unprepared, only when chunk size is not specified
213213

214214
function DI.value_and_jacobian(
215-
f!::F, y, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
216-
) where {F,C}
217-
fc! = with_contexts(f!, contexts...)
218-
jac = similar(y, length(y), length(x))
219-
result = MutableDiffResult(y, (jac,))
220-
result = jacobian!(result, fc!, y, x)
221-
return DiffResults.value(result), DiffResults.jacobian(result)
215+
f!::F, y, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
216+
) where {F,C,chunksize}
217+
if isnothing(chunksize)
218+
fc! = with_contexts(f!, contexts...)
219+
jac = similar(y, length(y), length(x))
220+
result = MutableDiffResult(y, (jac,))
221+
result = jacobian!(result, fc!, y, x)
222+
return DiffResults.value(result), DiffResults.jacobian(result)
223+
else
224+
prep = DI.prepare_jacobian(f!, y, backend, x, contexts...)
225+
return DI.value_and_jacobian(f!, y, prep, backend, x, contexts...)
226+
end
222227
end
223228

224229
function DI.value_and_jacobian!(
225-
f!::F, y, jac, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
226-
) where {F,C}
227-
fc! = with_contexts(f!, contexts...)
228-
result = MutableDiffResult(y, (jac,))
229-
result = jacobian!(result, fc!, y, x)
230-
return DiffResults.value(result), DiffResults.jacobian(result)
230+
f!::F, y, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
231+
) where {F,C,chunksize}
232+
if isnothing(chunksize)
233+
fc! = with_contexts(f!, contexts...)
234+
result = MutableDiffResult(y, (jac,))
235+
result = jacobian!(result, fc!, y, x)
236+
return DiffResults.value(result), DiffResults.jacobian(result)
237+
else
238+
prep = DI.prepare_jacobian(f!, y, backend, x, contexts...)
239+
return DI.value_and_jacobian!(f!, y, jac, prep, backend, x, contexts...)
240+
end
231241
end
232242

233243
function DI.jacobian(
234-
f!::F, y, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
235-
) where {F,C}
236-
fc! = with_contexts(f!, contexts...)
237-
return jacobian(fc!, y, x)
244+
f!::F, y, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
245+
) where {F,C,chunksize}
246+
if isnothing(chunksize)
247+
fc! = with_contexts(f!, contexts...)
248+
return jacobian(fc!, y, x)
249+
else
250+
prep = DI.prepare_jacobian(f!, y, backend, x, contexts...)
251+
return DI.jacobian(f!, y, prep, backend, x, contexts...)
252+
end
238253
end
239254

240255
function DI.jacobian!(
241-
f!::F, y, jac, ::AutoForwardDiff, x, contexts::Vararg{Context,C}
242-
) where {F,C}
243-
fc! = with_contexts(f!, contexts...)
244-
return jacobian!(jac, fc!, y, x)
256+
f!::F, y, jac, backend::AutoForwardDiff{chunksize}, x, contexts::Vararg{Context,C}
257+
) where {F,C,chunksize}
258+
if isnothing(chunksize)
259+
fc! = with_contexts(f!, contexts...)
260+
return jacobian!(jac, fc!, y, x)
261+
else
262+
prep = DI.prepare_jacobian(f!, y, backend, x, contexts...)
263+
return DI.jacobian!(f!, y, jac, prep, backend, x, contexts...)
264+
end
245265
end
246266

247267
### Prepared

0 commit comments

Comments
 (0)