Skip to content

Commit b7d264c

Browse files
committed
Adds an abstract type to NamedArrayPartition
In order to resolve #443 where I would like to be able to subtype from NamedArrayPartition. This is a large restructuring of NamedArrayPartition but all tests within RecursiveArrayTools do pass. I'm not sure of other tests that I would need to run to ensure further compatibility. This maintains all current functionality but allows a user to create a new subtype of AbstractNamedArrayPartition as long as it maintains the same structure
1 parent 3450c85 commit b7d264c

3 files changed

Lines changed: 78 additions & 58 deletions

File tree

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,6 @@ export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_pus
150150
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
151151
recursive_unitless_bottom_eltype, recursive_unitless_eltype
152152

153-
export ArrayPartition, NamedArrayPartition
153+
export ArrayPartition, NamedArrayPartition, AbstractNamedArrayPartition
154154

155155
end # module

src/named_array_partition.jl

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
abstract type AbstractNamedArrayPartition{T, A, NT} <: AbstractVector{T} end
2+
13
"""
24
NamedArrayPartition(; kwargs...)
35
NamedArrayPartition(x::NamedTuple)
@@ -6,137 +8,155 @@ Similar to an `ArrayPartition` but the individual arrays can be accessed via the
68
constructor-specified names. However, unlike `ArrayPartition`, each individual array
79
must have the same element type.
810
"""
9-
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractVector{T}
11+
struct NamedArrayPartition{T, A <: ArrayPartition{T}, NT <: NamedTuple} <: AbstractNamedArrayPartition{T, A, NT}
1012
array_partition::A
1113
names_to_indices::NT
1214
end
13-
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs))
14-
function NamedArrayPartition(x::NamedTuple)
15+
(::Type{T})(; kwargs...) where {T<:AbstractNamedArrayPartition} = T(NamedTuple(kwargs))
16+
function (::Type{T})(x::NamedTuple) where {T<:AbstractNamedArrayPartition}
1517
names_to_indices = NamedTuple(Pair(symbol, index)
1618
for (index, symbol) in enumerate(keys(x)))
1719

1820
# enforce homogeneity of eltypes
1921
@assert all(eltype.(values(x)) .== eltype(first(x)))
20-
T = eltype(first(x))
22+
R = eltype(first(x))
2123
S = typeof(values(x))
22-
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
24+
return T(ArrayPartition{R, S}(values(x)), names_to_indices)
25+
end
26+
27+
function named_partition_constructor(X::T) where {T<:AbstractNamedArrayPartition}
28+
getfield(parentmodule(T), nameof(T))
2329
end
2430

2531
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
2632
# fields except through `getfield` and accessor functions.
27-
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
33+
ArrayPartition(x::AbstractNamedArrayPartition) = getfield(x, :array_partition)
2834

