Skip to content

Commit b8b023b

Browse files
committed
fix: make wrong-mode pushforward/pullback return the correct array type
1 parent 2b7c5e9 commit b8b023b

5 files changed

Lines changed: 34 additions & 18 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3-
authors = ["Guillaume Dalle", "Adrian Hill"]
43
version = "0.7.16"
4+
authors = ["Guillaume Dalle", "Adrian Hill"]
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99

1010
[weakdeps]
11+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1112
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1213
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1314
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
@@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3839
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3940
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
4041
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
41-
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
42+
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
4243
DifferentiationInterfaceGTPSAExt = "GTPSA"
4344
DifferentiationInterfaceMooncakeExt = "Mooncake"
4445
DifferentiationInterfacePolyesterForwardDiffExt = [

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module DifferentiationInterfaceGPUArraysCoreExt
22

3+
using Adapt: adapt
34
import DifferentiationInterface as DI
45
using GPUArraysCore: @allowscalar, AbstractGPUArray
56

@@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
1718
return b
1819
end
1920

21+
function DI.arroftup_to_tupofarr(
22+
tx::AbstractArray{NTuple{B, T}}, x::AbstractGPUArray{T}
23+
) where {B, T}
24+
return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B))
25+
end
26+
2027
end

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward(
332332
tx = map(ty) do dy
333333
dot(a, dy)
334334
end
335-
return y, arroftup_to_tupofarr(tx)
335+
return y, arroftup_to_tupofarr(tx, x)
336336
end
337337

338338
function _value_and_pullback_via_pushforward(
@@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward(
348348
tx = map(ty) do dy
349349
real(dot(a, dy)) + im * real(dot(b, dy))
350350
end
351-
return y, arroftup_to_tupofarr(tx)
351+
return y, arroftup_to_tupofarr(tx, x)
352352
end
353353

354354
function _value_and_pullback_via_pushforward(
@@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward(
366366
dot(a, dy)
367367
end
368368
end
369-
return y, arroftup_to_tupofarr(tx)
369+
return y, arroftup_to_tupofarr(tx, x)
370370
end
371371

372372
function _value_and_pullback_via_pushforward(
@@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward(
387387
real(dot(a, dy)) + im * real(dot(b, dy))
388388
end
389389
end
390-
return y, arroftup_to_tupofarr(tx)
390+
return y, arroftup_to_tupofarr(tx, x)
391391
end
392392

393393
function value_and_pullback(
@@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward(
458458
tx = map(ty) do dy
459459
dot(a, dy)
460460
end
461-
return y, arroftup_to_tupofarr(tx)
461+
return y, arroftup_to_tupofarr(tx, x)
462462
end
463463

464464
function _value_and_pullback_via_pushforward(
@@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward(
477477
tx = map(ty) do dy
478478
real(dot(a, dy)) + im * real(dot(b, dy))
479479
end
480-
return y, arroftup_to_tupofarr(tx)
480+
return y, arroftup_to_tupofarr(tx, x)
481481
end
482482

483483
function _value_and_pullback_via_pushforward(
@@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward(
495495
dot(a, dy)
496496
end
497497
end
498-
return y, arroftup_to_tupofarr(tx)
498+
return y, arroftup_to_tupofarr(tx, x)
499499
end
500500

501501
function _value_and_pullback_via_pushforward(
@@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward(
518518
real(dot(a, dy)) + im * real(dot(b, dy))
519519
end
520520
end
521-
return y, arroftup_to_tupofarr(tx)
521+
return y, arroftup_to_tupofarr(tx, x)
522522
end
523523

524524
function value_and_pullback(

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback(
331331
ty = map(tx) do dx
332332
dot(a, dx)
333333
end
334-
return y, arroftup_to_tupofarr(ty)
334+
return y, arroftup_to_tupofarr(ty, y)
335335
end
336336

337337
function _value_and_pushforward_via_pullback(
@@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback(
348348
ty = map(tx) do dx
349349
real(dot(a, dx)) + im * real(dot(b, dx))
350350
end
351-
return y, arroftup_to_tupofarr(ty)
351+
return y, arroftup_to_tupofarr(ty, y)
352352
end
353353

354354
function _value_and_pushforward_via_pullback(
@@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback(
367367
dot(a, dx)
368368
end
369369
end
370-
return y, arroftup_to_tupofarr(ty)
370+
return y, arroftup_to_tupofarr(ty, y)
371371
end
372372

373373
function _value_and_pushforward_via_pullback(
@@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback(
387387
real(dot(a, dx)) + im * real(dot(b, dx))
388388
end
389389
end
390-
return y, arroftup_to_tupofarr(ty)
390+
return y, arroftup_to_tupofarr(ty, y)
391391
end
392392

393393
function value_and_pushforward(
@@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback(
460460
dot(a, dx)
461461
end
462462
end
463-
return y, arroftup_to_tupofarr(ty)
463+
return y, arroftup_to_tupofarr(ty, y)
464464
end
465465

466466
function _value_and_pushforward_via_pullback(
@@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback(
481481
real(dot(a, dx)) + im * real(dot(b, dx))
482482
end
483483
end
484-
return y, arroftup_to_tupofarr(ty)
484+
return y, arroftup_to_tupofarr(ty, y)
485485
end
486486

487487
function value_and_pushforward(

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,13 @@ get_pattern(M::AbstractMatrix) = trues(size(M))
3636

3737
onlysecond((a, b)) = (a, only(b))
3838

39-
arroftup_to_tupofarr(x::NTuple) = x
40-
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))
39+
"""
40+
arroftup_to_tupofarr(tx, x)
41+
42+
Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`.
43+
"""
44+
arroftup_to_tupofarr(tx::NTuple{B, T}, x::T) where {B, T} = tx
45+
46+
function arroftup_to_tupofarr(tx::AbstractArray{NTuple{B, T}}, x::AbstractArray{T}) where {B, T}
47+
return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B))
48+
end

0 commit comments

Comments
 (0)