Skip to content

Commit bb0d874

Browse files
authored
Test batched operators (#331)
* Add batched tests * Full batched interface with tests * Fixes * Fix * Don't batchify by default * Coverage in DIT * Typos and coverage
1 parent b8f82b0 commit bb0d874

27 files changed

Lines changed: 1053 additions & 494 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,41 @@ tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)
77
make_dual_similar(::Type{T}, x::Number, dx::Number) where {T} = Dual{T}(x, dx)
88
make_dual_similar(::Type{T}, x, dx) where {T} = similar(x, Dual{T,eltype(x),1})
99

10+
function make_dual_similar(::Type{T}, x::Number, dx::Batch{B,<:Number}) where {T,B}
11+
return Dual{T}(x, dx.elements)
12+
end
13+
1014
function make_dual_similar(::Type{T}, x, dx::Batch{B}) where {T,B}
1115
return similar(x, Dual{T,eltype(x),B})
1216
end
1317

14-
make_dual(::Type{T}, x::Number, dx::Number) where {T} = Dual{T}(x, dx)
15-
make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx)
18+
function make_dual(::Type{T}, x::Number, dx::Number) where {T}
19+
return Dual{T}(x, dx)
20+
end
1621

17-
function make_dual!(::Type{T}, xdual, x, dx::Batch{B}) where {T,B}
18-
return map!(Dual{T}, xdual, x, dx.elements...)
22+
function make_dual(::Type{T}, x::Number, dx::Batch{B,<:Number}) where {T,B}
23+
return Dual{T}(x, dx.elements...)
1924
end
2025

21-
myvalue(::Type{T}, ydual::Dual{T}) where {T} = value(T, ydual)
22-
myvalue(::Type{T}, ydual) where {T} = map(Fix1(myvalue, T), ydual)
26+
function make_dual!(::Type{T}, xdual, x, dx) where {T}
27+
return xdual .= Dual{T}.(x, dx)
28+
end
2329

24-
function myvalue!(::Type{T}, y, ydual) where {T}
25-
return map!(Fix1(myvalue, T), y, ydual)
30+
function make_dual!(::Type{T}, xdual, x, dx::Batch{B}) where {T,B}
31+
return xdual .= Dual{T}.(x, dx.elements...)
2632
end
2733

34+
myvalue(::Type{T}, ydual::Dual{T}) where {T} = value(T, ydual)
35+
myvalue(::Type{T}, ydual) where {T} = myvalue.(T, ydual)
36+
myvalue!(::Type{T}, y, ydual) where {T} = y .= myvalue.(T, ydual)
37+
2838
myderivative(::Type{T}, ydual::Dual{T}) where {T} = extract_derivative(T, ydual)
29-
myderivative(::Type{T}, ydual) where {T} = map(Fix1(myderivative, T), ydual)
39+
myderivative(::Type{T}, ydual) where {T} = myderivative.(T, ydual)
40+
myderivative!(::Type{T}, dy, ydual) where {T} = dy .= myderivative.(T, ydual)
3041

31-
function myderivative!(::Type{T}, dy, ydual) where {T}
32-
return map!(Fix1(myderivative, T), dy, ydual)
42+
function mypartials(::Type{T}, ::Val{B}, ydual::Dual) where {T,B}
43+
elements = partials(T, ydual).values
44+
return Batch(elements)
3345
end
3446

3547
function mypartials(::Type{T}, ::Val{B}, ydual) where {T,B}

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,16 @@ include("utils/exceptions.jl")
5656
include("utils/maybe.jl")
5757

5858
include("first_order/pushforward.jl")
59+
include("first_order/pushforward_batched.jl")
5960
include("first_order/pullback.jl")
61+
include("first_order/pullback_batched.jl")
6062
include("first_order/derivative.jl")
6163
include("first_order/gradient.jl")
6264
include("first_order/jacobian.jl")
6365

6466
include("second_order/second_derivative.jl")
6567
include("second_order/hvp.jl")
68+
include("second_order/hvp_batched.jl")
6669
include("second_order/hessian.jl")
6770

6871
include("sparse/fallbacks.jl")

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ end
7272

7373
## One argument
7474

