Skip to content

Commit d4b17c1

Browse files
authored
Support static arrays with reverse Enzyme (#585)
1 parent 94f9bc5 commit d4b17c1

4 files changed

Lines changed: 82 additions & 78 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.13"
4+
version = "0.6.14"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 58 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -61,40 +61,6 @@ end
6161

6262
### Out-of-place
6363

64-
function DI.value_and_pullback(
65-
f::F,
66-
::NoPullbackPrep,
67-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
68-
x::Number,
69-
ty::NTuple{1},
70-
contexts::Vararg{Context,C},
71-
) where {F,C}
72-
f_and_df = force_annotation(get_f_and_df(f, backend))
73-
mode = reverse_split_withprimal(backend)
74-
RA = eltype(ty) <: Number ? Active : Duplicated
75-
dinputs, result = seeded_autodiff_thunk(
76-
mode, only(ty), f_and_df, RA, Active(x), map(translate, contexts)...
77-
)
78-
return result, (first(dinputs),)
79-
end
80-
81-
function DI.value_and_pullback(
82-
f::F,
83-
::NoPullbackPrep,
84-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
85-
x::Number,
86-
ty::NTuple{B},
87-
contexts::Vararg{Context,C},
88-
) where {F,B,C}
89-
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
90-
mode = reverse_split_withprimal(backend)
91-
RA = eltype(ty) <: Number ? Active : BatchDuplicated
92-
dinputs, result = batch_seeded_autodiff_thunk(
93-
mode, ty, f_and_df, RA, Active(x), map(translate, contexts)...
94-
)
95-
return result, values(first(dinputs))
96-
end
97-
9864
function DI.value_and_pullback(
9965
f::F,
10066
::NoPullbackPrep,
@@ -105,12 +71,18 @@ function DI.value_and_pullback(
10571
) where {F,C}
10672
f_and_df = force_annotation(get_f_and_df(f, backend))
10773
mode = reverse_split_withprimal(backend)
108-
RA = eltype(ty) <: Number ? Active : Duplicated
74+
IA = guess_activity(typeof(x), mode)
75+
RA = guess_activity(eltype(ty), mode)
10976
dx = make_zero(x)
110-
_, result = seeded_autodiff_thunk(
111-
mode, only(ty), f_and_df, RA, Duplicated(x, dx), map(translate, contexts)...
77+
dinputs, result = seeded_autodiff_thunk(
78+
mode, only(ty), f_and_df, RA, annotate(IA, x, dx), map(translate, contexts)...
11279
)
113-
return result, (dx,)
80+
new_dx = first(dinputs)
81+
if isnothing(new_dx)
82+
return result, (dx,)
83+
else
84+
return result, (new_dx,)
85+
end
11486
end
11587

11688
function DI.value_and_pullback(
@@ -123,12 +95,18 @@ function DI.value_and_pullback(
12395
) where {F,B,C}
12496
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
12597
mode = reverse_split_withprimal(backend)
126-
RA = eltype(ty) <: Number ? Active : BatchDuplicated
98+
IA = batchify_activity(guess_activity(typeof(x), mode), Val(B))
99+
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
127100
tx = ntuple(_ -> make_zero(x), Val(B))
128-
_, result = batch_seeded_autodiff_thunk(
129-
mode, ty, f_and_df, RA, BatchDuplicated(x, tx), map(translate, contexts)...
101+
dinputs, result = batch_seeded_autodiff_thunk(
102+
mode, ty, f_and_df, RA, annotate(IA, x, tx), map(translate, contexts)...
130103
)
131-
return result, tx
104+
new_tx = values(first(dinputs))
105+
if isnothing(new_tx)
106+
return result, tx
107+
else
108+
return result, new_tx
109+
end
132110
end
133111

134112
function DI.pullback(
@@ -155,7 +133,7 @@ function DI.value_and_pullback!(
155133
) where {F,C}
156134
f_and_df = force_annotation(get_f_and_df(f, backend))
157135
mode = reverse_split_withprimal(backend)
158-
RA = eltype(ty) <: Number ? Active : Duplicated
136+
RA = guess_activity(eltype(ty), mode)
159137
dx_righttype = convert(typeof(x), only(tx))
160138
make_zero!(dx_righttype)
161139
_, result = seeded_autodiff_thunk(
@@ -181,7 +159,7 @@ function DI.value_and_pullback!(
181159
) where {F,B,C}
182160
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
183161
mode = reverse_split_withprimal(backend)
184-
RA = eltype(ty) <: Number ? Active : BatchDuplicated
162+
RA = batchify_activity(guess_activity(eltype(ty), mode), Val(B))
185163
tx_righttype = map(Fix1(convert, typeof(x)), tx)
186164
make_zero!(tx_righttype)
187165
_, result = batch_seeded_autodiff_thunk(
@@ -213,29 +191,39 @@ end
213191
### Without preparation
214192

215193
function DI.gradient(
216-
f::F,
217-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
218-
x,
219-
contexts::Vararg{Context,C},
194+
f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
220195
) where {F,C}
221196
f_and_df = get_f_and_df(f, backend)
222-
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
223-
grad = first(ders)
224-
return grad
197+
mode = reverse_noprimal(backend)
198+
IA = guess_activity(typeof(x), mode)
199+
grad = make_zero(x)
200+
dinputs = only(
201+
autodiff(mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...)
202+
)
203+
new_grad = first(dinputs)
204+
if isnothing(new_grad)
205+
return grad
206+
else
207+
return new_grad
208+
end
225209
end
226210

227211
function DI.value_and_gradient(
228-
f::F,
229-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
230-
x,
231-
contexts::Vararg{Context,C},
212+
f::F, backend::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
232213
) where {F,C}
233214
f_and_df = get_f_and_df(f, backend)
234-
ders, y = gradient(
235-
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
215+
mode = reverse_withprimal(backend)
216+
IA = guess_activity(typeof(x), mode)
217+
grad = make_zero(x)
218+
dinputs, result = autodiff(
219+
mode, f_and_df, Active, annotate(IA, x, grad), map(translate, contexts)...
236220
)
237-
grad = first(ders)
238-
return y, grad
221+
new_grad = first(dinputs)
222+
if isnothing(new_grad)
223+
return result, grad
224+
else
225+
return result, new_grad
226+
end
239227
end
240228

241229
### With preparation
@@ -245,10 +233,7 @@ struct EnzymeGradientPrep{G} <: GradientPrep
245233
end
246234

247235
function DI.prepare_gradient(
248-
f::F,
249-
::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
250-
x,
251-
contexts::Vararg{Context,C},
236+
f::F, ::AutoEnzyme{<:Union{ReverseMode,Nothing}}, x, contexts::Vararg{Context,C}
252237
) where {F,C}
253238
grad_righttype = make_zero(x)
254239
return EnzymeGradientPrep(grad_righttype)
@@ -257,21 +242,18 @@ end
257242
function DI.gradient(
258243
f::F,
259244
::EnzymeGradientPrep,
260-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
245+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
261246
x,
262247
contexts::Vararg{Context,C},
263248
) where {F,C}
264-
f_and_df = get_f_and_df(f, backend)
265-
ders = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
266-
grad = first(ders)
267-
return grad
249+
return DI.gradient(f, backend, x, contexts...)
268250
end
269251

270252
function DI.gradient!(
271253
f::F,
272254
grad,
273255
prep::EnzymeGradientPrep,
274-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
256+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
275257
x,
276258
contexts::Vararg{Context,C},
277259
) where {F,C}
@@ -292,23 +274,18 @@ end
292274
function DI.value_and_gradient(
293275
f::F,
294276
::EnzymeGradientPrep,
295-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
277+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
296278
x,
297279
contexts::Vararg{Context,C},
298280
) where {F,C}
299-
f_and_df = get_f_and_df(f, backend)
300-
ders, y = gradient(
301-
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
302-
)
303-
grad = first(ders)
304-
return y, grad
281+
return DI.value_and_gradient(f, backend, x, contexts...)
305282
end
306283

307284
function DI.value_and_gradient!(
308285
f::F,
309286
grad,
310287
prep::EnzymeGradientPrep,
311-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
288+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
312289
x,
313290
contexts::Vararg{Context,C},
314291
) where {F,C}
@@ -328,6 +305,9 @@ end
328305

329306
## Jacobian
330307

308+
# TODO: does not support static arrays
309+
310+
#=
331311
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
332312
333313
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
@@ -385,3 +365,4 @@ function DI.value_and_jacobian!(
385365
y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
386366
return y, copyto!(jac, new_jac)
387367
end
368+
=#

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,13 @@ end
7676
function maybe_reshape(A::AbstractArray, m, n)
7777
return reshape(A, m, n)
7878
end
79+
80+
annotate(::Type{Active{T}}, x, dx) where {T} = Active(x)
81+
annotate(::Type{Duplicated{T}}, x, dx) where {T} = Duplicated(x, dx)
82+
83+
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
84+
return BatchDuplicated(x, tx)
85+
end
86+
87+
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
88+
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,16 @@ test_differentiation(
9595
sparsity=true,
9696
logging=LOGGING,
9797
);
98+
99+
##
100+
101+
filtered_static_scenarios = filter(static_scenarios()) do s
102+
DIT.operator_place(s) == :out && DIT.function_place(s) == :out
103+
end
104+
105+
test_differentiation(
106+
[AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)],
107+
filtered_static_scenarios;
108+
excluded=SECOND_ORDER,
109+
logging=LOGGING,
110+
)

0 commit comments

Comments
 (0)