29-
function Base.similar(A::NamedArrayPartition)
30-
NamedArrayPartition(
35+
# With new type structure this function does the same as Base.similar(x::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
36+
#= function Base.similar(A::T) where {T<:AbstractNamedArrayPartition}
37+
Tconstr = named_partition_constructor(A)
38+
Tconstr(
3139
similar(getfield(A, :array_partition)), getfield(A, :names_to_indices))
32-
end
40+
end =#
3341

3442
# return ArrayPartition when possible, otherwise next best thing of the correct size
35-
function Base.similar(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N}
36-
NamedArrayPartition(
43+
function Base.similar(A::T, dims::NTuple{N, Int}) where {T<:AbstractNamedArrayPartition, N}
44+
Tconstr = named_partition_constructor(A)
45+
Tconstr(
3746
similar(getfield(A, :array_partition), dims), getfield(A, :names_to_indices))
3847
end
3948

4049
# similar array partition of common type
41-
@inline function Base.similar(A::NamedArrayPartition, ::Type{T}) where {T}
42-
NamedArrayPartition(
50+
@inline function Base.similar(A::S, ::Type{T}) where {S<:AbstractNamedArrayPartition, T}
51+
Tconstr = named_partition_constructor(A)
52+
Tconstr(
4353
similar(getfield(A, :array_partition), T), getfield(A, :names_to_indices))
4454
end
4555

4656
# return ArrayPartition when possible, otherwise next best thing of the correct size
47-
function Base.similar(A::NamedArrayPartition, ::Type{T}, dims::NTuple{N, Int}) where {T, N}
48-
NamedArrayPartition(
57+
function Base.similar(A::S, ::Type{T}, dims::NTuple{N, Int}) where {T, N, S<:AbstractNamedArrayPartition}
58+
Tconstr = named_partition_constructor(A)
59+
Tconstr(
4960
similar(getfield(A, :array_partition), T, dims), getfield(A, :names_to_indices))
5061
end
5162

5263
# similar array partition with different types
5364
function Base.similar(
54-
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
55-
NamedArrayPartition(
65+
A::U, ::Type{T}, ::Type{S}, R::DataType...) where {T, S, U<:AbstractNamedArrayPartition}
66+
Tconstr = named_partition_constructor(A)
67+
Tconstr(
5668
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
5769
end
5870

59-
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
71+
Base.Array(x::AbstractNamedArrayPartition) = Array(ArrayPartition(x))
6072

61-
function Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN}
62-
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
73+
function Base.zero(x::R) where {R <: AbstractNamedArrayPartition}
74+
R(zero(ArrayPartition(x)), getfield(x, :names_to_indices))
6375
end
64-
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
76+
Base.zero(A::AbstractNamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors
6577

66-
Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
67-
function Base.getproperty(x::NamedArrayPartition, s::Symbol)
78+
Base.propertynames(x::AbstractNamedArrayPartition) = propertynames(getfield(x, :names_to_indices))
79+
function Base.getproperty(x::AbstractNamedArrayPartition, s::Symbol)
6880
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
6981
end
7082

7183
# this enables x.s = some_array.
72-
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
84+
@inline function Base.setproperty!(x::AbstractNamedArrayPartition, s::Symbol, v)
7385
index = getproperty(getfield(x, :names_to_indices), s)
7486
ArrayPartition(x).x[index] .= v
7587
end
7688

7789
# print out NamedArrayPartition as a NamedTuple
78-
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:")
79-
function Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition)
90+
Base.summary(x::AbstractNamedArrayPartition) = string(typeof(x), " with arrays:")
91+
function Base.show(io::IO, m::MIME"text/plain", x::AbstractNamedArrayPartition)
8092
show(
8193
io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x)))
8294
end
8395

84-
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x))
85-
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x))
86-
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
96+
Base.size(x::AbstractNamedArrayPartition) = size(ArrayPartition(x))
97+
Base.length(x::AbstractNamedArrayPartition) = length(ArrayPartition(x))
98+
Base.getindex(x::AbstractNamedArrayPartition, args...) = getindex(ArrayPartition(x), args...)
8799

88-
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
89-
function Base.map(f, x::NamedArrayPartition)
90-
NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
100+
Base.setindex!(x::AbstractNamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...)
101+
function Base.map(f, x::T) where {T<:AbstractNamedArrayPartition}
102+
Tconstr = named_partition_constructor(x)
103+
Tconstr(map(f, ArrayPartition(x)), getfield(x, :names_to_indices))
91104
end
92-
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
93-
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x))
105+
Base.mapreduce(f, op, x::AbstractNamedArrayPartition) = mapreduce(f, op, ArrayPartition(x))
106+
# Base.filter(f, x::AbstractNamedArrayPartition) = filter(f, ArrayPartition(x))
94107

95-
function Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT}
96-
NamedArrayPartition{T, S, NT}(
97-
similar(ArrayPartition(x)), getfield(x, :names_to_indices))
98-
end
108+
function Base.similar(x::AbstractNamedArrayPartition{T, A, NT}) where {T, A, NT}
109+
# Safely extract the concrete type parameters
99110

