Skip to content

Commit b6bc16c

Browse files
authored
Base reverse mode for allocating functions on split (#149)
* Base reverse mode for allocating functions on split * No fallback structure
1 parent 6454cbc commit b6bc16c

12 files changed

Lines changed: 187 additions & 224 deletions

File tree

DifferentiationInterface/docs/src/design.md

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,5 @@ For simplicity, we remove `value_` in the operator names below.
2424
Full edges in the following graphs require a single call to the destination.
2525
Dotted edges require multiple calls to the destination, the number is indicated on the edge.
2626

27-
### First order
28-
29-
```mermaid
30-
flowchart LR
31-
direction LR
32-
subgraph Out-of-place
33-
pushforward
34-
pullback
35-
derivative --> pushforward
36-
gradient --> pullback
37-
jacobian .-> |n|pushforward
38-
jacobian .-> |m|pullback
39-
end
40-
41-
subgraph In-place
42-
pushforward!! --> pushforward
43-
pullback!! --> pullback
44-
derivative!! --> pushforward!!
45-
gradient!! --> pullback!!
46-
jacobian!! .-> |n|pushforward!!
47-
jacobian!! .-> |m|pullback!!
48-
end
49-
50-
pushforward .-> |m|pullback
51-
pullback .-> |n|pushforward
52-
```
27+
!!! warning
28+
This is still in flux, come back later!

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ DI.mode(::AutoReverseChainRules) = ADTypes.AbstractReverseMode
1919

2020
DI.prepare_pullback(f, ::AutoForwardChainRules, x) = NoPullbackExtras()
2121

22-
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras)
23-
rc = ruleconfig(backend)
24-
y, pullback = rrule_via_ad(rc, f, x)
25-
_, new_dx = pullback(dy)
26-
return y, new_dx
27-
end
28-
2922
function DI.value_and_pullback_split(
3023
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
3124
)
@@ -35,13 +28,4 @@ function DI.value_and_pullback_split(
3528
return y, pullbackfunc
3629
end
3730

38-
function DI.value_and_pullback!!_split(
39-
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
40-
)
41-
rc = ruleconfig(backend)
42-
y, pullback = rrule_via_ad(rc, f, x)
43-
pullbackfunc!!(_dx, dy) = last(pullback(dy))
44-
return y, pullbackfunc!!
45-
end
46-
4731
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,25 @@ function DI.value_and_pullback(
2424
return y, new_dx
2525
end
2626

27+
function DI.value_and_pullback(
28+
f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras
29+
)
30+
dx = similar(x)
31+
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
32+
end
33+
34+
function DI.pullback(f, backend::AutoReverseEnzyme, x, dy, extras::NoPullbackExtras)
35+
return DI.value_and_pullback(f, backend, x, dy, extras)[2]
36+
end
37+
2738
### In-place
2839

40+
function DI.value_and_pullback!!(
41+
f, _dx, backend::AutoReverseEnzyme, x::Number, dy, extras::NoPullbackExtras
42+
)
43+
return DI.value_and_pullback(f, backend, x, dy, extras)
44+
end
45+
2946
function DI.value_and_pullback!!(
3047
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
3148
)
@@ -48,11 +65,26 @@ function DI.value_and_pullback!!(
4865
return y, dx_sametype
4966
end
5067

51-
function DI.value_and_pullback(
52-
f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras
68+
function DI.pullback!!(f, dx, backend::AutoReverseEnzyme, x, dy, extras::NoPullbackExtras)
69+
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)[2]
70+
end
71+
72+
### Closure
73+
74+
function DI.value_and_pullback_split(
75+
f, backend::AutoReverseEnzyme, x, extras::NoPullbackExtras
5376
)
54-
dx = similar(x)
55-
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
77+
y = f(x)
78+
pullbackfunc(dy) = DI.pullback(f, backend, x, dy, extras)
79+
return y, pullbackfunc
80+
end
81+
82+
function DI.value_and_pullback!!_split(
83+
f, backend::AutoReverseEnzyme, x, extras::NoPullbackExtras
84+
)
85+
y = f(x)
86+
pullbackfunc!!(dx, dy) = DI.pullback!!(f, dx, backend, x, dy, extras)
87+
return y, pullbackfunc!!
5688
end
5789

5890
## Gradient

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/allocating.jl

Lines changed: 25 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,79 +2,40 @@
22

33
DI.prepare_pullback(f, ::AnyAutoReverseDiff, x) = NoPullbackExtras()
44

5-
function DI.value_and_pullback!!(
6-
f,
7-
dx::AbstractArray,
8-
backend::AnyAutoReverseDiff,
9-
x::AbstractArray,
10-
dy,
11-
extras::NoPullbackExtras,
12-
)
13-
return f(x), DI.pullback!!(f, dx, backend, x, dy, extras)
14-
end
15-
16-
function DI.value_and_pullback(
17-
f, backend::AnyAutoReverseDiff, x::AbstractArray, dy, extras::NoPullbackExtras
18-
)
19-
return f(x), DI.pullback(f, backend, x, dy, extras)
20-
end
21-
22-
### Number out
23-
24-
function DI.pullback!!(
25-
f,
26-
dx::AbstractArray,
27-
::AnyAutoReverseDiff,
28-
x::AbstractArray,
29-
dy::Number,
30-
::NoPullbackExtras,
5+
function DI.value_and_pullback_split(
6+
f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras
317
)
32-
dx = gradient!(dx, f, x)
33-
dx .*= dy
34-
return dx
35-
end
36-
37-
function DI.pullback(
38-
f, ::AnyAutoReverseDiff, x::AbstractArray, dy::Number, ::NoPullbackExtras
39-
)
40-
dx = gradient(f, x)
41-
dx .*= dy
42-
return dx
43-
end
44-
45-
### Array out
46-
47-
function DI.pullback!!(
48-
f,
49-
dx::AbstractArray,
50-
::AnyAutoReverseDiff,
51-
x::AbstractArray,
52-
dy::AbstractArray,
53-
::NoPullbackExtras,
54-
)
55-
dotproduct_closure(x) = dot(f(x), dy)
56-
dx = gradient!(dx, dotproduct_closure, x)
57-
return dx
8+
y = f(x)
9+
pullbackfunc = if y isa Number
10+
dy -> dy .* gradient(f, x)
11+
elseif y isa AbstractArray
12+
dy -> gradient(z -> dot(f(z), dy), x)
13+
end
14+
return y, pullbackfunc
5815
end
5916

60-
function DI.pullback(
61-
f, ::AnyAutoReverseDiff, x::AbstractArray, dy::AbstractArray, extras::NoPullbackExtras
17+
function DI.value_and_pullback!!_split(
18+
f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras
6219
)
63-
dotproduct_closure(x) = dot(f(x), dy)
64-
dx = gradient(dotproduct_closure, x)
65-
return dx
20+
y = f(x)
21+
pullbackfunc!! = if y isa Number
22+
(dx, dy) -> begin
23+
dx = gradient!(dx, f, x)
24+
dx .*= dy
25+
end
26+
elseif y isa AbstractArray
27+
(dx, dy) -> gradient!(dx, z -> dot(f(z), dy), x)
28+
end
29+
return y, pullbackfunc!!
6630
end
6731

68-
### Number in, not supported
69-
70-
function DI.value_and_pullback(
71-
f, backend::AnyAutoReverseDiff, x::Number, dy, ::NoPullbackExtras
32+
function DI.value_and_pullback_split(
33+
f, backend::AnyAutoReverseDiff, x::Number, ::NoPullbackExtras
7234
)
7335
x_array = [x]
7436
f_array = f only
75-
new_extras = DI.prepare_pullback(f_array, backend, x_array)
76-
y, dx_array = DI.value_and_pullback(f_array, backend, x_array, dy, new_extras)
77-
return y, only(dx_array)
37+
y, pullbackfunc = DI.value_and_pullback_split(f_array, backend, x_array)
38+
return y, only pullbackfunc
7839
end
7940

8041
## Gradient

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/DifferentiationInterfaceTapirExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Tapir:
99
NoTangent,
1010
build_rrule,
1111
increment!!,
12+
primal,
1213
set_to_zero!!,
1314
tangent,
1415
tangent_type,

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/allocating.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,48 @@ function DI.value_and_pullback!!(
2323
)
2424
return new_y, new_dx
2525
end
26+
27+
function DI.pullback(f, backend::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras)
28+
return DI.value_and_pullback(f, backend, x, dy, extras)[2]
29+
end
30+
31+
function DI.pullback!!(
32+
f, dx, backend::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras
33+
)
34+
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)[2]
35+
end
36+
37+
#=
38+
# First try
39+
40+
function DI.value_and_pullback_split(f, ::AutoTapir, x, extras::TapirAllocatingPullbackExtras)
41+
tf = zero_tangent(f)
42+
tx = zero_tangent(x)
43+
out, pb!! = extras.rrule(CoDual(f, tf), CoDual(x, tx))
44+
y = copy(primal(out))
45+
function pullbackfunc(dy)
46+
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
47+
ty = increment!!(tangent(out), dy_righttype)
48+
res = pb!!(ty, tf, tx)
49+
extras.rrule(CoDual(f, tf), CoDual(x, tx))
50+
return last(res)
51+
end
52+
return y, pullbackfunc
53+
end
54+
=#
55+
56+
function DI.value_and_pullback_split(
57+
f, backend::AutoTapir, x, extras::TapirAllocatingPullbackExtras
58+
)
59+
y = f(x)
60+
pullbackfunc(dy) = DI.pullback(f, backend, x, dy, extras)
61+
return y, pullbackfunc
62+
end
63+
64+
function DI.value_and_pullback!!_split(
65+
f, backend::AutoTapir, x, extras::TapirAllocatingPullbackExtras
66+
)
67+
y = f(x)
68+
pullbackfunc!!(dx, dy) = DI.pullback!!(f, dx, backend, x, dy, extras)
69+
return y, pullbackfunc!!
70+
end

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,12 @@ DI.supports_mutation(::AutoTracker) = DI.MutationNotSupported()
1111

1212
DI.prepare_pullback(f, ::AutoTracker, x) = NoPullbackExtras()
1313

14-
function DI.value_and_pullback(f, ::AutoTracker, x, dy, ::NoPullbackExtras)
15-
y, back = forward(f, x)
16-
return y, data(only(back(dy)))
17-
end
18-
1914
function DI.value_and_pullback_split(f, ::AutoTracker, x, ::NoPullbackExtras)
2015
y, back = forward(f, x)
2116
pullbackfunc(dy) = data(only(back(dy)))
2217
return y, pullbackfunc
2318
end
2419

25-
function DI.value_and_pullback!!_split(f, ::AutoTracker, x, ::NoPullbackExtras)
26-
y, back = forward(f, x)
27-
pullbackfunc!!(_dx, dy) = data(only(back(dy)))
28-
return y, pullbackfunc!!
29-
end
30-
3120
## Gradient
3221

3322
DI.prepare_gradient(f, ::AutoTracker, x) = NoGradientExtras()

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,12 @@ DI.supports_mutation(::AnyAutoZygote) = DI.MutationNotSupported()
1616

1717
DI.prepare_pullback(f, ::AnyAutoZygote, x) = NoPullbackExtras()
1818

19-
function DI.value_and_pullback(f, ::AnyAutoZygote, x, dy, ::NoPullbackExtras)
20-
y, back = pullback(f, x)
21-
dx = only(back(dy))
22-
return y, dx
23-
end
24-
2519
function DI.value_and_pullback_split(f, ::AnyAutoZygote, x, ::NoPullbackExtras)
2620
y, back = pullback(f, x)
2721
pullbackfunc(dy) = only(back(dy))
2822
return y, pullbackfunc
2923
end
3024

31-
function DI.value_and_pullback!!_split(f, ::AnyAutoZygote, x, ::NoPullbackExtras)
32-
y, back = pullback(f, x)
33-
pullbackfunc!!(_dx, dy) = only(back(dy))
34-
return y, pullbackfunc!!
35-
end
36-
3725
## Gradient
3826

3927
DI.prepare_gradient(f, ::AnyAutoZygote, x) = NoGradientExtras()

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,4 @@ export prepare_second_derivative, prepare_hvp, prepare_hessian
130130

131131
export check_available, check_mutation, check_hessian
132132

133-
function __init__()
134-
Base.Experimental.register_error_hint(StackOverflowError) do io, exc
135-
print(
136-
io,
137-
"""\n
138-
HINT: One of DifferentiationInterface's functions might be missing a method, which would trigger an endless loop of `pullback` calling `pushforward` and vice-versa.
139-
Some possible fixes:
140-
- switch to another backend
141-
- if you don't want to switch, load the package extension corresponding to your backend
142-
- if your backend is already loaded, define the primitive operator for the right combination of argument types
143-
""",
144-
)
145-
return nothing
146-
end
147-
end
148-
149133
end # module

0 commit comments

Comments
 (0)