Skip to content

Commit f663225

Browse files
authored
Improve Enzyme batch size and gradient (#557)
* Adaptive Enzyme batch size * Use Enzyme.gradient whenever possible
1 parent 3470e14 commit f663225

3 files changed

Lines changed: 14 additions & 33 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.8"
4+
version = "0.6.9"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,8 @@ function DI.gradient(
219219
contexts::Vararg{Context,C},
220220
) where {F,C}
221221
f_and_df = get_f_and_df(f, backend)
222-
grad = make_zero(x)
223-
autodiff(
224-
reverse_noprimal(backend),
225-
f_and_df,
226-
Active,
227-
Duplicated(x, grad),
228-
map(translate, contexts)...,
229-
)
222+
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
223+
grad = first(ders)
230224
return grad
231225
end
232226

@@ -237,14 +231,10 @@ function DI.value_and_gradient(
237231
contexts::Vararg{Context,C},
238232
) where {F,C}
239233
f_and_df = get_f_and_df(f, backend)
240-
grad = make_zero(x)
241-
_, y = autodiff(
242-
reverse_withprimal(backend),
243-
f_and_df,
244-
Active,
245-
Duplicated(x, grad),
246-
map(translate, contexts)...,
234+
ders, y = gradient(
235+
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
247236
)
237+
grad = first(ders)
248238
return y, grad
249239
end
250240

@@ -272,13 +262,8 @@ function DI.gradient(
272262
contexts::Vararg{Context,C},
273263
) where {F,C}
274264
f_and_df = get_f_and_df(f, backend)
275-
grad = make_zero(x)
276-
autodiff(
277-
reverse_noprimal(backend),
278-
f_and_df,
279-
Duplicated(x, grad),
280-
map(translate, contexts)...,
281-
)
265+
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
266+
grad = first(ders)
282267
return grad
283268
end
284269

@@ -300,7 +285,7 @@ function DI.gradient!(
300285
Duplicated(x, grad_righttype),
301286
map(translate, contexts)...,
302287
)
303-
grad isa typeof(x) || copyto!(grad, grad_righttype)
288+
grad === grad_righttype || copyto!(grad, grad_righttype)
304289
return grad
305290
end
306291

@@ -312,14 +297,10 @@ function DI.value_and_gradient(
312297
contexts::Vararg{Context,C},
313298
) where {F,C}
314299
f_and_df = get_f_and_df(f, backend)
315-
grad = make_zero(x)
316-
_, y = autodiff(
317-
reverse_withprimal(backend),
318-
f_and_df,
319-
Active,
320-
Duplicated(x, grad),
321-
map(translate, contexts)...,
300+
ders, y = gradient(
301+
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
322302
)
303+
grad = first(ders)
323304
return y, grad
324305
end
325306

@@ -341,7 +322,7 @@ function DI.value_and_gradient!(
341322
Duplicated(x, grad_righttype),
342323
map(translate, contexts)...,
343324
)
344-
grad isa typeof(x) || copyto!(grad, grad_righttype)
325+
grad === grad_righttype || copyto!(grad, grad_righttype)
345326
return y, grad
346327
end
347328

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
2-
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(16)
2+
DI.pick_batchsize(::AutoEnzyme, dimension::Integer) = Val(min(dimension, 16))
33

44
## Annotations
55

0 commit comments

Comments
 (0)