Skip to content

Commit 90d4e56

Browse files
committed
Merge remote-tracking branch 'origin/main' into gd/12
2 parents c101ac4 + ced97ee commit 90d4e56

32 files changed

Lines changed: 864 additions & 413 deletions

File tree

.github/workflows/Documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
dir: './DifferentiationInterfaceTest'
3636

3737
steps:
38-
- uses: actions/checkout@v4
38+
- uses: actions/checkout@v5
3939
- uses: julia-actions/setup-julia@v2
4040
with:
4141
version: '1' # TODO: 1

.github/workflows/PreCommit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
pre-commit:
1919
runs-on: ubuntu-latest
2020
steps:
21-
- uses: actions/checkout@v4
21+
- uses: actions/checkout@v5
2222
- uses: julia-actions/setup-julia@v2
2323
- uses: julia-actions/cache@v2
2424
- run: julia -e 'using Pkg; Pkg.add("JuliaFormatter")'

.github/workflows/Test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
JULIA_DI_TEST_GROUP: ${{ matrix.group }}
6868
JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
6969
steps:
70-
- uses: actions/checkout@v4
70+
- uses: actions/checkout@v5
7171
- uses: julia-actions/setup-julia@v2
7272
with:
7373
version: ${{ matrix.version }}
@@ -126,7 +126,7 @@ jobs:
126126
JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
127127
JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
128128
steps:
129-
- uses: actions/checkout@v4
129+
- uses: actions/checkout@v5
130130
- uses: julia-actions/setup-julia@v2
131131
with:
132132
version: ${{ matrix.version }}

