Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.16"
authors = ["Guillaume Dalle", "Adrian Hill"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Expand Down Expand Up @@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = [
Expand All @@ -56,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
Adapt = "4.5.0"
ADTypes = "1.18.0"
ChainRulesCore = "1.23.0"
DiffResults = "1.1.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module DifferentiationInterfaceGPUArraysCoreExt

using Adapt: adapt
import DifferentiationInterface as DI
using GPUArraysCore: @allowscalar, AbstractGPUArray

Expand All @@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
return b
end

function DI.arroftup_to_tupofarr(
tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number}
) where {B}
return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B))
end

end
16 changes: 8 additions & 8 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down Expand Up @@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down
12 changes: 6 additions & 6 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
dot(a, dx)
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down Expand Up @@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down
12 changes: 10 additions & 2 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,13 @@ get_pattern(M::AbstractMatrix) = trues(size(M))

onlysecond((a, b)) = (a, only(b))

arroftup_to_tupofarr(x::NTuple) = x
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))
"""
arroftup_to_tupofarr(tx, x)

Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`.
"""
arroftup_to_tupofarr(tx::NTuple{B, <:Number}, x::Number) where {B} = tx

function arroftup_to_tupofarr(tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractArray{<:Number}) where {B}
return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B))
end
19 changes: 18 additions & 1 deletion DifferentiationInterface/test/Core/Internals/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DifferentiationInterface: recursive_similar, get_pattern
using DifferentiationInterface: recursive_similar, get_pattern, arroftup_to_tupofarr
using SparseArrays
using Test
using JLArrays, ComponentArrays

@testset "Recursive similar" begin
@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
Expand All @@ -16,3 +17,19 @@ end
@test_broken get_pattern(D) == Diagonal(trues(10))
@test get_pattern(sparse(D)) == Diagonal(trues(10))
end

@testset "Wrong-mode array conversion" begin
x = [1.0, 3.0, 5.0]
xt = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
y = ComponentVector(a = [1.0, 3.0], b = [5.0])
yt = ComponentVector(a = [(1.0, 2.0), (3.0, 4.0)], b = [(5.0, 6.0)])
z = jl([1.0, 3.0, 5.0])
zt = jl([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
@test arroftup_to_tupofarr((1.0, 2.0), 1.0) == (1.0, 2.0)
@test arroftup_to_tupofarr(xt, x) == ([1.0, 3.0, 5.0], [2.0, 4.0, 6.0])
@test arroftup_to_tupofarr(yt, y) == (ComponentVector(a = [1.0, 3.0], b = [5.0]), ComponentVector(a = [2.0, 4.0], b = [6.0]))
@test arroftup_to_tupofarr(zt, z) == (jl([1.0, 3.0, 5.0]), jl([2.0, 4.0, 6.0]))
@test arroftup_to_tupofarr(xt, x)[1] isa Vector
@test arroftup_to_tupofarr(yt, y)[1] isa ComponentVector
@test arroftup_to_tupofarr(zt, z)[1] isa JLVector
end
5 changes: 5 additions & 0 deletions DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,8 @@ end
logging = LOGGING,
)
end;

@testset "Array format preservation in wrong mode" begin
@test gradient(sum, AutoSimpleFiniteDiff(), jl(ones(2))) isa JLVector
@test derivative(t -> jl(fill(t, 2)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), 1.0) isa JLVector
end
Loading