Skip to content

Commit de23245

Browse files
authored
[BREAKING] Change order of arguments (#435)
* Move extras in core code * Update backend extensions * Update docs * Typos * Typos * Fixes * Typos * Fix * Fix ForwardDiff * Fixes * Fixes * Fix Enzyme * Bump versions and compats
1 parent 4eb9642 commit de23245

51 files changed

Lines changed: 954 additions & 946 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.17"
4+
version = "0.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/implementations.md

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,75 +47,75 @@ function operators_and_types_f(backend::T) where {T<:AbstractADType}
4747
# (val_and_op, types_val_and_op),
4848
# (val_and_op!, types_val_and_op!),
4949
(
50-
(:derivative, (Any, T, Any, Any)),
51-
(:derivative!, (Any, Any, T, Any, Any)),
52-
(:value_and_derivative, (Any, T, Any, Any)),
53-
(:value_and_derivative!, (Any, Any, T, Any, Any)),
50+
(:derivative, (Any, Any, T, Any)),
51+
(:derivative!, (Any, Any, Any, T, Any)),
52+
(:value_and_derivative, (Any, Any, T, Any)),
53+
(:value_and_derivative!, (Any, Any, Any, T, Any)),
5454
),
5555
(
56-
(:gradient, (Any, T, Any, Any)),
57-
(:gradient!, (Any, Any, T, Any, Any)),
58-
(:value_and_gradient, (Any, T, Any, Any)),
59-
(:value_and_gradient!, (Any, Any, T, Any, Any)),
56+
(:gradient, (Any, Any, T, Any)),
57+
(:gradient!, (Any, Any, Any, T, Any)),
58+
(:value_and_gradient, (Any, Any, T, Any)),
59+
(:value_and_gradient!, (Any, Any, Any, T, Any)),
6060
),
6161
(
62-
(:jacobian, (Any, T, Any, Any)),
63-
(:jacobian!, (Any, Any, T, Any, Any)),
64-
(:value_and_jacobian, (Any, T, Any, Any)),
65-
(:value_and_jacobian!, (Any, Any, T, Any, Any)),
62+
(:jacobian, (Any, Any, T, Any)),
63+
(:jacobian!, (Any, Any, Any, T, Any)),
64+
(:value_and_jacobian, (Any, Any, T, Any)),
65+
(:value_and_jacobian!, (Any, Any, Any, T, Any)),
6666
),
6767
(
68-
(:hessian, (Any, T, Any, Any)),
69-
(:hessian!, (Any, Any, T, Any, Any)),
68+
(:hessian, (Any, Any, T, Any)),
69+
(:hessian!, (Any, Any, Any, T, Any)),
7070
(nothing, nothing),
7171
(nothing, nothing),
7272
),
7373
(
74-
(:hvp, (Any, T, Any, Any, Any)),
75-
(:hvp!, (Any, Any, T, Any, Any, Any)),
74+
(:hvp, (Any, Any, T, Any, Any)),
75+
(:hvp!, (Any, Any, Any, T, Any, Any)),
7676
(nothing, nothing),
7777
(nothing, nothing),
7878
),
7979
(
80-
(:pullback, (Any, T, Any, Any, Any)),
81-
(:pullback!, (Any, Any, T, Any, Any, Any)),
82-
(:value_and_pullback, (Any, T, Any, Any, Any)),
83-
(:value_and_pullback!, (Any, Any, T, Any, Any, Any)),
80+
(:pullback, (Any, Any, T, Any, Any)),
81+
(:pullback!, (Any, Any, Any, T, Any, Any)),
82+
(:value_and_pullback, (Any, Any, T, Any, Any)),
83+
(:value_and_pullback!, (Any, Any, Any, T, Any, Any)),
8484
),
8585
(
86-
(:pushforward, (Any, T, Any, Any, Any)),
87-
(:pushforward!, (Any, Any, T, Any, Any, Any)),
88-
(:value_and_pushforward, (Any, T, Any, Any, Any)),
89-
(:value_and_pushforward!, (Any, Any, T, Any, Any, Any)),
86+
(:pushforward, (Any, Any, T, Any, Any)),
87+
(:pushforward!, (Any, Any, Any, T, Any, Any)),
88+
(:value_and_pushforward, (Any, Any, T, Any, Any)),
89+
(:value_and_pushforward!, (Any, Any, Any, T, Any, Any)),
9090
),
9191
)
9292
end
9393
9494
function operators_and_types_f!(backend::T) where {T<:AbstractADType}
9595
return (
9696
(
97-
(:derivative, (Any, Any, T, Any, Any)),
98-
(:derivative!, (Any, Any, Any, T, Any, Any)),
99-
(:value_and_derivative, (Any, Any, T, Any, Any)),
100-
(:value_and_derivative!, (Any, Any, Any, T, Any, Any)),
97+
(:derivative, (Any, Any, Any, T, Any)),
98+
(:derivative!, (Any, Any, Any, Any, T, Any)),
99+
(:value_and_derivative, (Any, Any, Any, T, Any)),
100+
(:value_and_derivative!, (Any, Any, Any, Any, T, Any)),
101101
),
102102
(
103-
(:jacobian, (Any, Any, T, Any, Any)),
104-
(:jacobian!, (Any, Any, Any, T, Any, Any)),
105-
(:value_and_jacobian, (Any, Any, T, Any, Any)),
106-
(:value_and_jacobian!, (Any, Any, Any, T, Any, Any)),
103+
(:jacobian, (Any, Any, Any, T, Any)),
104+
(:jacobian!, (Any, Any, Any, Any, T, Any)),
105+
(:value_and_jacobian, (Any, Any, Any, T, Any)),
106+
(:value_and_jacobian!, (Any, Any, Any, Any, T, Any)),
107107
),
108108
(
109-
(:pullback, (Any, Any, T, Any, Any, Any)),
110-
(:pullback!, (Any, Any, Any, T, Any, Any, Any)),
111-
(:value_and_pullback, (Any, Any, T, Any, Any, Any)),
112-
(:value_and_pullback!, (Any, Any, Any, T, Any, Any, Any)),
109+
(:pullback, (Any, Any, Any, T, Any, Any)),
110+
(:pullback!, (Any, Any, Any, Any, T, Any, Any)),
111+
(:value_and_pullback, (Any, Any, Any, T, Any, Any)),
112+
(:value_and_pullback!, (Any, Any, Any, Any, T, Any, Any)),
113113
),
114114
(
115-
(:pushforward, (Any, Any, T, Any, Any, Any)),
116-
(:pushforward!, (Any, Any, Any, T, Any, Any, Any)),
117-
(:value_and_pushforward, (Any, Any, T, Any, Any, Any)),
118-
(:value_and_pushforward!, (Any, Any, Any, T, Any, Any, Any)),
115+
(:pushforward, (Any, Any, Any, T, Any, Any)),
116+
(:pushforward!, (Any, Any, Any, Any, T, Any, Any)),
117+
(:value_and_pushforward, (Any, Any, Any, T, Any, Any)),
118+
(:value_and_pushforward!, (Any, Any, Any, Any, T, Any, Any)),
119119
),
120120
)
121121
end

