Skip to content

Commit 90160b9

Browse files
authored
fix: make wrong-mode pushforward/pullback return the correct array type (#974)
* fix: make wrong-mode pushforward/pullback return the correct array type * Relax typing, add Adapt bound * Fix method ambiguity * Add tests
1 parent 2b7c5e9 commit 90160b9

7 files changed

Lines changed: 58 additions & 19 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3-
authors = ["Guillaume Dalle", "Adrian Hill"]
43
version = "0.7.16"
4+
authors = ["Guillaume Dalle", "Adrian Hill"]
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99

1010
[weakdeps]
11+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1112
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1213
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1314
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
@@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3839
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3940
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
4041
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
41-
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
42+
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
4243
DifferentiationInterfaceGTPSAExt = "GTPSA"
4344
DifferentiationInterfaceMooncakeExt = "Mooncake"
4445
DifferentiationInterfacePolyesterForwardDiffExt = [
@@ -56,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
5657
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
5758

5859
[compat]
60+
Adapt = "4.5.0"
5961
ADTypes = "1.18.0"
6062
ChainRulesCore = "1.23.0"
6163
DiffResults = "1.1.0"

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module DifferentiationInterfaceGPUArraysCoreExt
22

3+
using Adapt: adapt
34
import DifferentiationInterface as DI
45
using GPUArraysCore: @allowscalar, AbstractGPUArray
56

@@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
1718
return b
1819
end
1920

21+
function DI.arroftup_to_tupofarr(
22+
tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number}
23+
) where {B}
24+
return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B))
25+
end
26+
2027
end

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward(
332332
tx = map(ty) do dy
333333
dot(a, dy)
334334
end
335-
return y, arroftup_to_tupofarr(tx)
335+
return y, arroftup_to_tupofarr(tx, x)
336336
end
337337

338338
function _value_and_pullback_via_pushforward(
@@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward(
348348
tx = map(ty) do dy
349349
real(dot(a, dy)) + im * real(dot(b, dy))
350350
end
351-
return y, arroftup_to_tupofarr(tx)
351+
return y, arroftup_to_tupofarr(tx, x)
352352
end
353353

354354
function _value_and_pullback_via_pushforward(
@@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward(
366366
dot(a, dy)
367367
end
368368
end
369-
return y, arroftup_to_tupofarr(tx)
369+
return y, arroftup_to_tupofarr(tx, x)
370370
end
371371

372372
function _value_and_pullback_via_pushforward(
@@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward(
387387
real(dot(a, dy)) + im * real(dot(b, dy))
388388
end
389389
end
390-
return y, arroftup_to_tupofarr(tx)
390+
return y, arroftup_to_tupofarr(tx, x)
391391
end
392392

393393
function value_and_pullback(
@@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward(
458458
tx = map(ty) do dy
459459
dot(a, dy)
460460
end
461-
return y, arroftup_to_tupofarr(tx)
461+
return y, arroftup_to_tupofarr(tx, x)
462462
end
463463

464464
function _value_and_pullback_via_pushforward(
@@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward(
477477
tx = map(ty) do dy
478478
real(dot(a, dy)) + im * real(dot(b, dy))
479479
end
480-
return y, arroftup_to_tupofarr(tx)
480+
return y, arroftup_to_tupofarr(tx, x)
481481
end
482482

483483
function _value_and_pullback_via_pushforward(
@@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward(
495495
dot(a, dy)
496496
end
497497
end
498-
return y, arroftup_to_tupofarr(tx)
498+
return y, arroftup_to_tupofarr(tx, x)
499499
end
500500

501501
function _value_and_pullback_via_pushforward(
@@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward(
518518
real(dot(a, dy)) + im * real(dot(b, dy))
519519
end
520520
end
521-
return y, arroftup_to_tupofarr(tx)
521+
return y, arroftup_to_tupofarr(tx, x)
522522
end
523523

524524
function value_and_pullback(

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback(
331331
ty = map(tx) do dx
332332
dot(a, dx)
333333
end
334-
return y, arroftup_to_tupofarr(ty)
334+
return y, arroftup_to_tupofarr(ty, y)
335335
end
336336

337337
function _value_and_pushforward_via_pullback(
@@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback(
348348
ty = map(tx) do dx
349349
real(dot(a, dx)) + im * real(dot(b, dx))
350350
end
351-
return y, arroftup_to_tupofarr(ty)
351+
return y, arroftup_to_tupofarr(ty, y)
352352
end
353353

354354
function _value_and_pushforward_via_pullback(
@@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback(
367367
dot(a, dx)
368368
end
369369
end
370-
return y, arroftup_to_tupofarr(ty)
370+
return y, arroftup_to_tupofarr(ty, y)
371371
end
372372

373373
function _value_and_pushforward_via_pullback(
@@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback(
387387
real(dot(a, dx)) + im * real(dot(b, dx))
388388
end
389389
end
390-
return y, arroftup_to_tupofarr(ty)
390+
return y, arroftup_to_tupofarr(ty, y)
391391
end
392392

393393
function value_and_pushforward(
@@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback(
460460
dot(a, dx)
461461
end
462462
end
463-
return y, arroftup_to_tupofarr(ty)
463+
return y, arroftup_to_tupofarr(ty, y)
464464
end
465465

466466
function _value_and_pushforward_via_pullback(
@@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback(
481481
real(dot(a, dx)) + im * real(dot(b, dx))
482482
end
483483
end
484-
return y, arroftup_to_tupofarr(ty)
484+
return y, arroftup_to_tupofarr(ty, y)
485485
end
486486

487487
function value_and_pushforward(

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,13 @@ get_pattern(M::AbstractMatrix) = trues(size(M))
3636

3737
onlysecond((a, b)) = (a, only(b))
3838

39-
arroftup_to_tupofarr(x::NTuple) = x
40-
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))
39+
"""
40+
arroftup_to_tupofarr(tx, x)
41+
42+
Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`.
43+
"""
44+
arroftup_to_tupofarr(tx::NTuple{B, <:Number}, x::Number) where {B} = tx
45+
46+
function arroftup_to_tupofarr(tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractArray{<:Number}) where {B}
47+
return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B))
48+
end

DifferentiationInterface/test/Core/Internals/linalg.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using DifferentiationInterface: recursive_similar, get_pattern
1+
using DifferentiationInterface: recursive_similar, get_pattern, arroftup_to_tupofarr
22
using SparseArrays
33
using Test
4+
using JLArrays, ComponentArrays
45

56
@testset "Recursive similar" begin
67
@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
@@ -16,3 +17,19 @@ end
1617
@test_broken get_pattern(D) == Diagonal(trues(10))
1718
@test get_pattern(sparse(D)) == Diagonal(trues(10))
1819
end
20+
21+
@testset "Wrong-mode array conversion" begin
22+
x = [1.0, 3.0, 5.0]
23+
xt = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
24+
y = ComponentVector(a = [1.0, 3.0], b = [5.0])
25+
yt = ComponentVector(a = [(1.0, 2.0), (3.0, 4.0)], b = [(5.0, 6.0)])
26+
z = jl([1.0, 3.0, 5.0])
27+
zt = jl([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
28+
@test arroftup_to_tupofarr((1.0, 2.0), 1.0) == (1.0, 2.0)
29+
@test arroftup_to_tupofarr(xt, x) == ([1.0, 3.0, 5.0], [2.0, 4.0, 6.0])
30+
@test arroftup_to_tupofarr(yt, y) == (ComponentVector(a = [1.0, 3.0], b = [5.0]), ComponentVector(a = [2.0, 4.0], b = [6.0]))
31+
@test arroftup_to_tupofarr(zt, z) == (jl([1.0, 3.0, 5.0]), jl([2.0, 4.0, 6.0]))
32+
@test arroftup_to_tupofarr(xt, x)[1] isa Vector
33+
@test arroftup_to_tupofarr(yt, y)[1] isa ComponentVector
34+
@test arroftup_to_tupofarr(zt, z)[1] isa JLVector
35+
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,8 @@ end
169169
logging = LOGGING,
170170
)
171171
end;
172+
173+
@testset "Array format preservation in wrong mode" begin
174+
@test gradient(sum, AutoSimpleFiniteDiff(), jl(ones(2))) isa JLVector
175+
@test derivative(t -> jl(fill(t, 2)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), 1.0) isa JLVector
176+
end

0 commit comments

Comments
 (0)