Skip to content

Commit f71edc4

Browse files
authored
Revert mistakes in "Better explanation of preparation" (#538)
* Revert "Better explanation of preparation (#536)" This reverts commit 4d39e77. * Fix
1 parent 4d39e77 commit f71edc4

2 files changed

Lines changed: 6 additions & 10 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ end
117117
function DI.prepare_gradient(
118118
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
119119
) where {F}
120-
valB = pick_batchsize(backend, length(x))
121-
shadows = create_shadows(valB, x)
120+
B = pick_batchsize(backend, length(x))
121+
shadows = create_shadows(Val(B), x)
122122
return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows)
123123
end
124124

@@ -180,8 +180,8 @@ function DI.prepare_jacobian(
180180
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
181181
) where {F}
182182
y = f(x)
183-
valB = pick_batchsize(backend, length(x))
184-
shadows = create_shadows(valB, x)
183+
B = pick_batchsize(backend, length(x))
184+
shadows = create_shadows(Val(B), x)
185185
return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y))
186186
end
187187

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,11 @@ end
349349

350350
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
351351

352-
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
353-
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
354-
end
355-
356352
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
357353
y = f(x)
358354
Sy = size(y)
359-
valB = pick_batchsize(backend, prod(Sy))
360-
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
355+
B = pick_batchsize(backend, prod(Sy))
356+
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
361357
end
362358

363359
function DI.jacobian(

0 commit comments

Comments
 (0)