Skip to content

Commit 2751235

Browse files
authored
Fallback from sparse to dense (#183)
* Fallback from sparse to dense * Fix ambiguities * Missing Any * Fix more fixes
1 parent 19b5be4 commit 2751235

16 files changed

Lines changed: 269 additions & 139 deletions

File tree

DifferentiationInterface/docs/src/backends.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ AutoZygote
6565

6666
### Sparse
6767

68-
!!! warning
69-
For sparse backends, only the Jacobian and Hessian operators are implemented.
68+
For sparse backends, only the Jacobian and Hessian operators are implemented differently, the other operators behave the same as for the corresponding dense backend.
7069

7170
```@docs
7271
AutoSparseFastDifferentiation

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ const AnyAutoFastDifferentiation = Union{
3131
AutoFastDifferentiation,AutoSparseFastDifferentiation
3232
}
3333

34-
DI.check_available(::AnyAutoFastDifferentiation) = true
35-
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
36-
DI.pushforward_performance(::AnyAutoFastDifferentiation) = DI.PushforwardFast()
37-
DI.pullback_performance(::AnyAutoFastDifferentiation) = DI.PullbackSlow()
34+
DI.check_available(::AutoFastDifferentiation) = true
35+
DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
36+
DI.pushforward_performance(::AutoFastDifferentiation) = DI.PushforwardFast()
37+
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()
3838

3939
monovec(x::Number) = Fill(x, 1)
4040

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
66
jvp_exe!::E2
77
end
88

9-
function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
9+
function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx)
1010
y_prototype = f(x)
1111
x_var = if x isa Number
1212
only(make_variables(:x))
@@ -24,11 +24,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
2424
end
2525

2626
function DI.pushforward(
27-
f,
28-
::AnyAutoFastDifferentiation,
29-
x,
30-
dx,
31-
extras::FastDifferentiationOneArgPushforwardExtras,
27+
f, ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras
3228
)
3329
v_vec = vcat(myvec(x), myvec(dx))
3430
if extras.y_prototype isa Number
@@ -41,7 +37,7 @@ end
4137
function DI.pushforward!(
4238
f,
4339
dy,
44-
::AnyAutoFastDifferentiation,
40+
::AutoFastDifferentiation,
4541
x,
4642
dx,
4743
extras::FastDifferentiationOneArgPushforwardExtras,
@@ -53,7 +49,7 @@ end
5349

5450
function DI.value_and_pushforward(
5551
f,
56-
backend::AnyAutoFastDifferentiation,
52+
backend::AutoFastDifferentiation,
5753
x,
5854
dx,
5955
extras::FastDifferentiationOneArgPushforwardExtras,
@@ -64,7 +60,7 @@ end
6460
function DI.value_and_pushforward!(
6561
f,
6662
dy,
67-
backend::AnyAutoFastDifferentiation,
63+
backend::AutoFastDifferentiation,
6864
x,
6965
dx,
7066
extras::FastDifferentiationOneArgPushforwardExtras,
@@ -84,7 +80,7 @@ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras
8480
der_exe!::E2
8581
end
8682

87-
function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
83+
function DI.prepare_derivative(f, ::AutoFastDifferentiation, x)
8884
y_prototype = f(x)
8985
x_var = only(make_variables(:x))
9086
y_var = f(x_var)
@@ -98,7 +94,7 @@ function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
9894
end
9995

10096
function DI.derivative(
101-
f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
97+
f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
10298
)
10399
if extras.y_prototype isa Number
104100
return only(extras.der_exe(monovec(x)))
@@ -108,19 +104,15 @@ function DI.derivative(
108104
end
109105

110106
function DI.derivative!(
111-
f,
112-
der,
113-
::AnyAutoFastDifferentiation,
114-
x,
115-
extras::FastDifferentiationOneArgDerivativeExtras,
107+
f, der, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
116108
)
117109
extras.der_exe!(vec(der), monovec(x))
118110
return der
119111
end
120112

121113
function DI.value_and_derivative(
122114
f,
123-
backend::AnyAutoFastDifferentiation,
115+
backend::AutoFastDifferentiation,
124116
x,
125117
extras::FastDifferentiationOneArgDerivativeExtras,
126118
)
@@ -130,7 +122,7 @@ end
130122
function DI.value_and_derivative!(
131123
f,
132124
der,
133-
backend::AnyAutoFastDifferentiation,
125+
backend::AutoFastDifferentiation,
134126
x,
135127
extras::FastDifferentiationOneArgDerivativeExtras,
136128
)
@@ -144,7 +136,7 @@ struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
144136
jac_exe!::E2
145137
end
146138

147-
function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
139+
function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x)
148140
y_prototype = f(x)
149141
x_var = make_variables(:x, size(x)...)
150142
y_var = f(x_var)
@@ -158,37 +150,30 @@ function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
158150
end
159151

160152
function DI.gradient(
161-
f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
153+
f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
162154
)
163155
jac = extras.jac_exe(vec(x))
164156
grad_vec = @view jac[1, :]
165157
return reshape(grad_vec, size(x))
166158
end
167159

168160
function DI.gradient!(
169-
f,
170-
grad,
171-
::AnyAutoFastDifferentiation,
172-
x,
173-
extras::FastDifferentiationOneArgGradientExtras,
161+
f, grad, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
174162
)
175163
extras.jac_exe!(reshape(grad, 1, length(grad)), vec(x))
176164
return grad
177165
end
178166

179167
function DI.value_and_gradient(
180-
f,
181-
backend::AnyAutoFastDifferentiation,
182-
x,
183-
extras::FastDifferentiationOneArgGradientExtras,
168+
f, backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
184169
)
185170
return f(x), DI.gradient(f, backend, x, extras)
186171
end
187172

188173
function DI.value_and_gradient!(
189174
f,
190175
grad,
191-
backend::AnyAutoFastDifferentiation,
176+
backend::AutoFastDifferentiation,
192177
x,
193178
extras::FastDifferentiationOneArgGradientExtras,
194179
)
@@ -261,7 +246,7 @@ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E2} <:
261246
der2_exe!::E2
262247
end
263248

264-
function DI.prepare_second_derivative(f, ::AnyAutoFastDifferentiation, x)
249+
function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
265250
y_prototype = f(x)
266251
x_var = only(make_variables(:x))
267252
y_var = f(x_var)
@@ -278,7 +263,7 @@ end
278263

279264
function DI.second_derivative(
280265
f,
281-
::AnyAutoFastDifferentiation,
266+
::AutoFastDifferentiation,
282267
x,
283268
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
284269
)
@@ -292,7 +277,7 @@ end
292277
function DI.second_derivative!(
293278
f,
294279
der2,
295-
backend::AnyAutoFastDifferentiation,
280+
backend::AutoFastDifferentiation,
296281
x,
297282
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
298283
)
@@ -307,7 +292,7 @@ struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras
307292
hvp_exe!::E2
308293
end
309294

310-
function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
295+
function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v)
311296
x_var = make_variables(:x, size(x)...)
312297
y_var = f(x_var)
313298

@@ -318,14 +303,14 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
318303
return FastDifferentiationHVPExtras(hvp_exe, hvp_exe!)
319304
end
320305

321-
function DI.hvp(f, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras)
306+
function DI.hvp(f, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras)
322307
v_vec = vcat(vec(x), vec(v))
323308
hv_vec = extras.hvp_exe(v_vec)
324309
return reshape(hv_vec, size(x))
325310
end
326311

327312
function DI.hvp!(
328-
f, p, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras
313+
f, p, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras
329314
)
330315
v_vec = vcat(vec(x), vec(v))
331316
extras.hvp_exe!(p, v_vec)

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
55
jvp_exe!::E2
66
end
77

8-
function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x, dx)
8+
function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx)
99
x_var = if x isa Number
1010
only(make_variables(:x))
1111
else
@@ -25,7 +25,7 @@ end
2525
function DI.value_and_pushforward(
2626
f!,
2727
y,
28-
::AnyAutoFastDifferentiation,
28+
::AutoFastDifferentiation,
2929
x,
3030
dx,
3131
extras::FastDifferentiationTwoArgPushforwardExtras,
@@ -40,7 +40,7 @@ function DI.value_and_pushforward!(
4040
f!,
4141
y,
4242
dy,
43-
::AnyAutoFastDifferentiation,
43+
::AutoFastDifferentiation,
4444
x,
4545
dx,
4646
extras::FastDifferentiationTwoArgPushforwardExtras,
@@ -54,7 +54,7 @@ end
5454
function DI.pushforward(
5555
f!,
5656
y,
57-
::AnyAutoFastDifferentiation,
57+
::AutoFastDifferentiation,
5858
x,
5959
dx,
6060
extras::FastDifferentiationTwoArgPushforwardExtras,
@@ -68,7 +68,7 @@ function DI.pushforward!(
6868
f!,
6969
y,
7070
dy,
71-
::AnyAutoFastDifferentiation,
71+
::AutoFastDifferentiation,
7272
x,
7373
dx,
7474
extras::FastDifferentiationTwoArgPushforwardExtras,
@@ -85,7 +85,7 @@ struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
8585
der_exe!::E2
8686
end
8787

88-
function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x)
88+
function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x)
8989
x_var = only(make_variables(:x))
9090
y_var = make_variables(:y, size(y)...)
9191
f!(y_var, x_var)
@@ -99,11 +99,7 @@ function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x)
9999
end
100100

101101
function DI.value_and_derivative(
102-
f!,
103-
y,
104-
::AnyAutoFastDifferentiation,
105-
x,
106-
extras::FastDifferentiationTwoArgDerivativeExtras,
102+
f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras
107103
)
108104
f!(y, x)
109105
der = reshape(extras.der_exe(monovec(x)), size(y))
@@ -114,7 +110,7 @@ function DI.value_and_derivative!(
114110
f!,
115111
y,
116112
der,
117-
::AnyAutoFastDifferentiation,
113+
::AutoFastDifferentiation,
118114
x,
119115
extras::FastDifferentiationTwoArgDerivativeExtras,
120116
)
@@ -124,11 +120,7 @@ function DI.value_and_derivative!(
124120
end
125121

126122
function DI.derivative(
127-
f!,
128-
y,
129-
::AnyAutoFastDifferentiation,
130-
x,
131-
extras::FastDifferentiationTwoArgDerivativeExtras,
123+
f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras
132124
)
133125
der = reshape(extras.der_exe(monovec(x)), size(y))
134126
return der
@@ -138,7 +130,7 @@ function DI.derivative!(
138130
f!,
139131
y,
140132
der,
141-
::AnyAutoFastDifferentiation,
133+
::AutoFastDifferentiation,
142134
x,
143135
extras::FastDifferentiationTwoArgDerivativeExtras,
144136
)

DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceSparseDiffToolsExt
33
using ADTypes
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6-
HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
6+
AnyAutoSparse, HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
77
using SparseDiffTools:
88
JacPrototypeSparsityDetection,
99
SymbolicsSparsityDetection,
@@ -12,35 +12,14 @@ using SparseDiffTools:
1212
sparse_jacobian_cache
1313
using Symbolics: Symbolics
1414

15-
AnyOneArgAutoSparse = Union{
15+
AnyAutoSparseNoSymbolic = Union{
1616
AutoSparseFiniteDiff,
1717
AutoSparseForwardDiff,
1818
AutoSparsePolyesterForwardDiff,
1919
AutoSparseReverseDiff,
2020
AutoSparseZygote,
2121
}
2222

23-
AnyTwoArgAutoSparse = Union{
24-
AutoSparseFiniteDiff,
25-
AutoSparseForwardDiff,
26-
AutoSparsePolyesterForwardDiff,
27-
AutoSparseReverseDiff,
28-
}
29-
30-
dense(::AutoSparseFiniteDiff) = AutoFiniteDiff()
31-
dense(backend::AutoSparseReverseDiff) = AutoReverseDiff(backend.compile)
32-
dense(::AutoSparseZygote) = AutoZygote()
33-
34-
function dense(backend::AutoSparseForwardDiff{chunksize,T}) where {chunksize,T}
35-
return AutoForwardDiff{chunksize,T}(backend.tag)
36-
end
37-
38-
function dense(::AutoSparsePolyesterForwardDiff{chunksize}) where {chunksize}
39-
return AutoSparsePolyesterForwardDiff{chunksize}()
40-
end
41-
42-
DI.check_available(backend::AnyOneArgAutoSparse) = DI.check_available(dense(backend))
43-
4423
include("onearg.jl")
4524
include("twoarg.jl")
4625

0 commit comments

Comments
 (0)