Skip to content

Commit e549302

Browse files
authored
Add support for AutoEnzyme(mode=nothing) (#211)
* Add support for AutoEnzyme(mode=nothing) * Mode for nothing backend * Fix undefined
1 parent 5161e56 commit e549302

6 files changed

Lines changed: 86 additions & 42 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,15 @@ using Enzyme:
3030
make_zero
3131

3232
const AutoForwardEnzyme = AutoEnzyme{<:ForwardMode}
33+
const AutoForwardOrNothingEnzyme = Union{AutoEnzyme{<:ForwardMode},AutoEnzyme{Nothing}}
3334
const AutoReverseEnzyme = AutoEnzyme{<:ReverseMode}
35+
const AutoReverseOrNothingEnzyme = Union{AutoEnzyme{<:ReverseMode},AutoEnzyme{Nothing}}
36+
37+
forward_mode(backend::AutoEnzyme{<:ForwardMode}) = backend.mode
38+
forward_mode(::AutoEnzyme{Nothing}) = Forward
39+
40+
reverse_mode(backend::AutoEnzyme{<:ReverseMode}) = backend.mode
41+
reverse_mode(::AutoEnzyme{Nothing}) = Reverse
3442

3543
DI.check_available(::AutoEnzyme) = true
3644

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f, ::AutoForwardEnzyme, x, dx) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
6-
f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras
6+
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
77
)
88
dx_sametype = convert(typeof(x), dx)
9-
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
9+
y, new_dy = autodiff(forward_mode(backend), f, Duplicated, Duplicated(x, dx_sametype))
1010
return y, new_dy
1111
end
1212

13-
function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras)
13+
function DI.pushforward(
14+
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
15+
)
1416
dx_sametype = convert(typeof(x), dx)
15-
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
17+
new_dy = only(
18+
autodiff(forward_mode(backend), f, DuplicatedNoNeed, Duplicated(x, dx_sametype))
19+
)
1620
return new_dy
1721
end
1822

1923
function DI.value_and_pushforward!(
20-
f, dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras
24+
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
2125
)
2226
# dy cannot be passed anyway
2327
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
2428
return y, copyto!(dy, new_dy)
2529
end
2630

2731
function DI.pushforward!(
28-
f, dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras
32+
f, dy, backend::AutoForwardOrNothingEnzyme, x, dx, extras::NoPushforwardExtras
2933
)
3034
# dy cannot be passed anyway
3135
return copyto!(dy, DI.pushforward(f, backend, x, dx, extras))
@@ -46,7 +50,7 @@ end
4650
function DI.gradient(
4751
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
4852
) where {C}
49-
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
53+
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
5054
return reshape(collect(grad_tup), size(x))
5155
end
5256

@@ -59,14 +63,14 @@ end
5963
function DI.gradient!(
6064
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
6165
) where {C}
62-
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
66+
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
6367
return copyto!(grad, grad_tup)
6468
end
6569

6670
function DI.value_and_gradient!(
6771
f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C}
6872
) where {C}
69-
grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
73+
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
7074
return f(x), copyto!(grad, grad_tup)
7175
end
7276

@@ -76,35 +80,43 @@ struct EnzymeForwardOneArgJacobianExtras{C,O}
7680
shadow::O
7781
end
7882

79-
function DI.prepare_jacobian(f, ::AutoForwardEnzyme, x)
83+
function DI.prepare_jacobian(f, ::AutoForwardOrNothingEnzyme, x)
8084
C = pick_chunksize(length(x))
8185
shadow = chunkedonehot(x, Val(C))
8286
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
8387
end
8488

8589
function DI.jacobian(
86-
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C}
90+
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C}
8791
) where {C}
88-
jac_wrongshape = jacobian(backend.mode, f, x, Val{C}(); shadow=extras.shadow)
92+
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
8993
nx = length(x)
9094
ny = length(jac_wrongshape) ÷ length(x)
9195
return reshape(jac_wrongshape, ny, nx)
9296
end
9397

9498
function DI.value_and_jacobian(
95-
f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
99+
f, backend::AutoForwardOrNothingEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
96100
)
97101
return f(x), DI.jacobian(f, backend, x, extras)
98102
end
99103

