Skip to content

Commit 7b672a0

Browse files
authored
Safe benchmark (#213)
1 parent b85ea79 commit 7b672a0

11 files changed

Lines changed: 691 additions & 448 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using DifferentiationInterface:
99
JacobianExtras,
1010
NoDerivativeExtras,
1111
NoPushforwardExtras
12-
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult
12+
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
1313
using ForwardDiff:
1414
Chunk,
1515
Dual,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function DI.value_and_gradient!(
2828
x::AbstractArray,
2929
extras::ForwardDiffGradientExtras,
3030
)
31-
result = DiffResult(zero(eltype(x)), grad)
31+
result = MutableDiffResult(zero(eltype(x)), (grad,))
3232
result = gradient!(result, f, x, extras.config)
3333
return DiffResults.value(result), DiffResults.gradient(result)
3434
end
@@ -74,7 +74,7 @@ function DI.value_and_jacobian!(
7474
extras::ForwardDiffOneArgJacobianExtras,
7575
)
7676
y = f(x)
77-
result = DiffResult(y, jac)
77+
result = MutableDiffResult(y, (jac,))
7878
result = jacobian!(result, f, x, extras.config)
7979
return DiffResults.value(result), DiffResults.jacobian(result)
8080
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function DI.value_and_derivative(
2727
x::Number,
2828
extras::ForwardDiffTwoArgDerivativeExtras,
2929
)
30-
result = DiffResult(y, similar(y))
30+
result = MutableDiffResult(y, (similar(y),))
3131
result = derivative!(result, f!, y, x, extras.config)
3232
return DiffResults.value(result), DiffResults.derivative(result)
3333
end
@@ -40,7 +40,7 @@ function DI.value_and_derivative!(
4040
x::Number,
4141
extras::ForwardDiffTwoArgDerivativeExtras,
4242
)
43-
result = DiffResult(y, der)
43+
result = MutableDiffResult(y, (der,))
4444
result = derivative!(result, f!, y, x, extras.config)
4545
return DiffResults.value(result), DiffResults.derivative(result)
4646
end
@@ -90,7 +90,7 @@ function DI.value_and_jacobian(
9090
extras::ForwardDiffTwoArgJacobianExtras,
9191
)
9292
jac = similar(y, length(y), length(x))
93-
result = DiffResult(y, jac)
93+
result = MutableDiffResult(y, (jac,))
9494
result = jacobian!(result, f!, y, x, extras.config)
9595
return DiffResults.value(result), DiffResults.jacobian(result)
9696
end
@@ -103,7 +103,7 @@ function DI.value_and_jacobian!(
103103
x::AbstractArray,
104104
extras::ForwardDiffTwoArgJacobianExtras,
105105
)
106-
result = DiffResult(y, jac)
106+
result = MutableDiffResult(y, (jac,))
107107
result = jacobian!(result, f!, y, x, extras.config)
108108
return DiffResults.value(result), DiffResults.jacobian(result)
109109
end

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ADTypes: AutoReverseDiff
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
66
DerivativeExtras, GradientExtras, HessianExtras, JacobianExtras, NoPullbackExtras
7-
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult
7+
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
88
using DocStringExtensions
99
using LinearAlgebra: dot, mul!
1010
using ReverseDiff:

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ function DI.value_and_gradient!(
5757
x::AbstractArray,
5858
extras::ReverseDiffGradientExtras,
5959
)
60-
result = DiffResult(zero(eltype(x)), grad)
60+
result = MutableDiffResult(zero(eltype(x)), (grad,))
6161
result = gradient!(result, extras.tape, x)
6262
return DiffResults.value(result), DiffResults.derivative(result)
6363
end
@@ -107,7 +107,7 @@ function DI.value_and_jacobian!(
107107
extras::ReverseDiffOneArgJacobianExtras,
108108
)
109109
y = f(x)
110-
result = DiffResult(y, jac)
110+
result = MutableDiffResult(y, (jac,))
111111
result = jacobian!(result, extras.tape, x)
112112
return DiffResults.value(result), DiffResults.derivative(result)
113113
end

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,15 @@ function DI.value_and_jacobian(
8585
_f!, y, ::AutoReverseDiff, x, extras::ReverseDiffTwoArgJacobianExtras
8686
)
8787
jac = similar(y, length(y), length(x))
88-
result = DiffResults.DiffResult(y, jac)
88+
result = MutableDiffResult(y, (jac,))
8989
result = jacobian!(result, extras.tape, x)
9090
return DiffResults.value(result), DiffResults.derivative(result)
9191
end
9292

9393
function DI.value_and_jacobian!(
9494
_f!, y, jac, ::AutoReverseDiff, x, extras::ReverseDiffTwoArgJacobianExtras
9595
)
96-
result = DiffResults.DiffResult(y, jac)
96+
result = MutableDiffResult(y, (jac,))
9797
result = jacobian!(result, extras.tape, x)
9898
return DiffResults.value(result), DiffResults.derivative(result)
9999
end

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,26 @@ The reference keyword `ref` should be a function that takes `x` (and a potential
3636
3737
The operator behavior keyword `operator` should be either `:inplace` or `:outofplace` depending on what must be tested.
3838
"""
39-
abstract type AbstractScenario{A,O,F,X,Y,R} end
39+
abstract type AbstractScenario{args,op,F,X,Y,R} end
4040

41-
abstract type AbstractFirstOrderScenario{A,O,F,X,Y,R} <: AbstractScenario{A,O,F,X,Y,R} end
42-
abstract type AbstractSecondOrderScenario{A,O,F,X,Y,R} <: AbstractScenario{A,O,F,X,Y,R} end
41+
abstract type AbstractFirstOrderScenario{args,op,F,X,Y,R} <:
42+
AbstractScenario{args,op,F,X,Y,R} end
43+
abstract type AbstractSecondOrderScenario{args,op,F,X,Y,R} <:
44+
AbstractScenario{args,op,F,X,Y,R} end
4345

44-
nbargs(::AbstractScenario{A}) where {A} = A
45-
operator(::AbstractScenario{A,O}) where {A,O} = O
46+
scen_type(scenario::AbstractScenario) = nameof(typeof(scenario))
47+
nb_args(::AbstractScenario{args}) where {args} = args
48+
operator_place(::AbstractScenario{args,op}) where {args,op} = op
4649

4750
function compatible(backend::AbstractADType, scen::AbstractScenario)
48-
if nbargs(scen) == 2
51+
if nb_args(scen) == 2
4952
return Bool(mutation_support(backend))
5053
end
5154
return true
5255
end
5356

54-
function Base.string(scen::S) where {A,O,F,X,Y,S<:AbstractScenario{A,O,F,X,Y}}
55-
return "$(S.name.name){$A,$O} $(string(scen.f)) : $X -> $Y"
57+
function Base.string(scen::S) where {args,op,F,X,Y,S<:AbstractScenario{args,op,F,X,Y}}
58+
return "$(S.name.name){$args,$op} $(string(scen.f)) : $X -> $Y"
5659
end
5760

5861
## Struct definitions
@@ -62,7 +65,8 @@ end
6265
6366
See [`AbstractScenario`](@ref) for details.
6467
"""
65-
struct PushforwardScenario{A,O,F,X,Y,DX,R} <: AbstractFirstOrderScenario{A,O,F,X,Y,R}
68+
struct PushforwardScenario{args,op,F,X,Y,DX,R} <:
69+
AbstractFirstOrderScenario{args,op,F,X,Y,R}
6670
"function"
6771
f::F
6872
"input"
@@ -80,7 +84,7 @@ end
8084
8185
See [`AbstractScenario`](@ref) for details.
8286
"""
83-
struct PullbackScenario{A,O,F,X,Y,DY,R} <: AbstractFirstOrderScenario{A,O,F,X,Y,R}
87+
struct PullbackScenario{args,op,F,X,Y,DY,R} <: AbstractFirstOrderScenario{args,op,F,X,Y,R}
8488
"function"
8589
f::F
8690
"input"
@@ -98,7 +102,8 @@ end
98102
99103
See [`AbstractScenario`](@ref) for details.
100104
"""
101-
struct DerivativeScenario{A,O,F,X<:Number,Y,R} <: AbstractFirstOrderScenario{A,O,F,X,Y,R}
105+
struct DerivativeScenario{args,op,F,X<:Number,Y,R} <:
106+
AbstractFirstOrderScenario{args,op,F,X,Y,R}
102107
"function"
103108
f::F
104109
"input"
@@ -114,7 +119,8 @@ end
114119
115120
See [`AbstractScenario`](@ref) for details.
116121
"""
117-
struct GradientScenario{A,O,F,X,Y<:Number,R} <: AbstractFirstOrderScenario{A,O,F,X,Y,R}
122+
struct GradientScenario{args,op,F,X,Y<:Number,R} <:
123+
AbstractFirstOrderScenario{args,op,F,X,Y,R}
118124
"function"
119125
f::F
120126
"input"
@@ -130,8 +136,8 @@ end
130136
131137
See [`AbstractScenario`](@ref) for details.
132138
"""
133-
struct JacobianScenario{A,O,F,X<:AbstractArray,Y<:AbstractArray,R} <:
134-
AbstractFirstOrderScenario{A,O,F,X,Y,R}
139+
struct JacobianScenario{args,op,F,X<:AbstractArray,Y<:AbstractArray,R} <:
140+
AbstractFirstOrderScenario{args,op,F,X,Y,R}
135141
"function"
136142
f::F
137143
"input"
@@ -147,8 +153,8 @@ end
147153
148154
See [`AbstractScenario`](@ref) for details.
149155
"""
150-
struct SecondDerivativeScenario{A,O,F,X<:Number,Y,R} <:
151-
AbstractSecondOrderScenario{A,O,F,X,Y,R}
156+
struct SecondDerivativeScenario{args,op,F,X<:Number,Y,R} <:
157+
AbstractSecondOrderScenario{args,op,F,X,Y,R}
152158
"function"
153159
f::F
154160
"input"
@@ -164,7 +170,8 @@ end
164170
165171
See [`AbstractScenario`](@ref) for details.
166172
"""
167-
struct HVPScenario{A,O,F,X,Y<:Number,DX,R} <: AbstractSecondOrderScenario{A,O,F,X,Y,R}
173+
struct HVPScenario{args,op,F,X,Y<:Number,DX,R} <:
174+
AbstractSecondOrderScenario{args,op,F,X,Y,R}
168175
"function"
169176
f::F
170177
"input"
@@ -182,8 +189,8 @@ end
182189
183190
See [`AbstractScenario`](@ref) for details.
184191
"""
185-
struct HessianScenario{A,O,F,X<:AbstractArray,Y<:Number,R} <:
186-
AbstractSecondOrderScenario{A,O,F,X,Y,R}
192+
struct HessianScenario{args,op,F,X<:AbstractArray,Y<:Number,R} <:
193+
AbstractSecondOrderScenario{args,op,F,X,Y,R}
187194
"function"
188195
f::F
189196
"input"
@@ -205,13 +212,13 @@ for S in (
205212
)
206213
@eval begin
207214
function $S(f::F; x::X, y=nothing, ref::R=nothing, operator=:inplace) where {F,X,R}
208-
A = isnothing(y) ? 1 : 2
209-
if A == 2
215+
args = isnothing(y) ? 1 : 2
216+
if args == 2
210217
f(y, x)
211218
else
212219
y = f(x)
213220
end
214-
return ($S){A,operator,F,X,typeof(y),R}(f, x, y, ref)
221+
return ($S){args,operator,F,X,typeof(y),R}(f, x, y, ref)
215222
end
216223
end
217224
end
@@ -221,16 +228,16 @@ for S in (:PushforwardScenario, :HVPScenario)
221228
function $S(
222229
f::F; x::X, y=nothing, ref::R=nothing, dx=nothing, operator=:inplace
223230
) where {F,X,R}
224-
A = isnothing(y) ? 1 : 2
225-
if A == 2
231+
args = isnothing(y) ? 1 : 2
232+
if args == 2
226233
f(y, x)
227234
else
228235
y = f(x)
229236
end
230237
if isnothing(dx)
231238
dx = mysimilar_random(x)
232239
end
233-
return ($S){A,operator,F,X,typeof(y),typeof(dx),R}(f, x, y, dx, ref)
240+
return ($S){args,operator,F,X,typeof(y),typeof(dx),R}(f, x, y, dx, ref)
234241
end
235242
end
236243
end
@@ -240,16 +247,16 @@ for S in (:PullbackScenario,)
240247
function $S(
241248
f::F; x::X, y=nothing, ref::R=nothing, dy=nothing, operator=:inplace
242249
) where {F,X,R}
243-
A = isnothing(y) ? 1 : 2
244-
if A == 2
250+
args = isnothing(y) ? 1 : 2
251+
if args == 2
245252
f(y, x)
246253
else
247254
y = f(x)
248255
end
249256
if isnothing(dy)
250257
dy = mysimilar_random(y)
251258
end
252-
return ($S){A,operator,F,X,typeof(y),typeof(dy),R}(f, x, y, dy, ref)
259+
return ($S){args,operator,F,X,typeof(y),typeof(dy),R}(f, x, y, dy, ref)
253260
end
254261
end
255262
end

DifferentiationInterfaceTest/src/scenarios/static.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function static_scenarios()
2626
mat_to_mat_scenarios_twoarg(MMatrix{2,3}(randn(2, 3))),
2727
)
2828
scens = filter(scens) do s
29-
operator(s) == :outofplace || typeof(s.x) isa Union{Number,MVector,MMatrix}
29+
operator_place(s) == :outofplace || s.x isa Union{Number,MVector,MMatrix}
3030
end
3131
return scens
3232
end

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ function test_differentiation(
7777
(:backend, "$(backend_string(backend)) - $i/$(length(backends))"),
7878
(
7979
:scenario,
80-
"$(typeof(scen).name.name) - $j/$(length(filtered_scenarios))",
80+
"$(scen_type(scen)) - $j/$(length(filtered_scenarios))",
8181
),
82-
(:arguments, typeof(scen).parameters[1]),
83-
(:operator, typeof(scen).parameters[2]),
82+
(:arguments, nb_args(scen)),
83+
(:operator, operator_place(scen)),
8484
(:function, scen.f),
8585
(:input, typeof(scen.x)),
8686
(:output, typeof(scen.y)),
@@ -145,12 +145,9 @@ function benchmark_differentiation(
145145
prog;
146146
showvalues=[
147147
(:backend, "$(backend_string(backend)) - $i/$(length(backends))"),
148-
(
149-
:scenario,
150-
"$(typeof(scen).name.name) - $j/$(length(filtered_scenarios))",
151-
),
152-
(:arguments, typeof(scen).parameters[1]),
153-
(:operator, typeof(scen).parameters[2]),
148+
(:scenario, "$(scen_type(scen)) - $j/$(length(filtered_scenarios))"),
149+
(:arguments, nb_args(scen)),
150+
(:operator, operator_place(scen)),
154151
(:function, scen.f),
155152
(:input, typeof(scen.x)),
156153
(:output, typeof(scen.y)),

0 commit comments

Comments
 (0)