Skip to content

Commit f3f7b29

Browse files
authored
fix: check nothing output for Zygote (#667)
1 parent 9df2763 commit f3f7b29

2 files changed

Lines changed: 68 additions & 26 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ using ForwardDiff: ForwardDiff
66
using Zygote:
77
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
88

9+
struct ZygoteNothingError <: Exception
10+
f
11+
x
12+
contexts
13+
end
14+
15+
function Base.showerror(io::IO, e::ZygoteNothingError)
16+
(; f, x, contexts) = e
17+
sig = (typeof(x), map(typeof DI.unwrap, contexts)...)
18+
return print(
19+
io,
20+
"Zygote failed to differentiate function `$f` with argument types `$sig` (the pullback returned `nothing`).",
21+
)
22+
end
23+
24+
check_nothing(::Nothing, f, x, contexts) = throw(ZygoteNothingError(f, x, contexts))
25+
check_nothing(::Any, f, x, contexts) = nothing
26+
927
DI.check_available(::AutoZygote) = true
1028
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
1129

@@ -46,6 +64,7 @@ function DI.value_and_pullback(
4664
tx = map(ty) do dy
4765
first(pb(dy))
4866
end
67+
check_nothing(first(tx), f, x, contexts)
4968
return y, tx
5069
end
5170

@@ -61,6 +80,7 @@ function DI.value_and_pullback(
6180
tx = map(ty) do dy
6281
first(pb(dy))
6382
end
83+
check_nothing(first(tx), f, x, contexts)
6484
return copy(y), tx
6585
end
6686

@@ -76,6 +96,7 @@ function DI.pullback(
7696
tx = map(ty) do dy
7797
first(pb(dy))
7898
end
99+
check_nothing(first(tx), f, x, contexts)
79100
return tx
80101
end
81102

@@ -95,6 +116,7 @@ function DI.value_and_gradient(
95116
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
96117
) where {C}
97118
(; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...)
119+
check_nothing(first(grad), f, x, contexts)
98120
return val, first(grad)
99121
end
100122

@@ -105,7 +127,9 @@ function DI.gradient(
105127
x,
106128
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
107129
) where {C}
108-
return first(gradient(f, x, map(DI.unwrap, contexts)...))
130+
grad = gradient(f, x, map(DI.unwrap, contexts)...)
131+
check_nothing(first(grad), f, x, contexts)
132+
return first(grad)
109133
end
110134

111135
function DI.value_and_gradient!(
@@ -146,8 +170,11 @@ function DI.value_and_jacobian(
146170
x,
147171
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
148172
) where {C}
149-
return f(x, map(DI.unwrap, contexts)...),
150-
first(jacobian(f, x, map(DI.unwrap, contexts)...)) # https://github.com/FluxML/Zygote.jl/issues/1506
173+
y = f(x, map(DI.unwrap, contexts)...)
174+
# https://github.com/FluxML/Zygote.jl/issues/1506
175+
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
176+
check_nothing(first(jac), f, x, contexts)
177+
return y, first(jac)
151178
end
152179

153180
function DI.jacobian(
@@ -157,7 +184,9 @@ function DI.jacobian(
157184
x,
158185
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
159186
) where {C}
160-
return first(jacobian(f, x, map(DI.unwrap, contexts)...))
187+
jac = jacobian(f, x, map(DI.unwrap, contexts)...)
188+
check_nothing(first(jac), f, x, contexts)
189+
return first(jac)
161190
end
162191

163192
function DI.value_and_jacobian!(
@@ -266,7 +295,9 @@ function DI.hessian(
266295
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
267296
) where {C}
268297
fc = DI.with_contexts(f, contexts...)
269-
return hessian(fc, x)
298+
hess = hessian(fc, x)
299+
check_nothing(hess, f, x, contexts)
300+
return hess
270301
end
271302

272303
function DI.hessian!(

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,38 @@ end
2424

2525
## Dense
2626

27-
test_differentiation(
28-
backends,
29-
default_scenarios(; include_constantified=true);
30-
excluded=[:second_derivative],
31-
logging=LOGGING,
32-
);
33-
34-
test_differentiation(second_order_backends; logging=LOGGING);
35-
36-
test_differentiation(
37-
backends[1],
38-
vcat(component_scenarios(), gpu_scenarios());
39-
excluded=SECOND_ORDER,
40-
logging=LOGGING,
41-
)
27+
@testset "Dense" begin
28+
test_differentiation(
29+
backends,
30+
default_scenarios(; include_constantified=true);
31+
excluded=[:second_derivative],
32+
logging=LOGGING,
33+
)
34+
35+
test_differentiation(second_order_backends; logging=LOGGING)
36+
37+
test_differentiation(
38+
backends[1],
39+
vcat(component_scenarios(), gpu_scenarios());
40+
excluded=SECOND_ORDER,
41+
logging=LOGGING,
42+
)
43+
end
4244

4345
## Sparse
4446

45-
test_differentiation(
46-
MyAutoSparse.(vcat(backends, second_order_backends)),
47-
sparse_scenarios(; band_sizes=0:-1);
48-
sparsity=true,
49-
logging=LOGGING,
50-
)
47+
@testset "Sparse" begin
48+
test_differentiation(
49+
MyAutoSparse.(vcat(backends, second_order_backends)),
50+
sparse_scenarios(; band_sizes=0:-1);
51+
sparsity=true,
52+
logging=LOGGING,
53+
)
54+
end
55+
56+
## Errors
57+
58+
@testset "Errors" begin
59+
safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN)
60+
@test_throws "Zygote failed to differentiate" derivative(safe_log, AutoZygote(), 0.0)
61+
end

0 commit comments

Comments
 (0)