|
23 | 23 |
|
24 | 24 | function batch_seeded_autodiff_thunk( |
25 | 25 | rmode::ReverseModeSplit{ReturnPrimal}, |
26 | | - dresults::NTuple, |
| 26 | + dresults::NTuple{B}, |
27 | 27 | f::FA, |
28 | 28 | ::Type{RA}, |
29 | 29 | args::Vararg{Annotation,N}, |
30 | | -) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N} |
31 | | - forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...) |
| 30 | +) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N} |
| 31 | + rmode_rightwidth = set_width(rmode, Val(B)) |
| 32 | + forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...) |
32 | 33 | tape, result, shadow_results = forward(f, args...) |
33 | 34 | if RA <: Active |
34 | 35 | dresults_righttype = map(Fix1(convert, typeof(result)), dresults) |
|
79 | 80 |
|
80 | 81 | function DI.value_and_pullback( |
81 | 82 | f::F, |
82 | | - prep::NoPullbackPrep, |
| 83 | + ::NoPullbackPrep, |
83 | 84 | backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, |
84 | 85 | x::Number, |
85 | 86 | ty::Tangents{B}, |
86 | 87 | contexts::Vararg{Context,C}, |
87 | 88 | ) where {F,B,C} |
88 | | - # TODO: improve |
89 | | - ys_and_dxs = map(ty.d) do dy |
90 | | - y, tx = DI.value_and_pullback(f, prep, backend, x, Tangents(dy), contexts...) |
91 | | - y, only(tx) |
92 | | - end |
93 | | - y = first(ys_and_dxs[1]) |
94 | | - dxs = last.(ys_and_dxs) |
95 | | - return y, Tangents(dxs...) |
| 89 | + f_and_df = force_annotation(get_f_and_df(f, backend, Val(B))) |
| 90 | + mode = reverse_mode_split_withprimal(backend) |
| 91 | + RA = eltype(ty) <: Number ? Active : BatchDuplicated |
| 92 | + dinputs, result = batch_seeded_autodiff_thunk( |
| 93 | + mode, NTuple(ty), f_and_df, RA, Active(x), map(translate, contexts)... |
| 94 | + ) |
| 95 | + return result, Tangents(first(dinputs)...) |
96 | 96 | end |
97 | 97 |
|
98 | 98 | function DI.value_and_pullback( |
@@ -293,37 +293,37 @@ end |
293 | 293 |
|
294 | 294 | ## Jacobian |
295 | 295 |
|
296 | | -struct EnzymeReverseOneArgJacobianPrep{M,B} <: JacobianPrep end |
| 296 | +struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end |
297 | 297 |
|
298 | 298 | function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F} |
299 | 299 | y = f(x) |
300 | | - M = length(y) |
301 | | - B = pick_batchsize(backend, M) |
302 | | - return EnzymeReverseOneArgJacobianPrep{M,B}() |
| 300 | + Sy = size(y) |
| 301 | + B = pick_batchsize(backend, prod(Sy)) |
| 302 | + return EnzymeReverseOneArgJacobianPrep{Sy,B}() |
303 | 303 | end |
304 | 304 |
|
305 | 305 | function DI.jacobian( |
306 | 306 | f::F, |
307 | | - ::EnzymeReverseOneArgJacobianPrep{M,B}, |
| 307 | + ::EnzymeReverseOneArgJacobianPrep{Sy,B}, |
308 | 308 | backend::AutoEnzyme{<:ReverseMode,Nothing}, |
309 | 309 | x, |
310 | | -) where {F,M,B} |
311 | | - derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B)) |
| 310 | +) where {F,Sy,B} |
| 311 | + derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)) |
312 | 312 | jac_tensor = only(derivs) |
313 | | - return maybe_reshape(jac_tensor, M, length(x)) |
| 313 | + return maybe_reshape(jac_tensor, prod(Sy), length(x)) |
314 | 314 | end |
315 | 315 |
|
316 | 316 | function DI.value_and_jacobian( |
317 | 317 | f::F, |
318 | | - prep::EnzymeReverseOneArgJacobianPrep{M,B}, |
| 318 | + ::EnzymeReverseOneArgJacobianPrep{Sy,B}, |
319 | 319 | backend::AutoEnzyme{<:ReverseMode,Nothing}, |
320 | 320 | x, |
321 | | -) where {F,M,B} |
| 321 | +) where {F,Sy,B} |
322 | 322 | (; derivs, val) = jacobian( |
323 | | - reverse_mode_withprimal(backend), f, x; n_outs=Val((M,)), chunk=Val(B) |
| 323 | + reverse_mode_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B) |
324 | 324 | ) |
325 | | - jac_tensor = derivs |
326 | | - return val, maybe_reshape(jac_tensor, M, length(x)) |
| 325 | + jac_tensor = only(derivs) |
| 326 | + return val, maybe_reshape(jac_tensor, prod(Sy), length(x)) |
327 | 327 | end |
328 | 328 |
|
329 | 329 | function DI.jacobian!( |
|
0 commit comments