Skip to content

Commit 0afa039

Browse files
authored
fix: simplify basis (#692)
* fix: propagate traits properly * Remove backend-specific basis * Remove tests * Fix JLArrays * Imports * Coverage
1 parent 66f7ff3 commit 0afa039

13 files changed

Lines changed: 71 additions & 208 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/DifferentiationInterfaceReverseDiffExt.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ using ReverseDiff:
2626

2727
DI.check_available(::AutoReverseDiff) = true
2828

29-
function DI.basis(::AutoReverseDiff, a::AbstractArray{T}, i) where {T<:Real}
30-
return DI.OneElement(i, one(T), a)
31-
end
32-
3329
include("onearg.jl")
3430
include("twoarg.jl")
3531
include("utils.jl")

DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite
2828
if DI.pushforward_performance(backend) isa DI.PushforwardFast
2929
p = similar(y)
3030
prep = DI.prepare_pushforward_same_point(
31-
f, backend, x, (DI.basis(backend, x, first(eachindex(x))),)
31+
f, backend, x, (DI.basis(x, first(eachindex(x))),)
3232
)
3333
for (kj, j) in enumerate(eachindex(x))
34-
pushforward!(f, (p,), prep, backend, x, (DI.basis(backend, x, j),))
34+
pushforward!(f, (p,), prep, backend, x, (DI.basis(x, j),))
3535
for ki in LinearIndices(p)
3636
if abs(p[ki]) > atol
3737
push!(I, ki)
@@ -42,10 +42,10 @@ function ADTypes.jacobian_sparsity(f, x, detector::DI.DenseSparsityDetector{:ite
4242
else
4343
p = similar(x)
4444
prep = DI.prepare_pullback_same_point(
45-
f, backend, x, (DI.basis(backend, y, first(eachindex(y))),)
45+
f, backend, x, (DI.basis(y, first(eachindex(y))),)
4646
)
4747
for (ki, i) in enumerate(eachindex(y))
48-
pullback!(f, (p,), prep, backend, x, (DI.basis(backend, y, i),))
48+
pullback!(f, (p,), prep, backend, x, (DI.basis(y, i),))
4949
for kj in LinearIndices(p)
5050
if abs(p[kj]) > atol
5151
push!(I, ki)
@@ -64,10 +64,10 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{
6464
if DI.pushforward_performance(backend) isa DI.PushforwardFast
6565
p = similar(y)
6666
prep = DI.prepare_pushforward_same_point(
67-
f!, y, backend, x, (DI.basis(backend, x, first(eachindex(x))),)
67+
f!, y, backend, x, (DI.basis(x, first(eachindex(x))),)
6868
)
6969
for (kj, j) in enumerate(eachindex(x))
70-
pushforward!(f!, y, (p,), prep, backend, x, (DI.basis(backend, x, j),))
70+
pushforward!(f!, y, (p,), prep, backend, x, (DI.basis(x, j),))
7171
for ki in LinearIndices(p)
7272
if abs(p[ki]) > atol
7373
push!(I, ki)
@@ -78,10 +78,10 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DI.DenseSparsityDetector{
7878
else
7979
p = similar(x)
8080
prep = DI.prepare_pullback_same_point(
81-
f!, y, backend, x, (DI.basis(backend, y, first(eachindex(y))),)
81+
f!, y, backend, x, (DI.basis(y, first(eachindex(y))),)
8282
)
8383
for (ki, i) in enumerate(eachindex(y))
84-
pullback!(f!, y, (p,), prep, backend, x, (DI.basis(backend, y, i),))
84+
pullback!(f!, y, (p,), prep, backend, x, (DI.basis(y, i),))
8585
for kj in LinearIndices(p)
8686
if abs(p[kj]) > atol
8787
push!(I, ki)
@@ -98,11 +98,9 @@ function ADTypes.hessian_sparsity(f, x, detector::DI.DenseSparsityDetector{:iter
9898
n = length(x)
9999
I, J = Int[], Int[]
100100
p = similar(x)
101-
prep = DI.prepare_hvp_same_point(
102-
f, backend, x, (DI.basis(backend, x, first(eachindex(x))),)
103-
)
101+
prep = DI.prepare_hvp_same_point(f, backend, x, (DI.basis(x, first(eachindex(x))),))
104102
for (kj, j) in enumerate(eachindex(x))
105-
hvp!(f, (p,), prep, backend, x, (DI.basis(backend, x, j),))
103+
hvp!(f, (p,), prep, backend, x, (DI.basis(x, j),))
106104
for ki in LinearIndices(p)
107105
if abs(p[ki]) > atol
108106
push!(I, ki)

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function _prepare_sparse_hessian_aux(
5252
(; N, A) = batch_size_settings
5353
dense_backend = dense_ad(backend)
5454
groups = column_groups(coloring_result)
55-
seeds = [DI.multibasis(backend, x, eachindex(x)[group]) for group in groups]
55+
seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups]
5656
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
5757
batched_seeds = [
5858
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function _prepare_sparse_jacobian_aux_aux(
9595
(; N, A) = batch_size_settings
9696
dense_backend = dense_ad(backend)
9797
groups = column_groups(coloring_result)
98-
seeds = [DI.multibasis(backend, x, eachindex(x)[group]) for group in groups]
98+
seeds = [DI.multibasis(x, eachindex(x)[group]) for group in groups]
9999
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
100100
batched_seeds = [
101101
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
@@ -126,7 +126,7 @@ function _prepare_sparse_jacobian_aux_aux(
126126
(; N, A) = batch_size_settings
127127
dense_backend = dense_ad(backend)
128128
groups = row_groups(coloring_result)
129-
seeds = [DI.multibasis(backend, y, eachindex(y)[group]) for group in groups]
129+
seeds = [DI.multibasis(y, eachindex(y)[group]) for group in groups]
130130
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
131131
batched_seeds = [
132132
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,8 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
8888
groups_forward = column_groups(coloring_result)
8989
groups_reverse = row_groups(coloring_result)
9090

91-
seeds_forward = [
92-
DI.multibasis(backend, x, eachindex(x)[group]) for group in groups_forward
93-
]
94-
seeds_reverse = [
95-
DI.multibasis(backend, y, eachindex(y)[group]) for group in groups_reverse
96-
]
91+
seeds_forward = [DI.multibasis(x, eachindex(x)[group]) for group in groups_forward]
92+
seeds_reverse = [DI.multibasis(y, eachindex(y)[group]) for group in groups_reverse]
9793

9894
compressed_matrix_forward = stack(_ -> vec(similar(y)), groups_forward; dims=2)
9995
compressed_matrix_reverse = stack(_ -> vec(similar(x)), groups_reverse; dims=1)

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ function _prepare_jacobian_aux(
136136
contexts::Vararg{Context,C},
137137
) where {B,FY,C}
138138
(; N, A) = batch_size_settings
139-
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
139+
seeds = [basis(x, ind) for ind in eachindex(x)]
140140
batched_seeds = [
141141
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
142142
]
@@ -159,7 +159,7 @@ function _prepare_jacobian_aux(
159159
contexts::Vararg{Context,C},
160160
) where {B,FY,C}
161161
(; N, A) = batch_size_settings
162-
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
162+
seeds = [basis(y, ind) for ind in eachindex(y)]
163163
batched_seeds = [
164164
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
165165
]

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function _prepare_pullback_aux(
126126
ty::NTuple,
127127
contexts::Vararg{Context,C},
128128
) where {F,C}
129-
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
129+
dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x)))
130130
pushforward_prep = prepare_pushforward(f, backend, x, (dx,), contexts...)
131131
return PushforwardPullbackPrep(pushforward_prep)
132132
end
@@ -140,7 +140,7 @@ function _prepare_pullback_aux(
140140
ty::NTuple,
141141
contexts::Vararg{Context,C},
142142
) where {F,C}
143-
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
143+
dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x)))
144144
pushforward_prep = prepare_pushforward(f!, y, backend, x, (dx,), contexts...)
145145
return PushforwardPullbackPrep(pushforward_prep)
146146
end
@@ -169,7 +169,7 @@ function _pullback_via_pushforward(
169169
contexts::Vararg{Context,C},
170170
) where {F,C}
171171
dx = map(CartesianIndices(x)) do j
172-
t1 = pushforward(f, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...)
172+
t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)
173173
dot(only(t1), dy)
174174
end
175175
return dx
@@ -255,9 +255,7 @@ function _pullback_via_pushforward(
255255
contexts::Vararg{Context,C},
256256
) where {F,C}
257257
dx = map(CartesianIndices(x)) do j # preserve shape
258-
t1 = pushforward(
259-
f!, y, pushforward_prep, backend, x, (basis(backend, x, j),), contexts...
260-
)
258+
t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)
261259
dot(only(t1), dy)
262260
end
263261
return dx

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function _prepare_pushforward_aux(
127127
contexts::Vararg{Context,C},
128128
) where {F,C}
129129
y = f(x, map(unwrap, contexts)...)
130-
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
130+
dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y)))
131131
pullback_prep = prepare_pullback(f, backend, x, (dy,), contexts...)
132132
return PullbackPushforwardPrep(pullback_prep)
133133
end
@@ -141,7 +141,7 @@ function _prepare_pushforward_aux(
141141
tx::NTuple,
142142
contexts::Vararg{Context,C},
143143
) where {F,C}
144-
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
144+
dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y)))
145145
pullback_prep = prepare_pullback(f!, y, backend, x, (dy,), contexts...)
146146
return PullbackPushforwardPrep(pullback_prep)
147147
end
@@ -172,7 +172,7 @@ function _pushforward_via_pullback(
172172
contexts::Vararg{Context,C},
173173
) where {F,C}
174174
dy = map(CartesianIndices(y)) do i
175-
t1 = pullback(f, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
175+
t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)
176176
dot(only(t1), dx)
177177
end
178178
return dy
@@ -244,7 +244,7 @@ function _pushforward_via_pullback(
244244
contexts::Vararg{Context,C},
245245
) where {F,C}
246246
dy = map(CartesianIndices(y)) do i # preserve shape
247-
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(backend, y, i),), contexts...)
247+
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)
248248
dot(only(t1), dx)
249249
end
250250
return dy

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 2 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,12 @@
11
abstract type FromPrimitive <: AbstractADType end
22

