@@ -6,6 +6,24 @@ using ForwardDiff: ForwardDiff
66using Zygote:
77 ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
88
9+ struct ZygoteNothingError <: Exception
10+ f
11+ x
12+ contexts
13+ end
14+
15+ function Base. showerror (io:: IO , e:: ZygoteNothingError )
16+ (; f, x, contexts) = e
17+ sig = (typeof (x), map (typeof ∘ DI. unwrap, contexts)... )
18+ return print (
19+ io,
20+ " Zygote failed to differentiate function `$f ` with argument types `$sig ` (the pullback returned `nothing`)." ,
21+ )
22+ end
23+
24+ check_nothing (:: Nothing , f, x, contexts) = throw (ZygoteNothingError (f, x, contexts))
25+ check_nothing (:: Any , f, x, contexts) = nothing
26+
927DI. check_available (:: AutoZygote ) = true
1028DI. inplace_support (:: AutoZygote ) = DI. InPlaceNotSupported ()
1129
@@ -46,6 +64,7 @@ function DI.value_and_pullback(
4664 tx = map (ty) do dy
4765 first (pb (dy))
4866 end
67+ check_nothing (first (tx), f, x, contexts)
4968 return y, tx
5069end
5170
@@ -61,6 +80,7 @@ function DI.value_and_pullback(
6180 tx = map (ty) do dy
6281 first (pb (dy))
6382 end
83+ check_nothing (first (tx), f, x, contexts)
6484 return copy (y), tx
6585end
6686
@@ -76,6 +96,7 @@ function DI.pullback(
7696 tx = map (ty) do dy
7797 first (pb (dy))
7898 end
99+ check_nothing (first (tx), f, x, contexts)
79100 return tx
80101end
81102
@@ -95,6 +116,7 @@ function DI.value_and_gradient(
95116 contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
96117) where {C}
97118 (; val, grad) = withgradient (f, x, map (DI. unwrap, contexts)... )
119+ check_nothing (first (grad), f, x, contexts)
98120 return val, first (grad)
99121end
100122
@@ -105,7 +127,9 @@ function DI.gradient(
105127 x,
106128 contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
107129) where {C}
108- return first (gradient (f, x, map (DI. unwrap, contexts)... ))
130+ grad = gradient (f, x, map (DI. unwrap, contexts)... )
131+ check_nothing (first (grad), f, x, contexts)
132+ return first (grad)
109133end
110134
111135function DI. value_and_gradient! (
@@ -146,8 +170,11 @@ function DI.value_and_jacobian(
146170 x,
147171 contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
148172) where {C}
149- return f (x, map (DI. unwrap, contexts)... ),
150- first (jacobian (f, x, map (DI. unwrap, contexts)... )) # https://github.com/FluxML/Zygote.jl/issues/1506
173+ y = f (x, map (DI. unwrap, contexts)... )
174+ # https://github.com/FluxML/Zygote.jl/issues/1506
175+ jac = jacobian (f, x, map (DI. unwrap, contexts)... )
176+ check_nothing (first (jac), f, x, contexts)
177+ return y, first (jac)
151178end
152179
153180function DI. jacobian (
@@ -157,7 +184,9 @@ function DI.jacobian(
157184 x,
158185 contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
159186) where {C}
160- return first (jacobian (f, x, map (DI. unwrap, contexts)... ))
187+ jac = jacobian (f, x, map (DI. unwrap, contexts)... )
188+ check_nothing (first (jac), f, x, contexts)
189+ return first (jac)
161190end
162191
163192function DI. value_and_jacobian! (
@@ -266,7 +295,9 @@ function DI.hessian(
266295 contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
267296) where {C}
268297 fc = DI. with_contexts (f, contexts... )
269- return hessian (fc, x)
298+ hess = hessian (fc, x)
299+ check_nothing (hess, f, x, contexts)
300+ return hess
270301end
271302
272303function DI. hessian! (
0 commit comments