Skip to content

Commit 618c8ae

Browse files
committed
Add tests
1 parent 3d17483 commit 618c8ae

7 files changed

Lines changed: 234 additions & 52 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function DI.gradient!(
241241
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
242242
x,
243243
) where {F}
244-
DI.check_prep(f, prep, backend, x, contexts...)
244+
DI.check_prep(f, prep, backend, x)
245245
mode = reverse_noprimal(backend)
246246
f_and_df = get_f_and_df(f, backend, mode)
247247
gradient!(mode, grad, f_and_df, x)
@@ -255,7 +255,7 @@ function DI.value_and_gradient!(
255255
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
256256
x,
257257
) where {F}
258-
DI.check_prep(f, prep, backend, x, contexts...)
258+
DI.check_prep(f, prep, backend, x)
259259
mode = reverse_withprimal(backend)
260260
f_and_df = get_f_and_df(f, backend, mode)
261261
_, result = gradient!(mode, grad, f_and_df, x)

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ function DI.jacobian(
271271
x,
272272
contexts::Vararg{DI.Context,C},
273273
) where {C}
274-
DI.check_prep(f, backend, x, contexts...)
274+
DI.check_prep(f, prep, backend, x, contexts...)
275275
return prep.jac_exe(x, map(DI.unwrap, contexts)...)
276276
end
277277

@@ -283,7 +283,7 @@ function DI.jacobian!(
283283
x,
284284
contexts::Vararg{DI.Context,C},
285285
) where {C}
286-
DI.check_prep(f, backend, x, contexts...)
286+
DI.check_prep(f, prep, backend, x, contexts...)
287287
prep.jac_exe!(jac, x, map(DI.unwrap, contexts)...)
288288
return jac
289289
end
@@ -295,7 +295,7 @@ function DI.value_and_jacobian(
295295
x,
296296
contexts::Vararg{DI.Context,C},
297297
) where {C}
298-
DI.check_prep(f, backend, x, contexts...)
298+
DI.check_prep(f, prep, backend, x, contexts...)
299299
return f(x, map(DI.unwrap, contexts)...), DI.jacobian(f, prep, backend, x, contexts...)
300300
end
301301

@@ -307,7 +307,7 @@ function DI.value_and_jacobian!(
307307
x,
308308
contexts::Vararg{DI.Context,C},
309309
) where {C}
310-
DI.check_prep(f, backend, x, contexts...)
310+
DI.check_prep(f, prep, backend, x, contexts...)
311311
return f(x, map(DI.unwrap, contexts)...),
312312
DI.jacobian!(f, jac, prep, backend, x, contexts...)
313313
end

