Skip to content

Commit de6d614

Browse files
committed
Fixes
1 parent 26517dc commit de6d614

4 files changed

Lines changed: 56 additions & 35 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ end
4040

4141
function DI.pushforward(
4242
f,
43-
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
43+
prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing},
4444
backend::AutoFiniteDiff,
4545
x,
4646
tx::NTuple,
4747
contexts::Vararg{DI.Context,C},
48-
) where {C}
48+
) where {SIG,C}
4949
DI.check_prep(f, prep, backend, x, tx, contexts...)
5050
(; relstep, absstep, dir) = prep
5151
step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...)
@@ -59,12 +59,12 @@ end
5959

6060
function DI.value_and_pushforward(
6161
f,
62-
prep::FiniteDiffOneArgPushforwardPrep{Nothing},
62+
prep::FiniteDiffOneArgPushforwardPrep{SIG,Nothing},
6363
backend::AutoFiniteDiff,
6464
x,
6565
tx::NTuple,
6666
contexts::Vararg{DI.Context,C},
67-
) where {C}
67+
) where {SIG,C}
6868
DI.check_prep(f, prep, backend, x, tx, contexts...)
6969
(; relstep, absstep, dir) = prep
7070
step(t::Number, dx) = f(x .+ t .* dx, map(DI.unwrap, contexts)...)
@@ -86,12 +86,12 @@ end
8686

8787
function DI.pushforward(
8888
f,
89-
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
89+
prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache},
9090
backend::AutoFiniteDiff,
9191
x,
9292
tx::NTuple,
9393
contexts::Vararg{DI.Context,C},
94-
) where {C}
94+
) where {SIG,C}
9595
DI.check_prep(f, prep, backend, x, tx, contexts...)
9696
(; relstep, absstep, dir) = prep
9797
fc = DI.with_contexts(f, contexts...)
@@ -103,12 +103,12 @@ end
103103

104104
function DI.value_and_pushforward(
105105
f,
106-
prep::FiniteDiffOneArgPushforwardPrep{<:JVPCache},
106+
prep::FiniteDiffOneArgPushforwardPrep{SIG,<:JVPCache},
107107
backend::AutoFiniteDiff,
108108
x,
109109
tx::NTuple,
110110
contexts::Vararg{DI.Context,C},
111-
) where {C}
111+
) where {SIG,C}
112112
DI.check_prep(f, prep, backend, x, tx, contexts...)
113113
(; relstep, absstep, dir) = prep
114114
fc = DI.with_contexts(f, contexts...)
@@ -159,11 +159,11 @@ end
159159

160160
function DI.derivative(
161161
f,
162-
prep::FiniteDiffOneArgDerivativePrep{Nothing},
162+
prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing},
163163
backend::AutoFiniteDiff,
164164
x,
165165
contexts::Vararg{DI.Context,C},
166-
) where {C}
166+
) where {SIG,C}
167167
DI.check_prep(f, prep, backend, x, contexts...)
168168
(; relstep, absstep, dir) = prep
169169
fc = DI.with_contexts(f, contexts...)
@@ -172,11 +172,11 @@ end
172172

173173
function DI.value_and_derivative(
174174
f,
175-
prep::FiniteDiffOneArgDerivativePrep{Nothing},
175+
prep::FiniteDiffOneArgDerivativePrep{SIG,Nothing},
176176
backend::AutoFiniteDiff,
177177
x,
178178
contexts::Vararg{DI.Context,C},
179-
) where {C}
179+
) where {SIG,C}
180180
DI.check_prep(f, prep, backend, x, contexts...)
181181
(; relstep, absstep, dir) = prep
182182
fc = DI.with_contexts(f, contexts...)
@@ -193,11 +193,11 @@ end
193193

194194
function DI.derivative(
195195
f,
196-
prep::FiniteDiffOneArgDerivativePrep{<:GradientCache},
196+
prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache},
197197
backend::AutoFiniteDiff,
198198
x,
199199
contexts::Vararg{DI.Context,C},
200-
) where {C}
200+
) where {SIG,C}
201201
DI.check_prep(f, prep, backend, x, contexts...)
202202
(; relstep, absstep, dir) = prep
203203
fc = DI.with_contexts(f, contexts...)
@@ -207,11 +207,11 @@ end
207207
function DI.derivative!(
208208
f,
209209
der,
210-
prep::FiniteDiffOneArgDerivativePrep{<:GradientCache},
210+
prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache},
211211
backend::AutoFiniteDiff,
212212
x,
213213
contexts::Vararg{DI.Context,C},
214-
) where {C}
214+
) where {SIG,C}
215215
DI.check_prep(f, prep, backend, x, contexts...)
216216
(; relstep, absstep, dir) = prep
217217
fc = DI.with_contexts(f, contexts...)
@@ -220,11 +220,11 @@ end
220220

