Skip to content

Commit 2eb6299

Browse files
committed
fix: replace one with oneunit for basis computation
1 parent ea73473 commit 2eb6299

6 files changed

Lines changed: 29 additions & 18 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Aqua = "0.8.12"
6161
ChainRulesCore = "1.23.0"
6262
ComponentArrays = "0.15.27"
6363
DataFrames = "1.7.0"
64+
Dates = "1"
6465
DiffResults = "1.1.0"
6566
Diffractor = "=0.2.6"
6667
Enzyme = "0.13.39"
@@ -98,6 +99,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
9899
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99100
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
100101
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
102+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
101103
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
102104
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
103105
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
@@ -130,6 +132,7 @@ test = [
130132
"Aqua",
131133
"ComponentArrays",
132134
"DataFrames",
135+
"Dates",
133136
"ExplicitImports",
134137
"JET",
135138
"JLArrays",

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ using GPUArraysCore: @allowscalar, AbstractGPUArray
66
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
77
b = similar(a)
88
fill!(b, zero(T))
9-
@allowscalar b[i] = one(T)
9+
@allowscalar b[i] = oneunit(T)
1010
return b
1111
end
1212

1313
function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
1414
b = similar(a)
1515
fill!(b, zero(T))
16-
view(b, inds) .= one(T)
16+
view(b, inds) .= oneunit(T)
1717
return b
1818
end
1919

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function prepare_derivative_nokwarg(
143143
) where {F,C}
144144
_sig = signature(f, backend, x, contexts...; strict)
145145
pushforward_prep = prepare_pushforward_nokwarg(
146-
strict, f, backend, x, (one(x),), contexts...
146+
strict, f, backend, x, (oneunit(x),), contexts...
147147
)
148148
return PushforwardDerivativePrep(_sig, pushforward_prep)
149149
end
@@ -153,7 +153,7 @@ function prepare_derivative_nokwarg(
153153
) where {F,C}
154154
_sig = signature(f!, y, backend, x, contexts...; strict)
155155
pushforward_prep = prepare_pushforward_nokwarg(
156-
strict, f!, y, backend, x, (one(x),), contexts...
156+
strict, f!, y, backend, x, (oneunit(x),), contexts...
157157
)
158158
return PushforwardDerivativePrep(_sig, pushforward_prep)
159159
end
@@ -169,7 +169,7 @@ function value_and_derivative(
169169
) where {F,C}
170170
check_prep(f, prep, backend, x, contexts...)
171171
y, ty = value_and_pushforward(
172-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
172+
f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
173173
)
174174
return y, only(ty)
175175
end
@@ -184,7 +184,7 @@ function value_and_derivative!(
184184
) where {F,C}
185185
check_prep(f, prep, backend, x, contexts...)
186186
y, _ = value_and_pushforward!(
187-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
187+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
188188
)
189189
return y, der
190190
end
@@ -197,7 +197,7 @@ function derivative(
197197
contexts::Vararg{Context,C},
198198
) where {F,C}
199199
check_prep(f, prep, backend, x, contexts...)
200-
ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
200+
ty = pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
201201
return only(ty)
202202
end
203203

@@ -210,7 +210,7 @@ function derivative!(
210210
contexts::Vararg{Context,C},
211211
) where {F,C}
212212
check_prep(f, prep, backend, x, contexts...)
213-
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
213+
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
214214
return der
215215
end
216216

@@ -226,7 +226,7 @@ function value_and_derivative(
226226
) where {F,C}
227227
check_prep(f!, y, prep, backend, x, contexts...)
228228
y, ty = value_and_pushforward(
229-
f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...
229+
f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
230230
)
231231
return y, only(ty)
232232
end
@@ -242,7 +242,7 @@ function value_and_derivative!(
242242
) where {F,C}
243243
check_prep(f!, y, prep, backend, x, contexts...)
244244
y, _ = value_and_pushforward!(
245-
f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
245+
f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
246246
)
247247
return y, der
248248
end
@@ -256,7 +256,7 @@ function derivative(
256256
contexts::Vararg{Context,C},
257257
) where {F,C}
258258
check_prep(f!, y, prep, backend, x, contexts...)
259-
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...)
259+
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
260260
return only(ty)
261261
end
262262

@@ -270,7 +270,9 @@ function derivative!(
270270
contexts::Vararg{Context,C},
271271
) where {F,C}
272272
check_prep(f!, y, prep, backend, x, contexts...)
273-
pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
273+
pushforward!(
274+
f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
275+
)
274276
return der
275277
end
276278

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ function value_and_gradient(
106106
contexts::Vararg{Context,C},
107107
) where {F,SIG,Y,C}
108108
check_prep(f, prep, backend, x, contexts...)
109-
y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
109+
y, tx = value_and_pullback(
110+
f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...
111+
)
110112
return y, only(tx)
111113
end
112114

@@ -120,7 +122,7 @@ function value_and_gradient!(
120122
) where {F,SIG,Y,C}
121123
check_prep(f, prep, backend, x, contexts...)
122124
y, _ = value_and_pullback!(
123-
f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...
125+
f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...
124126
)
125127
return y, grad
126128
end
@@ -133,7 +135,7 @@ function gradient(
133135
contexts::Vararg{Context,C},
134136
) where {F,SIG,Y,C}
135137
check_prep(f, prep, backend, x, contexts...)
136-
tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
138+
tx = pullback(f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...)
137139
return only(tx)
138140
end
139141

@@ -146,7 +148,7 @@ function gradient!(
146148
contexts::Vararg{Context,C},
147149
) where {F,SIG,Y,C}
148150
check_prep(f, prep, backend, x, contexts...)
149-
pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...)
151+
pullback!(f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...)
150152
return grad
151153
end
152154

DifferentiationInterface/src/utils/basis.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`.
66
function basis(a::AbstractArray{T}, i) where {T}
77
b = similar(a)
88
fill!(b, zero(T))
9-
b[i] = one(T)
9+
b[i] = oneunit(T)
1010
if ismutable_array(a)
1111
return b
1212
else
@@ -23,7 +23,7 @@ function multibasis(a::AbstractArray{T}, inds) where {T}
2323
b = similar(a)
2424
fill!(b, zero(T))
2525
for i in inds
26-
b[i] = one(T)
26+
b[i] = oneunit(T)
2727
end
2828
return ismutable_array(a) ? b : map(+, zero(a), b)
2929
end

DifferentiationInterface/test/Core/Internals/basis.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using DifferentiationInterface: basis, multibasis
22
using LinearAlgebra
33
using StaticArrays, JLArrays
44
using Test
5+
using Dates
56

67
@testset "Basis" begin
78
b_ref = [0, 1, 0]
@@ -22,4 +23,7 @@ using Test
2223
@test all(basis(jl(rand(3, 3)), 4) .== b_ref)
2324
@test basis(@SMatrix(rand(3, 3)), 4) isa SMatrix
2425
@test basis(@SMatrix(rand(3, 3)), 4) == b_ref
26+
27+
t = [Time(1) - Time(0)]
28+
@test basis(t, 1) isa Vector{Nanosecond}
2529
end

0 commit comments

Comments
 (0)