Skip to content

Commit 21a96ea

Browse files
authored
Better support for FiniteDiff and SparseDiffTools (#121)
1 parent 07d272a commit 21a96ea

39 files changed

Lines changed: 644 additions & 262 deletions

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2020
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2121
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2222
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
23-
Taped = "07d77754-e150-4737-8c94-cd238a1fb45b"
23+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
24+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
25+
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
2426
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2527
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2628

@@ -34,6 +36,8 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3436
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
3537
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
3638
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
39+
DifferentiationInterfaceSparseDiffToolsExt = ["SparseDiffTools", "Symbolics"]
40+
DifferentiationInterfaceTapirExt = "Tapir"
3741
DifferentiationInterfaceTrackerExt = "Tracker"
3842
DifferentiationInterfaceZygoteExt = "Zygote"
3943

@@ -52,7 +56,9 @@ ForwardDiff = "0.10"
5256
LinearAlgebra = "1"
5357
PolyesterForwardDiff = "0.1"
5458
ReverseDiff = "1.15"
55-
Taped = "0.1"
59+
SparseDiffTools = "2.17"
60+
Symbolics = "5.27"
61+
Tapir = "0.1"
5662
Test = "1"
5763
Tracker = "0.2"
5864
Zygote = "0.6"
@@ -75,9 +81,11 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
7581
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7682
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7783
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
84+
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
85+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
7886
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7987
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8088
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8189

8290
[targets]
83-
test = ["ADTypes", "Aqua", "Chairmarks", "DataFrames", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "Test", "Tracker", "Zygote"]
91+
test = ["ADTypes", "Aqua", "Chairmarks", "DataFrames", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "SparseDiffTools", "Symbolics", "Test", "Tracker", "Zygote"]

docs/src/backends.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function all_backends()
1616
AutoEnzyme(Enzyme.Reverse),
1717
AutoFastDifferentiation(),
1818
AutoFiniteDiff(),
19-
AutoFiniteDifferences(FiniteDifferences.central_fdm(5, 1)),
19+
AutoFiniteDifferences(FiniteDifferences.central_fdm(3, 1)),
2020
AutoForwardDiff(),
2121
AutoPolyesterForwardDiff(; chunksize=2),
2222
AutoReverseDiff(),
@@ -60,6 +60,8 @@ We also provide a few of our own:
6060

6161
```@docs
6262
AutoFastDifferentiation
63+
AutoSparseFastDifferentiation
64+
AutoTapir
6365
```
6466

6567
## Availability

ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module DifferentiationInterfaceFastDifferentiationExt
22

33
using ADTypes: ADTypes
4-
using DifferentiationInterface: AutoFastDifferentiation
4+
using DifferentiationInterface: AutoFastDifferentiation, AutoSparseFastDifferentiation
55
import DifferentiationInterface as DI
66
using FastDifferentiation:
77
derivative,
8+
hessian,
89
jacobian,
910
jacobian_times_v,
1011
jacobian_transpose_v,
@@ -13,9 +14,12 @@ using FastDifferentiation:
1314
using LinearAlgebra: dot
1415
using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
1516

16-
DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
17-
DI.supports_mutation(::AutoFastDifferentiation) = DI.MutationNotSupported()
18-
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()
17+
const AnyAutoFastDifferentiation = Union{
18+
AutoFastDifferentiation,AutoSparseFastDifferentiation
19+
}
20+
21+
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
22+
DI.supports_mutation(::AnyAutoFastDifferentiation) = DI.MutationNotSupported()
1923

2024
myvec(x::Number) = [x]
2125
myvec(x::AbstractArray) = vec(x)

ext/DifferentiationInterfaceFastDifferentiationExt/allocating.jl

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Pushforward
22

3-
function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x)
3+
function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x)
44
x_var = if x isa Number
55
only(make_variables(:x))
66
else
@@ -16,7 +16,7 @@ function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x)
1616
end
1717

