Skip to content

Commit ee11b70

Browse files
authored
Contexts for Zygote (#474)
1 parent 71ea7fd commit ee11b70

2 files changed

Lines changed: 50 additions & 21 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module DifferentiationInterfaceZygoteExt
33
using ADTypes: AutoForwardDiff, AutoZygote
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6+
Context,
67
HVPExtras,
78
NoGradientExtras,
89
NoHessianExtras,
910
NoJacobianExtras,
1011
NoPullbackExtras,
1112
PullbackExtras,
12-
Tangents
13+
Tangents,
14+
unwrap
1315
using ForwardDiff: ForwardDiff
1416
using Zygote:
1517
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
@@ -25,63 +27,83 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
2527
pb::PB
2628
end
2729

28-
DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents) = NoPullbackExtras()
30+
function DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context})
31+
return NoPullbackExtras()
32+
end
2933

3034
function DI.prepare_pullback_same_point(
31-
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents
35+
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
3236
)
33-
y, pb = pullback(f, x)
37+
y, pb = pullback(f, x, map(unwrap, contexts)...)
3438
return ZygotePullbackExtrasSamePoint(y, pb)
3539
end
3640

37-
function DI.value_and_pullback(f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents)
38-
y, pb = pullback(f, x)
41+
function DI.value_and_pullback(
42+
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
43+
)
44+
y, pb = pullback(f, x, map(unwrap, contexts)...)
3945
tx = map(ty) do dy
40-
only(pb(dy))
46+
first(pb(dy))
4147
end
4248
return y, tx
4349
end
4450

4551
function DI.value_and_pullback(
46-
f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents
52+
f,
53+
extras::ZygotePullbackExtrasSamePoint,
54+
::AutoZygote,
55+
x,
56+
ty::Tangents,
57+
contexts::Vararg{Context},
4758
)
4859
@compat (; y, pb) = extras
4960
tx = map(ty) do dy
50-
only(pb(dy))
61+
first(pb(dy))
5162
end
5263
return copy(y), tx
5364
end
5465

5566
function DI.pullback(
56-
f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents
67+
f,
68+
extras::ZygotePullbackExtrasSamePoint,
69+
::AutoZygote,
70+
x,
71+
ty::Tangents,
72+
contexts::Vararg{Context},
5773
)
5874
@compat (; pb) = extras
5975
tx = map(ty) do dy
60-
only(pb(dy))
76+
first(pb(dy))
6177
end
6278
return tx
6379
end
6480

6581
## Gradient
6682

67-
DI.prepare_gradient(f, ::AutoZygote, x) = NoGradientExtras()
83+
DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{Context}) = NoGradientExtras()
6884

69-
function DI.value_and_gradient(f, ::NoGradientExtras, ::AutoZygote, x)
70-
@compat (; val, grad) = withgradient(f, x)
71-
return val, only(grad)
85+
function DI.value_and_gradient(
86+
f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context}
87+
)
88+
@compat (; val, grad) = withgradient(f, x, map(unwrap, contexts)...)
89+
return val, first(grad)
7290
end
7391

74-
function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x)
75-
return only(gradient(f, x))
92+
function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context})
93+
return first(gradient(f, x, map(unwrap, contexts)...))
7694
end
7795

78-
function DI.value_and_gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x)
79-
y, new_grad = DI.value_and_gradient(f, extras, backend, x)
96+
function DI.value_and_gradient!(
97+
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
98+
)
99+
y, new_grad = DI.value_and_gradient(f, extras, backend, x, contexts...)
80100
return y, copyto!(grad, new_grad)
81101
end
82102

83-
function DI.gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x)
84-
return copyto!(grad, DI.gradient(f, extras, backend, x))
103+
function DI.gradient!(
104+
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
105+
)
106+
return copyto!(grad, DI.gradient(f, extras, backend, x, contexts...))
85107
end
86108

87109
## Jacobian

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ end
3030

3131
test_differentiation(AutoZygote(); excluded=[:second_derivative], logging=LOGGING);
3232

33+
test_differentiation(
34+
AutoZygote(),
35+
default_scenarios(; include_normal=false, include_constantified=true);
36+
second_order=false,
37+
logging=LOGGING,
38+
);
39+
3340
if VERSION >= v"1.10"
3441
test_differentiation(
3542
AutoZygote(),

0 commit comments

Comments
 (0)