diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index f6527dd12..b4411b272 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.51" +version = "0.6.52" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -17,6 +17,7 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" @@ -37,6 +38,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] +DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] @@ -109,21 +111,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = [ - "ADTypes", - "Aqua", - "ComponentArrays", - "DataFrames", - "ExplicitImports", - "JET", - "JLArrays", - "JuliaFormatter", - "Pkg", - "Random", - "SparseArrays", - "SparseConnectivityTracer", - "SparseMatrixColorings", - "StableRNGs", - "StaticArrays", - "Test", -] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl new file mode 100644 index 000000000..d9d9749ed --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl @@ -0,0 +1,43 @@ +module DifferentiationInterfaceGPUArraysCoreExt + +import DifferentiationInterface as DI +using GPUArraysCore: AbstractGPUArray + +""" + OneElement + +Efficient storage for a one-hot array, aka an array in the standard Euclidean basis. +""" +struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N} + ind::I + val::T + a::A + + function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}} + right_ind = eachindex(a)[ind] + return new{typeof(right_ind),N,T,A}(right_ind, val, a) + end + + function OneElement( + ind::CartesianIndex{N}, val::T, a::A + ) where {N,T,A<:AbstractArray{T,N}} + linear_ind = LinearIndices(a)[ind] + right_ind = eachindex(a)[linear_ind] + return new{typeof(right_ind),N,T,A}(right_ind, val, a) + end +end + +Base.size(oe::OneElement) = size(oe.a) +Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a) + +function Base.getindex(oe::OneElement{<:Integer}, ind::Integer) + return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a))) +end + +function DI.basis(a::AbstractGPUArray{T}, i) where {T} + b = zero(a) + b .+= OneElement(i, one(T), a) + return b +end + +end diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index a6ed3f420..46c7162fc 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -1,46 +1,3 @@ -""" - OneElement - -Efficient storage for a one-hot array, aka an array in the standard Euclidean basis. -""" -struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N} - ind::I - val::T - a::A - - function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}} - right_ind = eachindex(a)[ind] - return new{typeof(right_ind),N,T,A}(right_ind, val, a) - end - - function OneElement( - ind::CartesianIndex{N}, val::T, a::A - ) where {N,T,A<:AbstractArray{T,N}} - linear_ind = LinearIndices(a)[ind] - right_ind = eachindex(a)[linear_ind] - return new{typeof(right_ind),N,T,A}(right_ind, val, a) - end -end - -Base.size(oe::OneElement) = size(oe.a) -Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a) - -function Base.getindex(oe::OneElement{<:Integer}, ind::Integer) - if ind == oe.ind - return oe.val - else - return zero(eltype(oe.a)) - end -end - -function Base.getindex(oe::OneElement{<:CartesianIndex{N}}, ind::Vararg{Int,N}) where {N} - if ind == Tuple(oe.ind) - return oe.val - else - return zero(eltype(oe.a)) - end -end - """ basis(a::AbstractArray, i) @@ -49,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`. function basis(a::AbstractArray{T}, i) where {T} b = similar(a) fill!(b, zero(T)) - b .+= OneElement(i, one(T), a) + b[i] = one(T) if ismutable_array(a) return b else @@ -61,12 +18,15 @@ end multibasis(a::AbstractArray, inds) Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`. + +!!! warning + Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either. """ function multibasis(a::AbstractArray{T}, inds) where {T} b = similar(a) fill!(b, zero(T)) for i in inds - b .+= OneElement(i, one(T), a) + b[i] = one(T) end return ismutable_array(a) ? b : map(+, zero(a), b) end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 15cf4d723..3ceeb8ee2 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -5,6 +5,7 @@ using DifferentiationInterface: AutoReverseFromPrimitive, DenseSparsityDetector using SparseMatrixColorings +using JLArrays, StaticArrays using Test LOGGING = get(ENV, "CI", "false") == "false" @@ -137,3 +138,9 @@ end pushforward, copyto!, [1.0], AutoSimpleFiniteDiff(), [1.0], ([1.0], [1.0]) ) end + +@testset "Weird arrays" begin + test_differentiation( + AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING + ) +end;