221221
function DI.value_and_derivative(
222222
f,
223-
prep::FiniteDiffOneArgDerivativePrep{<:GradientCache},
223+
prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache},
224224
backend::AutoFiniteDiff,
225225
x,
226226
contexts::Vararg{DI.Context,C},
227-
) where {C}
227+
) where {SIG,C}
228228
DI.check_prep(f, prep, backend, x, contexts...)
229229
fc = DI.with_contexts(f, contexts...)
230230
(; relstep, absstep, dir) = prep
@@ -235,11 +235,11 @@ end
235235
function DI.value_and_derivative!(
236236
f,
237237
der,
238-
prep::FiniteDiffOneArgDerivativePrep{<:GradientCache},
238+
prep::FiniteDiffOneArgDerivativePrep{SIG,<:GradientCache},
239239
backend::AutoFiniteDiff,
240240
x,
241241
contexts::Vararg{DI.Context,C},
242-
) where {C}
242+
) where {SIG,C}
243243
DI.check_prep(f, prep, backend, x, contexts...)
244244
(; relstep, absstep, dir) = prep
245245
fc = DI.with_contexts(f, contexts...)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ end
4040
function DI.value_and_pushforward(
4141
f!,
4242
y,
43-
prep::FiniteDiffTwoArgPushforwardPrep{Nothing},
43+
prep::FiniteDiffTwoArgPushforwardPrep{SIG,Nothing},
4444
backend::AutoFiniteDiff,
4545
x,
4646
tx::NTuple,
4747
contexts::Vararg{DI.Context,C},
48-
) where {C}
48+
) where {SIG,C}
4949
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
5050
(; relstep, absstep, dir) = prep
5151
function step(t::Number, dx)
@@ -72,12 +72,12 @@ end
7272
function DI.pushforward(
7373
f!,
7474
y,
75-
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
75+
prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache},
7676
backend::AutoFiniteDiff,
7777
x,
7878
tx::NTuple,
7979
contexts::Vararg{DI.Context,C},
80-
) where {C}
80+
) where {SIG,C}
8181
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8282
(; relstep, absstep, dir) = prep
8383
fc! = DI.with_contexts(f!, contexts...)
@@ -92,12 +92,12 @@ end
9292
function DI.value_and_pushforward(
9393
f!,
9494
y,
95-
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
95+
prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache},
9696
backend::AutoFiniteDiff,
9797
x,
9898
tx::NTuple,
9999
contexts::Vararg{DI.Context,C},
100-
) where {C}
100+
) where {SIG,C}
101101
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
102102
(; relstep, absstep, dir) = prep
103103
fc! = DI.with_contexts(f!, contexts...)
@@ -114,12 +114,12 @@ function DI.pushforward!(
114114
f!,
115115
y,
116116
ty::NTuple,
117-
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
117+
prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache},
118118
backend::AutoFiniteDiff,
119119
x,
120120
tx::NTuple,
121121
contexts::Vararg{DI.Context,C},
122-
) where {C}
122+
) where {SIG,C}
123123
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
124124
(; relstep, absstep, dir) = prep
125125
fc! = DI.with_contexts(f!, contexts...)
@@ -134,12 +134,12 @@ function DI.value_and_pushforward!(
134134
f!,
135135
y,
136136
ty::NTuple,
137-
prep::FiniteDiffTwoArgPushforwardPrep{<:JVPCache},
137+
prep::FiniteDiffTwoArgPushforwardPrep{SIG,<:JVPCache},
138138
backend::AutoFiniteDiff,
139139
x,
140140
tx::NTuple,
141141
contexts::Vararg{DI.Context,C},
142-
) where {C}
142+
) where {SIG,C}
143143
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
144144
(; relstep, absstep, dir) = prep
145145
fc! = DI.with_contexts(f!, contexts...)

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function prepare_gradient(
6161
) where {F,C}
6262
SIG = signature(f, backend, x, contexts...; strict)
6363
y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference?
64-
pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...; strict)
64+
pullback_prep = prepare_pullback(f, backend, x, (one(typeof(y)),), contexts...; strict)
6565
return PullbackGradientPrep{SIG,typeof(y),typeof(pullback_prep)}(pullback_prep)
6666
end
6767

DifferentiationInterface/src/utils/prep.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ struct NoSecondDerivativePrep{SIG} <: SecondDerivativePrep{SIG} end
5252

5353
is_strict(::Prep{SIG}) where {SIG} = SIG !== Nothing
5454

55+
function inconsistent_signatures_error(SIG, RUNTIME_SIG)
56+
msg = """
57+
Inconsistent signatures:
58+
- at preparation time: $SIG
59+
- at execution time: $RUNTIME_SIG
60+
"""
61+
return ArgumentError(msg)
62+
end
63+
5564
function signature(
5665
f, backend::AbstractADType, x, contexts::Vararg{Context,C}; strict::Bool
5766
) where {C}
@@ -96,30 +105,42 @@ function check_prep(
96105
f, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C}
97106
) where {SIG,C}
98107
if SIG !== Nothing
99-
@assert SIG == typeof((f, backend, x, contexts))
108+
RUNTIME_SIG = typeof((f, backend, x, contexts))
109+
if SIG != RUNTIME_SIG
110+
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
111+
end
100112
end
101113
end
102114

103115
function check_prep(
104116
f!, y, ::Prep{SIG}, backend::AbstractADType, x, contexts::Vararg{Context,C}
105117
) where {SIG,C}
106118
if SIG !== Nothing
107-
@assert SIG == typeof((f!, y, backend, x, contexts))
119+
RUNTIME_SIG = typeof((f!, y, backend, x, contexts))
120+
if SIG != RUNTIME_SIG
121+
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
122+
end
108123
end
109124
end
110125

111126
function check_prep(
112127
f, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}
113128
) where {SIG,C}
114129
if SIG !== Nothing
115-
@assert SIG == typeof((f, backend, x, t, contexts))
130+
RUNTIME_SIG = typeof((f, backend, x, t, contexts))
131+
if SIG != RUNTIME_SIG
132+
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
133+
end
116134
end
117135
end
118136

119137
function check_prep(
120138
f!, y, ::Prep{SIG}, backend::AbstractADType, x, t::NTuple, contexts::Vararg{Context,C}
121139
) where {SIG,C}
122140
if SIG !== Nothing
123-
@assert SIG == typeof((f!, y, backend, x, t, contexts))
141+
RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts))
142+
if SIG != RUNTIME_SIG
143+
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
144+
end
124145
end
125146
end

0 commit comments

Comments
 (0)