Skip to content

Commit 6e9b50d

Browse files
authored
Switch order of arguments in preparation (#177)
1 parent 84d98a3 commit 6e9b50d

19 files changed

Lines changed: 118 additions & 110 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, ::AutoForwardEnzyme, y, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f!, y, ::AutoForwardEnzyme, x) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
66
f!, y, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pullback
22

3-
DI.prepare_pullback(f!, ::AutoReverseEnzyme, y, x) = NoPullbackExtras()
3+
DI.prepare_pullback(f!, y, ::AutoReverseEnzyme, x) = NoPullbackExtras()
44

55
function DI.value_and_pullback(
66
f!, y, ::AutoReverseEnzyme, x::Number, dy, ::NoPullbackExtras

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
55
jvp_exe!::E2
66
end
77

8-
function DI.prepare_pushforward(f!, ::AnyAutoFastDifferentiation, y, x)
8+
function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x)
99
x_var = if x isa Number
1010
only(make_variables(:x))
1111
else
@@ -85,7 +85,7 @@ struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
8585
der_exe!::E2
8686
end
8787

88-
function DI.prepare_derivative(f!, ::AnyAutoFastDifferentiation, y, x)
88+
function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x)
8989
x_var = only(make_variables(:x))
9090
y_var = make_variables(:y, size(y)...)
9191
f!(y_var, x_var)
@@ -153,7 +153,7 @@ struct FastDifferentiationTwoArgJacobianExtras{E1,E2} <: JacobianExtras
153153
jac_exe!::E2
154154
end
155155

156-
function DI.prepare_jacobian(f!, backend::AnyAutoFastDifferentiation, y, x)
156+
function DI.prepare_jacobian(f!, y, backend::AnyAutoFastDifferentiation, x)
157157
x_var = make_variables(:x, size(x)...)
158158
y_var = make_variables(:y, size(y)...)
159159
f!(y_var, x_var)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, ::AnyAutoFiniteDiff, y, x) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f!, y, ::AnyAutoFiniteDiff, x) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
66
f!, y, backend::AnyAutoFiniteDiff, x, dx, ::NoPushforwardExtras
@@ -23,7 +23,7 @@ struct FiniteDiffTwoArgDerivativeExtras{C}
2323
cache::C
2424
end
2525

26-
function DI.prepare_derivative(f!, ::AnyAutoFiniteDiff, y, x)
26+
function DI.prepare_derivative(f!, y, ::AnyAutoFiniteDiff, x)
2727
cache = nothing
2828
return FiniteDiffTwoArgDerivativeExtras(cache)
2929
end
@@ -65,7 +65,7 @@ struct FiniteDiffTwoArgJacobianExtras{C}
6565
cache::C
6666
end
6767

68-
function DI.prepare_jacobian(f!, backend::AnyAutoFiniteDiff, y, x)
68+
function DI.prepare_jacobian(f!, y, backend::AnyAutoFiniteDiff, x)
6969
x1 = similar(x)
7070
fx = similar(y)
7171
fx1 = similar(y)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
DI.prepare_pushforward(f!, ::AnyAutoForwardDiff, y, x) = NoPushforwardExtras()
1+
DI.prepare_pushforward(f!, y, ::AnyAutoForwardDiff, x) = NoPushforwardExtras()
22

33
function DI.value_and_pushforward(f!, y, ::AnyAutoForwardDiff, x, dx, ::NoPushforwardExtras)
44
T = tag_type(f!, x)
@@ -16,7 +16,7 @@ struct ForwardDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
1616
config::C
1717
end
1818

19-
function DI.prepare_derivative(f!, ::AnyAutoForwardDiff, y::AbstractArray, x::Number)
19+
function DI.prepare_derivative(f!, y::AbstractArray, ::AnyAutoForwardDiff, x::Number)
2020
return ForwardDiffTwoArgDerivativeExtras(DerivativeConfig(f!, y, x))
2121
end
2222

@@ -75,7 +75,7 @@ struct ForwardDiffTwoArgJacobianExtras{C} <: JacobianExtras
7575
end
7676