DifferentiationInterface/CHANGELOG.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.7...main)
9+
10+
## [0.7.7](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...DifferentiationInterface-v0.7.7)
11+
12+
- Improve support for empty inputs (still not guaranteed) ([#835](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/835))
13+
14+
## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6)
15+
16+
### Fixed
17+
18+
- Put test deps into `test/Project.toml` ([#840](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/840))
19+
- Set up `pre-commit` ([#837](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/837))
20+
21+
### Fixed
22+
23+
- Put test deps into `test/Project.toml` ([#840](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/840))
924

1025
## [0.7.5](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.4...DifferentiationInterface-v0.7.5)
1126

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.5"
4+
version = "0.7.7"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138138
BS<:BatchSizeSettings,
139139
S<:AbstractVector{<:NTuple},
140140
R<:AbstractVector{<:NTuple},
141+
SE<:NTuple,
141142
E<:PushforwardPrep,
142143
} <: StandardJacobianPrep{SIG}
143144
_sig::Val{SIG}
144145
batch_size_settings::BS
145146
batched_seeds::S
146147
batched_results::R
148+
seed_example::SE
147149
pushforward_prep::E
148150
end
149151

@@ -152,12 +154,14 @@ struct PullbackJacobianPrep{
152154
BS<:BatchSizeSettings,
153155
S<:AbstractVector{<:NTuple},
154156
R<:AbstractVector{<:NTuple},
157+
SE<:NTuple,
155158
E<:PullbackPrep,
156159
} <: StandardJacobianPrep{SIG}
157160
_sig::Val{SIG}
158161
batch_size_settings::BS
159162
batched_seeds::S
160163
batched_results::R
164+
seed_example::SE
161165
pullback_prep::E
162166
end
163167

@@ -211,11 +215,17 @@ function _prepare_jacobian_aux(
211215
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
212216
]
213217
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
218+
seed_example = ntuple(b -> basis(x), Val(B))
214219
pushforward_prep = prepare_pushforward_nokwarg(
215-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
220+
strict, f_or_f!y..., backend, x, seed_example, contexts...
216221
)
217222
return PushforwardJacobianPrep(
218-
_sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223+
_sig,
224+
batch_size_settings,
225+
batched_seeds,
226+
batched_results,
227+
seed_example,
228+
pushforward_prep,
219229
)
220230
end
221231

@@ -236,11 +246,17 @@ function _prepare_jacobian_aux(
236246
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
237247
]
238248
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
249+
seed_example = ntuple(b -> basis(y), Val(B))
239250
pullback_prep = prepare_pullback_nokwarg(
240-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
251+
strict, f_or_f!y..., backend, x, seed_example, contexts...
241252
)
242253
return PullbackJacobianPrep(
243-
_sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
254+
_sig,
255+
batch_size_settings,
256+
batched_seeds,
257+
batched_results,
258+
seed_example,
259+
pullback_prep,
244260
)
245261
end
246262

@@ -363,11 +379,11 @@ function _jacobian_aux(
363379
x,
364380
contexts::Vararg{Context,C},
365381
) where {FY,SIG,B,aligned,C}
366-
(; batch_size_settings, batched_seeds, pushforward_prep) = prep
382+
(; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
367383
(; A, B_last) = batch_size_settings
368384

369385
pushforward_prep_same = prepare_pushforward_same_point(
370-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
386+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
371387
)
372388

373389
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -419,11 +435,11 @@ function _jacobian_aux(
419435
x,
420436
contexts::Vararg{Context,C},
421437
) where {FY,SIG,B,aligned,C}
422-
(; batch_size_settings, batched_seeds, pullback_prep) = prep
438+
(; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep
423439
(; A, B_last) = batch_size_settings
424440

425441
pullback_prep_same = prepare_pullback_same_point(
426-
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
442+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
427443
)
428444

429445
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
@@ -451,11 +467,13 @@ function _jacobian_aux!(
451467
x,
452468
contexts::Vararg{Context,C},
453469
) where {FY,SIG,B,C}
454-
(; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470+
(;
471+
batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472+
) = prep
455473
(; N) = batch_size_settings
456474

457475
pushforward_prep_same = prepare_pushforward_same_point(
458-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
476+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
459477
)
460478

461479
for a in eachindex(batched_seeds, batched_results)
@@ -487,11 +505,12 @@ function _jacobian_aux!(
487505
x,
488506
contexts::Vararg{Context,C},
489507
) where {FY,SIG,B,C}
490-
(; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
508+
(; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509+
prep
491510
(; N) = batch_size_settings
492511

493512
pullback_prep_same = prepare_pullback_same_point(
494-
f_or_f!y..., pullback_prep, backend, x, batched_seeds[1], contexts...
513+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
495514
)
496515

497516
for a in eachindex(batched_seeds, batched_results)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ function _prepare_pullback_aux(
285285
contexts::Vararg{Context,C};
286286
) where {F,C}
287287
_sig = signature(f, backend, x, ty, contexts...; strict)
288-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
288+
dx = if x isa Number
289+
oneunit(x)
290+
else
291+
basis(x)
292+
end
289293
pushforward_prep = prepare_pushforward_nokwarg(
290294
strict, f, backend, x, (dx,), contexts...
291295
)
@@ -303,7 +307,11 @@ function _prepare_pullback_aux(
303307
contexts::Vararg{Context,C};
304308
) where {F,C}
305309
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
306-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
310+
dx = if x isa Number
311+
oneunit(x)
312+
else
313+
basis(x)
314+
end
307315
pushforward_prep = prepare_pushforward_nokwarg(
308316
strict, f!, y, backend, x, (dx,), contexts...
309317
)

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ function _prepare_pushforward_aux(
290290
) where {F,C}
291291
_sig = signature(f, backend, x, tx, contexts...; strict)
292292
y = f(x, map(unwrap, contexts)...)
293-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
293+
dy = if y isa Number
294+
oneunit(y)
295+
else
296+
basis(y)
297+
end
294298
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
295299
return PullbackPushforwardPrep(_sig, pullback_prep)
296300
end
@@ -306,7 +310,7 @@ function _prepare_pushforward_aux(
306310
contexts::Vararg{Context,C};
307311
) where {F,C}
308312
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
309-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
313+
dy = basis(y)
310314
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
311315
return PullbackPushforwardPrep(_sig, pullback_prep)
312316
end

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
8484
BS<:BatchSizeSettings,
8585
S<:AbstractVector{<:NTuple},
8686
R<:AbstractVector{<:NTuple},
87+
SE<:NTuple,
8788
E2<:HVPPrep,
8889
E1<:GradientPrep,
8990
} <: HessianPrep{SIG}
9091
_sig::Val{SIG}
9192
batch_size_settings::BS
9293
batched_seeds::S
9394
batched_results::R
95+
seed_example::SE
9496
hvp_prep::E2
9597
gradient_prep::E1
9698
end
@@ -119,10 +121,17 @@ function _prepare_hessian_aux(
119121
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
120122
]
121123
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
122-
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...)
124+
seed_example = ntuple(b -> basis(x), Val(B))
125+
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, seed_example, contexts...)
123126
gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...)
124127
return HVPGradientHessianPrep(
125-
_sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
128+
_sig,
129+
batch_size_settings,
130+
batched_seeds,
131+
batched_results,
132+
seed_example,
133+
hvp_prep,
134+
gradient_prep,
126135
)
127136
end
128137

@@ -150,11 +159,11 @@ function hessian(
150159
contexts::Vararg{Context,C},
151160
) where {F,SIG,B,aligned,C}
152161
check_prep(f, prep, backend, x, contexts...)
153-
(; batch_size_settings, batched_seeds, hvp_prep) = prep
162+
(; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep
154163
(; A, B_last) = batch_size_settings
155164

156165
hvp_prep_same = prepare_hvp_same_point(
157-
f, hvp_prep, backend, x, batched_seeds[1], contexts...
166+
f, hvp_prep, backend, x, seed_example, contexts...
158167
)
159168

160169
hess = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -178,11 +187,11 @@ function hessian!(
178187
contexts::Vararg{Context,C},
179188
) where {F,SIG,B,C}
180189
check_prep(f, prep, backend, x, contexts...)
181-
(; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep
190+
(; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep
182191
(; N) = batch_size_settings
183192

184193
hvp_prep_same = prepare_hvp_same_point(
185-
f, hvp_prep, backend, x, batched_seeds[1], contexts...
194+
f, hvp_prep, backend, x, seed_example, contexts...
186195
)
187196

188197
for a in eachindex(batched_seeds, batched_results)
Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,39 @@
1-
"""
2-
basis(a::AbstractArray, i)
1+
pre_basis(a::AbstractArray{T}) where {T} = fill!(similar(a), zero(T))
32

4-
Construct the `i`-th standard basis array in the vector space of `a`.
5-
"""
6-
function basis(a::AbstractArray{T}, i) where {T}
7-
b = similar(a)
8-
fill!(b, zero(T))
9-
b[i] = oneunit(T)
3+
function post_basis(b::AbstractArray, a::AbstractArray)
104
if ismutable_array(a)
115
return b
126
else
137
return map(+, zero(a), b)
148
end
159
end
1610

11+
"""
12+
basis(a::AbstractArray, i)
13+
14+
Construct the `i`-th standard basis array in the vector space of `a`.
15+
"""
16+
function basis(a::AbstractArray, i)
17+
b = pre_basis(a)
18+
b[i] = oneunit(eltype(b))
19+
return post_basis(b, a)
20+
end
21+
22+
# compatible with zero-length vectors
23+
function basis(a::AbstractArray)
24+
b = pre_basis(a)
25+
return post_basis(b, a)
26+
end
27+
1728
"""
1829
multibasis(a::AbstractArray, inds)
1930
2031
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
2132
"""
22-
function multibasis(a::AbstractArray{T}, inds) where {T}
23-
b = similar(a)
24-
fill!(b, zero(T))
33+
function multibasis(a::AbstractArray, inds)
34+
b = pre_basis(a)
2535
for i in inds
26-
b[i] = oneunit(T)
36+
b[i] = oneunit(eltype(b))
2737
end
28-
return ismutable_array(a) ? b : map(+, zero(a), b)
38+
return post_basis(b, a)
2939
end

0 commit comments

Comments
 (0)