@@ -28,7 +28,7 @@ function batch_seeded_autodiff_thunk(
2828 :: Type{RA} ,
2929 args:: Vararg{Annotation,N} ,
3030) where {ReturnPrimal,B,FA<: Annotation ,RA<: Annotation ,N}
31- rmode_rightwidth = set_width (rmode, Val (B))
31+ rmode_rightwidth = ReverseSplitWidth (rmode, Val (B))
3232 forward, reverse = autodiff_thunk (rmode_rightwidth, FA, RA, typeof .(args)... )
3333 tape, result, shadow_results = forward (f, args... )
3434 if RA <: Active
@@ -70,7 +70,7 @@ function DI.value_and_pullback(
7070 contexts:: Vararg{Context,C} ,
7171) where {F,C}
7272 f_and_df = force_annotation (get_f_and_df (f, backend))
73- mode = reverse_mode_split_withprimal (backend)
73+ mode = reverse_split_withprimal (backend)
7474 RA = eltype (ty) <: Number ? Active : Duplicated
7575 dinputs, result = seeded_autodiff_thunk (
7676 mode, only (ty), f_and_df, RA, Active (x), map (translate, contexts)...
@@ -87,7 +87,7 @@ function DI.value_and_pullback(
8787 contexts:: Vararg{Context,C} ,
8888) where {F,B,C}
8989 f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
90- mode = reverse_mode_split_withprimal (backend)
90+ mode = reverse_split_withprimal (backend)
9191 RA = eltype (ty) <: Number ? Active : BatchDuplicated
9292 dinputs, result = batch_seeded_autodiff_thunk (
9393 mode, ty, f_and_df, RA, Active (x), map (translate, contexts)...
@@ -104,7 +104,7 @@ function DI.value_and_pullback(
104104 contexts:: Vararg{Context,C} ,
105105) where {F,C}
106106 f_and_df = force_annotation (get_f_and_df (f, backend))
107- mode = reverse_mode_split_withprimal (backend)
107+ mode = reverse_split_withprimal (backend)
108108 RA = eltype (ty) <: Number ? Active : Duplicated
109109 dx = make_zero (x)
110110 _, result = seeded_autodiff_thunk (
@@ -122,7 +122,7 @@ function DI.value_and_pullback(
122122 contexts:: Vararg{Context,C} ,
123123) where {F,B,C}
124124 f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
125- mode = reverse_mode_split_withprimal (backend)
125+ mode = reverse_split_withprimal (backend)
126126 RA = eltype (ty) <: Number ? Active : BatchDuplicated
127127 tx = ntuple (_ -> make_zero (x), Val (B))
128128 _, result = batch_seeded_autodiff_thunk (
@@ -154,7 +154,7 @@ function DI.value_and_pullback!(
154154 contexts:: Vararg{Context,C} ,
155155) where {F,C}
156156 f_and_df = force_annotation (get_f_and_df (f, backend))
157- mode = reverse_mode_split_withprimal (backend)
157+ mode = reverse_split_withprimal (backend)
158158 RA = eltype (ty) <: Number ? Active : Duplicated
159159 dx_righttype = convert (typeof (x), only (tx))
160160 make_zero! (dx_righttype)
@@ -180,7 +180,7 @@ function DI.value_and_pullback!(
180180 contexts:: Vararg{Context,C} ,
181181) where {F,B,C}
182182 f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
183- mode = reverse_mode_split_withprimal (backend)
183+ mode = reverse_split_withprimal (backend)
184184 RA = eltype (ty) <: Number ? Active : BatchDuplicated
185185 tx_righttype = map (Fix1 (convert, typeof (x)), tx)
186186 make_zero! (tx_righttype)
@@ -227,9 +227,7 @@ function DI.gradient(
227227 contexts:: Vararg{Context,C} ,
228228) where {F,C}
229229 f_and_df = get_f_and_df (f, backend)
230- derivs = gradient (
231- reverse_mode_noprimal (backend), f_and_df, x, map (translate, contexts)...
232- )
230+ derivs = gradient (reverse_noprimal (backend), f_and_df, x, map (translate, contexts)... )
233231 return first (derivs)
234232end
235233
@@ -245,7 +243,7 @@ function DI.gradient!(
245243 dx_righttype = convert (typeof (x), grad)
246244 make_zero! (dx_righttype)
247245 autodiff (
248- reverse_mode_noprimal (backend),
246+ reverse_noprimal (backend),
249247 f_and_df,
250248 Active,
251249 Duplicated (x, dx_righttype),
@@ -264,7 +262,7 @@ function DI.value_and_gradient(
264262) where {F,C}
265263 f_and_df = get_f_and_df (f, backend)
266264 (; derivs, val) = gradient (
267- reverse_mode_withprimal (backend), f_and_df, x, map (translate, contexts)...
265+ reverse_withprimal (backend), f_and_df, x, map (translate, contexts)...
268266 )
269267 return val, first (derivs)
270268end
@@ -281,7 +279,7 @@ function DI.value_and_gradient!(
281279 dx_righttype = convert (typeof (x), grad)
282280 make_zero! (dx_righttype)
283281 _, y = autodiff (
284- reverse_mode_withprimal (backend),
282+ reverse_withprimal (backend),
285283 f_and_df,
286284 Active,
287285 Duplicated (x, dx_righttype),
@@ -308,7 +306,7 @@ function DI.jacobian(
308306 backend:: AutoEnzyme{<:ReverseMode,Nothing} ,
309307 x,
310308) where {F,Sy,B}
311- derivs = jacobian (reverse_mode_noprimal (backend), f, x; n_outs= Val (Sy), chunk= Val (B))
309+ derivs = jacobian (reverse_noprimal (backend), f, x; n_outs= Val (Sy), chunk= Val (B))
312310 jac_tensor = only (derivs)
313311 return maybe_reshape (jac_tensor, prod (Sy), length (x))
314312end
@@ -320,7 +318,7 @@ function DI.value_and_jacobian(
320318 x,
321319) where {F,Sy,B}
322320 (; derivs, val) = jacobian (
323- reverse_mode_withprimal (backend), f, x; n_outs= Val (Sy), chunk= Val (B)
321+ reverse_withprimal (backend), f, x; n_outs= Val (Sy), chunk= Val (B)
324322 )
325323 jac_tensor = only (derivs)
326324 return val, maybe_reshape (jac_tensor, prod (Sy), length (x))
0 commit comments