Skip to content

Commit d9a5cab

Browse files
authored
Fix Enzyme's batched pullback and Jacobian (#499)
* Fix Enzyme's batched pullback and Jacobian * Version
1 parent ea192f6 commit d9a5cab

3 files changed

Lines changed: 52 additions & 26 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ADTypes = "1.7.0"
4848
ChainRulesCore = "1.23.0"
4949
Compat = "3.46,4.2"
5050
Diffractor = "=0.2.6"
51-
Enzyme = "0.13.1"
51+
Enzyme = "0.13.2"
5252
FastDifferentiation = "0.3.17"
5353
FiniteDiff = "2.23.1"
5454
FiniteDifferences = "0.12.31"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ end
2323

2424
function batch_seeded_autodiff_thunk(
2525
rmode::ReverseModeSplit{ReturnPrimal},
26-
dresults::NTuple,
26+
dresults::NTuple{B},
2727
f::FA,
2828
::Type{RA},
2929
args::Vararg{Annotation,N},
30-
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
31-
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
30+
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
31+
rmode_rightwidth = set_width(rmode, Val(B))
32+
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
3233
tape, result, shadow_results = forward(f, args...)
3334
if RA <: Active
3435
dresults_righttype = map(Fix1(convert, typeof(result)), dresults)
@@ -79,20 +80,19 @@ end
7980

8081
function DI.value_and_pullback(
8182
f::F,
82-
prep::NoPullbackPrep,
83+
::NoPullbackPrep,
8384
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
8485
x::Number,
8586
ty::Tangents{B},
8687
contexts::Vararg{Context,C},
8788
) where {F,B,C}
88-
# TODO: improve
89-
ys_and_dxs = map(ty.d) do dy
90-
y, tx = DI.value_and_pullback(f, prep, backend, x, Tangents(dy), contexts...)
91-
y, only(tx)
92-
end
93-
y = first(ys_and_dxs[1])
94-
dxs = last.(ys_and_dxs)
95-
return y, Tangents(dxs...)
89+
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
90+
mode = reverse_mode_split_withprimal(backend)
91+
RA = eltype(ty) <: Number ? Active : BatchDuplicated
92+
dinputs, result = batch_seeded_autodiff_thunk(
93+
mode, NTuple(ty), f_and_df, RA, Active(x), map(translate, contexts)...
94+
)
95+
return result, Tangents(first(dinputs)...)
9696
end
9797

9898
function DI.value_and_pullback(
@@ -293,37 +293,37 @@ end
293293

294294
## Jacobian
295295

296-
struct EnzymeReverseOneArgJacobianPrep{M,B} <: JacobianPrep end
296+
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
297297

298298
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
299299
y = f(x)
300-
M = length(y)
301-
B = pick_batchsize(backend, M)
302-
return EnzymeReverseOneArgJacobianPrep{M,B}()
300+
Sy = size(y)
301+
B = pick_batchsize(backend, prod(Sy))
302+
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
303303
end
304304

305305
function DI.jacobian(
306306
f::F,
307-
::EnzymeReverseOneArgJacobianPrep{M,B},
307+
::EnzymeReverseOneArgJacobianPrep{Sy,B},
308308
backend::AutoEnzyme{<:ReverseMode,Nothing},
309309
x,
310-
) where {F,M,B}
311-
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B))
310+
) where {F,Sy,B}
311+
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
312312
jac_tensor = only(derivs)
313-
return maybe_reshape(jac_tensor, M, length(x))
313+
return maybe_reshape(jac_tensor, prod(Sy), length(x))
314314
end
315315

316316
function DI.value_and_jacobian(
317317
f::F,
318-
prep::EnzymeReverseOneArgJacobianPrep{M,B},
318+
::EnzymeReverseOneArgJacobianPrep{Sy,B},
319319
backend::AutoEnzyme{<:ReverseMode,Nothing},
320320
x,
321-
) where {F,M,B}
321+
) where {F,Sy,B}
322322
(; derivs, val) = jacobian(
323-
reverse_mode_withprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B)
323+
reverse_mode_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
324324
)
325-
jac_tensor = derivs
326-
return val, maybe_reshape(jac_tensor, M, length(x))
325+
jac_tensor = only(derivs)
326+
return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
327327
end
328328

329329
function DI.jacobian!(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,32 @@ function mode_split(
115115
}()
116116
end
117117

118+
function set_width(
119+
::ReverseModeSplit{
120+
ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI,ErrIfFuncWritten
121+
},
122+
::Val{NewWidth},
123+
) where {
124+
ReturnPrimal,
125+
ReturnShadow,
126+
RuntimeActivity,
127+
Width,
128+
ModifiedBetween,
129+
ABI,
130+
ErrIfFuncWritten,
131+
NewWidth,
132+
}
133+
return ReverseModeSplit{
134+
ReturnPrimal,
135+
ReturnShadow,
136+
RuntimeActivity,
137+
NewWidth,
138+
ModifiedBetween,
139+
ABI,
140+
ErrIfFuncWritten,
141+
}()
142+
end
143+
118144
mode_noprimal(mode::Mode) = mode_noprimal(typeof(mode))
119145
mode_withprimal(mode::Mode) = mode_withprimal(typeof(mode))
120146
mode_split(mode::Mode) = mode_split(typeof(mode))

0 commit comments

Comments
 (0)