100104
function DI.jacobian!(
101-
f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
105+
f,
106+
jac,
107+
backend::AutoForwardOrNothingEnzyme,
108+
x,
109+
extras::EnzymeForwardOneArgJacobianExtras,
102110
)
103111
return copyto!(jac, DI.jacobian(f, backend, x, extras))
104112
end
105113

106114
function DI.value_and_jacobian!(
107-
f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras
115+
f,
116+
jac,
117+
backend::AutoForwardOrNothingEnzyme,
118+
x,
119+
extras::EnzymeForwardOneArgJacobianExtras,
108120
)
109121
y, new_jac = DI.value_and_jacobian(f, backend, x, extras)
110122
return y, copyto!(jac, new_jac)
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
## Pushforward
22

3-
DI.prepare_pushforward(f!, y, ::AutoForwardEnzyme, x, dx) = NoPushforwardExtras()
3+
DI.prepare_pushforward(f!, y, ::AutoForwardOrNothingEnzyme, x, dx) = NoPushforwardExtras()
44

55
function DI.value_and_pushforward(
6-
f!, y, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras
6+
f!, y, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
77
)
88
dx_sametype = convert(typeof(x), dx)
99
dy_sametype = zero(y)
1010
autodiff(
11-
backend.mode, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype)
11+
forward_mode(backend),
12+
f!,
13+
Const,
14+
Duplicated(y, dy_sametype),
15+
Duplicated(x, dx_sametype),
1216
)
1317
return y, dy_sametype
1418
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
## Pullback
22

3-
DI.prepare_pullback(f, ::AutoReverseEnzyme, x, dy) = NoPullbackExtras()
3+
DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras()
44

55
### Out-of-place
66

77
function DI.value_and_pullback(
8-
f, ::AutoReverseEnzyme, x::Number, dy::Number, ::NoPullbackExtras
8+
f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras
99
)
1010
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
1111
new_dx = dy * only(der)
1212
return y, new_dx
1313
end
1414

