Skip to content

Commit 7607ec2

Browse files
authored
Custom stacking for StaticArrays (#564)
* Improve type stability tests and benchmarking * Remove `first_order` and `second_order` * Docs * Zero allocs * Fixes * Call count * Fix * Fix * Add count calls * Default count calls * Fix * Custom stacking for StaticArrays * Bump * Clearer modulo * Woops * Undo mo1
1 parent 3698dbe commit 7607ec2

6 files changed

Lines changed: 20 additions & 4 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.9"
4+
version = "0.6.10"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -20,6 +20,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2020
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2121
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2222
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2324
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2425
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2526
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
@@ -37,6 +38,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
3738
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
3839
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
3940
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
41+
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
4042
DifferentiationInterfaceSymbolicsExt = "Symbolics"
4143
DifferentiationInterfaceTrackerExt = "Tracker"
4244
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
@@ -56,6 +58,7 @@ PolyesterForwardDiff = "0.1.2"
5658
ReverseDiff = "1.15.1"
5759
SparseArrays = "<0.0.1,1"
5860
SparseConnectivityTracer = "0.5.0,0.6"
61+
StaticArrays = "1.9.7"
5962
SparseMatrixColorings = "0.4.5"
6063
Symbolics = "5.27.1, 6"
6164
Tracker = "0.2.33"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module DifferentiationInterfaceStaticArraysExt
2+
3+
import DifferentiationInterface as DI
4+
using StaticArrays: SArray
5+
6+
function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B}
7+
return hcat(map(vec, t)...)
8+
end
9+
10+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ include("utils/check.jl")
4343
include("utils/exceptions.jl")
4444
include("utils/printing.jl")
4545
include("utils/context.jl")
46+
include("utils/linalg.jl")
4647

4748
include("first_order/pushforward.jl")
4849
include("first_order/pullback.jl")

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ function _jacobian_aux(
241241
batched_seeds[a],
242242
contexts...,
243243
)
244-
block = stack(vec, dy_batch; dims=2)
244+
block = stack_vec_col(dy_batch)
245245
if N % B != 0 && a == lastindex(batched_seeds)
246246
block = block[:, 1:(N - (a - 1) * B)]
247247
end
@@ -269,7 +269,7 @@ function _jacobian_aux(
269269
dx_batch = pullback(
270270
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
271271
)
272-
block = stack(vec, dx_batch; dims=1)
272+
block = stack_vec_row(dx_batch)
273273
if M % B != 0 && a == lastindex(batched_seeds)
274274
block = block[1:(M - (a - 1) * B), :]
275275
end

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function hessian(
113113

114114
hess_blocks = map(eachindex(batched_seeds)) do a
115115
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
116-
block = stack(vec, dg_batch; dims=2)
116+
block = stack_vec_col(dg_batch)
117117
if N % B != 0 && a == lastindex(batched_seeds)
118118
block = block[:, 1:(N - (a - 1) * B)]
119119
end
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
2+
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)

0 commit comments

Comments
 (0)