Skip to content

Commit fd7580c

Browse files
authored
Pick correct batch size for hessian (#574)
1 parent 4e70d75 commit fd7580c

5 files changed

Lines changed: 8 additions & 4 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.11"
4+
version = "0.6.12"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using DifferentiationInterface:
2020
PushforwardSlow,
2121
inner,
2222
multibasis,
23-
pick_batchsize,
23+
pick_hessian_batchsize,
2424
pick_jacobian_batchsize,
2525
pushforward_performance,
2626
unwrap,

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ SMC.column_groups(prep::SparseHessianPrep) = column_groups(prep.coloring_result)
4242
function DI.prepare_hessian(
4343
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
4444
) where {F,C}
45-
valB = pick_batchsize(dense_ad(backend), length(x))
45+
valB = pick_hessian_batchsize(dense_ad(backend), length(x))
4646
return _prepare_sparse_hessian_aux(valB, f, backend, x, contexts...)
4747
end
4848

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
function prepare_hessian(
7373
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
7474
) where {F,C}
75-
valB = pick_batchsize(backend, length(x))
75+
valB = pick_hessian_batchsize(backend, length(x))
7676
return _prepare_hessian_aux(valB, f, backend, x, contexts...)
7777
end
7878

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,8 @@ function pick_jacobian_batchsize(
2323
return pick_batchsize(backend, M)
2424
end
2525

26+
function pick_hessian_batchsize(backend::AbstractADType, N::Integer)
27+
return pick_batchsize(outer(backend), N)
28+
end
29+
2630
threshold_batchsize(backend::AbstractADType, ::Integer) = backend

0 commit comments

Comments
 (0)