Skip to content

Commit 5750b31

Browse files
authored
Clarify role of constants in preparation (#485)
1 parent ee11b70 commit 5750b31

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceZygoteExt
33
using ADTypes: AutoForwardDiff, AutoZygote
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6-
Context,
6+
Constant,
77
HVPExtras,
88
NoGradientExtras,
99
NoHessianExtras,
@@ -27,19 +27,19 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
2727
pb::PB
2828
end
2929

30-
function DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context})
30+
function DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Constant})
3131
return NoPullbackExtras()
3232
end
3333

3434
function DI.prepare_pullback_same_point(
35-
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
35+
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Constant}
3636
)
3737
y, pb = pullback(f, x, map(unwrap, contexts)...)
3838
return ZygotePullbackExtrasSamePoint(y, pb)
3939
end
4040

4141
function DI.value_and_pullback(
42-
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
42+
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Constant}
4343
)
4444
y, pb = pullback(f, x, map(unwrap, contexts)...)
4545
tx = map(ty) do dy
@@ -54,7 +54,7 @@ function DI.value_and_pullback(
5454
::AutoZygote,
5555
x,
5656
ty::Tangents,
57-
contexts::Vararg{Context},
57+
contexts::Vararg{Constant},
5858
)
5959
@compat (; y, pb) = extras
6060
tx = map(ty) do dy
@@ -69,7 +69,7 @@ function DI.pullback(
6969
::AutoZygote,
7070
x,
7171
ty::Tangents,
72-
contexts::Vararg{Context},
72+
contexts::Vararg{Constant},
7373
)
7474
@compat (; pb) = extras
7575
tx = map(ty) do dy
@@ -80,28 +80,28 @@ end
8080

8181
## Gradient
8282

83-
DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{Context}) = NoGradientExtras()
83+
DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{Constant}) = NoGradientExtras()
8484

8585
function DI.value_and_gradient(
86-
f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context}
86+
f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Constant}
8787
)
8888
@compat (; val, grad) = withgradient(f, x, map(unwrap, contexts)...)
8989
return val, first(grad)
9090
end
9191

92-
function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context})
92+
function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Constant})
9393
return first(gradient(f, x, map(unwrap, contexts)...))
9494
end
9595

9696
function DI.value_and_gradient!(
97-
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
97+
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Constant}
9898
)
9999
y, new_grad = DI.value_and_gradient(f, extras, backend, x, contexts...)
100100
return y, copyto!(grad, new_grad)
101101
end
102102

103103
function DI.gradient!(
104-
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
104+
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Constant}
105105
)
106106
return copyto!(grad, DI.gradient(f, extras, backend, x, contexts...))
107107
end

DifferentiationInterface/src/utils/context.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ abstract type Context end
2424
Concrete type of [`Context`](@ref) argument which is kept constant during differentiation.
2525
2626
Note that an operator can be prepared with an arbitrary value of the constant.
27+
However, same-point preparation must occur with the exact value that will be reused later.
2728
2829
# Example
2930

0 commit comments

Comments
 (0)