7777
function DI.prepare_jacobian(
78-
f!, backend::AnyAutoForwardDiff, y::AbstractArray, x::AbstractArray
78+
f!, y::AbstractArray, backend::AnyAutoForwardDiff, x::AbstractArray
7979
)
8080
return ForwardDiffTwoArgJacobianExtras(
8181
JacobianConfig(f!, y, x, choose_chunk(backend, x))

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Pushforward
22

3-
function DI.prepare_pushforward(f!, backend::AnyAutoPolyForwardDiff, y, x)
4-
return DI.prepare_pushforward(f!, single_threaded(backend), y, x)
3+
function DI.prepare_pushforward(f!, y, backend::AnyAutoPolyForwardDiff, x)
4+
return DI.prepare_pushforward(f!, y, single_threaded(backend), x)
55
end
66

77
function DI.value_and_pushforward(
@@ -30,8 +30,8 @@ end
3030

3131
## Derivative
3232

33-
function DI.prepare_derivative(f!, backend::AnyAutoPolyForwardDiff, y, x)
34-
return DI.prepare_derivative(f!, single_threaded(backend), y, x)
33+
function DI.prepare_derivative(f!, y, backend::AnyAutoPolyForwardDiff, x)
34+
return DI.prepare_derivative(f!, y, single_threaded(backend), x)
3535
end
3636

3737
function DI.value_and_derivative(
@@ -58,7 +58,7 @@ end
5858

5959
## Jacobian
6060

61-
DI.prepare_jacobian(f!, ::AnyAutoPolyForwardDiff, y, x) = NoJacobianExtras()
61+
DI.prepare_jacobian(f!, y, ::AnyAutoPolyForwardDiff, x) = NoJacobianExtras()
6262

6363
function DI.value_and_jacobian(
6464
f!, y, ::AnyAutoPolyForwardDiff{C}, x, ::NoJacobianExtras

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pullback
22

3-
DI.prepare_pullback(f!, ::AnyAutoReverseDiff, y, x) = NoPullbackExtras()
3+
DI.prepare_pullback(f!, y, ::AnyAutoReverseDiff, x) = NoPullbackExtras()
44

55
### Array in
66

@@ -60,7 +60,7 @@ function DI.value_and_pullback(
6060
x_array = [x]
6161
dx_array = similar(x_array)
6262
f!_array(_y::AbstractArray, _x_array) = f!(_y, only(_x_array))
63-
new_extras = DI.prepare_pullback(f!_array, backend, y, x_array)
63+
new_extras = DI.prepare_pullback(f!_array, y, backend, x_array)
6464
y, dx_array = DI.value_and_pullback(f!_array, y, backend, x_array, dy, new_extras)
6565
return y, only(dx_array)
6666
end
@@ -72,7 +72,7 @@ struct ReverseDiffTwoArgJacobianExtras{T} <: JacobianExtras
7272
end
7373

7474
function DI.prepare_jacobian(
75-
f!, backend::AnyAutoReverseDiff, y::AbstractArray, x::AbstractArray
75+
f!, y::AbstractArray, backend::AnyAutoReverseDiff, x::AbstractArray
7676
)
7777
tape = JacobianTape(f!, y, x)
7878
if backend.compile

DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ for AutoSparse in SPARSE_BACKENDS
77
## Jacobian
88

99
function DI.prepare_jacobian(
10-
f!, backend::$AutoSparse, y::AbstractArray, x::AbstractArray
10+
f!, y::AbstractArray, backend::$AutoSparse, x::AbstractArray
1111
)
1212
cache = sparse_jacobian_cache(
1313
backend, SymbolicsSparsityDetection(), f!, similar(y), x

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ struct TapirTwoArgPullbackExtras{R} <: PullbackExtras
22
rrule::R
33
end
44

5-
function DI.prepare_pullback(f!, ::AutoTapir, y, x)
5+
function DI.prepare_pullback(f!, y, ::AutoTapir, x)
66
return TapirTwoArgPullbackExtras(build_rrule(f!, y, x))
77
end
88

DifferentiationInterface/src/derivative.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
## Docstrings
22

33
"""
4-
prepare_derivative(f, backend, x) -> extras
5-
prepare_derivative(f!, backend, y, x) -> extras
4+
prepare_derivative(f, backend, x) -> extras
5+
prepare_derivative(f!, y, backend, x) -> extras
66
77
Create an `extras` object subtyping [`DerivativeExtras`](@ref) that can be given to derivative operators.
8+
9+
Beware that in the two-argument case, `y` is mutated by `f!` during preparation.
810
"""
911
function prepare_derivative end
1012

@@ -51,8 +53,8 @@ function prepare_derivative(f, backend::AbstractADType, x)
5153
return PushforwardDerivativeExtras(prepare_pushforward(f, backend, x))
5254
end
5355

54-
function prepare_derivative(f!, backend::AbstractADType, y, x)
55-
return PushforwardDerivativeExtras(prepare_pushforward(f!, backend, y, x))
56+
function prepare_derivative(f!, y, backend::AbstractADType, x)
57+
return PushforwardDerivativeExtras(prepare_pushforward(f!, y, backend, x))
5658
end
5759

5860
## One argument
@@ -102,7 +104,7 @@ function value_and_derivative(
102104
y,
103105
backend::AbstractADType,
104106
x,
105-
extras::DerivativeExtras=prepare_derivative(f!, backend, y, x),
107+
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
106108
)
107109
return value_and_pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
108110
end
@@ -113,7 +115,7 @@ function value_and_derivative!(
113115
der,
114116
backend::AbstractADType,
115117
x,
116-
extras::DerivativeExtras=prepare_derivative(f!, backend, y, x),
118+
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
117119
)
118120
return value_and_pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
119121
end
@@ -123,7 +125,7 @@ function derivative(
123125
y,
124126
backend::AbstractADType,
125127
x,
126-
extras::DerivativeExtras=prepare_derivative(f!, backend, y, x),
128+
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
127129
)
128130
return pushforward(f!, y, backend, x, one(x), extras.pushforward_extras)
129131
end
@@ -134,7 +136,7 @@ function derivative!(
134136
der,
135137
backend::AbstractADType,
136138
x,
137-
extras::DerivativeExtras=prepare_derivative(f!, backend, y, x),
139+
extras::DerivativeExtras=prepare_derivative(f!, y, backend, x),
138140
)
139141
return pushforward!(f!, y, der, backend, x, one(x), extras.pushforward_extras)
140142
end

0 commit comments

Comments
 (0)