Skip to content

Commit 32b70a2

Browse files
authored
Throw informative errors when backends are not loaded (#214)
* More informative error when backend not loaded * Refactor to `MissingBackendError` * Test exceptions * Run JuliaFormatter
1 parent c1dbc01 commit 32b70a2

7 files changed

Lines changed: 81 additions & 13 deletions

File tree

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("utils/basis.jl")
4141
include("utils/printing.jl")
4242
include("utils/chunk.jl")
4343
include("utils/check.jl")
44+
include("utils/exceptions.jl")
4445

4546
include("first_order/pushforward.jl")
4647
include("first_order/pullback.jl")

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackSlow)
8484
return PushforwardPullbackExtras(pushforward_extras)
8585
end
8686

87+
# Throw error if backend is missing
88+
prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) = throw(MissingBackendError(backend))
89+
prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) = throw(MissingBackendError(backend))
90+
8791
## One argument
8892

8993
function value_and_pullback(

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ function prepare_pushforward_aux(f!, y, backend, x, dx, ::PushforwardSlow)
7373
return PullbackPushforwardExtras(pullback_extras)
7474
end
7575

76+
# Throw error if backend is missing
77+
prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) = throw(MissingBackendError(backend))
78+
prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) = throw(MissingBackendError(backend))
79+
7680
## One argument
7781

7882
function value_and_pushforward(
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
struct MissingBackendError <: Exception
2+
backend::AbstractADType
3+
end
4+
function Base.showerror(io::IO, e::MissingBackendError)
5+
println(io, "failed to use $(backend_string(e.backend)) backend.")
6+
if !check_available(e.backend)
7+
print(
8+
io,
9+
"""Backend package is not loaded. To fix, run
10+
11+
using $(backend_package_name(e.backend))
12+
""",
13+
)
14+
else
15+
print(
16+
io,
17+
"Please open an issue: https://github.com/gdalle/DifferentiationInterface.jl/issues/new",
18+
)
19+
end
20+
end

DifferentiationInterface/src/utils/printing.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
backend_string_aux(b::AbstractADType) = string(b)
1+
backend_package_name(b::AbstractADType) = strip(string(b), ['(', ')'])
22

3-
backend_string_aux(::AutoChainRules) = "ChainRules"
4-
backend_string_aux(::AutoDiffractor) = "Diffractor"
5-
backend_string_aux(::AutoEnzyme) = "Enzyme"
6-
backend_string_aux(::AutoFastDifferentiation) = "FastDifferentiation"
7-
backend_string_aux(::AutoFiniteDiff) = "FiniteDiff"
8-
backend_string_aux(::AutoFiniteDifferences) = "FiniteDifferences"
9-
backend_string_aux(::AutoForwardDiff) = "ForwardDiff"
10-
backend_string_aux(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff"
3+
backend_package_name(::AutoChainRules) = "ChainRules"
4+
backend_package_name(::AutoDiffractor) = "Diffractor"
5+
backend_package_name(::AutoEnzyme) = "Enzyme"
6+
backend_package_name(::AutoFastDifferentiation) = "FastDifferentiation"
7+
backend_package_name(::AutoFiniteDiff) = "FiniteDiff"
8+
backend_package_name(::AutoFiniteDifferences) = "FiniteDifferences"
9+
backend_package_name(::AutoForwardDiff) = "ForwardDiff"
10+
backend_package_name(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff"
11+
backend_package_name(::AutoSymbolics) = "Symbolics"
12+
backend_package_name(::AutoTapir) = "Tapir"
13+
backend_package_name(::AutoTracker) = "Tracker"
14+
backend_package_name(::AutoZygote) = "Zygote"
15+
backend_package_name(::AutoReverseDiff) = "ReverseDiff"
16+
17+
backend_string_aux(b::AbstractADType) = backend_package_name(b)
1118
backend_string_aux(b::AutoReverseDiff) = "ReverseDiff$(b.compile ? "{compiled}" : "")"
12-
backend_string_aux(::AutoSymbolics) = "Symbolics"
13-
backend_string_aux(::AutoTapir) = "Tapir"
14-
backend_string_aux(::AutoTracker) = "Tracker"
15-
backend_string_aux(::AutoZygote) = "Zygote"
1619

1720
function backend_string(backend::AbstractADType)
1821
bs = backend_string_aux(backend)

DifferentiationInterface/test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ include("test_imports.jl")
2323

2424
Documenter.doctest(DifferentiationInterface)
2525

26+
@testset verbose = true "Exception handling" begin
27+
include("test_exceptions.jl")
28+
end
2629
@testset verbose = true "First order" begin
2730
include("first_order.jl")
2831
end
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using DifferentiationInterface: MissingBackendError
2+
3+
"""
4+
AutoBrokenForward <: ADTypes.AbstractADType
5+
6+
Available forward-mode backend with no pushforward implementation.
7+
Used to test error messages.
8+
"""
9+
struct AutoBrokenForward <: AbstractADType end
10+
ADTypes.mode(::AutoBrokenForward) = ADTypes.ForwardMode()
11+
DifferentiationInterface.check_available(::AutoBrokenForward) = true
12+
13+
"""
14+
AutoBrokenReverse <: ADTypes.AbstractADType
15+
16+
Available reverse-mode backend with no pullback implementation.
17+
Used to test error messages.
18+
"""
19+
struct AutoBrokenReverse <: AbstractADType end
20+
ADTypes.mode(::AutoBrokenReverse) = ADTypes.ReverseMode()
21+
DifferentiationInterface.check_available(::AutoBrokenReverse) = true
22+
23+
## Test exceptions
24+
@testset "MissingBackendError" begin
25+
f(x::AbstractArray) = sum(abs2, x)
26+
x = [1.0, 2.0, 3.0]
27+
28+
@test_throws MissingBackendError gradient(f, AutoBrokenForward(), x)
29+
@test_throws MissingBackendError gradient(f, AutoBrokenReverse(), x)
30+
31+
@test_throws MissingBackendError hvp(f, AutoBrokenForward(), x, x)
32+
@test_throws MissingBackendError hvp(f, AutoBrokenReverse(), x, x)
33+
end

0 commit comments

Comments
 (0)