DifferentiationInterface/docs/src/operators.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ The same operators are defined for both cases, but they have different signature
7070

7171
| signature | out-of-place | in-place |
7272
| :--------- | :------------------------------------------- | :---------------------------------------------------- |
73-
| `f(x)` | `operator(f, backend, x, [v], [extras])` | `operator!(f, result, backend, x, [v], [extras])` |
74-
| `f!(y, x)` | `operator(f!, y, backend, x, [v], [extras])` | `operator!(f!, y, result, backend, x, [v], [extras])` |
73+
| `f(x)` | `operator(f, [extras,] backend, x, [v])` | `operator!(f, result, [extras,] backend, x, [v])` |
74+
| `f!(y, x)` | `operator(f!, y, [extras,] backend, x, [v])` | `operator!(f!, y, result, [extras,] backend, x, [v])` |
7575

7676
!!! warning
7777
The positional arguments between `f`/`f!` and `backend` are always mutated.
@@ -108,12 +108,11 @@ The idea is that you prepare only once, which can be costly, but then call the o
108108

109109
```julia
110110
operator(f, backend, x, [v]) # slow because it includes preparation
111-
operator(f, backend, x, [v], extras) # fast because it skips preparation
111+
operator(f, extras, backend, x, [v]) # fast because it skips preparation
112112
```
113113

114114
!!! warning
115-
The `extras` object is always mutated when given to an operator, even though it is the last argument.
116-
This convention holds regardless of the bang `!` in the operator name.
115+
The `extras` object is always mutated, regardless of the bang `!` in the operator name.
117116

118117
### Reusing preparation
119118

@@ -123,7 +122,7 @@ Here are the general rules that we strive to implement:
123122
| | different point | same point |
124123
| :------------------------ | :------------------------------------------- | :------------------------------------------- |
125124
| the output `extras` of... | `prepare_operator(f, b, x)` | `prepare_operator_same_point(f, b, x, v)` |
126-
| can be used in... | `operator(f, b, other_x, extras)` | `operator(f, b, x, other_v, extras)` |
125+
| can be used in... | `operator(f, extras, b, other_x)` | `operator(f, extras, b, x, other_v)` |
127126
| provided that... | `other_x` has the same type and shape as `x` | `other_v` has the same type and shape as `v` |
128127

129128
These rules hold for the majority of backends, but there are some exceptions: see [this page](@ref "Preparation") to know more.

DifferentiationInterface/docs/src/tutorial1.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ You can thus reuse the `extras` for different values of the input.
9393

9494
```@example tuto1
9595
grad = similar(x)
96-
gradient!(f, grad, backend, x, extras)
96+
gradient!(f, grad, extras, backend, x)
9797
grad # has been mutated
9898
```
9999