1818
function DI.value_and_pushforward(
19-
f, ::AutoFastDifferentiation, x, dx, jvp_exe::RuntimeGeneratedFunction
19+
f, ::AnyAutoFastDifferentiation, x, dx, jvp_exe::RuntimeGeneratedFunction
2020
)
2121
y = f(x)
2222
v_vec = vcat(myvec(x), myvec(dx))
@@ -28,9 +28,91 @@ function DI.value_and_pushforward(
2828
end
2929
end
3030

31-
function DI.value_and_pushforward(
32-
f, backend::AutoFastDifferentiation, x, dx, extras::Nothing
31+
## Pullback
32+
33+
#=
34+
35+
# TODO: this only fails for scalar -> matrix, not sure why
36+
37+
function DI.prepare_pullback(f, ::AnyAutoFastDifferentiation, x)
38+
x_var = if x isa Number
39+
only(make_variables(:x))
40+
else
41+
make_variables(:x, size(x)...)
42+
end
43+
y_var = f(x_var)
44+
45+
x_vec_var = x_var isa Number ? [x_var] : vec(x_var)
46+
y_vec_var = y_var isa Number ? [y_var] : vec(y_var)
47+
vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var)
48+
vjp_exe = make_function(vj_vec_var, [x_vec_var; v_vec_var]; in_place=false)
49+
return vjp_exe
50+
end
51+
52+
function DI.value_and_pullback(
53+
f, ::AnyAutoFastDifferentiation, x, dy, vjp_exe::RuntimeGeneratedFunction
3354
)
34-
jvp_exe = DI.prepare_pushforward(f, backend, x)
35-
return DI.value_and_pushforward(f, backend, x, dx, jvp_exe)
55+
y = f(x)
56+
v_vec = vcat(myvec(x), myvec(dy))
57+
vj_vec = vjp_exe(v_vec)
58+
if x isa Number
59+
return y, only(vj_vec)
60+
else
61+
return y, reshape(vj_vec, size(x))
62+
end
63+
end
64+
65+
=#
66+
67+
## Jacobian
68+
69+
function DI.prepare_jacobian(f, ::AnyAutoFastDifferentiation, x)
70+
x_vec_var = make_variables(:x, size(x)...)
71+
y_vec_var = f(x_vec_var)
72+
jac_var = jacobian(vec(y_vec_var), vec(x_vec_var))
73+
jac_exe = make_function(jac_var, vec(x_vec_var); in_place=false)
74+
return jac_exe
75+
end
76+
77+
function DI.jacobian(
78+
f, backend::AnyAutoFastDifferentiation, x, jac_exe::RuntimeGeneratedFunction
79+
)
80+
return jac_exe(vec(x))
81+
end
82+
83+
function DI.value_and_jacobian(f, backend, x, extras)
84+
return f(x), DI.jacobian(f, backend, x, extras)
85+
end
86+
87+
function DI.jacobian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
88+
return DI.jacobian(f, backend, x, extras)
89+
end
90+
91+
function DI.value_and_jacobian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
92+
return DI.value_and_jacobian(f, backend, x, extras)
93+
end
94+
95+
## Hessian
96+
97+
function DI.prepare_hessian(f, ::AnyAutoFastDifferentiation, x)
98+
x_vec_var = make_variables(:x, size(x)...)
99+
y_vec_var = f(x_vec_var)
100+
hess_var = hessian(y_vec_var, vec(x_vec_var))
101+
hess_exe = make_function(hess_var, vec(x_vec_var); in_place=false)
102+
return hess_exe
103+
end
104+
105+
function DI.hessian(
106+
f, backend::AnyAutoFastDifferentiation, x, hess_exe::RuntimeGeneratedFunction
107+
)
108+
return hess_exe(vec(x))
109+
end
110+
111+
function DI.hessian(f, backend::AnyAutoFastDifferentiation, x, extras::Nothing)
112+
hess_exe = prepare_hessian(f, backend, x)
113+
return DI.hessian(f, backend, x, hess_exe)
114+
end
115+
116+
function DI.hessian!!(f, backend::AnyAutoFastDifferentiation, x, extras)
117+
return DI.hessian(f, backend, x, extras)
36118
end

ext/DifferentiationInterfaceFiniteDiffExt/DifferentiationInterfaceFiniteDiffExt.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
module DifferentiationInterfaceFiniteDiffExt
22

3-
using ADTypes: AutoFiniteDiff
3+
using ADTypes: AutoFiniteDiff, AutoSparseFiniteDiff
44
import DifferentiationInterface as DI
55
using FiniteDiff:
66
finite_difference_derivative,
77
finite_difference_gradient,
88
finite_difference_gradient!,
9+
finite_difference_hessian,
10+
finite_difference_hessian!,
911
finite_difference_jacobian,
1012
finite_difference_jacobian!
1113
using LinearAlgebra: dot, mul!
1214

15+
const AnyAutoFiniteDiff = Union{AutoFiniteDiff,AutoSparseFiniteDiff}
16+
17+
# see https://github.com/SciML/ADTypes.jl/issues/33
18+
19+
fdtype(::AutoFiniteDiff{fdt}) where {fdt} = fdt
20+
fdjtype(::AutoFiniteDiff{fdt,fdjt}) where {fdt,fdjt} = fdjt
21+
fdhtype(::AutoFiniteDiff{fdt,fdjt,fdht}) where {fdt,fdjt,fdht} = fdht
22+
23+
fdtype(::AutoSparseFiniteDiff) = Val{:central}()
24+
fdjtype(::AutoSparseFiniteDiff) = Val{:central}()
25+
fdhtype(::AutoSparseFiniteDiff) = Val{:hcentral}()
26+
1327
# see https://docs.sciml.ai/FiniteDiff/stable/#f-Definitions
1428
const FUNCTION_INPLACE = Val{true}
1529
const FUNCTION_NOT_INPLACE = Val{false}
Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,90 @@
11
## Pushforward
22

3-
function DI.value_and_pushforward!!(
4-
f, _dy::Number, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
5-
) where {fdtype}
3+
function DI.pushforward(f, backend::AnyAutoFiniteDiff, x, dx, extras::Nothing)
4+
step(t::Number) = f(x .+ t .* dx)
5+
new_dy = finite_difference_derivative(step, zero(eltype(x)), fdtype(backend))
6+
return new_dy
7+
end
8+
9+
function DI.value_and_pushforward(f, backend::AnyAutoFiniteDiff, x, dx, extras::Nothing)
610
y = f(x)
7-
step(t::Number)::Number = f(x .+ t .* dx)
8-
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
11+
step(t::Number) = f(x .+ t .* dx)
12+
new_dy = finite_difference_derivative(
13+
step, zero(eltype(x)), fdtype(backend), eltype(y), y
14+
)
915
return y, new_dy
1016
end
1117

12-
function DI.value_and_pushforward!!(
13-
f, dy::AbstractArray, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
14-
) where {fdtype}
18+
## Derivative
19+
20+
function DI.derivative(f, backend::AnyAutoFiniteDiff, x, extras::Nothing)
21+
return finite_difference_derivative(f, x, fdtype(backend))
22+
end
23+
24+
function DI.value_and_derivative(f, backend::AnyAutoFiniteDiff, x, extras::Nothing)
1525
y = f(x)
16-
step(t::Number)::AbstractArray = f(x .+ t .* dx)
17-
finite_difference_gradient!(
18-
dy, step, zero(eltype(dx)), fdtype, eltype(y), FUNCTION_NOT_INPLACE, y
19-
)
20-
return y, dy
26+
return y, finite_difference_derivative(f, x, fdtype(backend), eltype(y), y)
2127
end
2228

23-
function DI.value_and_pushforward(
24-
f, ::AutoFiniteDiff{fdtype}, x, dx, extras::Nothing
25-
) where {fdtype}
29+
## Gradient
30+
31+
function DI.gradient(f, backend::AnyAutoFiniteDiff, x::Number, extras::Nothing)
32+
return DI.derivative(f, backend, x, extras)
33+
end
34+
35+
function DI.value_and_gradient(f, backend::AnyAutoFiniteDiff, x::Number, extras::Nothing)
36+
return DI.value_and_derivative(f, backend, x, extras)
37+
end
38+
39+
function DI.gradient(f, backend::AnyAutoFiniteDiff, x::AbstractArray, extras::Nothing)
40+
return finite_difference_gradient(f, x, fdtype(backend))
41+
end
42+
43+
function DI.value_and_gradient(
44+
f, backend::AnyAutoFiniteDiff, x::AbstractArray, extras::Nothing
45+
)
2646
y = f(x)
27-
step(t::Number) = f(x .+ t .* dx)
28-
new_dy = finite_difference_derivative(step, zero(eltype(dx)), fdtype, eltype(y), y)
29-
return y, new_dy
47+
return y, finite_difference_gradient(f, x, fdtype(backend), typeof(y), y)
48+
end
49+
50+
function DI.gradient!!(
51+
f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, extras::Nothing
52+
)
53+
return finite_difference_gradient!(grad, f, x, fdtype(backend))
54+
end
55+
56+
function DI.value_and_gradient!!(
57+
f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, extras::Nothing
58+
)
59+
y = f(x)
60+
return y, finite_difference_gradient!(grad, f, x, fdtype(backend), typeof(y), y)
61+
end
62+
63+
## Jacobian
64+
65+
function DI.jacobian(f, backend::AnyAutoFiniteDiff, x, extras::Nothing)
66+
return finite_difference_jacobian(f, x, fdjtype(backend))
67+
end
68+
69+
function DI.value_and_jacobian(f, backend::AnyAutoFiniteDiff, x, extras::Nothing)
70+
y = f(x)
71+
return y, finite_difference_jacobian(f, x, fdjtype(backend), eltype(y), y)
72+
end
73+
74+
function DI.jacobian!!(f, jac, backend::AnyAutoFiniteDiff, x, extras::Nothing)
75+
return DI.jacobian(f, backend, x, extras)
76+
end
77+
78+
function DI.value_and_jacobian!!(f, jac, backend::AnyAutoFiniteDiff, x, extras::Nothing)
79+
return DI.value_and_jacobian(f, backend, x, extras)
80+
end
81+
82+
## Hessian
83+
84+
function DI.hessian(f, backend::AnyAutoFiniteDiff, x, extras::Nothing)
85+
return finite_difference_hessian(f, x, fdhtype(backend))
86+
end
87+
88+
function DI.hessian!!(f, hess, backend::AnyAutoFiniteDiff, x, extras::Nothing)
89+
return finite_difference_hessian!(hess, f, x, fdhtype(backend))
3090
end

ext/DifferentiationInterfaceFiniteDiffExt/mutating.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,44 @@ function DI.value_and_pushforward!!(
44
f!,
55
y::AbstractArray,
66
dy::AbstractArray,
7-
::AutoFiniteDiff{fdtype},
7+
backend::AnyAutoFiniteDiff,
88
x,
99
dx,
1010
extras::Nothing,
11-
) where {fdtype}
11+
)
1212
function step(t::Number)::AbstractArray
1313
new_y = similar(y)
1414
f!(new_y, x .+ t .* dx)
1515
return new_y
1616
end
17-
finite_difference_gradient!(
18-
dy, step, zero(eltype(dx)), fdtype, eltype(y), FUNCTION_NOT_INPLACE, y
17+
f!(y, x)
18+
new_dy = finite_difference_derivative(
19+
step, zero(eltype(x)), fdtype(backend), eltype(y), y
1920
)
21+
return y, new_dy
22+
end
23+
24+
## Derivative
25+
26+
function DI.value_and_derivative!!(
27+
f!, y::AbstractArray, der::AbstractArray, backend::AnyAutoFiniteDiff, x, extras::Nothing
28+
)
29+
f!(y, x)
30+
finite_difference_gradient!(der, f!, x, fdtype(backend), eltype(y), FUNCTION_INPLACE, y)
31+
return y, der
32+
end
33+
34+
## Jacobian
35+
36+
function DI.value_and_jacobian!!(
37+
f!,
38+
y::AbstractArray,
39+
jac::AbstractMatrix,
40+
backend::AnyAutoFiniteDiff,
41+
x,
42+
extras::Nothing,
43+
)
2044
f!(y, x)
21-
return y, dy
45+
finite_difference_jacobian!(jac, f!, x, fdjtype(backend), eltype(y), y)
46+
return y, jac
2247
end

0 commit comments

Comments
 (0)