75+
### Without extras
76+
7577
function value_and_derivative(f::F, backend::AbstractADType, x) where {F}
7678
return value_and_derivative(f, backend, x, prepare_derivative(f, backend, x))
7779
end
@@ -88,6 +90,8 @@ function derivative!(f::F, der, backend::AbstractADType, x) where {F}
8890
return derivative!(f, der, backend, x, prepare_derivative(f, backend, x))
8991
end
9092

93+
### With extras
94+
9195
function value_and_derivative(
9296
f::F, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
9397
) where {F}
@@ -114,6 +118,8 @@ end
114118

115119
## Two arguments
116120

121+
### Without extras
122+
117123
function value_and_derivative(f!::F, y, backend::AbstractADType, x) where {F}
118124
return value_and_derivative(f!, y, backend, x, prepare_derivative(f!, y, backend, x))
119125
end
@@ -132,6 +138,8 @@ function derivative!(f!::F, y, der, backend::AbstractADType, x) where {F}
132138
return derivative!(f!, y, der, backend, x, prepare_derivative(f!, y, backend, x))
133139
end
134140

141+
### With extras
142+
135143
function value_and_derivative(
136144
f!::F, y, backend::AbstractADType, x, extras::PushforwardDerivativeExtras
137145
) where {F}

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ end
6262

6363
## One argument
6464

65+
### Without extras
66+
6567
function value_and_gradient(f::F, backend::AbstractADType, x) where {F}
6668
return value_and_gradient(f, backend, x, prepare_gradient(f, backend, x))
6769
end
@@ -78,6 +80,8 @@ function gradient!(f::F, der, backend::AbstractADType, x) where {F}
7880
return gradient!(f, der, backend, x, prepare_gradient(f, backend, x))
7981
end
8082

83+
### With extras
84+
8185
function value_and_gradient(
8286
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
8387
) where {F}

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ end
104104

105105
## One argument
106106

107+
### Without extras
108+
107109
function jacobian(f::F, backend::AbstractADType, x) where {F}
108110
return jacobian(f, backend, x, prepare_jacobian(f, backend, x))
109111
end
@@ -120,6 +122,8 @@ function value_and_jacobian!(f::F, jac, backend::AbstractADType, x) where {F}
120122
return value_and_jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x))
121123
end
122124

125+
### With extras
126+
123127
function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F}
124128
return jacobian_aux((f,), backend, x, extras)
125129
end
@@ -142,6 +146,8 @@ end
142146

143147
## Two arguments
144148

149+
### Without extras
150+
145151
function jacobian(f!::F, y, backend::AbstractADType, x) where {F}
146152
return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x))
147153
end
@@ -158,6 +164,8 @@ function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F
158164
return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x))
159165
end
160166