100100
Preparation makes the gradient computation much faster, and (in this case) allocation-free.
101101

102102
```@example tuto1
103-
@benchmark gradient!($f, _grad, $backend, $x, _extras) evals=1 setup=(
103+
@benchmark gradient!($f, _grad, _extras, $backend, $x) evals=1 setup=(
104104
_grad=similar($x);
105105
_extras=prepare_gradient($f, $backend, $x)
106106
)
@@ -128,7 +128,7 @@ gradient(f, backend2, x)
128128
And you can run the same benchmarks to see what you gained (although such a small input may not be realistic):
129129

130130
```@example tuto1
131-
@benchmark gradient!($f, _grad, $backend2, $x, _extras) evals=1 setup=(
131+
@benchmark gradient!($f, _grad, _extras, $backend2, $x) evals=1 setup=(
132132
_grad=similar($x);
133133
_extras=prepare_gradient($f, $backend2, $x)
134134
)

DifferentiationInterface/docs/src/tutorial2.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ nothing # hide
8282
```
8383

8484
```@example tuto2
85-
@benchmark jacobian($f_sparse_vector, $dense_first_order_backend, $(randn(n)), $jac_extras_dense) evals=1
85+
@benchmark jacobian($f_sparse_vector, $jac_extras_dense, $dense_first_order_backend, $(randn(n))) evals=1
8686
```
8787

8888
```@example tuto2
89-
@benchmark jacobian($f_sparse_vector, $sparse_first_order_backend, $(randn(n)), $jac_extras_sparse) evals=1
89+
@benchmark jacobian($f_sparse_vector, $jac_extras_sparse, $sparse_first_order_backend, $(randn(n))) evals=1
9090
```

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ function ChainRulesCore.rrule(dw::DifferentiateWith, x)
1111
@compat (; f, backend) = dw
1212
y = f(x)
1313
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
14-
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, backend, x, dy, extras_same))
14+
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, extras_same, backend, x, dy))
1515
return y, pullbackfunc
1616
end

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,30 @@ end
88
DI.prepare_pullback(f, ::AutoReverseChainRules, x, ty::Tangents) = NoPullbackExtras()
99

1010
function DI.prepare_pullback_same_point(
11-
f, backend::AutoReverseChainRules, x, ty::Tangents, ::NoPullbackExtras
11+
f, ::NoPullbackExtras, backend::AutoReverseChainRules, x, ty::Tangents
1212
)
1313
rc = ruleconfig(backend)
1414
y, pb = rrule_via_ad(rc, f, x)
1515
return ChainRulesPullbackExtrasSamePoint(y, pb)
1616
end
1717

1818
function DI.value_and_pullback(
19-
f, backend::AutoReverseChainRules, x, ty::Tangents, ::NoPullbackExtras
19+
f, ::NoPullbackExtras, backend::AutoReverseChainRules, x, ty::Tangents
2020
)
2121
rc = ruleconfig(backend)
2222
y, pb = rrule_via_ad(rc, f, x)
2323
return y, Tangents(last.(pb.(ty.d)))
2424
end
2525

2626
function DI.value_and_pullback(
27-
f, ::AutoReverseChainRules, x, ty::Tangents, extras::ChainRulesPullbackExtrasSamePoint
27+
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
2828
)
2929
@compat (; y, pb) = extras
3030
return copy(y), Tangents(last.(pb.(ty.d)))
3131
end
3232

3333
function DI.pullback(
34-
f, ::AutoReverseChainRules, x, ty::Tangents, extras::ChainRulesPullbackExtrasSamePoint
34+
f, extras::ChainRulesPullbackExtrasSamePoint, ::AutoReverseChainRules, x, ty::Tangents
3535
)
3636
@compat (; pb) = extras
3737
return Tangents(last.(pb.(ty.d)))

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1313

1414
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::Tangents) = NoPushforwardExtras()
1515

16-
function DI.pushforward(f, ::AutoDiffractor, x, tx::Tangents, ::NoPushforwardExtras)
16+
function DI.pushforward(f, ::NoPushforwardExtras, ::AutoDiffractor, x, tx::Tangents)
1717
dys = map(tx.d) do dx
1818
# code copied from Diffractor.jl
1919
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
@@ -23,9 +23,9 @@ function DI.pushforward(f, ::AutoDiffractor, x, tx::Tangents, ::NoPushforwardExt
2323
end
2424

2525
function DI.value_and_pushforward(
26-
f, backend::AutoDiffractor, x, tx::Tangents, extras::NoPushforwardExtras
26+
f, extras::NoPushforwardExtras, backend::AutoDiffractor, x, tx::Tangents
2727
)
28-
return f(x), DI.pushforward(f, backend, x, tx, extras)
28+
return f(x), DI.pushforward(f, extras, backend, x, tx)
2929
end
3030

3131
end

0 commit comments

Comments
 (0)