Skip to content

Commit d537c45

Browse files
authored
perf: check mutability of array before preallocating dual buffer (#619)
* perf: check mutability of array before preallocating ForwardDiff dual buffer * Add allocation testing
1 parent 9a524d3 commit d537c45

6 files changed

Lines changed: 36 additions & 5 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.22"
4+
version = "0.6.23"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ using DifferentiationInterface:
2525
outer,
2626
shuffled_gradient,
2727
unwrap,
28-
with_contexts
28+
with_contexts,
29+
ismutable_array
2930
import ForwardDiff.DiffResults as DR
3031
using ForwardDiff.DiffResults:
3132
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ function DI.prepare_pushforward(
6868
f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
6969
) where {F,C}
7070
T = tag_type(f, backend, x)
71-
xdual_tmp = make_dual_similar(T, x, tx)
71+
if ismutable_array(x)
72+
xdual_tmp = make_dual_similar(T, x, tx)
73+
else
74+
xdual_tmp = nothing
75+
end
7276
return ForwardDiffOneArgPushforwardPrep{T,typeof(xdual_tmp)}(xdual_tmp)
7377
end
7478

@@ -92,8 +96,12 @@ function compute_ydual_onearg(
9296
tx::NTuple{B},
9397
contexts::Vararg{Context,C},
9498
) where {F,T,B,C}
95-
(; xdual_tmp) = prep
96-
make_dual!(T, xdual_tmp, x, tx)
99+
if ismutable_array(x)
100+
make_dual!(T, prep.xdual_tmp, x, tx)
101+
xdual_tmp = prep.xdual_tmp
102+
else
103+
xdual_tmp = make_dual(T, x, tx)
104+
end
97105
contexts_dual = translate(T, Val(B), contexts...)
98106
ydual = f(xdual_tmp, contexts_dual...)
99107
return ydual

DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ function DI.stack_vec_row(t::NTuple{B,<:StaticArray}) where {B}
1313
return vcat(transpose.(map(vec, t))...)
1414
end
1515

16+
DI.ismutable_array(::Type{<:SArray}) = false
17+
1618
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, x::StaticArray)
1719
return BatchSizeSettings{length(x),true,true}(length(x))
1820
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
22
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)
3+
4+
ismutable_array(::Type) = true
5+
ismutable_array(x) = ismutable_array(typeof(x))

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Pkg.add("ForwardDiff")
33

44
using ComponentArrays: ComponentArrays
55
using DifferentiationInterface, DifferentiationInterfaceTest
6+
import DifferentiationInterfaceTest as DIT
67
using ForwardDiff: ForwardDiff
78
using StaticArrays: StaticArrays
89
using Test
@@ -65,3 +66,19 @@ test_differentiation(
6566
## Static
6667

6768
test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
69+
70+
@testset verbose = true "No allocations on StaticArrays" begin
71+
filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen
72+
DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out
73+
end
74+
data = benchmark_differentiation(
75+
AutoForwardDiff(),
76+
filtered_static_scenarios;
77+
benchmark=:prepared,
78+
excluded=[:hessian, :pullback], # TODO: figure this out
79+
logging=LOGGING,
80+
)
81+
@testset "$(row[:scenario])" for row in eachrow(data)
82+
@test row[:allocs] == 0
83+
end
84+
end;

0 commit comments

Comments
 (0)