Skip to content

Commit ea192f6

Browse files
authored
More batch mode for Enzyme (#495)
1 parent 2f06fb1 commit ea192f6

3 files changed

Lines changed: 112 additions & 66 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 87 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,31 @@ function seeded_autodiff_thunk(
2121
end
2222
end
2323

24+
function batch_seeded_autodiff_thunk(
25+
rmode::ReverseModeSplit{ReturnPrimal},
26+
dresults::NTuple,
27+
f::FA,
28+
::Type{RA},
29+
args::Vararg{Annotation,N},
30+
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
31+
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
32+
tape, result, shadow_results = forward(f, args...)
33+
if RA <: Active
34+
dresults_righttype = map(Fix1(convert, typeof(result)), dresults)
35+
dinputs = only(reverse(f, args..., dresults_righttype, tape))
36+
else
37+
foreach(shadow_results, dresults) do d0, d
38+
d0 .+= d # use recursive_add here?
39+
end
40+
dinputs = only(reverse(f, args..., tape))
41+
end
42+
if ReturnPrimal
43+
return (dinputs, result)
44+
else
45+
return (dinputs,)
46+
end
47+
end
48+
2449
## Pullback
2550

2651
function DI.prepare_pullback(
@@ -35,24 +60,6 @@ end
3560

3661
### Out-of-place
3762

38-
function DI.value_and_pullback(
39-
f::F,
40-
prep::NoPullbackPrep,
41-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
42-
x,
43-
ty::Tangents,
44-
contexts::Vararg{Context,C},
45-
) where {F,C}
46-
ys_and_dxs = map(ty.d) do dy
47-
y, tx = DI.value_and_pullback(f, prep, backend, x, Tangents(dy), contexts...)
48-
y, only(tx)
49-
end
50-
y = first(ys_and_dxs[1])
51-
dxs = last.(ys_and_dxs)
52-
tx = Tangents(dxs...)
53-
return y, tx
54-
end
55-
5663
function DI.value_and_pullback(
5764
f::F,
5865
::NoPullbackPrep,
@@ -70,6 +77,24 @@ function DI.value_and_pullback(
7077
return result, Tangents(first(dinputs))
7178
end
7279

80+
function DI.value_and_pullback(
81+
f::F,
82+
prep::NoPullbackPrep,
83+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
84+
x::Number,
85+
ty::Tangents{B},
86+
contexts::Vararg{Context,C},
87+
) where {F,B,C}
88+
# TODO: improve
89+
ys_and_dxs = map(ty.d) do dy
90+
y, tx = DI.value_and_pullback(f, prep, backend, x, Tangents(dy), contexts...)
91+
y, only(tx)
92+
end
93+
y = first(ys_and_dxs[1])
94+
dxs = last.(ys_and_dxs)
95+
return y, Tangents(dxs...)
96+
end
97+
7398
function DI.value_and_pullback(
7499
f::F,
75100
::NoPullbackPrep,
@@ -88,53 +113,37 @@ function DI.value_and_pullback(
88113
return result, Tangents(dx)
89114
end
90115

91-
function DI.pullback(
92-
f::F,
93-
prep::NoPullbackPrep,
94-
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
95-
x::Number,
96-
ty::Tangents{1},
97-
contexts::Vararg{Context,C},
98-
) where {F,C}
99-
return last(DI.value_and_pullback(f, prep, backend, x, ty, contexts...))
100-
end
101-
102-
### In-place
103-
104-
function DI.value_and_pullback!(
116+
function DI.value_and_pullback(
105117
f::F,
106-
tx::Tangents,
107-
prep::NoPullbackPrep,
118+
::NoPullbackPrep,
108119
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
109120
x,
110-
ty::Tangents,
121+
ty::Tangents{B},
111122
contexts::Vararg{Context,C},
112-
) where {F,C}
113-
ys = map(tx.d, ty.d) do dx, dy
114-
y, _ = DI.value_and_pullback!(
115-
f, Tangents(dx), prep, backend, x, Tangents(dy), contexts...
116-
)
117-
y
118-
end
119-
y = first(ys)
120-
return y, tx
123+
) where {F,B,C}
124+
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
125+
mode = reverse_mode_split_withprimal(backend)
126+
RA = eltype(ty) <: Number ? Active : BatchDuplicated
127+
dxs = ntuple(_ -> make_zero(x), Val(B))
128+
_, result = batch_seeded_autodiff_thunk(
129+
mode, NTuple(ty), f_and_df, RA, BatchDuplicated(x, dxs), map(translate, contexts)...
130+
)
131+
return result, Tangents(dxs...)
121132
end
122133

123-
function DI.pullback!(
134+
function DI.pullback(
124135
f::F,
125-
tx::Tangents,
126136
prep::NoPullbackPrep,
127137
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
128138
x,
129139
ty::Tangents,
130140
contexts::Vararg{Context,C},
131141
) where {F,C}
132-
for b in eachindex(tx.d, ty.d)
133-
DI.pullback!(f, Tangents(tx.d[b]), prep, backend, x, Tangents(ty.d[b]), contexts...)
134-
end
135-
return tx
142+
return last(DI.value_and_pullback(f, prep, backend, x, ty, contexts...))
136143
end
137144

145+
### In-place
146+
138147
function DI.value_and_pullback!(
139148
f::F,
140149
tx::Tangents{1},
@@ -161,13 +170,39 @@ function DI.value_and_pullback!(
161170
return result, tx
162171
end
163172

173+
function DI.value_and_pullback!(
174+
f::F,
175+
tx::Tangents{B},
176+
::NoPullbackPrep,
177+
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
178+
x,
179+
ty::Tangents{B},
180+
contexts::Vararg{Context,C},
181+
) where {F,B,C}
182+
f_and_df = force_annotation(get_f_and_df(f, backend, Val(B)))
183+
mode = reverse_mode_split_withprimal(backend)
184+
RA = eltype(ty) <: Number ? Active : BatchDuplicated
185+
dxs_righttype = map(Fix1(convert, typeof(x)), NTuple(tx))
186+
make_zero!(dxs_righttype)
187+
_, result = batch_seeded_autodiff_thunk(
188+
mode,
189+
NTuple(ty),
190+
f_and_df,
191+
RA,
192+
BatchDuplicated(x, dxs_righttype),
193+
map(translate, contexts)...,
194+
)
195+
foreach(copyto!, NTuple(tx), dxs_righttype)
196+
return result, tx
197+
end
198+
164199
function DI.pullback!(
165200
f::F,
166-
tx::Tangents{1},
201+
tx::Tangents,
167202
prep::NoPullbackPrep,
168203
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
169204
x,
170-
ty::Tangents{1},
205+
ty::Tangents,
171206
contexts::Vararg{Context,C},
172207
) where {F,C}
173208
return last(DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...))

DifferentiationInterface/src/utils/tangents.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ pick_batchsize(::AbstractADType, dimension::Integer) = 1
1010
"""
1111
Tangents{B}
1212
13-
Storage for a batch of `B` tangents (behaves like an `NTuple`).
13+
Storage for a batch of `B` tangents (wrapper around an `NTuple`).
1414
15-
Must be passed as an argument to [`pushforward`](@ref), [`pullback`](@ref) and [`hvp`](@ref), in addition to the input `x`.
15+
The operators [`pushforward`](@ref), [`pullback`](@ref) and [`hvp`](@ref) require a `Tangents` argument in addition to the input `x`.
16+
17+
The underlying `NTuple` of `t::Tangents` can be retrieved with `NTuple(t)`.
18+
We also define a few utility functions, as shown below.
1619
1720
# Constructors
1821
@@ -27,6 +30,9 @@ julia> using DifferentiationInterface
2730
julia> t = Tangents(2.0)
2831
Tangents{1, Float64}((2.0,))
2932
33+
julia> NTuple(t)
34+
(2.0,)
35+
3036
julia> length(t)
3137
1
3238
@@ -36,6 +42,9 @@ julia> only(t)
3642
julia> t = Tangents([2.0], [4.0], [6.0])
3743
Tangents{3, Vector{Float64}}(([2.0], [4.0], [6.0]))
3844
45+
julia> NTuple(t)
46+
([2.0], [4.0], [6.0])
47+
3948
julia> length(t)
4049
3
4150
@@ -59,25 +68,25 @@ end
5968
Base.length(::Tangents{B,T}) where {B,T} = B
6069
Base.eltype(::Tangents{B,T}) where {B,T} = T
6170

62-
Base.only(t::Tangents) = only(t.d)
63-
Base.getindex(t::Tangents, ind) = t.d[ind]
64-
Base.firstindex(t::Tangents) = firstindex(t.d)
65-
Base.lastindex(t::Tangents) = lastindex(t.d)
71+
Base.NTuple(t::Tangents) = t.d
6672

67-
Base.iterate(t::Tangents) = iterate(t.d)
68-
Base.iterate(t::Tangents, state) = iterate(t.d, state)
73+
Base.only(t::Tangents) = only(NTuple(t))
74+
Base.getindex(t::Tangents, ind) = NTuple(t)[ind]
75+
Base.firstindex(t::Tangents) = firstindex(NTuple(t))
76+
Base.lastindex(t::Tangents) = lastindex(NTuple(t))
6977

70-
Base.map(f, t::Tangents) = Tangents(map(f, t.d)...)
78+
Base.iterate(t::Tangents) = iterate(NTuple(t))
79+
Base.iterate(t::Tangents, state) = iterate(NTuple(t), state)
7180

72-
Base.:(==)(t1::Tangents{B}, t2::Tangents{B}) where {B} = t1.d == t2.d
81+
Base.map(f, t::Tangents) = Tangents(map(f, NTuple(t))...)
82+
83+
Base.:(==)(t1::Tangents{B}, t2::Tangents{B}) where {B} = NTuple(t1) == NTuple(t2)
7384

7485
function Base.isapprox(t1::Tangents{B}, t2::Tangents{B}; kwargs...) where {B}
75-
return all(isapprox.(t1.d, t2.d; kwargs...))
86+
return all(isapprox.(NTuple(t1), NTuple(t2); kwargs...))
7687
end
7788

7889
function Base.copyto!(t1::Tangents{B}, t2::Tangents{B}) where {B}
79-
for b in eachindex(t1.d, t2.d)
80-
copyto!(t1.d[b], t2.d[b])
81-
end
90+
foreach(copyto!, NTuple(t1), NTuple(t2))
8291
return t1
8392
end

DifferentiationInterface/test/Internals/tangents.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ using Test
44
@test_throws ArgumentError Tangents()
55

66
t = Tangents([2.0])
7+
@test NTuple(t) == ([2.0],)
78
@test length(t) == 1
89
@test eltype(t) == Vector{Float64}
910
@test only(t) == [2.0]
1011
@test copyto!(map(zero, t), t) t
1112

1213
t = Tangents(2.0, 4.0, 6.0)
14+
@test NTuple(t) == (2.0, 4.0, 6.0)
1315
@test length(t) == 3
1416
@test eltype(t) == Float64
1517
@test t[begin] == first(t) == 2.0

0 commit comments

Comments
 (0)