Skip to content

Commit 1251518

Browse files
committed
Propagate with value types
1 parent 0a4b355 commit 1251518

43 files changed

Lines changed: 610 additions & 515 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=true)
4+
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,); strict=Val(true))
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## Pullback
22

33
struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
4-
_sig::Type{SIG}
4+
_sig::Val{SIG}
55
y::Y
66
pb::PB
77
end
@@ -12,10 +12,10 @@ function DI.prepare_pullback(
1212
x,
1313
ty::NTuple,
1414
contexts::Vararg{DI.GeneralizedConstant,C};
15-
strict::Bool=false,
15+
strict::Val=Val(false),
1616
) where {C}
17-
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
18-
return DI.NoPullbackPrep{SIG}()
17+
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
18+
return DI.NoPullbackPrep(_sig)
1919
end
2020

2121
function DI.prepare_pullback_same_point(
@@ -25,13 +25,13 @@ function DI.prepare_pullback_same_point(
2525
x,
2626
ty::NTuple,
2727
contexts::Vararg{DI.GeneralizedConstant,C};
28-
strict::Bool=false,
28+
strict::Val=Val(false),
2929
) where {C}
3030
DI.check_prep(f, prep, backend, x, ty, contexts...)
31-
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
31+
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
3232
rc = ruleconfig(backend)
3333
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
34-
return ChainRulesPullbackPrepSamePoint(SIG, y, pb)
34+
return ChainRulesPullbackPrepSamePoint(_sig, y, pb)
3535
end
3636

3737
function DI.value_and_pullback(

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1111
## Pushforward
1212

1313
function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple)
14-
SIG = DI.signature(f, backend, x, tx)
15-
return DI.NoPushforwardPrep{SIG}()
14+
_sig = DI.signature(f, backend, x, tx)
15+
return DI.NoPushforwardPrep(_sig)
1616
end
1717

1818
function DI.pushforward(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ function DI.prepare_pushforward(
66
x,
77
tx::NTuple,
88
contexts::Vararg{DI.Context,C};
9-
strict::Bool=false,
9+
strict::Val=Val(false),
1010
) where {F,C}
11-
SIG = DI.signature(f, backend, x, tx, contexts...; strict)
12-
return DI.NoPushforwardPrep{SIG}()
11+
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
12+
return DI.NoPushforwardPrep(_sig)
1313
end
1414

1515
function DI.value_and_pushforward(
@@ -117,24 +117,22 @@ end
117117
## Gradient
118118

119119
struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
120+
_sig::Val{SIG}
121+
_valB::Val{B}
120122
shadows::O
121123
end
122124

123-
function EnzymeForwardGradientPrep(::Type{SIG}, ::Val{B}, shadows::O) where {SIG,B,O}
124-
return EnzymeForwardGradientPrep{SIG,B,O}(shadows)
125-
end
126-
127125
function DI.prepare_gradient(
128126
f::F,
129127
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
130128
x,
131129
contexts::Vararg{DI.Constant,C};
132-
strict::Bool=false,
130+
strict::Val=Val(false),
133131
) where {F,C}
134-
SIG = DI.signature(f, backend, x, contexts...; strict)
132+
_sig = DI.signature(f, backend, x, contexts...; strict)
135133
valB = to_val(DI.pick_batchsize(backend, x))
136134
shadows = create_shadows(valB, x)
137-
return EnzymeForwardGradientPrep(SIG, valB, shadows)
135+
return EnzymeForwardGradientPrep(_sig, valB, shadows)
138136
end
139137

140138
function DI.gradient(
@@ -199,28 +197,24 @@ end
199197
## Jacobian
200198

201199
struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
200+
_sig::Val{SIG}
201+
_valB::Val{B}
202202
shadows::O
203203
output_length::Int
204204
end
205205

206-
function EnzymeForwardOneArgJacobianPrep(
207-
::Type{SIG}, ::Val{B}, shadows::O, output_length::Integer
208-
) where {SIG,B,O}
209-
return EnzymeForwardOneArgJacobianPrep{SIG,B,O}(shadows, output_length)
210-
end
211-
212206
function DI.prepare_jacobian(
213207
f::F,
214208
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
215209
x,
216210
contexts::Vararg{DI.Constant,C};
217-
strict::Bool=false,
211+
strict::Val=Val(false),
218212
) where {F,C}
219-
SIG = DI.signature(f, backend, x, contexts...; strict)
213+
_sig = DI.signature(f, backend, x, contexts...; strict)
220214
y = f(x, map(DI.unwrap, contexts)...)
221215
valB = to_val(DI.pick_batchsize(backend, x))
222216
shadows = create_shadows(valB, x)
223-
return EnzymeForwardOneArgJacobianPrep(SIG, valB, shadows, length(y))
217+
return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y))
224218
end
225219

226220
function DI.jacobian(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ function DI.prepare_pushforward(
77
x,
88
tx::NTuple,
99
contexts::Vararg{DI.Context,C};
10-
strict::Bool=false,
10+
strict::Val=Val(false),
1111
) where {F,C}
12-
SIG = DI.signature(f!, y, backend, x, tx, contexts...; strict)
13-
return DI.NoPushforwardPrep{SIG}()
12+
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
13+
return DI.NoPushforwardPrep(_sig)
1414
end
1515

1616
function DI.value_and_pushforward(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ end
4848
## Pullback
4949

5050
struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG}
51+
_sig::Val{SIG}
5152
y_example::Y # useful to create return activity
5253
end
5354

@@ -57,11 +58,11 @@ function DI.prepare_pullback(
5758
x,
5859
ty::NTuple,
5960
contexts::Vararg{DI.Context,C};
60-
strict::Bool=false,
61+
strict::Val=Val(false),
6162
) where {F,C}
62-
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
63+
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
6364
y = f(x, map(DI.unwrap, contexts)...)
64-
return EnzymeReverseOneArgPullbackPrep{SIG,typeof(y)}(y)
65+
return EnzymeReverseOneArgPullbackPrep(_sig, y)
6566
end
6667

6768
### Out-of-place
@@ -195,10 +196,10 @@ function DI.prepare_gradient(
195196
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
196197
x,
197198
contexts::Vararg{DI.Context,C};
198-
strict::Bool=false,
199+
strict::Val=Val(false),
199200
) where {F,C}
200-
SIG = DI.signature(f, backend, x, contexts...; strict)
201-
return DI.NoGradientPrep{SIG}()
201+
_sig = DI.signature(f, backend, x, contexts...; strict)
202+
return DI.NoGradientPrep(_sig)
202203
end
203204

204205
### Enzyme gradient API (only constants)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## Pullback
22

33
struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG}
4+
_sig::Val{SIG}
45
ty_copy::TY
56
end
67

@@ -11,11 +12,11 @@ function DI.prepare_pullback(
1112
x,
1213
ty::NTuple,
1314
contexts::Vararg{DI.Context,C};
14-
strict::Bool=false,
15+
strict::Val=Val(false),
1516
) where {F,C}
16-
SIG = DI.signature(f!, y, backend, x, ty, contexts...; strict)
17+
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
1718
ty_copy = map(copy, ty)
18-
return EnzymeReverseTwoArgPullbackPrep{SIG,typeof(ty_copy)}(ty_copy)
19+
return EnzymeReverseTwoArgPullbackPrep(_sig, ty_copy)
1920
end
2021

2122
function DI.value_and_pullback(

0 commit comments

Comments
 (0)