Skip to content

Commit e19fe6b

Browse files
committed
Replace one with oneunit everywhere
1 parent 2eb6299 commit e19fe6b

8 files changed

Lines changed: 46 additions & 37 deletions

File tree

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Fixed
11+
12+
- Replace `one` with `oneunit` in basis computation ([#825])
13+
1014
## [0.7.3]
1115

1216
### Fixed
@@ -62,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6266
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
6367
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
6468

69+
[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825
6570
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
6671
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
6772
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,4 @@ For same-point preparation, the same rules hold with two modifications:
152152

153153
!!! warning
154154
These rules hold for the majority of backends, but there are some exceptions.
155-
The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function.
155+
The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function.

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,27 +189,27 @@ end
189189
function DI.value_and_derivative(
190190
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
191191
) where {F,C}
192-
y, ty = DI.value_and_pushforward(f, backend, x, (one(x),), contexts...)
192+
y, ty = DI.value_and_pushforward(f, backend, x, (oneunit(x),), contexts...)
193193
return y, only(ty)
194194
end
195195

196196
function DI.value_and_derivative!(
197197
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
198198
) where {F,C}
199-
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (one(x),), contexts...)
199+
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
200200
return y, der
201201
end
202202

203203
function DI.derivative(
204204
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
205205
) where {F,C}
206-
return only(DI.pushforward(f, backend, x, (one(x),), contexts...))
206+
return only(DI.pushforward(f, backend, x, (oneunit(x),), contexts...))
207207
end
208208

209209
function DI.derivative!(
210210
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
211211
) where {F,C}
212-
DI.pushforward!(f, (der,), backend, x, (one(x),), contexts...)
212+
DI.pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
213213
return der
214214
end
215215

@@ -220,7 +220,7 @@ function DI.prepare_derivative_nokwarg(
220220
) where {F,C}
221221
_sig = DI.signature(f, backend, x, contexts...; strict)
222222
pushforward_prep = DI.prepare_pushforward_nokwarg(
223-
strict, f, backend, x, (one(x),), contexts...
223+
strict, f, backend, x, (oneunit(x),), contexts...
224224
)
225225
return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep)
226226
end
@@ -234,7 +234,7 @@ function DI.value_and_derivative(
234234
) where {F,C}
235235
DI.check_prep(f, prep, backend, x, contexts...)
236236
y, ty = DI.value_and_pushforward(
237-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
237+
f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
238238
)
239239
return y, only(ty)
240240
end
@@ -249,7 +249,7 @@ function DI.value_and_derivative!(
249249
) where {F,C}
250250
DI.check_prep(f, prep, backend, x, contexts...)
251251
y, _ = DI.value_and_pushforward!(
252-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
252+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
253253
)
254254
return y, der
255255
end
@@ -263,7 +263,7 @@ function DI.derivative(
263263
) where {F,C}
264264
DI.check_prep(f, prep, backend, x, contexts...)
265265
return only(
266-
DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
266+
DI.pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
267267
)
268268
end
269269

@@ -276,7 +276,9 @@ function DI.derivative!(
276276
contexts::Vararg{DI.Context,C},
277277
) where {F,C}
278278
DI.check_prep(f, prep, backend, x, contexts...)
279-
DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
279+
DI.pushforward!(
280+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
281+
)
280282
return der
281283
end
282284

@@ -638,9 +640,9 @@ function DI.second_derivative(
638640
) where {F,C}
639641
DI.check_prep(f, prep, backend, x, contexts...)
640642
T = tag_type(f, backend, x)
641-
xdual = make_dual(T, x, one(x))
643+
xdual = make_dual(T, x, oneunit(x))
642644
T2 = tag_type(f, backend, xdual)
643-
xdual2 = make_dual(T2, xdual, one(xdual))
645+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
644646
contexts_dual = translate(typeof(xdual2), contexts)
645647
ydual = f(xdual2, contexts_dual...)
646648
return myderivative(T, myderivative(T2, ydual))
@@ -656,9 +658,9 @@ function DI.second_derivative!(
656658
) where {F,C}
657659
DI.check_prep(f, prep, backend, x, contexts...)
658660
T = tag_type(f, backend, x)
659-
xdual = make_dual(T, x, one(x))
661+
xdual = make_dual(T, x, oneunit(x))
660662
T2 = tag_type(f, backend, xdual)
661-
xdual2 = make_dual(T2, xdual, one(xdual))
663+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
662664
contexts_dual = translate(typeof(xdual2), contexts)
663665
ydual = f(xdual2, contexts_dual...)
664666
return myderivative!(T, der2, myderivative(T2, ydual))
@@ -673,9 +675,9 @@ function DI.value_derivative_and_second_derivative(
673675
) where {F,C}
674676
DI.check_prep(f, prep, backend, x, contexts...)
675677
T = tag_type(f, backend, x)
676-
xdual = make_dual(T, x, one(x))
678+
xdual = make_dual(T, x, oneunit(x))
677679
T2 = tag_type(f, backend, xdual)
678-
xdual2 = make_dual(T2, xdual, one(xdual))
680+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
679681
contexts_dual = translate(typeof(xdual2), contexts)
680682
ydual = f(xdual2, contexts_dual...)
681683
y = myvalue(T, myvalue(T2, ydual))
@@ -695,9 +697,9 @@ function DI.value_derivative_and_second_derivative!(
695697
) where {F,C}
696698
DI.check_prep(f, prep, backend, x, contexts...)
697699
T = tag_type(f, backend, x)
698-
xdual = make_dual(T, x, one(x))
700+
xdual = make_dual(T, x, oneunit(x))
699701
T2 = tag_type(f, backend, xdual)
700-
xdual2 = make_dual(T2, xdual, one(xdual))
702+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
701703
contexts_dual = translate(typeof(xdual2), contexts)
702704
ydual = f(xdual2, contexts_dual...)
703705
y = myvalue(T, myvalue(T2, ydual))
@@ -756,7 +758,7 @@ function DI.value_gradient_and_hessian!(
756758
contexts isa NTuple{C,DI.GeneralizedConstant}
757759
)
758760
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
759-
result = DiffResult(one(eltype(x)), (grad, hess))
761+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
760762
result = hessian!(result, fc, x)
761763
y = DR.value(result)
762764
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
@@ -855,7 +857,7 @@ function DI.value_gradient_and_hessian!(
855857
DI.check_prep(f, prep, backend, x, contexts...)
856858
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
857859
fc = DI.fix_tail(f, contexts_dual...)
858-
result = DiffResult(one(eltype(x)), (grad, hess))
860+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
859861
CHK = tag_type(backend) === Nothing
860862
if CHK
861863
checktag(prep.result_config, f, x)

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function prepare_gradient_nokwarg(
9191
_sig = signature(f, backend, x, contexts...; strict)
9292
y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference?
9393
pullback_prep = prepare_pullback_nokwarg(
94-
strict, f, backend, x, (one(typeof(y)),), contexts...
94+
strict, f, backend, x, (oneunit(typeof(y)),), contexts...
9595
)
9696
return PullbackGradientPrep(_sig, y, pullback_prep)
9797
end

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ function _prepare_pullback_aux(
280280
contexts::Vararg{Context,C};
281281
) where {F,C}
282282
_sig = signature(f, backend, x, ty, contexts...; strict)
283-
dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x)))
283+
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
284284
pushforward_prep = prepare_pushforward_nokwarg(
285285
strict, f, backend, x, (dx,), contexts...
286286
)
@@ -298,7 +298,7 @@ function _prepare_pullback_aux(
298298
contexts::Vararg{Context,C};
299299
) where {F,C}
300300
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
301-
dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x)))
301+
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
302302
pushforward_prep = prepare_pushforward_nokwarg(
303303
strict, f!, y, backend, x, (dx,), contexts...
304304
)
@@ -315,7 +315,7 @@ function _pullback_via_pushforward(
315315
dy,
316316
contexts::Vararg{Context,C},
317317
) where {F,C}
318-
a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...))
318+
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
319319
dx = dot(a, dy)
320320
return dx
321321
end
@@ -328,8 +328,8 @@ function _pullback_via_pushforward(
328328
dy,
329329
contexts::Vararg{Context,C},
330330
) where {F,C}
331-
a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...))
332-
b = only(pushforward(f, pushforward_prep, backend, x, (im * one(x),), contexts...))
331+
a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
332+
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
333333
dx = real(dot(a, dy)) + im * real(dot(b, dy))
334334
return dx
335335
end
@@ -436,7 +436,7 @@ function _pullback_via_pushforward(
436436
dy,
437437
contexts::Vararg{Context,C},
438438
) where {F,C}
439-
a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...))
439+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
440440
dx = dot(a, dy)
441441
return dx
442442
end
@@ -450,8 +450,10 @@ function _pullback_via_pushforward(
450450
dy,
451451
contexts::Vararg{Context,C},
452452
) where {F,C}
453-
a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...))
454-
b = only(pushforward(f!, y, pushforward_prep, backend, x, (im * one(x),), contexts...))
453+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
454+
b = only(
455+
pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
456+
)
455457
dx = real(dot(a, dy)) + im * real(dot(b, dy))
456458
return dx
457459
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ function _prepare_pushforward_aux(
285285
) where {F,C}
286286
_sig = signature(f, backend, x, tx, contexts...; strict)
287287
y = f(x, map(unwrap, contexts)...)
288-
dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y)))
288+
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
289289
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
290290
return PullbackPushforwardPrep(_sig, pullback_prep)
291291
end
@@ -301,7 +301,7 @@ function _prepare_pushforward_aux(
301301
contexts::Vararg{Context,C};
302302
) where {F,C}
303303
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
304-
dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y)))
304+
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
305305
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
306306
return PullbackPushforwardPrep(_sig, pullback_prep)
307307
end
@@ -317,7 +317,7 @@ function _pushforward_via_pullback(
317317
dx,
318318
contexts::Vararg{Context,C},
319319
) where {F,C}
320-
a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...))
320+
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
321321
dy = dot(a, dx)
322322
return dy
323323
end
@@ -331,8 +331,8 @@ function _pushforward_via_pullback(
331331
dx,
332332
contexts::Vararg{Context,C},
333333
) where {F,C}
334-
a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...))
335-
b = only(pullback(f, pullback_prep, backend, x, (im * one(y),), contexts...))
334+
a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...))
335+
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...))
336336
dy = real(dot(a, dx)) + im * real(dot(b, dx))
337337
return dy
338338
end

DifferentiationInterfaceTest/src/scenarios/allocfree.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function identity_scenarios(x::Number; dx::Number, dy::Number)
22
f = identity
33
dy_from_dx = dx
44
dx_from_dy = dy
5-
der = one(x)
5+
der = oneunit(x)
66

77
return [
88
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),
@@ -16,7 +16,7 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number)
1616
dy_from_dx = sum(dx)
1717
dx_from_dy = (similar(x) .= dy)
1818
grad = similar(x)
19-
grad .= one(eltype(x))
19+
grad .= oneunit(eltype(x))
2020

2121
return [
2222
Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)),

DifferentiationInterfaceTest/src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mysimilar(x::Number) = one(x)
1+
mysimilar(x::Number) = oneunit(x)
22
mysimilar(x::AbstractArray) = similar(x)
33
mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x)
44

0 commit comments

Comments
 (0)