Skip to content

Commit 6280762

Browse files
committed
Merge remote-tracking branch 'origin/main' into gd/forward_mooncake
2 parents 2f9b365 + cfb1a94 commit 6280762

25 files changed

Lines changed: 213 additions & 88 deletions

File tree

.buildkite/pipeline.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
steps:
22
- label: "DI GPU tests"
3-
if: build.pull_request.labels includes "gpu"
43
plugins:
54
- JuliaCI/julia#v1:
65
version: "1"

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
name: ${{ matrix.version }} - DI (${{ matrix.group }})
2121
runs-on: ubuntu-latest
2222
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
23-
timeout-minutes: 60
23+
timeout-minutes: 120
2424
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
2525
actions: write
2626
contents: read

DifferentiationInterface/CHANGELOG.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.7.4]
11+
12+
### Added
13+
14+
- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#825])
15+
16+
### Fixed
17+
18+
- Replace `one` with `oneunit` in basis computation ([#826])
19+
20+
## [0.7.3]
21+
22+
### Fixed
23+
24+
- Bump compat for SparseConnectivityTracer v1 ([#823])
25+
26+
## [0.7.2]
27+
28+
### Feat
29+
30+
- Backend switching for Mooncake ([#768])
31+
32+
### Fixed
33+
34+
- Speed up sparse preparation for GPU arrays ([#818])
35+
1036
## [0.7.1]
1137

1238
### Feat
@@ -38,17 +64,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3864

3965
- Allocate Enzyme shadow memory during preparation ([#782])
4066

41-
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.1...main
67+
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.4...main
68+
[0.7.4]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.3...DifferentiationInterface-v0.7.4
69+
[0.7.3]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.2...DifferentiationInterface-v0.7.3
70+
[0.7.2]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.1...DifferentiationInterface-v0.7.2
4271
[0.7.1]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.0...DifferentiationInterface-v0.7.1
4372
[0.7.0]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.54...DifferentiationInterface-v0.7.0
4473
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
4574
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
4675

76+
[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826
77+
[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825
78+
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
79+
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
4780
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812
4881
[#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810
4982
[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809
5083
[#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799
5184
[#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795
5285
[#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790
5386
[#788]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/788
54-
[#782]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/782
87+
[#782]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/782
88+
[#768]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/768

DifferentiationInterface/Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.1"
4+
version = "0.7.4"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -42,7 +42,9 @@ DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4242
DifferentiationInterfaceGTPSAExt = "GTPSA"
4343
DifferentiationInterfaceMooncakeExt = "Mooncake"
4444
DifferentiationInterfacePolyesterForwardDiffExt = [
45-
"PolyesterForwardDiff", "ForwardDiff", "DiffResults"
45+
"PolyesterForwardDiff",
46+
"ForwardDiff",
47+
"DiffResults",
4648
]
4749
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4850
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
@@ -54,11 +56,12 @@ DifferentiationInterfaceTrackerExt = "Tracker"
5456
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
5557

5658
[compat]
57-
Aqua = "0.8.12"
5859
ADTypes = "1.13.0"
60+
Aqua = "0.8.12"
5961
ChainRulesCore = "1.23.0"
6062
ComponentArrays = "0.15.27"
6163
DataFrames = "1.7.0"
64+
Dates = "1"
6265
DiffResults = "1.1.0"
6366
Diffractor = "=0.2.6"
6467
Enzyme = "0.13.39"
@@ -80,7 +83,7 @@ PolyesterForwardDiff = "0.1.2"
8083
Random = "1"
8184
ReverseDiff = "1.15.1"
8285
SparseArrays = "1"
83-
SparseConnectivityTracer = "0.6.14"
86+
SparseConnectivityTracer = "0.6.14, 1"
8487
SparseMatrixColorings = "0.4.9"
8588
StableRNGs = "1.0.1"
8689
StaticArrays = "1.9.7"
@@ -96,6 +99,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
9699
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
97100
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
98101
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
102+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
99103
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
100104
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
101105
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
@@ -128,6 +132,7 @@ test = [
128132
"Aqua",
129133
"ComponentArrays",
130134
"DataFrames",
135+
"Dates",
131136
"ExplicitImports",
132137
"JET",
133138
"JLArrays",

DifferentiationInterface/docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ MixedMode
132132
DenseSparsityDetector
133133
```
134134

135+
### From primitive
136+
137+
```@docs
138+
DifferentiationInterface.AutoForwardFromPrimitive
139+
DifferentiationInterface.AutoReverseFromPrimitive
140+
```
141+
135142
## Internals
136143

137144
The following is not part of the public API.

DifferentiationInterface/docs/src/explanation/advanced.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,16 @@ AutoSparse(
8686
)
8787
```
8888

89-
At the moment, mixed mode tends to work best when the [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) is provided with a [`RandomOrder`](@extref SparseMatrixColorings.RandomOrder) instead of the usual [`NaturalOrder`](@extref SparseMatrixColorings.NaturalOrder).
89+
At the moment, mixed mode tends to work best (output fewer colors) when the [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) is provided with a [`RandomOrder`](@extref SparseMatrixColorings.RandomOrder) instead of the usual [`NaturalOrder`](@extref SparseMatrixColorings.NaturalOrder), and when "post-processing" is activated after coloring.
90+
For full reproducibility, you should use a random number generator from [StableRNGs.jl](https://github.com/JuliaRandom/StableRNGs.jl).
91+
Thus, the right setup looks like:
92+
93+
```julia
94+
using StableRNGs
95+
96+
seed = 3
97+
coloring_algorithm = GreedyColoringAlgorithm(RandomOrder(StableRNG(seed), seed); postprocessing=true)
98+
```
9099

91100
## Batch mode
92101

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,4 @@ For same-point preparation, the same rules hold with two modifications:
152152

153153
!!! warning
154154
These rules hold for the majority of backends, but there are some exceptions.
155-
The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function.
155+
The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function.

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,27 +189,27 @@ end
189189
function DI.value_and_derivative(
190190
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
191191
) where {F,C}
192-
y, ty = DI.value_and_pushforward(f, backend, x, (one(x),), contexts...)
192+
y, ty = DI.value_and_pushforward(f, backend, x, (oneunit(x),), contexts...)
193193
return y, only(ty)
194194
end
195195

196196
function DI.value_and_derivative!(
197197
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
198198
) where {F,C}
199-
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (one(x),), contexts...)
199+
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
200200
return y, der
201201
end
202202

203203
function DI.derivative(
204204
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
205205
) where {F,C}
206-
return only(DI.pushforward(f, backend, x, (one(x),), contexts...))
206+
return only(DI.pushforward(f, backend, x, (oneunit(x),), contexts...))
207207
end
208208

209209
function DI.derivative!(
210210
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
211211
) where {F,C}
212-
DI.pushforward!(f, (der,), backend, x, (one(x),), contexts...)
212+
DI.pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
213213
return der
214214
end
215215

@@ -220,7 +220,7 @@ function DI.prepare_derivative_nokwarg(
220220
) where {F,C}
221221
_sig = DI.signature(f, backend, x, contexts...; strict)
222222
pushforward_prep = DI.prepare_pushforward_nokwarg(
223-
strict, f, backend, x, (one(x),), contexts...
223+
strict, f, backend, x, (oneunit(x),), contexts...
224224
)
225225
return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep)
226226
end
@@ -234,7 +234,7 @@ function DI.value_and_derivative(
234234
) where {F,C}
235235
DI.check_prep(f, prep, backend, x, contexts...)
236236
y, ty = DI.value_and_pushforward(
237-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
237+
f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
238238
)
239239
return y, only(ty)
240240
end
@@ -249,7 +249,7 @@ function DI.value_and_derivative!(
249249
) where {F,C}
250250
DI.check_prep(f, prep, backend, x, contexts...)
251251
y, _ = DI.value_and_pushforward!(
252-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
252+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
253253
)
254254
return y, der
255255
end
@@ -263,7 +263,7 @@ function DI.derivative(
263263
) where {F,C}
264264
DI.check_prep(f, prep, backend, x, contexts...)
265265
return only(
266-
DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
266+
DI.pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
267267
)
268268
end
269269

@@ -276,7 +276,9 @@ function DI.derivative!(
276276
contexts::Vararg{DI.Context,C},
277277
) where {F,C}
278278
DI.check_prep(f, prep, backend, x, contexts...)
279-
DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
279+
DI.pushforward!(
280+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
281+
)
280282
return der
281283
end
282284

@@ -638,9 +640,9 @@ function DI.second_derivative(
638640
) where {F,C}
639641
DI.check_prep(f, prep, backend, x, contexts...)
640642
T = tag_type(f, backend, x)
641-
xdual = make_dual(T, x, one(x))
643+
xdual = make_dual(T, x, oneunit(x))
642644
T2 = tag_type(f, backend, xdual)
643-
xdual2 = make_dual(T2, xdual, one(xdual))
645+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
644646
contexts_dual = translate(typeof(xdual2), contexts)
645647
ydual = f(xdual2, contexts_dual...)
646648
return myderivative(T, myderivative(T2, ydual))
@@ -656,9 +658,9 @@ function DI.second_derivative!(
656658
) where {F,C}
657659
DI.check_prep(f, prep, backend, x, contexts...)
658660
T = tag_type(f, backend, x)
659-
xdual = make_dual(T, x, one(x))
661+
xdual = make_dual(T, x, oneunit(x))
660662
T2 = tag_type(f, backend, xdual)
661-
xdual2 = make_dual(T2, xdual, one(xdual))
663+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
662664
contexts_dual = translate(typeof(xdual2), contexts)
663665
ydual = f(xdual2, contexts_dual...)
664666
return myderivative!(T, der2, myderivative(T2, ydual))
@@ -673,9 +675,9 @@ function DI.value_derivative_and_second_derivative(
673675
) where {F,C}
674676
DI.check_prep(f, prep, backend, x, contexts...)
675677
T = tag_type(f, backend, x)
676-
xdual = make_dual(T, x, one(x))
678+
xdual = make_dual(T, x, oneunit(x))
677679
T2 = tag_type(f, backend, xdual)
678-
xdual2 = make_dual(T2, xdual, one(xdual))
680+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
679681
contexts_dual = translate(typeof(xdual2), contexts)
680682
ydual = f(xdual2, contexts_dual...)
681683
y = myvalue(T, myvalue(T2, ydual))
@@ -695,9 +697,9 @@ function DI.value_derivative_and_second_derivative!(
695697
) where {F,C}
696698
DI.check_prep(f, prep, backend, x, contexts...)
697699
T = tag_type(f, backend, x)
698-
xdual = make_dual(T, x, one(x))
700+
xdual = make_dual(T, x, oneunit(x))
699701
T2 = tag_type(f, backend, xdual)
700-
xdual2 = make_dual(T2, xdual, one(xdual))
702+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
701703
contexts_dual = translate(typeof(xdual2), contexts)
702704
ydual = f(xdual2, contexts_dual...)
703705
y = myvalue(T, myvalue(T2, ydual))
@@ -756,7 +758,7 @@ function DI.value_gradient_and_hessian!(
756758
contexts isa NTuple{C,DI.GeneralizedConstant}
757759
)
758760
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
759-
result = DiffResult(one(eltype(x)), (grad, hess))
761+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
760762
result = hessian!(result, fc, x)
761763
y = DR.value(result)
762764
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
@@ -855,7 +857,7 @@ function DI.value_gradient_and_hessian!(
855857
DI.check_prep(f, prep, backend, x, contexts...)
856858
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
857859
fc = DI.fix_tail(f, contexts_dual...)
858-
result = DiffResult(one(eltype(x)), (grad, hess))
860+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
859861
CHK = tag_type(backend) === Nothing
860862
if CHK
861863
checktag(prep.result_config, f, x)

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@ using GPUArraysCore: @allowscalar, AbstractGPUArray
66
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
77
b = similar(a)
88
fill!(b, zero(T))
9-
@allowscalar b[i] = one(T)
9+
@allowscalar b[i] = oneunit(T)
1010
return b
1111
end
1212

1313
function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
1414
b = similar(a)
1515
fill!(b, zero(T))
16-
for i in inds
17-
@allowscalar b[i] = one(T)
18-
end
16+
view(b, inds) .= oneunit(T)
1917
return b
2018
end
2119

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ export AutoSparse
128128
## Public but not exported
129129

130130
@public inner, outer
131+
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
131132

132133
include("init.jl")
133134

0 commit comments

Comments
 (0)