diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index af16a6ae2..d4058e7ac 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,13 +1,14 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -authors = ["Guillaume Dalle", "Adrian Hill"] version = "0.7.16" +authors = ["Guillaume Dalle", "Adrian Hill"] [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" @@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] -DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" +DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" DifferentiationInterfacePolyesterForwardDiffExt = [ @@ -56,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] +Adapt = "4.5.0" ADTypes = "1.18.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl index 60d1ef6c0..03d79c89c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl @@ -1,5 +1,6 @@ module DifferentiationInterfaceGPUArraysCoreExt +using Adapt: adapt import DifferentiationInterface as DI using GPUArraysCore: @allowscalar, AbstractGPUArray @@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T} return b end +function DI.arroftup_to_tupofarr( + tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number} + ) where {B} + return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B)) +end + end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 333b12967..57c2c8513 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy dot(a, dy) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy real(dot(a, dy)) + im * real(dot(b, dy)) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward( dot(a, dy) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward( real(dot(a, dy)) + im * real(dot(b, dy)) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function value_and_pullback( @@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy dot(a, dy) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy real(dot(a, dy)) + im * real(dot(b, dy)) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward( dot(a, dy) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward( real(dot(a, dy)) + im * real(dot(b, dy)) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function value_and_pullback( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index e49203615..0aadac34e 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback( ty = map(tx) do dx dot(a, dx) end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback( ty = map(tx) do dx real(dot(a, dx)) + im * real(dot(b, dx)) end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback( dot(a, dx) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback( real(dot(a, dx)) + im * real(dot(b, dx)) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function value_and_pushforward( @@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback( dot(a, dx) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback( real(dot(a, dx)) + im * real(dot(b, dx)) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function value_and_pushforward( diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index a5fb15cab..e11b7e266 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -36,5 +36,13 @@ get_pattern(M::AbstractMatrix) = trues(size(M)) onlysecond((a, b)) = (a, only(b)) -arroftup_to_tupofarr(x::NTuple) = x -arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B)) +""" + arroftup_to_tupofarr(tx, x) + +Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`. +""" +arroftup_to_tupofarr(tx::NTuple{B, <:Number}, x::Number) where {B} = tx + +function arroftup_to_tupofarr(tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractArray{<:Number}) where {B} + return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B)) +end diff --git a/DifferentiationInterface/test/Core/Internals/linalg.jl b/DifferentiationInterface/test/Core/Internals/linalg.jl index bc12b1b17..c9a372d63 100644 --- a/DifferentiationInterface/test/Core/Internals/linalg.jl +++ b/DifferentiationInterface/test/Core/Internals/linalg.jl @@ -1,6 +1,7 @@ -using DifferentiationInterface: recursive_similar, get_pattern +using DifferentiationInterface: recursive_similar, get_pattern, arroftup_to_tupofarr using SparseArrays using Test +using JLArrays, ComponentArrays @testset "Recursive similar" begin @test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32} @@ -16,3 +17,19 @@ end @test_broken get_pattern(D) == Diagonal(trues(10)) @test get_pattern(sparse(D)) == Diagonal(trues(10)) end + +@testset "Wrong-mode array conversion" begin + x = [1.0, 3.0, 5.0] + xt = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + y = ComponentVector(a = [1.0, 3.0], b = [5.0]) + yt = ComponentVector(a = [(1.0, 2.0), (3.0, 4.0)], b = [(5.0, 6.0)]) + z = jl([1.0, 3.0, 5.0]) + zt = jl([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]) + @test arroftup_to_tupofarr((1.0, 2.0), 1.0) == (1.0, 2.0) + @test arroftup_to_tupofarr(xt, x) == ([1.0, 3.0, 5.0], [2.0, 4.0, 6.0]) + @test arroftup_to_tupofarr(yt, y) == (ComponentVector(a = [1.0, 3.0], b = [5.0]), ComponentVector(a = [2.0, 4.0], b = [6.0])) + @test arroftup_to_tupofarr(zt, z) == (jl([1.0, 3.0, 5.0]), jl([2.0, 4.0, 6.0])) + @test arroftup_to_tupofarr(xt, x)[1] isa Vector + @test arroftup_to_tupofarr(yt, y)[1] isa ComponentVector + @test arroftup_to_tupofarr(zt, z)[1] isa JLVector +end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index fe739284e..1ff483e17 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -169,3 +169,8 @@ end logging = LOGGING, ) end; + +@testset "Array format preservation in wrong mode" begin + @test gradient(sum, AutoSimpleFiniteDiff(), jl(ones(2))) isa JLVector + @test derivative(t -> jl(fill(t, 2)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), 1.0) isa JLVector +end