1515
function DI.value_and_pullback(
16-
f, ::AutoReverseEnzyme, x::Number, dy::AbstractArray, ::NoPullbackExtras
16+
f, ::AutoReverseOrNothingEnzyme, x::Number, dy::AbstractArray, ::NoPullbackExtras
1717
)
1818
forw, rev = autodiff_thunk(
1919
ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Active{typeof(x)}
@@ -25,20 +25,22 @@ function DI.value_and_pullback(
2525
end
2626

2727
function DI.value_and_pullback(
28-
f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras
28+
f, backend::AutoReverseOrNothingEnzyme, x::AbstractArray, dy, extras::NoPullbackExtras
2929
)
3030
dx = similar(x)
3131
return DI.value_and_pullback!(f, dx, backend, x, dy, extras)
3232
end
3333

34-
function DI.pullback(f, backend::AutoReverseEnzyme, x, dy, extras::NoPullbackExtras)
34+
function DI.pullback(
35+
f, backend::AutoReverseOrNothingEnzyme, x, dy, extras::NoPullbackExtras
36+
)
3537
return DI.value_and_pullback(f, backend, x, dy, extras)[2]
3638
end
3739

3840
### In-place
3941

4042
function DI.value_and_pullback!(
41-
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
43+
f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
4244
)
4345
dx_sametype = zero_sametype!(dx, x)
4446
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
@@ -47,7 +49,12 @@ function DI.value_and_pullback!(
4749
end
4850

4951
function DI.value_and_pullback!(
50-
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::AbstractArray, ::NoPullbackExtras
52+
f,
53+
dx,
54+
::AutoReverseOrNothingEnzyme,
55+
x::AbstractArray,
56+
dy::AbstractArray,
57+
::NoPullbackExtras,
5158
)
5259
dx_sametype = zero_sametype!(dx, x)
5360
forw, rev = autodiff_thunk(
@@ -65,23 +72,27 @@ end
6572

6673
## Gradient
6774

68-
DI.prepare_gradient(f, ::AutoReverseEnzyme, x) = NoGradientExtras()
75+
DI.prepare_gradient(f, ::AutoReverseOrNothingEnzyme, x) = NoGradientExtras()
6976

70-
function DI.gradient(f, ::AutoReverseEnzyme, x, ::NoGradientExtras)
71-
return gradient(Reverse, f, x)
77+
function DI.gradient(f, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras)
78+
return gradient(reverse_mode(backend), f, x)
7279
end
7380

74-
function DI.gradient!(f, grad, ::AutoReverseEnzyme, x, ::NoGradientExtras)
81+
function DI.gradient!(f, grad, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras)
7582
grad_sametype = convert(typeof(x), grad)
76-
gradient!(Reverse, grad_sametype, f, x)
83+
gradient!(reverse_mode(backend), grad_sametype, f, x)
7784
return copyto!(grad, grad_sametype)
7885
end
7986

80-
function DI.value_and_gradient(f, backend::AutoReverseEnzyme, x, ::NoGradientExtras)
87+
function DI.value_and_gradient(
88+
f, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras
89+
)
8190
return DI.value_and_pullback(f, backend, x, one(eltype(x)), NoPullbackExtras())
8291
end
8392

84-
function DI.value_and_gradient!(f, grad, backend::AutoReverseEnzyme, x, ::NoGradientExtras)
93+
function DI.value_and_gradient!(
94+
f, grad, backend::AutoReverseOrNothingEnzyme, x, ::NoGradientExtras
95+
)
8596
return DI.value_and_pullback!(f, grad, backend, x, one(eltype(x)), NoPullbackExtras())
8697
end
8798

@@ -106,7 +117,7 @@ function DI.jacobian(
106117
x::AbstractArray,
107118
::EnzymeReverseOneArgJacobianExtras{C,N},
108119
) where {C,N}
109-
jac_wrongshape = jacobian(backend.mode, f, x, Val{N}(), Val{C}())
120+
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val{N}(), Val{C}())
110121
nx = length(x)
111122
ny = length(jac_wrongshape) ÷ length(x)
112123
jac_rightshape = reshape(jac_wrongshape, ny, nx)
Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
11
## Pullback
22

3-
DI.prepare_pullback(f!, y, ::AutoReverseEnzyme, x, dy) = NoPullbackExtras()
3+
DI.prepare_pullback(f!, y, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras()
44

55
function DI.value_and_pullback(
6-
f!, y, ::AutoReverseEnzyme, x::Number, dy, ::NoPullbackExtras
6+
f!, y, backend::AutoReverseOrNothingEnzyme, x::Number, dy, ::NoPullbackExtras
77
)
88
dy_sametype = convert(typeof(y), copy(dy))
9-
_, new_dx = only(autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Active(x)))
9+
_, new_dx = only(
10+
autodiff(reverse_mode(backend), f!, Const, Duplicated(y, dy_sametype), Active(x))
11+
)
1012
return y, new_dx
1113
end
1214

1315
function DI.value_and_pullback(
14-
f!, y, ::AutoReverseEnzyme, x::AbstractArray, dy, ::NoPullbackExtras
16+
f!, y, backend::AutoReverseOrNothingEnzyme, x::AbstractArray, dy, ::NoPullbackExtras
1517
)
1618
dx_sametype = zero(x)
1719
dy_sametype = convert(typeof(y), copy(dy))
18-
autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype))
20+
autodiff(
21+
reverse_mode(backend),
22+
f!,
23+
Const,
24+
Duplicated(y, dy_sametype),
25+
Duplicated(x, dx_sametype),
26+
)
1927
return y, dx_sametype
2028
end

DifferentiationInterface/test/first_order.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
dense_backends = [
22
AutoChainRules(Zygote.ZygoteRuleConfig()),
33
AutoDiffractor(),
4-
AutoEnzyme(Enzyme.Forward),
5-
AutoEnzyme(Enzyme.Reverse),
4+
AutoEnzyme(; mode=nothing),
5+
AutoEnzyme(; mode=Enzyme.Forward),
6+
AutoEnzyme(; mode=Enzyme.Reverse),
67
AutoFastDifferentiation(),
78
AutoFiniteDiff(),
8-
AutoFiniteDifferences(FiniteDifferences.central_fdm(3, 1)),
9+
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
910
AutoForwardDiff(),
1011
AutoPolyesterForwardDiff(; chunksize=1),
1112
AutoReverseDiff(; compile=true),

0 commit comments

Comments
 (0)