Skip to content

Commit 38fd496

Browse files
authored
Exploit more of FiniteDifferences (#114)
* Exploit more of FiniteDifferences * Remove exception
1 parent 6155b2a commit 38fd496

4 files changed

Lines changed: 79 additions & 37 deletions

File tree

ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceFiniteDifferencesExt
33
using ADTypes: AutoFiniteDifferences
44
import DifferentiationInterface as DI
55
using FillArrays: OneElement
6-
using FiniteDifferences: FiniteDifferences, jvp, j′vp
6+
using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
77
using LinearAlgebra: dot
88

99
DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported()
@@ -12,22 +12,62 @@ function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (ht
1212
return FiniteDifferences.to_vec(collect(a))
1313
end
1414

15-
function DI.value_and_pushforward(
16-
f, backend::AutoFiniteDifferences{fdm}, x, dx, extras::Nothing
17-
) where {fdm}
18-
y = f(x)
19-
return y, jvp(backend.fdm, f, (x, dx))
15+
## Pushforward
16+
17+
function DI.pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing)
18+
return jvp(backend.fdm, f, (x, dx))
19+
end
20+
21+
function DI.value_and_pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing)
22+
return f(x), DI.pushforward(f, backend, x, dx, extras)
23+
end
24+
25+
## Pullback
26+
27+
function DI.pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing)
28+
return only(j′vp(backend.fdm, f, dy, x))
2029
end
2130

22-
#=
23-
# TODO: why does this fail?
31+
function DI.value_and_pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing)
32+
return f(x), DI.pullback(f, backend, x, dy, extras)
33+
end
34+
35+
## Gradient
36+
37+
function DI.gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing)
38+
return only(grad(backend.fdm, f, x))
39+
end
40+
41+
function DI.value_and_gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing)
42+
return f(x), DI.gradient(f, backend, x, extras)
43+
end
44+
45+
function DI.gradient!!(f, grad, backend::AutoFiniteDifferences, x, extras::Nothing)
46+
return DI.gradient(f, backend, x, extras)
47+
end
48+
49+
function DI.value_and_gradient!!(
50+
f, grad, backend::AutoFiniteDifferences, x, extras::Nothing
51+
)
52+
return DI.value_and_gradient(f, backend, x)
53+
end
54+
55+
## Jacobian
56+
57+
function DI.jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing)
58+
return only(jacobian(backend.fdm, f, x))
59+
end
60+
61+
function DI.value_and_jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing)
62+
return f(x), DI.jacobian(f, backend, x, extras)
63+
end
64+
65+
function DI.jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing)
66+
return DI.jacobian(f, backend, x, extras)
67+
end
2468

25-
function DI.value_and_pullback(
26-
f, backend::AutoFiniteDifferences{fdm}, x, dy, extras::Nothing
27-
) where {fdm}
28-
y = f(x)
29-
return y, j′vp(backend.fdm, f, x, dy)[1]
69+
function DI.value_and_jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing)
70+
return DI.value_and_jacobian(f, backend, x)
3071
end
31-
=#
3272

3373
end

src/backends.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ function check_available(backend::AbstractADType)
1212
return true
1313
catch exception
1414
@warn "Backend $backend not available" exception
15-
throw(exception)
1615
if exception isa MethodError
1716
return false
1817
else

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
function value_and_gradient(
77
f, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x)
88
)
9-
return value_and_pullback(f, backend, x, true, extras)
9+
return value_and_pullback(f, backend, x, one(eltype(x)), extras)
1010
end
1111

1212
"""
@@ -15,7 +15,7 @@ end
1515
function value_and_gradient!!(
1616
f, grad, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x)
1717
)
18-
return value_and_pullback!!(f, grad, backend, x, true, extras)
18+
return value_and_pullback!!(f, grad, backend, x, one(eltype(x)), extras)
1919
end
2020

2121
"""

src/jacobian.jl

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,15 @@ end
3333
value_and_jacobian!!(f, jac, backend, x, [extras]) -> (y, jac)
3434
"""
3535
function value_and_jacobian!!(
36-
f,
37-
jac::AbstractMatrix,
38-
backend::AbstractADType,
39-
x,
40-
extras=prepare_jacobian(f, backend, x),
36+
f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x)
4137
)
4238
return value_and_jacobian_aux!!(
4339
f, jac, backend, x, extras, pushforward_performance(backend)
4440
)
4541
end
4642

4743
function value_and_jacobian_aux!!(
48-
f, jac, backend, x::AbstractArray, extras, ::PushforwardFast
44+
f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardFast
4945
)
5046
y = f(x)
5147
for (k, j) in enumerate(CartesianIndices(x))
@@ -59,7 +55,7 @@ function value_and_jacobian_aux!!(
5955
end
6056

6157
function value_and_jacobian_aux!!(
62-
f, jac, backend, x::AbstractArray, extras, ::PushforwardSlow
58+
f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardSlow
6359
)
6460
y = f(x)
6561
for (k, i) in enumerate(CartesianIndices(y))
@@ -83,11 +79,7 @@ end
8379
jacobian!!(f, jac, backend, x, [extras]) -> jac
8480
"""
8581
function jacobian!!(
86-
f,
87-
jac::AbstractMatrix,
88-
backend::AbstractADType,
89-
x,
90-
extras=prepare_jacobian(f, backend, x),
82+
f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x)
9183
)
9284
return value_and_jacobian!!(f, jac, backend, x, extras)[2]
9385
end
@@ -98,19 +90,22 @@ end
9890
value_and_jacobian!!(f!, y, jac, backend, x, [extras]) -> (y, jac)
9991
"""
10092
function value_and_jacobian!!(
101-
f!,
102-
y::AbstractArray,
103-
jac::AbstractMatrix,
104-
backend::AbstractADType,
105-
x::AbstractArray,
106-
extras=prepare_jacobian(f!, backend, y, x),
93+
f!, y, jac, backend::AbstractADType, x, extras=prepare_jacobian(f!, backend, y, x)
10794
)
10895
return value_and_jacobian_aux!!(
10996
f!, y, jac, backend, x, extras, pushforward_performance(backend)
11097
)
11198
end
11299

113-
function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardFast)
100+
function value_and_jacobian_aux!!(
101+
f!,
102+
y::AbstractArray,
103+
jac::AbstractMatrix,
104+
backend,
105+
x::AbstractArray,
106+
extras,
107+
::PushforwardFast,
108+
)
114109
f!(y, x)
115110
for (k, j) in enumerate(CartesianIndices(x))
116111
dx_j = basis(backend, x, j)
@@ -124,7 +119,15 @@ function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardF
124119
return y, jac
125120
end
126121

127-
function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardSlow)
122+
function value_and_jacobian_aux!!(
123+
f!,
124+
y::AbstractArray,
125+
jac::AbstractMatrix,
126+
backend,
127+
x::AbstractArray,
128+
extras,
129+
::PushforwardSlow,
130+
)
128131
f!(y, x)
129132
for (k, i) in enumerate(CartesianIndices(y))
130133
dy_i = basis(backend, y, i)

0 commit comments

Comments
 (0)