Skip to content

Commit bcf644d

Browse files
committed
Positional arguments
1 parent b13a31d commit bcf644d

43 files changed

Lines changed: 290 additions & 459 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,25 @@ struct ChainRulesPullbackPrepSamePoint{SIG,Y,PB} <: DI.PullbackPrep{SIG}
77
end
88

99
function DI.prepare_pullback(
10+
strict::Val,
1011
f,
1112
backend::AutoReverseChainRules,
1213
x,
1314
ty::NTuple,
1415
contexts::Vararg{DI.GeneralizedConstant,C};
15-
strict::Val=Val(false),
1616
) where {C}
1717
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
1818
return DI.NoPullbackPrep(_sig)
1919
end
2020

2121
function DI.prepare_pullback_same_point(
22+
strict,
2223
f,
2324
prep::DI.NoPullbackPrep,
2425
backend::AutoReverseChainRules,
2526
x,
2627
ty::NTuple,
2728
contexts::Vararg{DI.GeneralizedConstant,C};
28-
strict::Val=Val(false),
2929
) where {C}
3030
DI.check_prep(f, prep, backend, x, ty, contexts...)
3131
_sig = DI.signature(f, backend, x, ty, contexts...; strict)

DifferentiationInterface/ext/DifferentiationInterfaceDiffractorExt/DifferentiationInterfaceDiffractorExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
1010

1111
## Pushforward
1212

13-
function DI.prepare_pushforward(f, backend::AutoDiffractor, x, tx::NTuple)
14-
_sig = DI.signature(f, backend, x, tx)
13+
function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
14+
_sig = DI.signature(f, backend, x, tx; strict)
1515
return DI.NoPushforwardPrep(_sig)
1616
end
1717

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4+
strict::Val,
45
f::F,
56
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
67
x,
78
tx::NTuple,
89
contexts::Vararg{DI.Context,C};
9-
strict::Val=Val(false),
1010
) where {F,C}
1111
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
1212
return DI.NoPushforwardPrep(_sig)
@@ -123,11 +123,11 @@ struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
123123
end
124124

125125
function DI.prepare_gradient(
126+
strict::Val,
126127
f::F,
127128
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
128129
x,
129130
contexts::Vararg{DI.Constant,C};
130-
strict::Val=Val(false),
131131
) where {F,C}
132132
_sig = DI.signature(f, backend, x, contexts...; strict)
133133
valB = to_val(DI.pick_batchsize(backend, x))
@@ -204,11 +204,11 @@ struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
204204
end
205205

206206
function DI.prepare_jacobian(
207+
strict::Val,
207208
f::F,
208209
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
209210
x,
210211
contexts::Vararg{DI.Constant,C};
211-
strict::Val=Val(false),
212212
) where {F,C}
213213
_sig = DI.signature(f, backend, x, contexts...; strict)
214214
y = f(x, map(DI.unwrap, contexts)...)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
## Pushforward
22

33
function DI.prepare_pushforward(
4+
strict::Val,
45
f!::F,
56
y,
67
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
78
x,
89
tx::NTuple,
910
contexts::Vararg{DI.Context,C};
10-
strict::Val=Val(false),
1111
) where {F,C}
1212
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
1313
return DI.NoPushforwardPrep(_sig)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ struct EnzymeReverseOneArgPullbackPrep{SIG,Y} <: DI.PullbackPrep{SIG}
5353
end
5454

5555
function DI.prepare_pullback(
56+
strict::Val,
5657
f::F,
5758
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
5859
x,
5960
ty::NTuple,
6061
contexts::Vararg{DI.Context,C};
61-
strict::Val=Val(false),
6262
) where {F,C}
6363
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
6464
y = f(x, map(DI.unwrap, contexts)...)
@@ -192,11 +192,11 @@ end
192192
## Gradient
193193

194194
function DI.prepare_gradient(
195+
strict::Val,
195196
f::F,
196197
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
197198
x,
198199
contexts::Vararg{DI.Context,C};
199-
strict::Val=Val(false),
200200
) where {F,C}
201201
_sig = DI.signature(f, backend, x, contexts...; strict)
202202
return DI.NoGradientPrep(_sig)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ struct EnzymeReverseTwoArgPullbackPrep{SIG,TY} <: DI.PullbackPrep{SIG}
66
end
77