DifferentiationInterface/src/docstrings.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ function docstring_prepare(operator; samepoint=false, inplace=false)
2121
Depending on the backend, this can have several effects (preallocating memory, recording an execution trace) which are transparent to the user.
2222
2323
!!! warning
24-
The preparation result is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified.
25-
Otherwise, preparation will be invalidated and you will need to run it again.
26-
The keyword argument `strict` activates automatic type checking, but ensuring size consistency is up to the user.
24+
The preparation result `prep` is only reusable as long as the arguments to `$operator` do not change type or size, and the function and backend themselves are not modified.
25+
Otherwise, preparation becomes invalid and you need to run it again.
26+
In some settings, invalid preparations may still give correct results (e.g. for backends that require no preparation), but this is not a semantic guarantee and should not be relied upon.
27+
28+
When `strict=Val(true)`, type checking is enforced between preparation and execution (but size checking is left to the user).
29+
2730
$(inplace ? "\nFor in-place functions, `y` is mutated by `f!` during preparation." : "")
2831
"""
2932
end

DifferentiationInterface/src/utils/prep.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,30 +77,35 @@ end
7777
is_strict(::Prep{Nothing}) = Val(false)
7878
is_strict(::Prep) = Val(true)
7979

80-
struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception
80+
struct PreparationMismatchError{SIG,EXEC_SIG} <: Exception
8181
format::Vector{Symbol}
8282
end
8383

8484
function PreparationMismatchError(
85-
::Type{SIG}, ::Type{RUNTIME_SIG}; format
86-
) where {SIG,RUNTIME_SIG}
87-
return PreparationMismatchError{SIG,RUNTIME_SIG}(format)
85+
::Type{SIG}, ::Type{EXEC_SIG}; format
86+
) where {SIG,EXEC_SIG}
87+
return PreparationMismatchError{SIG,EXEC_SIG}(format)
8888
end
8989

9090
function Base.showerror(
91-
io::IO, e::PreparationMismatchError{SIG,RUNTIME_SIG}
92-
) where {SIG<:Tuple,RUNTIME_SIG<:Tuple}
91+
io::IO, e::PreparationMismatchError{SIG,EXEC_SIG}
92+
) where {SIG<:Tuple,EXEC_SIG<:Tuple}
9393
println(
9494
io,
9595
"PreparationMismatchError (inconsistent types between preparation and execution):",
9696
)
97-
for (s, pt, et) in zip(e.format, SIG.types, RUNTIME_SIG.types)
97+
for (s, pt, et) in zip(e.format, SIG.types, EXEC_SIG.types)
9898
if pt == et
9999
println(io, " - $s: ✅")
100100
else
101101
println(io, " - $s: ❌\n - prep: $pt\n - exec: $et")
102102
end
103103
end
104+
println(
105+
io,
106+
"To disable this check (not recommended), run preparation with the keyword argument `strict=Val(false)` when using DifferentiationInterface.",
107+
)
108+
return nothing
104109
end
105110

106111
function signature(
@@ -153,11 +158,11 @@ function check_prep(
153158
f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C}
154159
) where {SIG,C}
155160
if SIG !== Nothing
156-
RUNTIME_SIG = typeof((f, backend, x, contexts))
157-
if SIG != RUNTIME_SIG
161+
EXEC_SIG = typeof((f, backend, x, contexts))
162+
if SIG != EXEC_SIG
158163
throw(
159164
PreparationMismatchError(
160-
SIG, RUNTIME_SIG; format=[:f, :backend, :x, :contexts]
165+
SIG, EXEC_SIG; format=[:f, :backend, :x, :contexts]
161166
),
162167
)
163168
end
@@ -168,11 +173,11 @@ function check_prep(
168173
f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C}
169174
) where {SIG,C}
170175
if SIG !== Nothing
171-
RUNTIME_SIG = typeof((f!, y, backend, x, contexts))
172-
if SIG != RUNTIME_SIG
176+
EXEC_SIG = typeof((f!, y, backend, x, contexts))
177+
if SIG != EXEC_SIG
173178
throw(
174179
PreparationMismatchError(
175-
SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :contexts]
180+
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :contexts]
176181
),
177182
)
178183
end
@@ -183,11 +188,11 @@ function check_prep(
183188
f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}
184189
) where {SIG,C}
185190
if SIG !== Nothing
186-
RUNTIME_SIG = typeof((f, backend, x, t, contexts))
187-
if SIG != RUNTIME_SIG
191+
EXEC_SIG = typeof((f, backend, x, t, contexts))
192+
if SIG != EXEC_SIG
188193
throw(
189194
PreparationMismatchError(
190-
SIG, RUNTIME_SIG; format=[:f, :backend, :x, :t, :contexts]
195+
SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts]
191196
),
192197
)
193198
end
@@ -198,11 +203,11 @@ function check_prep(
198203
f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}
199204
) where {SIG,C}
200205
if SIG !== Nothing
201-
RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts))
202-
if SIG != RUNTIME_SIG
206+
EXEC_SIG = typeof((f!, y, backend, x, t, contexts))
207+
if SIG != EXEC_SIG
203208
throw(
204209
PreparationMismatchError(
205-
SIG, RUNTIME_SIG; format=[:f!, :y, :backend, :x, :t, :contexts]
210+
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts]
206211
),
207212
)
208213
end
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
using DifferentiationInterface
2+
using DifferentiationInterface: AutoZeroReverse, AutoZeroForward
3+
using Test
4+
5+
backend = AutoZeroForward()
6+
other_backend = AutoZeroReverse()
7+
f(x, c) = x + c
8+
f!(y, x) = y .= x
9+
x = 1.0
10+
y = zeros(2)
11+
c = 2.0
12+
13+
@testset "Out of place, no tangents" begin
14+
prep = prepare_derivative(f, backend, x, Constant(c); strict=Val(true))
15+
prep_chill = prepare_derivative(f, backend, x, Constant(c); strict=Val(false))
16+
17+
@test_throws MethodError derivative(nothing, prep_chill, backend, x, Constant(c))
18+
19+
@test_throws """
20+
PreparationMismatchError (inconsistent types between preparation and execution):
21+
- f: ❌
22+
- prep: typeof(f)
23+
- exec: Nothing
24+
- backend: ✅
25+
- x: ✅
26+
- contexts: ✅
27+
""" derivative(nothing, prep, backend, x, Constant(c))
28+
29+
@test_throws """
30+
PreparationMismatchError (inconsistent types between preparation and execution):
31+
- f: ✅
32+
- backend: ❌
33+
- prep: AutoZeroForward
34+
- exec: AutoZeroReverse
35+
- x: ✅
36+
- contexts: ✅
37+
""" derivative(f, prep, other_backend, x, Constant(c))
38+
39+
@test_throws """
40+
PreparationMismatchError (inconsistent types between preparation and execution):
41+
- f: ✅
42+
- backend: ✅
43+
- x: ❌
44+
- prep: Float64
45+
- exec: Int64
46+
- contexts: ✅
47+
""" derivative(f, prep, backend, 1, Constant(c))
48+
49+
@test_throws """
50+
PreparationMismatchError (inconsistent types between preparation and execution):
51+
- f: ✅
52+
- backend: ✅
53+
- x: ✅
54+
- contexts: ❌
55+
- prep: Tuple{Constant{Float64}}
56+
- exec: Tuple{Constant{Int64}}
57+
""" derivative(f, prep, backend, x, Constant(2))
58+
59+
@test_throws """
60+
PreparationMismatchError (inconsistent types between preparation and execution):
61+
- f: ✅
62+
- backend: ✅
63+
- x: ✅
64+
- contexts: ❌
65+
- prep: Tuple{Constant{Float64}}
66+
- exec: Tuple{Constant{Int64}, Constant{Int64}}
67+
""" derivative(f, prep, backend, x, Constant(2), Constant(3))
68+
end
69+
70+
@testset "In place, no tangents" begin
71+
prep = prepare_derivative(f!, y, backend, x; strict=Val(true))
72+
prep_chill = prepare_derivative(f!, y, backend, x; strict=Val(false))
73+
74+
@test_throws MethodError derivative(nothing, y, prep_chill, backend, x, Constant(c))
75+
76+
@test_throws """
77+
PreparationMismatchError (inconsistent types between preparation and execution):
78+
- f!: ❌
79+
- prep: typeof(f!)
80+
- exec: Nothing
81+
- y: ✅
82+
- backend: ✅
83+
- x: ✅
84+
- contexts: ✅
85+
""" derivative(nothing, y, prep, backend, x)
86+
end
87+
88+
@testset "Out of place, with tangents" begin
89+
prep = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(true))
90+
prep_chill = prepare_pushforward(f, backend, x, (x,), Constant(c); strict=Val(false))
91+
92+
@test_throws MethodError pushforward(nothing, prep_chill, backend, x, (x,))
93+
94+
@test_throws """
95+
PreparationMismatchError (inconsistent types between preparation and execution):
96+
- f: ❌
97+
- prep: typeof(f)
98+
- exec: Nothing
99+
- backend: ✅
100+
- x: ✅
101+
- tang: ✅
102+
- contexts: ✅
103+
""" pushforward(nothing, prep, backend, x, (x,), Constant(c))
104+
end
105+
106+
@testset "In place, with tangents" begin
107+
prep = prepare_pushforward(f!, y, backend, x, (x,); strict=Val(true))
108+
prep_chill = prepare_pushforward(
109+
f!, y, backend, x, (x,), Constant(c); strict=Val(false)
110+
)
111+
112+
@test_throws MethodError pushforward(nothing, y, prep_chill, backend, x, (x,))
113+
114+
@test_throws """
115+
PreparationMismatchError (inconsistent types between preparation and execution):
116+
- f!: ❌
117+
- prep: typeof(f!)
118+
- exec: Nothing
119+
- y: ✅
120+
- backend: ✅
121+
- x: ✅
122+
- tang: ✅
123+
- contexts: ✅
124+
""" pushforward(nothing, y, prep, backend, x, (x,))
125+
end

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ using DifferentiationInterface:
9090
pushforward_performance,
9191
pullback_performance
9292
using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap
93+
using DifferentiationInterface: PreparationMismatchError
9394
using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES
9495
using JET: @test_opt
9596
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent

0 commit comments

Comments
 (0)