167+
### With extras
168+
161169
function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F}
162170
return jacobian_aux((f!, y), backend, x, extras)
163171
end

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 8 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ Create an `extras` object that can be given to [`pullback`](@ref) and its varian
1212
"""
1313
function prepare_pullback end
1414

15-
function prepare_pullback_batched end
16-
1715
"""
1816
prepare_pullback_same_point(f, backend, x, dy) -> extras_same
1917
prepare_pullback_same_point(f!, y, backend, x, dy) -> extras_same
@@ -26,8 +24,6 @@ Create an `extras_same` object that can be given to [`pullback`](@ref) and its v
2624
"""
2725
function prepare_pullback_same_point end
2826

29-
function prepare_pullback_batched_same_point end
30-
3127
"""
3228
value_and_pullback(f, backend, x, dy, [extras]) -> (y, dx)
3329
value_and_pullback(f!, y, backend, x, dy, [extras]) -> (y, dx)
@@ -55,8 +51,6 @@ Compute the pullback of the function `f` at point `x` with seed `dy`.
5551
"""
5652
function pullback end
5753

58-
function pullback_batched end
59-
6054
"""
6155
pullback!(f, dx, backend, x, dy, [extras]) -> dx
6256
pullback!(f!, y, dx, backend, x, dy, [extras]) -> dx
@@ -65,8 +59,6 @@ Compute the pullback of the function `f` at point `x` with seed `dy`, overwritin
6559
"""
6660
function pullback! end
6761

68-
function pullback_batched! end
69-
7062
## Preparation
7163

7264
### Extras types
@@ -84,7 +76,7 @@ struct PushforwardPullbackExtras{E} <: PullbackExtras
8476
pushforward_extras::E
8577
end
8678

87-
## Standard
79+
## Different point
8880

8981
function prepare_pullback(f::F, backend::AbstractADType, x, dy) where {F}
9082
return prepare_pullback_aux(f, backend, x, dy, pullback_performance(backend))
@@ -114,7 +106,7 @@ function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast)
114106
throw(MissingBackendError(backend))
115107
end
116108

117-
### Standard, same point
109+
### Same point
118110

119111
function prepare_pullback_same_point(
120112
f::F, backend::AbstractADType, x, dy, extras::PullbackExtras
@@ -138,33 +130,9 @@ function prepare_pullback_same_point(f!::F, y, backend::AbstractADType, x, dy) w
138130
return prepare_pullback_same_point(f!, y, backend, x, dy, extras)
139131
end
140132

141-
### Batched
142-
143-
function prepare_pullback_batched(f::F, backend::AbstractADType, x, dy::Batch) where {F}
144-
return prepare_pullback(f, backend, x, first(dy.elements))
145-
end
146-
147-
function prepare_pullback_batched(f!::F, y, backend::AbstractADType, x, dy::Batch) where {F}
148-
return prepare_pullback(f!, y, backend, x, first(dy.elements))
149-
end
150-
151-
### Batched, same point
152-
153-
function prepare_pullback_batched_same_point(
154-
f::F, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
155-
) where {F}
156-
return prepare_pullback_same_point(f, backend, x, first(dy.elements), extras)
157-
end
158-
159-
function prepare_pullback_batched_same_point(
160-
f!::F, y, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
161-
) where {F}
162-
return prepare_pullback_same_point(f!, y, backend, x, first(dy.elements), extras)
163-
end
164-
165133
## One argument
166134

167-
### Standard
135+
### Without extras
168136

169137
function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F}
170138
return value_and_pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy))
@@ -182,6 +150,8 @@ function pullback!(f::F, dx, backend::AbstractADType, x, dy) where {F}
182150
return pullback!(f, dx, backend, x, dy, prepare_pullback(f, backend, x, dy))
183151
end
184152

153+
### With extras
154+
185155
function value_and_pullback(
186156
f::F, backend, x, dy, extras::PushforwardPullbackExtras
187157
) where {F}
@@ -220,29 +190,9 @@ function pullback!(
220190
return value_and_pullback!(f, dx, backend, x, dy, extras)[2]
221191
end
222192

223-
### Batched
224-
225-
function pullback_batched(
226-
f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
227-
) where {F,B}
228-
dx_elements = ntuple(Val(B)) do b
229-
pullback(f, backend, x, dy.elements[b], extras)
230-
end
231-
return Batch(dx_elements)
232-
end
233-
234-
function pullback_batched!(
235-
f::F, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
236-
) where {F}
237-
for b in eachindex(dx.elements, dy.elements)
238-
pullback!(f, dx.elements[b], backend, x, dy.elements[b], extras)
239-
end
240-
return dx
241-
end
242-
243193
## Two arguments
244194

245-
### Standard
195+
### Without extras
246196

247197
function value_and_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F}
248198
return value_and_pullback(
@@ -264,6 +214,8 @@ function pullback!(f!::F, y, dx, backend::AbstractADType, x, dy) where {F}
264214
return pullback!(f!, y, dx, backend, x, dy, prepare_pullback(f!, y, backend, x, dy))
265215
end
266216

217+
### With extras
218+
267219
function value_and_pullback(
268220
f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras
269221
) where {F}
@@ -297,23 +249,3 @@ function pullback!(
297249
) where {F}
298250
return value_and_pullback!(f!, y, dx, backend, x, dy, extras)[2]
299251
end
300-
301-
### Batched
302-
303-
function pullback_batched(
304-
f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
305-
) where {F,B}
306-
dx_elements = ntuple(Val(B)) do b
307-
pullback(f!, y, backend, x, dy.elements[b], extras)
308-
end
309-
return Batch(dx_elements)
310-
end
311-
312-
function pullback_batched!(
313-
f!::F, y, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
314-
) where {F}
315-
for b in eachindex(dx.elements, dy.elements)
316-
pullback!(f!, y, dx.elements[b], backend, x, dy.elements[b], extras)
317-
end
318-
return dx
319-
end

0 commit comments

Comments
 (0)