Skip to content

Commit 4dd20e1

Browse files
committed
feat: recursive similar for caches
1 parent c83ba21 commit 4dd20e1

3 files changed

Lines changed: 23 additions & 2 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function _translate(
8989
end
9090
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
9191
c0 = DI.unwrap(c)
92-
return similar(c0, D)
92+
return DI.recursive_similar(c0, D)
9393
end
9494

9595
function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
@@ -106,7 +106,7 @@ function _translate_toprep(
106106
end
107107
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
108108
c0 = DI.unwrap(c)
109-
return similar(c0, D)
109+
return DI.recursive_similar(c0, D)
110110
end
111111

112112
function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`.
1010
"""
1111
ismutable_array(::Type) = true
1212
ismutable_array(x) = ismutable_array(typeof(x))
13+
14+
"""
15+
recursive_similar(x, T)
16+
17+
Apply `similar(_, T)` recursively to `x` or its components.
18+
19+
Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s.
20+
"""
21+
recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T)
22+
function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T}
23+
return map(xi -> recursive_similar(xi, T), x)
24+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using DifferentiationInterface: recursive_similar
2+
using Test
3+
4+
@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
5+
@test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa
6+
Tuple{Vector{Float32},Matrix{Float32}}
7+
@test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa
8+
@NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}}
9+
@test_throws MethodError recursive_similar(1, Float32)

0 commit comments

Comments
 (0)