Skip to content

Commit 34476d6

Browse files
authored
fix: better warning for MethodError (#704)
1 parent 1f87272 commit 34476d6

3 files changed

Lines changed: 47 additions & 20 deletions

File tree

DifferentiationInterface/src/init.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,27 @@ function __init__()
22
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs
33
if exc.f in (_prepare_pushforward_aux, _prepare_pullback_aux)
44
B = first(T for T in argtypes if T <: AbstractADType)
5-
printstyled(
6-
io,
7-
"\n\nThe autodiff backend package you want to use may not be loaded. Please run the following command and try again:";
8-
bold=true,
9-
)
10-
printstyled(io, "\n\n\timport $(package_name(B))"; color=:cyan, bold=true)
5+
packages = required_packages(B)
6+
loaded = map(string, values(Base.loaded_modules))
7+
missing_package = any(!(p in loaded) for p in packages)
8+
if missing_package
9+
import_statement = "import $(packages[1])"
10+
for p in packages[2:end]
11+
import_statement *= ", $p"
12+
end
13+
printstyled(
14+
io,
15+
"\n\nThe autodiff backend you chose requires a package which may not be loaded. Please run the following command and try again:";
16+
bold=true,
17+
)
18+
printstyled(io, "\n\n\t$import_statement"; color=:cyan, bold=true)
19+
else
20+
printstyled(
21+
io,
22+
"\n\nThe autodiff backend you chose may not be compatible with the operation you want to perform. Please refer to the documentation of DifferentiationInterface.jl and open an issue if necessary.";
23+
bold=true,
24+
)
25+
end
1126
end
1227
end
1328
end

DifferentiationInterface/src/utils/printing.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
1-
package_name(b::AbstractADType) = package_name(typeof(b))
1+
required_packages(b::AbstractADType) = required_packages(typeof(b))
22

3-
function package_name(::Type{B}) where {B<:AbstractADType}
3+
function required_packages(::Type{B}) where {B<:AbstractADType}
44
s = string(B)
55
s = chopprefix(s, "ADTypes.")
66
s = chopprefix(s, "Auto")
77
k = findfirst('{', s)
88
if isnothing(k)
9-
return s
9+
return [s]
1010
else
11-
return s[begin:(k - 1)]
11+
return [s[begin:(k - 1)]]
1212
end
1313
end
1414

15-
function package_name(::Type{SecondOrder{O,I}}) where {O,I}
16-
p1 = package_name(O)
17-
p2 = package_name(I)
18-
return p1 == p2 ? p1 : "$p1, $p2"
15+
function required_packages(::Type{SecondOrder{O,I}}) where {O,I}
16+
p1 = required_packages(O)
17+
p2 = required_packages(I)
18+
return unique(vcat(p1, p2))
1919
end
2020

21-
package_name(::Type{<:AutoSparse{D}}) where {D} = package_name(D)
21+
function required_packages(::Type{MixedMode{F,R}}) where {F,R}
22+
p1 = required_packages(F)
23+
p2 = required_packages(R)
24+
return unique(vcat(p1, p2))
25+
end
26+
27+
function required_packages(::Type{<:AutoSparse{D}}) where {D}
28+
return unique(vcat(required_packages(D), "SparseMatrixColorings"))
29+
end
2230

2331
function document_preparation(operator_name::AbstractString; same_point=false)
2432
if same_point
Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ADTypes
22
using DifferentiationInterface
3+
using DifferentiationInterface: required_packages
34
using Test
45

56
backend = SecondOrder(AutoForwardDiff(), AutoZygote())
@@ -12,8 +13,11 @@ detector = DenseSparsityDetector(AutoForwardDiff(); atol=1e-23)
1213
diffwith = DifferentiateWith(exp, AutoForwardDiff())
1314
@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff())"
1415

15-
@test DifferentiationInterface.package_name(AutoForwardDiff()) == "ForwardDiff"
16-
@test DifferentiationInterface.package_name(AutoZygote()) == "Zygote"
17-
@test DifferentiationInterface.package_name(AutoSparse(AutoForwardDiff())) == "ForwardDiff"
18-
@test DifferentiationInterface.package_name(SecondOrder(AutoForwardDiff(), AutoZygote())) ==
19-
"ForwardDiff, Zygote"
16+
@test required_packages(AutoForwardDiff()) == ["ForwardDiff"]
17+
@test required_packages(AutoZygote()) == ["Zygote"]
18+
@test required_packages(AutoSparse(AutoForwardDiff())) ==
19+
["ForwardDiff", "SparseMatrixColorings"]
20+
@test required_packages(SecondOrder(AutoForwardDiff(), AutoZygote())) ==
21+
["ForwardDiff", "Zygote"]
22+
@test required_packages(MixedMode(AutoForwardDiff(), AutoZygote())) ==
23+
["ForwardDiff", "Zygote"]

0 commit comments

Comments
 (0)