Skip to content

Commit e05f4fd

Browse files
committed
Fixes
1 parent 3424a55 commit e05f4fd

3 files changed

Lines changed: 16 additions & 9 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

22
## Pushforward
33

4-
struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <:
5-
PolyesterForwardDiffOneArgPushforwardPrep{SIG}
4+
struct PolyesterForwardDiffOneArgPushforwardPrep{SIG,P} <: DI.PushforwardPrep{SIG}
65
_sig::Val{SIG}
76
single_threaded_prep::P
87
end

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function DI.prepare_pullback_same_point(
3535
ty::NTuple,
3636
contexts::Vararg{DI.GeneralizedConstant,C},
3737
) where {C}
38-
_sig = DI.signature(f, prep, backend, x, ty, contexts...; strict)
38+
_sig = DI.signature(f, prep, backend, x, ty, contexts...; strict=DI.is_strict(prep))
3939
DI.check_prep(f, prep, backend, x, ty, contexts...)
4040
y, pb = forward(f, x, map(DI.unwrap, contexts)...)
4141
return TrackerPullbackPrepSamePoint(_sig, y, pb)

DifferentiationInterface/src/utils/prep.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,21 @@ end
7777
is_strict(::Prep{Nothing}) = Val(false)
7878
is_strict(::Prep) = Val(true)
7979

80-
function inconsistent_signatures_error(SIG, RUNTIME_SIG)
80+
struct PreparationMismatchError{SIG,RUNTIME_SIG} <: Exception end
81+
82+
function PreparationMismatchError(::Type{SIG}, ::Type{RUNTIME_SIG}) where {SIG,RUNTIME_SIG}
83+
return PreparationMismatchError{SIG,RUNTIME_SIG}()
84+
end
85+
86+
function Base.showerror(
87+
io::IO, e::PreparationMismatchError{SIG,RUNTIME_SIG}
88+
) where {SIG,RUNTIME_SIG}
8189
msg = """
8290
Inconsistent signatures:
8391
- at preparation time: $SIG
8492
- at execution time: $RUNTIME_SIG
8593
"""
86-
return ArgumentError(msg)
94+
return print(io, msg)
8795
end
8896

8997
function signature(
@@ -138,7 +146,7 @@ function check_prep(
138146
if SIG !== Nothing
139147
RUNTIME_SIG = typeof((f, backend, x, contexts))
140148
if SIG != RUNTIME_SIG
141-
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
149+
throw(PreparationMismatchError(SIG, RUNTIME_SIG))
142150
end
143151
end
144152
end
@@ -149,7 +157,7 @@ function check_prep(
149157
if SIG !== Nothing
150158
RUNTIME_SIG = typeof((f!, y, backend, x, contexts))
151159
if SIG != RUNTIME_SIG
152-
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
160+
throw(PreparationMismatchError(SIG, RUNTIME_SIG))
153161
end
154162
end
155163
end
@@ -160,7 +168,7 @@ function check_prep(
160168
if SIG !== Nothing
161169
RUNTIME_SIG = typeof((f, backend, x, t, contexts))
162170
if SIG != RUNTIME_SIG
163-
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
171+
throw(PreparationMismatchError(SIG, RUNTIME_SIG))
164172
end
165173
end
166174
end
@@ -171,7 +179,7 @@ function check_prep(
171179
if SIG !== Nothing
172180
RUNTIME_SIG = typeof((f!, y, backend, x, t, contexts))
173181
if SIG != RUNTIME_SIG
174-
throw(inconsistent_signatures_error(SIG, RUNTIME_SIG))
182+
throw(PreparationMismatchError(SIG, RUNTIME_SIG))
175183
end
176184
end
177185
end

0 commit comments

Comments
 (0)