3-
function basis(fromprim::FromPrimitive, x::AbstractArray, i)
4-
return basis(fromprim.backend, x, i)
5-
end
6-
7-
function multibasis(fromprim::FromPrimitive, x::AbstractArray, inds)
8-
return multibasis(fromprim.backend, x, inds)
9-
end
10-
113
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
124
inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend)
135

14-
function BatchSizeSettings(fromprim::FromPrimitive, x::AbstractArray)
15-
return BatchSizeSettings(fromprim.backend, x)
16-
end
17-
18-
function BatchSizeSettings(fromprim::FromPrimitive, N::Integer)
19-
return BatchSizeSettings(fromprim.backend, N)
20-
end
21-
22-
## Forward (no longer used)
23-
24-
#=
25-
struct AutoForwardFromPrimitive{B} <: FromPrimitive
26-
backend::B
27-
end
28-
29-
ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode()
30-
31-
function threshold_batchsize(fromprim::AutoForwardFromPrimitive, dimension::Integer)
32-
return AutoForwardFromPrimitive(threshold_batchsize(fromprim.backend, dimension))
33-
end
34-
35-
struct FromPrimitivePushforwardPrep{E<:PushforwardPrep} <: PushforwardPrep
36-
pushforward_prep::E
37-
end
38-
39-
function prepare_pushforward(
40-
f::F, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}
41-
) where {F,C}
42-
primitive_prep = prepare_pushforward(f, fromprim.backend, x, tx, contexts...)
43-
return FromPrimitivePushforwardPrep(primitive_prep)
44-
end
45-
46-
function prepare_pushforward(
47-
f!::F, y, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}
48-
) where {F,C}
49-
primitive_prep = prepare_pushforward(f!, y, fromprim.backend, x, tx, contexts...)
50-
return FromPrimitivePushforwardPrep(primitive_prep)
51-
end
52-
53-
function value_and_pushforward(
54-
f::F,
55-
prep::FromPrimitivePushforwardPrep,
56-
fromprim::AutoForwardFromPrimitive,
57-
x,
58-
tx::NTuple,
59-
contexts::Vararg{Context,C},
60-
) where {F,C}
61-
return value_and_pushforward(
62-
f, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
63-
)
64-
end
65-
66-
function value_and_pushforward(
67-
f!::F,
68-
y,
69-
prep::FromPrimitivePushforwardPrep,
70-
fromprim::AutoForwardFromPrimitive,
71-
x,
72-
tx::NTuple,
73-
contexts::Vararg{Context,C},
74-
) where {F,C}
75-
return value_and_pushforward(
76-
f!, y, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
77-
)
6+
function pick_batchsize(fromprim::FromPrimitive, x_or_N::Union{AbstractArray,Integer})
7+
return pick_batchsize(fromprim.backend, x_or_N)
788
end
799

