Skip to content

Commit c3f5360

Browse files
authored
Stop relying on Enzyme internals (#511)
* Stop relying on Enzyme internals * Remove useless converter
1 parent b93c17e commit c3f5360

7 files changed

Lines changed: 43 additions & 153 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ADTypes = "1.9.0"
4848
ChainRulesCore = "1.23.0"
4949
Compat = "3.46,4.2"
5050
Diffractor = "=0.2.6"
51-
Enzyme = "0.13.2"
51+
Enzyme = "0.13.6"
5252
FastDifferentiation = "0.3.17"
5353
FiniteDiff = "2.23.1"
5454
FiniteDifferences = "0.12.31"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ using Enzyme:
3131
ForwardWithPrimal,
3232
MixedDuplicated,
3333
Mode,
34+
NoPrimal,
3435
Reverse,
3536
ReverseMode,
3637
ReverseModeSplit,
38+
ReverseSplitNoPrimal,
39+
ReverseSplitWidth,
3740
ReverseSplitWithPrimal,
3841
ReverseWithPrimal,
42+
WithPrimal,
3943
autodiff,
4044
autodiff_thunk,
4145
create_shadows,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function DI.value_and_pushforward(
2222
dx_sametype = convert(typeof(x), only(tx))
2323
x_and_dx = Duplicated(x, dx_sametype)
2424
dy, y = autodiff(
25-
forward_mode_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
25+
forward_withprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
2626
)
2727
return y, (dy,)
2828
end
@@ -39,7 +39,7 @@ function DI.value_and_pushforward(
3939
tx_sametype = map(Fix1(convert, typeof(x)), tx)
4040
x_and_tx = BatchDuplicated(x, tx_sametype)
4141
ty, y = autodiff(
42-
forward_mode_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
42+
forward_withprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
4343
)
4444
return y, values(ty)
4545
end
@@ -56,9 +56,7 @@ function DI.pushforward(
5656
dx_sametype = convert(typeof(x), only(tx))
5757
x_and_dx = Duplicated(x, dx_sametype)
5858
dy = only(
59-
autodiff(
60-
forward_mode_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...
61-
),
59+
autodiff(forward_noprimal(backend), f_and_df, x_and_dx, map(translate, contexts)...)
6260
)
6361
return (dy,)
6462
end
@@ -75,9 +73,7 @@ function DI.pushforward(
7573
tx_sametype = map(Fix1(convert, typeof(x)), tx)
7674
x_and_tx = BatchDuplicated(x, tx_sametype)
7775
ty = only(
78-
autodiff(
79-
forward_mode_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...
80-
),
76+
autodiff(forward_noprimal(backend), f_and_df, x_and_tx, map(translate, contexts)...)
8177
)
8278
return values(ty)
8379
end
@@ -134,7 +130,7 @@ function DI.gradient(
134130
) where {F,B}
135131
f_and_df = get_f_and_df(f, backend)
136132
derivs = gradient(
137-
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
133+
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
138134
)
139135
return only(derivs)
140136
end
@@ -147,7 +143,7 @@ function DI.value_and_gradient(
147143
) where {F,B}
148144
f_and_df = get_f_and_df(f, backend)
149145
(; derivs, val) = gradient(
150-
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
146+
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
151147
)
152148
return val, only(derivs)
153149
end
@@ -197,7 +193,7 @@ function DI.jacobian(
197193
) where {F,B}
198194
f_and_df = get_f_and_df(f, backend)
199195
derivs = jacobian(
200-
forward_mode_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
196+
forward_noprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
201197
)
202198
jac_tensor = only(derivs)
203199
return maybe_reshape(jac_tensor, prep.output_length, length(x))
@@ -211,7 +207,7 @@ function DI.value_and_jacobian(
211207
) where {F,B}
212208
f_and_df = get_f_and_df(f, backend)
213209
(; derivs, val) = jacobian(
214-
forward_mode_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
210+
forward_withprimal(backend), f_and_df, x; chunk=Val(B), shadows=prep.shadows
215211
)
216212
jac_tensor = only(derivs)
217213
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function DI.value_and_pushforward(
2626
x_and_dx = Duplicated(x, dx_sametype)
2727
y_and_dy = Duplicated(y, dy_sametype)
2828
autodiff(
29-
forward_mode_noprimal(backend),
29+
forward_noprimal(backend),
3030
f!_and_df!,
3131
Const,
3232
y_and_dy,
@@ -51,7 +51,7 @@ function DI.value_and_pushforward(
5151
x_and_tx = BatchDuplicated(x, tx_sametype)
5252
y_and_ty = BatchDuplicated(y, ty_sametype)
5353
autodiff(
54-
forward_mode_noprimal(backend),
54+
forward_noprimal(backend),
5555
f!_and_df!,
5656
Const,
5757
y_and_ty,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function batch_seeded_autodiff_thunk(
2828
::Type{RA},
2929
args::Vararg{Annotation,N},
3030
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
31-
rmode_rightwidth = set_width(rmode, Val(B))
31+
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
3232
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
3333
tape, result, shadow_results = forward(f, args...)
3434
if RA <: Active
@@ -70,7 +70,7 @@ function DI.value_and_pullback(
7070
contexts::Vararg{Context,C},
7171
) where {F,C}
7272
f_and_df = force_annotation(get_f_and_df(f, backend))
73-
mode = reverse_mode_split_withprimal(backend)
73+
mode = reverse_split_withprimal(backend)
7474
RA = eltype(ty) <: Number ? Active : Duplicated
7575
dinputs, result = seeded_autodiff_thunk(
7676
mode, only(ty), f_and_df, RA, Active(x), map(translate, contexts)...
@@ -87,7 +87,7 @@ function DI.value_and_pullback(
8787
contexts::Vararg{Context,C},
8888
) where {F,B,C}
8989
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
90-
mode = reverse_mode_split_withprimal(backend)
90+
mode = reverse_split_withprimal(backend)
9191
RA = eltype(ty) <: Number ? Active : BatchDuplicated
9292
dinputs, result = batch_seeded_autodiff_thunk(
9393
mode, ty, f_and_df, RA, Active(x), map(translate, contexts)...
@@ -104,7 +104,7 @@ function DI.value_and_pullback(
104104
contexts::Vararg{Context,C},
105105
) where {F,C}
106106
f_and_df = force_annotation(get_f_and_df(f, backend))
107-
mode = reverse_mode_split_withprimal(backend)
107+
mode = reverse_split_withprimal(backend)
108108
RA = eltype(ty) <: Number ? Active : Duplicated
109109
dx = make_zero(x)
110110
_, result = seeded_autodiff_thunk(
@@ -122,7 +122,7 @@ function DI.value_and_pullback(
122122
contexts::Vararg{Context,C},
123123
) where {F,B,C}
124124
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
125-
mode = reverse_mode_split_withprimal(backend)
125+
mode = reverse_split_withprimal(backend)
126126
RA = eltype(ty) <: Number ? Active : BatchDuplicated
127127
tx = ntuple(_ -> make_zero(x), Val(B))
128128
_, result = batch_seeded_autodiff_thunk(
@@ -154,7 +154,7 @@ function DI.value_and_pullback!(
154154
contexts::Vararg{Context,C},
155155
) where {F,C}
156156
f_and_df = force_annotation(get_f_and_df(f, backend))
157-
mode = reverse_mode_split_withprimal(backend)
157+
mode = reverse_split_withprimal(backend)
158158
RA = eltype(ty) <: Number ? Active : Duplicated
159159
dx_righttype = convert(typeof(x), only(tx))
160160
make_zero!(dx_righttype)
@@ -180,7 +180,7 @@ function DI.value_and_pullback!(
180180
contexts::Vararg{Context,C},
181181
) where {F,B,C}
182182
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
183-
mode = reverse_mode_split_withprimal(backend)
183+
mode = reverse_split_withprimal(backend)
184184
RA = eltype(ty) <: Number ? Active : BatchDuplicated
185185
tx_righttype = map(Fix1(convert, typeof(x)), tx)
186186
make_zero!(tx_righttype)
@@ -227,9 +227,7 @@ function DI.gradient(
227227
contexts::Vararg{Context,C},
228228
) where {F,C}
229229
f_and_df = get_f_and_df(f, backend)
230-
derivs = gradient(
231-
reverse_mode_noprimal(backend), f_and_df, x, map(translate, contexts)...
232-
)
230+
derivs = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
233231
return first(derivs)
234232
end
235233

@@ -245,7 +243,7 @@ function DI.gradient!(
245243
dx_righttype = convert(typeof(x), grad)
246244
make_zero!(dx_righttype)
247245
autodiff(
248-
reverse_mode_noprimal(backend),
246+
reverse_noprimal(backend),
249247
f_and_df,
250248
Active,
251249
Duplicated(x, dx_righttype),
@@ -264,7 +262,7 @@ function DI.value_and_gradient(
264262
) where {F,C}
265263
f_and_df = get_f_and_df(f, backend)
266264
(; derivs, val) = gradient(
267-
reverse_mode_withprimal(backend), f_and_df, x, map(translate, contexts)...
265+
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
268266
)
269267
return val, first(derivs)
270268
end
@@ -281,7 +279,7 @@ function DI.value_and_gradient!(
281279
dx_righttype = convert(typeof(x), grad)
282280
make_zero!(dx_righttype)
283281
_, y = autodiff(
284-
reverse_mode_withprimal(backend),
282+
reverse_withprimal(backend),
285283
f_and_df,
286284
Active,
287285
Duplicated(x, dx_righttype),
@@ -308,7 +306,7 @@ function DI.jacobian(
308306
backend::AutoEnzyme{<:ReverseMode,Nothing},
309307
x,
310308
) where {F,Sy,B}
311-
derivs = jacobian(reverse_mode_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
309+
derivs = jacobian(reverse_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
312310
jac_tensor = only(derivs)
313311
return maybe_reshape(jac_tensor, prod(Sy), length(x))
314312
end
@@ -320,7 +318,7 @@ function DI.value_and_jacobian(
320318
x,
321319
) where {F,Sy,B}
322320
(; derivs, val) = jacobian(
323-
reverse_mode_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
321+
reverse_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
324322
)
325323
jac_tensor = only(derivs)
326324
return val, maybe_reshape(jac_tensor, prod(Sy), length(x))

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function DI.value_and_pullback(
2525
y_and_dy = Duplicated(y, dy_sametype)
2626
dinputs = only(
2727
autodiff(
28-
reverse_mode_noprimal(backend),
28+
reverse_noprimal(backend),
2929
f!_and_df!,
3030
Const,
3131
y_and_dy,
@@ -51,7 +51,7 @@ function DI.value_and_pullback(
5151
y_and_ty = BatchDuplicated(y, ty_sametype)
5252
dinputs = only(
5353
autodiff(
54-
reverse_mode_noprimal(backend),
54+
reverse_noprimal(backend),
5555
f!_and_df!,
5656
Const,
5757
y_and_ty,
@@ -78,7 +78,7 @@ function DI.value_and_pullback(
7878
x_and_dx = Duplicated(x, dx_sametype)
7979
y_and_dy = Duplicated(y, dy_sametype)
8080
autodiff(
81-
reverse_mode_noprimal(backend),
81+
reverse_noprimal(backend),
8282
f!_and_df!,
8383
Const,
8484
y_and_dy,
@@ -103,7 +103,7 @@ function DI.value_and_pullback(
103103
x_and_tx = BatchDuplicated(x, tx_sametype)
104104
y_and_ty = BatchDuplicated(y, ty_sametype)
105105
autodiff(
106-
reverse_mode_noprimal(backend),
106+
reverse_noprimal(backend),
107107
f!_and_df!,
108108
Const,
109109
y_and_ty,

0 commit comments

Comments
 (0)