Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.51"
version = "0.6.52"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -17,6 +17,7 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Expand All @@ -37,6 +38,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
Expand Down Expand Up @@ -109,21 +111,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = [
"ADTypes",
"Aqua",
"ComponentArrays",
"DataFrames",
"ExplicitImports",
"JET",
"JLArrays",
"JuliaFormatter",
"Pkg",
"Random",
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"StableRNGs",
"StaticArrays",
"Test",
]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
module DifferentiationInterfaceGPUArraysCoreExt

import DifferentiationInterface as DI
using GPUArraysCore: AbstractGPUArray

"""
OneElement

Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
"""
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
ind::I
val::T
a::A

function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
right_ind = eachindex(a)[ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
end

function OneElement(

Check warning on line 21 in DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl#L21

Added line #L21 was not covered by tests
ind::CartesianIndex{N}, val::T, a::A
) where {N,T,A<:AbstractArray{T,N}}
linear_ind = LinearIndices(a)[ind]
right_ind = eachindex(a)[linear_ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)

Check warning on line 26 in DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl#L24-L26

Added lines #L24 - L26 were not covered by tests
end
end

Base.size(oe::OneElement) = size(oe.a)
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)

function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a)))
end

function DI.basis(a::AbstractGPUArray{T}, i) where {T}
b = zero(a)
b .+= OneElement(i, one(T), a)
return b
end

end
50 changes: 5 additions & 45 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,3 @@
"""
OneElement

Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
"""
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
ind::I
val::T
a::A

function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
right_ind = eachindex(a)[ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
end

function OneElement(
ind::CartesianIndex{N}, val::T, a::A
) where {N,T,A<:AbstractArray{T,N}}
linear_ind = LinearIndices(a)[ind]
right_ind = eachindex(a)[linear_ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
end
end

Base.size(oe::OneElement) = size(oe.a)
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)

function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
if ind == oe.ind
return oe.val
else
return zero(eltype(oe.a))
end
end

function Base.getindex(oe::OneElement{<:CartesianIndex{N}}, ind::Vararg{Int,N}) where {N}
if ind == Tuple(oe.ind)
return oe.val
else
return zero(eltype(oe.a))
end
end

"""
basis(a::AbstractArray, i)

Expand All @@ -49,7 +6,7 @@
function basis(a::AbstractArray{T}, i) where {T}
b = similar(a)
fill!(b, zero(T))
b .+= OneElement(i, one(T), a)
b[i] = one(T)
if ismutable_array(a)
return b
else
Expand All @@ -61,12 +18,15 @@
multibasis(a::AbstractArray, inds)

Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.

!!! warning
Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
"""
function multibasis(a::AbstractArray{T}, inds) where {T}
b = similar(a)
fill!(b, zero(T))
for i in inds
b .+= OneElement(i, one(T), a)
b[i] = one(T)

Check warning on line 29 in DifferentiationInterface/src/utils/basis.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/src/utils/basis.jl#L29

Added line #L29 was not covered by tests
end
return ismutable_array(a) ? b : map(+, zero(a), b)
end
7 changes: 7 additions & 0 deletions DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using DifferentiationInterface:
AutoReverseFromPrimitive,
DenseSparsityDetector
using SparseMatrixColorings
using JLArrays, StaticArrays
using Test

LOGGING = get(ENV, "CI", "false") == "false"
Expand Down Expand Up @@ -137,3 +138,9 @@ end
pushforward, copyto!, [1.0], AutoSimpleFiniteDiff(), [1.0], ([1.0], [1.0])
)
end

@testset "Weird arrays" begin
test_differentiation(
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
)
end;
Loading