80-
function value_and_pushforward!(
81-
f::F,
82-
ty::NTuple,
83-
prep::FromPrimitivePushforwardPrep,
84-
fromprim::AutoForwardFromPrimitive,
85-
x,
86-
tx::NTuple,
87-
contexts::Vararg{Context,C},
88-
) where {F,C}
89-
return value_and_pushforward!(
90-
f, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
91-
)
92-
end
93-
94-
function value_and_pushforward!(
95-
f!::F,
96-
y,
97-
ty::NTuple,
98-
prep::FromPrimitivePushforwardPrep,
99-
fromprim::AutoForwardFromPrimitive,
100-
x,
101-
tx::NTuple,
102-
contexts::Vararg{Context,C},
103-
) where {F,C}
104-
return value_and_pushforward!(
105-
f!, y, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
106-
)
107-
end
108-
=#
109-
110-
## Reverse
111-
11210
struct AutoReverseFromPrimitive{B} <: FromPrimitive
11311
backend::B
11412
end

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function _prepare_hessian_aux(
9191
contexts::Vararg{Context,C},
9292
) where {B,F,C}
9393
(; N, A) = batch_size_settings
94-
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
94+
seeds = [basis(x, ind) for ind in eachindex(x)]
9595
batched_seeds = [
9696
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
9797
]

0 commit comments

Comments
 (0)