88
function DI.prepare_pullback(
9+
strict::Val,
910
f!::F,
1011
y,
1112
backend::AutoEnzyme{<:Union{ReverseMode,Nothing}},
1213
x,
1314
ty::NTuple,
1415
contexts::Vararg{DI.Context,C};
15-
strict::Val=Val(false),
1616
) where {F,C}
1717
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
1818
ty_copy = map(copy, ty)

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ struct FastDifferentiationOneArgPushforwardPrep{SIG,Y,E1,E1!} <: DI.PushforwardP
88
end
99

1010
function DI.prepare_pushforward(
11+
strict::Val,
1112
f,
1213
backend::AutoFastDifferentiation,
1314
x,
1415
tx::NTuple,
1516
contexts::Vararg{DI.Context,C};
16-
strict::Val=Val(false),
1717
) where {C}
1818
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
1919
y_prototype = f(x, map(DI.unwrap, contexts)...)
@@ -106,12 +106,12 @@ struct FastDifferentiationOneArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
106106
end
107107

108108
function DI.prepare_pullback(
109+
strict::Val,
109110
f,
110111
backend::AutoFastDifferentiation,
111112
x,
112113
ty::NTuple,
113114
contexts::Vararg{DI.Context,C};
114-
strict::Val=Val(false),
115115
) where {C}
116116
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
117117
x_var = variablize(x, :x)
@@ -205,11 +205,7 @@ struct FastDifferentiationOneArgDerivativePrep{SIG,Y,E1,E1!} <: DI.DerivativePre
205205
end
206206

207207
function DI.prepare_derivative(
208-
f,
209-
backend::AutoFastDifferentiation,
210-
x,
211-
contexts::Vararg{DI.Context,C};
212-
strict::Val=Val(false),
208+
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
213209
) where {C}
214210
_sig = DI.signature(f, backend, x, contexts...; strict)
215211
y_prototype = f(x, map(DI.unwrap, contexts)...)
@@ -289,11 +285,7 @@ struct FastDifferentiationOneArgGradientPrep{SIG,E1,E1!} <: DI.GradientPrep{SIG}
289285
end
290286

291287
function DI.prepare_gradient(
292-
f,
293-
backend::AutoFastDifferentiation,
294-
x,
295-
contexts::Vararg{DI.Context,C};
296-
strict::Val=Val(false),
288+
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
297289
) where {C}
298290
_sig = DI.signature(f, backend, x, contexts...; strict)
299291
x_var = variablize(x, :x)
@@ -369,11 +361,11 @@ struct FastDifferentiationOneArgJacobianPrep{SIG,Y,E1,E1!} <: DI.JacobianPrep{SI
369361
end
370362

371363
function DI.prepare_jacobian(
364+
strict::Val,
372365
f,
373366
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
374367
x,
375368
contexts::Vararg{DI.Context,C};
376-
strict::Val=Val(false),
377369
) where {C}
378370
_sig = DI.signature(f, backend, x, contexts...; strict)
379371
y_prototype = f(x, map(DI.unwrap, contexts)...)
@@ -454,11 +446,7 @@ struct FastDifferentiationAllocatingSecondDerivativePrep{SIG,Y,D,E2,E2!} <:
454446
end
455447

456448
function DI.prepare_second_derivative(
457-
f,
458-
backend::AutoFastDifferentiation,
459-
x,
460-
contexts::Vararg{DI.Context,C};
461-
strict::Val=Val(false),
449+
strict::Val, f, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
462450
) where {C}
463451
_sig = DI.signature(f, backend, x, contexts...; strict)
464452
y_prototype = f(x, map(DI.unwrap, contexts)...)
@@ -547,12 +535,12 @@ struct FastDifferentiationHVPPrep{SIG,E2,E2!,E1} <: DI.HVPPrep{SIG}
547535
end
548536

549537
function DI.prepare_hvp(
538+
strict::Val,
550539
f,
551540
backend::AutoFastDifferentiation,
552541
x,
553542
tx::NTuple,
554543
contexts::Vararg{DI.Context,C};
555-
strict::Val=Val(false),
556544
) where {C}
557545
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
558546
x_var = variablize(x, :x)
@@ -646,11 +634,11 @@ struct FastDifferentiationHessianPrep{SIG,G,E2,E2!} <: DI.HessianPrep{SIG}
646634
end
647635

648636
function DI.prepare_hessian(
637+
strict::Val,
649638
f,
650639
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
651640
x,
652641
contexts::Vararg{DI.Context,C};
653-
strict::Val=Val(false),
654642
) where {C}
655643
_sig = DI.signature(f, backend, x, contexts...; strict)
656644
x_var = variablize(x, :x)

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ struct FastDifferentiationTwoArgPushforwardPrep{SIG,E1,E1!} <: DI.PushforwardPre
77
end
88

99
function DI.prepare_pushforward(
10+
strict::Val,
1011
f!,
1112
y,
1213
backend::AutoFastDifferentiation,
1314
x,
1415
tx::NTuple,
1516
contexts::Vararg{DI.Context,C};
16-
strict::Val=Val(false),
1717
) where {C}
1818
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
1919
x_var = variablize(x, :x)
@@ -108,13 +108,13 @@ struct FastDifferentiationTwoArgPullbackPrep{SIG,E1,E1!} <: DI.PullbackPrep{SIG}
108108
end
109109

110110
function DI.prepare_pullback(
111+
strict::Val,
111112
f!,
112113
y,
113114
backend::AutoFastDifferentiation,
114115
x,
115116
ty::NTuple,
116117
contexts::Vararg{DI.Context,C};
117-
strict::Val=Val(false),
118118
) where {C}
119119
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
120120
x_var = variablize(x, :x)
@@ -214,12 +214,7 @@ struct FastDifferentiationTwoArgDerivativePrep{SIG,E1,E1!} <: DI.DerivativePrep{
214214
end
215215

216216
function DI.prepare_derivative(
217-
f!,
218-
y,
219-
backend::AutoFastDifferentiation,
220-
x,
221-
contexts::Vararg{DI.Context,C};
222-
strict::Val=Val(false),
217+
strict::Val, f!, y, backend::AutoFastDifferentiation, x, contexts::Vararg{DI.Context,C};
223218
) where {C}
224219
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
225220
x_var = variablize(x, :x)
@@ -301,12 +296,12 @@ struct FastDifferentiationTwoArgJacobianPrep{SIG,E1,E1!} <: DI.JacobianPrep{SIG}
301296
end
302297

303298
function DI.prepare_jacobian(
299+
strict::Val,
304300
f!,
305301
y,
306302
backend::Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}},
307303
x,
308304
contexts::Vararg{DI.Context,C};
309-
strict::Val=Val(false),
310305
) where {C}
311306
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
312307
x_var = variablize(x, :x)

DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@ struct FiniteDiffOneArgPushforwardPrep{SIG,C,R,A,D} <: DI.PushforwardPrep{SIG}
99
end
1010

1111
function DI.prepare_pushforward(
12-
f,
13-
backend::AutoFiniteDiff,
14-
x,
15-
tx::NTuple,
16-
contexts::Vararg{DI.Context,C};
17-
strict::Val=Val(false),
12+
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C};
1813
) where {C}
1914
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
2015
fc = DI.with_contexts(f, contexts...)
@@ -130,7 +125,7 @@ struct FiniteDiffOneArgDerivativePrep{SIG,C,R,A,D} <: DI.DerivativePrep{SIG}
130125
end
131126

132127
function DI.prepare_derivative(
133-
f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false)
128+
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
134129
) where {C}
135130
_sig = DI.signature(f, backend, x, contexts...; strict)
136131
fc = DI.with_contexts(f, contexts...)
@@ -259,7 +254,7 @@ struct FiniteDiffGradientPrep{SIG,C,R,A,D} <: DI.GradientPrep{SIG}
259254
end
260255

261256
function DI.prepare_gradient(
262-
f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false)
257+
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
263258
) where {C}
264259
_sig = DI.signature(f, backend, x, contexts...; strict)
265260
fc = DI.with_contexts(f, contexts...)
@@ -347,7 +342,7 @@ struct FiniteDiffOneArgJacobianPrep{SIG,C,R,A,D} <: DI.JacobianPrep{SIG}
347342
end
348343

349344
function DI.prepare_jacobian(
350-
f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false)
345+
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
351346
) where {C}
352347
_sig = DI.signature(f, backend, x, contexts...; strict)
353348
fc = DI.with_contexts(f, contexts...)
@@ -452,7 +447,7 @@ struct FiniteDiffHessianPrep{SIG,C1,C2,RG,AG,RH,AH} <: DI.HessianPrep{SIG}
452447
end
453448

454449
function DI.prepare_hessian(
455-
f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}; strict::Val=Val(false)
450+
strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C}
456451
) where {C}
457452
_sig = DI.signature(f, backend, x, contexts...; strict)
458453
fc = DI.with_contexts(f, contexts...)

0 commit comments

Comments
 (0)