Skip to content

Commit a22ccc7

Browse files
authored
Prepare pushforward and pullback with seed (#181)
1 parent 21cf0f4 commit a22ccc7

34 files changed

Lines changed: 178 additions & 167 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1818

1919
## Pullback
2020

21-
DI.prepare_pullback(f, ::AutoReverseChainRules, x) = NoPullbackExtras()
21+
DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()
2222

2323
function DI.value_and_pullback_split(
2424
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode
1212

1313
## Pushforward
1414

15-
DI.prepare_pushforward(f, ::AutoDiffractor, x) = NoPushforwardExtras()
15+
DI.prepare_pushforward(f, ::AutoDiffractor, x, dx) = NoPushforwardExtras()
1616

1717
function DI.pushforward(f, ::AutoDiffractor, x, dx, ::NoPushforwardExtras)
1818
# code copied from Diffractor.jl

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f, ::AutoForwardEnzyme, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f, ::AutoForwardEnzyme, x, dx) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
66
f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, y, ::AutoForwardEnzyme, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f!, y, ::AutoForwardEnzyme, x, dx) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
66
f!, y, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pullback
22

3-
DI.prepare_pullback(f, ::AutoReverseEnzyme, x) = NoPullbackExtras()
3+
DI.prepare_pullback(f, ::AutoReverseEnzyme, x, dy) = NoPullbackExtras()
44

55
### Out-of-place
66

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pullback
22

3-
DI.prepare_pullback(f!, y, ::AutoReverseEnzyme, x) = NoPullbackExtras()
3+
DI.prepare_pullback(f!, y, ::AutoReverseEnzyme, x, dy) = NoPullbackExtras()
44

55
function DI.value_and_pullback(
66
f!, y, ::AutoReverseEnzyme, x::Number, dy, ::NoPullbackExtras

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
66
jvp_exe!::E2
77
end
88

9-
function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x)
9+
function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
1010
y_prototype = f(x)
1111
x_var = if x isa Number
1212
only(make_variables(:x))

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
55
jvp_exe!::E2
66
end
77

8-
function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x)
8+
function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x, dx)
99
x_var = if x isa Number
1010
only(make_variables(:x))
1111
else

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f, ::AutoFiniteDiff, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f, ::AutoFiniteDiff, x, dx) = NoPushforwardExtras()
44

55
function DI.pushforward(f, backend::AutoFiniteDiff, x, dx, ::NoPushforwardExtras)
66
step(t::Number) = f(x .+ t .* dx)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, y, ::AutoFiniteDiff, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f!, y, ::AutoFiniteDiff, x, dx) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
66
f!, y, backend::AutoFiniteDiff, x, dx, ::NoPushforwardExtras

0 commit comments

Comments
 (0)