Skip to content

Commit d366f09

Browse files
committed
Zygote ambiguity
1 parent 3530c17 commit d366f09

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,20 @@ function DI.prepare_hvp(
158158
end
159159

160160
function DI.hvp(
161-
f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}
161+
f,
162+
prep::DI.ForwardOverReverseHVPPrep,
163+
backend::AutoZygote,
164+
x,
165+
tx::NTuple,
166+
contexts::Vararg{DI.Context,C},
162167
) where {C}
163168
return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
164169
end
165170

166171
function DI.hvp!(
167172
f,
168173
tg::NTuple,
169-
prep::DI.HVPPrep,
174+
prep::DI.ForwardOverReverseHVPPrep,
170175
backend::AutoZygote,
171176
x,
172177
tx::NTuple,
@@ -178,7 +183,12 @@ function DI.hvp!(
178183
end
179184

180185
function DI.gradient_and_hvp(
181-
f, prep::DI.HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{DI.Context,C}
186+
f,
187+
prep::DI.ForwardOverReverseHVPPrep,
188+
backend::AutoZygote,
189+
x,
190+
tx::NTuple,
191+
contexts::Vararg{DI.Context,C},
182192
) where {C}
183193
return DI.gradient_and_hvp(
184194
f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
@@ -189,7 +199,7 @@ function DI.gradient_and_hvp!(
189199
f,
190200
grad,
191201
tg::NTuple,
192-
prep::DI.HVPPrep,
202+
prep::DI.ForwardOverReverseHVPPrep,
193203
backend::AutoZygote,
194204
x,
195205
tx::NTuple,

0 commit comments

Comments
 (0)