Skip to content

Commit 73f7314

Browse files
authored
Optimize Enzyme gradient (#515)
* Optimize Enzyme gradient * Active
1 parent 92ccd1c commit 73f7314

1 file changed

Lines changed: 74 additions & 18 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -210,82 +210,138 @@ end
210210

211211
## Gradient
212212

213+
### Without preparation
214+
215+
function DI.gradient(
216+
f::F,
217+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
218+
x,
219+
contexts::Vararg{Context,C},
220+
) where {F,C}
221+
f_and_df = get_f_and_df(f, backend)
222+
grad = make_zero(x)
223+
autodiff(
224+
reverse_noprimal(backend),
225+
f_and_df,
226+
Active,
227+
Duplicated(x, grad),
228+
map(translate, contexts)...,
229+
)
230+
return grad
231+
end
232+
233+
function DI.value_and_gradient(
234+
f::F,
235+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
236+
x,
237+
contexts::Vararg{Context,C},
238+
) where {F,C}
239+
f_and_df = get_f_and_df(f, backend)
240+
grad = make_zero(x)
241+
_, y = autodiff(
242+
reverse_withprimal(backend),
243+
f_and_df,
244+
Active,
245+
Duplicated(x, grad),
246+
map(translate, contexts)...,
247+
)
248+
return y, grad
249+
end
250+
251+
### With preparation
252+
253+
struct EnzymeGradientPrep{G} <: GradientPrep
254+
grad_righttype::G
255+
end
256+
213257
function DI.prepare_gradient(
214258
f::F,
215259
::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
216260
x,
217261
contexts::Vararg{Context,C},
218262
) where {F,C}
219-
return NoGradientPrep()
263+
grad_righttype = make_zero(x)
264+
return EnzymeGradientPrep(grad_righttype)
220265
end
221266

222267
function DI.gradient(
223268
f::F,
224-
::NoGradientPrep,
269+
::EnzymeGradientPrep,
225270
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
226271
x,
227272
contexts::Vararg{Context,C},
228273
) where {F,C}
229274
f_and_df = get_f_and_df(f, backend)
230-
derivs = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
231-
return first(derivs)
275+
grad = make_zero(x)
276+
autodiff(
277+
reverse_noprimal(backend),
278+
f_and_df,
279+
Duplicated(x, grad),
280+
map(translate, contexts)...,
281+
)
282+
return grad
232283
end
233284

234285
function DI.gradient!(
235286
f::F,
236287
grad,
237-
::NoGradientPrep,
288+
prep::EnzymeGradientPrep,
238289
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
239290
x,
240291
contexts::Vararg{Context,C},
241292
) where {F,C}
242293
f_and_df = get_f_and_df(f, backend)
243-
dx_righttype = convert(typeof(x), grad)
244-
make_zero!(dx_righttype)
294+
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
295+
make_zero!(grad_righttype)
245296
autodiff(
246297
reverse_noprimal(backend),
247298
f_and_df,
248299
Active,
249-
Duplicated(x, dx_righttype),
300+
Duplicated(x, grad_righttype),
250301
map(translate, contexts)...,
251302
)
252-
dx_righttype === grad || copyto!(grad, dx_righttype)
303+
grad isa typeof(x) || copyto!(grad, grad_righttype)
253304
return grad
254305
end
255306

256307
function DI.value_and_gradient(
257308
f::F,
258-
::NoGradientPrep,
309+
::EnzymeGradientPrep,
259310
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
260311
x,
261312
contexts::Vararg{Context,C},
262313
) where {F,C}
263314
f_and_df = get_f_and_df(f, backend)
264-
(; derivs, val) = gradient(
265-
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
315+
grad = make_zero(x)
316+
_, y = autodiff(
317+
reverse_withprimal(backend),
318+
f_and_df,
319+
Active,
320+
Duplicated(x, grad),
321+
map(translate, contexts)...,
266322
)
267-
return val, first(derivs)
323+
return y, grad
268324
end
269325

270326
function DI.value_and_gradient!(
271327
f::F,
272328
grad,
273-
::NoGradientPrep,
329+
prep::EnzymeGradientPrep,
274330
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
275331
x,
276332
contexts::Vararg{Context,C},
277333
) where {F,C}
278334
f_and_df = get_f_and_df(f, backend)
279-
dx_righttype = convert(typeof(x), grad)
280-
make_zero!(dx_righttype)
335+
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
336+
make_zero!(grad_righttype)
281337
_, y = autodiff(
282338
reverse_withprimal(backend),
283339
f_and_df,
284340
Active,
285-
Duplicated(x, dx_righttype),
341+
Duplicated(x, grad_righttype),
286342
map(translate, contexts)...,
287343
)
288-
dx_righttype === grad || copyto!(grad, dx_righttype)
344+
grad isa typeof(x) || copyto!(grad, grad_righttype)
289345
return y, grad
290346
end
291347

0 commit comments

Comments
 (0)