111+
Tconstr = named_partition_constructor(x)
112+
return Tconstr{T, A, NT}(
113+
similar(getfield(x, :array_partition)),
114+
getfield(x, :names_to_indices)
115+
)
116+
end
100117
# broadcasting
101-
function Base.BroadcastStyle(::Type{<:NamedArrayPartition})
102-
Broadcast.ArrayStyle{NamedArrayPartition}()
118+
function Base.BroadcastStyle(::Type{T}) where{T<:AbstractNamedArrayPartition}
119+
Broadcast.ArrayStyle{T}()
103120
end
104-
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}},
105-
::Type{ElType}) where {ElType}
121+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}},
122+
::Type{ElType}) where {ElType, T<:AbstractNamedArrayPartition}
106123
x = find_NamedArrayPartition(bc)
107-
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
124+
Tconstr = named_partition_constructor(x)
125+
return Tconstr(similar(ArrayPartition(x)), getfield(x, :names_to_indices))
108126
end
109127

110128
# when broadcasting with ArrayPartition + another array type, the output is the other array tupe
111129
function Base.BroadcastStyle(
112-
::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
130+
::Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1})
113131
Broadcast.DefaultArrayStyle{1}()
114132
end
115133

116134
# hook into ArrayPartition broadcasting routines
117-
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x))
118-
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = Broadcast.Broadcasted(
135+
@inline RecursiveArrayTools.npartitions(x::AbstractNamedArrayPartition) = npartitions(ArrayPartition(x))
136+
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}}, i) = Broadcast.Broadcasted(
119137
bc.f, RecursiveArrayTools.unpack_args(i, bc.args))
120-
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i)
138+
@inline RecursiveArrayTools.unpack(x::AbstractNamedArrayPartition, i) = unpack(ArrayPartition(x), i)
121139

122-
function Base.copy(A::NamedArrayPartition{T, S, NT}) where {T, S, NT}
123-
NamedArrayPartition{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
140+
function Base.copy(A::AbstractNamedArrayPartition{T, S, NT}) where {T, S, NT}
141+
Tconstr = named_partition_constructor(A)
142+
Tconstr{T, S, NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices))
124143
end
125144

126-
@inline NamedArrayPartition(f::F, N, names_to_indices) where {F <: Function} = NamedArrayPartition(
145+
@inline (::Type{T})(f::F, N, names_to_indices) where {F <: Function, T<:AbstractNamedArrayPartition} = T(
127146
ArrayPartition(ntuple(f, Val(N))), names_to_indices)
128147

129-
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
148+
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{T}}) where {T<:AbstractNamedArrayPartition}
130149
N = npartitions(bc)
131150
@inline function f(i)
132151
copy(unpack(bc, i))
133152
end
134153
x = find_NamedArrayPartition(bc)
135-
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
154+
Tconstr = named_partition_constructor(x)
155+
Tconstr(f, N, getfield(x, :names_to_indices))
136156
end
137157

138-
@inline function Base.copyto!(dest::NamedArrayPartition,
139-
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
158+
@inline function Base.copyto!(dest::AbstractNamedArrayPartition,
159+
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{<:AbstractNamedArrayPartition}})
140160
N = npartitions(dest, bc)
141161
@inline function f(i)
142162
copyto!(ArrayPartition(dest).x[i], unpack(bc, i))
@@ -146,7 +166,7 @@ end
146166
end
147167

148168
#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
149-
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
169+
function ArrayInterface.zeromatrix(A::AbstractNamedArrayPartition)
150170
B = ArrayPartition(A)
151171
x = reduce(vcat,vec.(B.x))
152172
x .* x' .* false
@@ -159,5 +179,5 @@ function find_NamedArrayPartition(args::Tuple)
159179
end
160180
find_NamedArrayPartition(x) = x
161181
find_NamedArrayPartition(::Tuple{}) = nothing
162-
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x
182+
find_NamedArrayPartition(x::AbstractNamedArrayPartition, rest) = x
163183
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest)

test/named_array_partition_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test
1+
using RecursiveArrayTools, Test, ArrayInterface
22

33
@testset "NamedArrayPartition tests" begin
44
x = NamedArrayPartition(a = ones(10), b = rand(20))

0 commit comments

Comments
 (0)