diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 2b6fe4060..dc80a5b6d 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Replace `one` with `oneunit` in basis computation ([#826]) + ## [0.7.3] ### Fixed @@ -62,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 +[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826 [#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823 [#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818 [#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812 diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index fd7ae96cc..e75fa8d36 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -61,6 +61,7 @@ Aqua = "0.8.12" ChainRulesCore = "1.23.0" ComponentArrays = "0.15.27" DataFrames = "1.7.0" +Dates = "1" DiffResults = "1.1.0" Diffractor = "=0.2.6" Enzyme = "0.13.39" @@ -98,6 +99,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -130,6 +132,7 @@ test = [ "Aqua", "ComponentArrays", "DataFrames", + "Dates", "ExplicitImports", "JET", "JLArrays", diff --git a/DifferentiationInterface/docs/src/explanation/operators.md b/DifferentiationInterface/docs/src/explanation/operators.md index 6ca844ae5..55d3c0518 100644 --- a/DifferentiationInterface/docs/src/explanation/operators.md +++ b/DifferentiationInterface/docs/src/explanation/operators.md @@ -152,4 +152,4 @@ For same-point preparation, the same rules hold with two modifications: !!! warning These rules hold for the majority of backends, but there are some exceptions. - The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function. + The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function. diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index bdfcb54a9..631e27076 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -189,27 +189,27 @@ end function DI.value_and_derivative( f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - y, ty = DI.value_and_pushforward(f, backend, x, (one(x),), contexts...) + y, ty = DI.value_and_pushforward(f, backend, x, (oneunit(x),), contexts...) return y, only(ty) end function DI.value_and_derivative!( f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (one(x),), contexts...) + y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...) return y, der end function DI.derivative( f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - return only(DI.pushforward(f, backend, x, (one(x),), contexts...)) + return only(DI.pushforward(f, backend, x, (oneunit(x),), contexts...)) end function DI.derivative!( f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C} ) where {F,C} - DI.pushforward!(f, (der,), backend, x, (one(x),), contexts...) + DI.pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...) return der end @@ -220,7 +220,7 @@ function DI.prepare_derivative_nokwarg( ) where {F,C} _sig = DI.signature(f, backend, x, contexts...; strict) pushforward_prep = DI.prepare_pushforward_nokwarg( - strict, f, backend, x, (one(x),), contexts... + strict, f, backend, x, (oneunit(x),), contexts... ) return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep) end @@ -234,7 +234,7 @@ function DI.value_and_derivative( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) y, ty = DI.value_and_pushforward( - f, prep.pushforward_prep, backend, x, (one(x),), contexts... + f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, only(ty) end @@ -249,7 +249,7 @@ function DI.value_and_derivative!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) y, _ = DI.value_and_pushforward!( - f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... + f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, der end @@ -263,7 +263,7 @@ function DI.derivative( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) return only( - DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) + DI.pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) ) end @@ -276,7 +276,9 @@ function DI.derivative!( contexts::Vararg{DI.Context,C}, ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) - DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) + DI.pushforward!( + f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... + ) return der end @@ -638,9 +640,9 @@ function DI.second_derivative( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) - xdual = make_dual(T, x, one(x)) + xdual = make_dual(T, x, oneunit(x)) T2 = tag_type(f, backend, xdual) - xdual2 = make_dual(T2, xdual, one(xdual)) + xdual2 = make_dual(T2, xdual, oneunit(xdual)) contexts_dual = translate(typeof(xdual2), contexts) ydual = f(xdual2, contexts_dual...) return myderivative(T, myderivative(T2, ydual)) @@ -656,9 +658,9 @@ function DI.second_derivative!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) - xdual = make_dual(T, x, one(x)) + xdual = make_dual(T, x, oneunit(x)) T2 = tag_type(f, backend, xdual) - xdual2 = make_dual(T2, xdual, one(xdual)) + xdual2 = make_dual(T2, xdual, oneunit(xdual)) contexts_dual = translate(typeof(xdual2), contexts) ydual = f(xdual2, contexts_dual...) return myderivative!(T, der2, myderivative(T2, ydual)) @@ -673,9 +675,9 @@ function DI.value_derivative_and_second_derivative( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) - xdual = make_dual(T, x, one(x)) + xdual = make_dual(T, x, oneunit(x)) T2 = tag_type(f, backend, xdual) - xdual2 = make_dual(T2, xdual, one(xdual)) + xdual2 = make_dual(T2, xdual, oneunit(xdual)) contexts_dual = translate(typeof(xdual2), contexts) ydual = f(xdual2, contexts_dual...) y = myvalue(T, myvalue(T2, ydual)) @@ -695,9 +697,9 @@ function DI.value_derivative_and_second_derivative!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) T = tag_type(f, backend, x) - xdual = make_dual(T, x, one(x)) + xdual = make_dual(T, x, oneunit(x)) T2 = tag_type(f, backend, xdual) - xdual2 = make_dual(T2, xdual, one(xdual)) + xdual2 = make_dual(T2, xdual, oneunit(xdual)) contexts_dual = translate(typeof(xdual2), contexts) ydual = f(xdual2, contexts_dual...) y = myvalue(T, myvalue(T2, ydual)) @@ -756,7 +758,7 @@ function DI.value_gradient_and_hessian!( contexts isa NTuple{C,DI.GeneralizedConstant} ) fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) - result = DiffResult(one(eltype(x)), (grad, hess)) + result = DiffResult(oneunit(eltype(x)), (grad, hess)) result = hessian!(result, fc, x) y = DR.value(result) grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) @@ -855,7 +857,7 @@ function DI.value_gradient_and_hessian!( DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) fc = DI.fix_tail(f, contexts_dual...) - result = DiffResult(one(eltype(x)), (grad, hess)) + result = DiffResult(oneunit(eltype(x)), (grad, hess)) CHK = tag_type(backend) === Nothing if CHK checktag(prep.result_config, f, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl index 161412cfc..60d1ef6c0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl @@ -6,14 +6,14 @@ using GPUArraysCore: @allowscalar, AbstractGPUArray function DI.basis(a::AbstractGPUArray{T}, i) where {T} b = similar(a) fill!(b, zero(T)) - @allowscalar b[i] = one(T) + @allowscalar b[i] = oneunit(T) return b end function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T} b = similar(a) fill!(b, zero(T)) - view(b, inds) .= one(T) + view(b, inds) .= oneunit(T) return b end diff --git a/DifferentiationInterface/src/first_order/derivative.jl b/DifferentiationInterface/src/first_order/derivative.jl index 4883f1c82..ca2314c97 100644 --- a/DifferentiationInterface/src/first_order/derivative.jl +++ b/DifferentiationInterface/src/first_order/derivative.jl @@ -143,7 +143,7 @@ function prepare_derivative_nokwarg( ) where {F,C} _sig = signature(f, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward_nokwarg( - strict, f, backend, x, (one(x),), contexts... + strict, f, backend, x, (oneunit(x),), contexts... ) return PushforwardDerivativePrep(_sig, pushforward_prep) end @@ -153,7 +153,7 @@ function prepare_derivative_nokwarg( ) where {F,C} _sig = signature(f!, y, backend, x, contexts...; strict) pushforward_prep = prepare_pushforward_nokwarg( - strict, f!, y, backend, x, (one(x),), contexts... + strict, f!, y, backend, x, (oneunit(x),), contexts... ) return PushforwardDerivativePrep(_sig, pushforward_prep) end @@ -169,7 +169,7 @@ function value_and_derivative( ) where {F,C} check_prep(f, prep, backend, x, contexts...) y, ty = value_and_pushforward( - f, prep.pushforward_prep, backend, x, (one(x),), contexts... + f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, only(ty) end @@ -184,7 +184,7 @@ function value_and_derivative!( ) where {F,C} check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pushforward!( - f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... + f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, der end @@ -197,7 +197,7 @@ function derivative( contexts::Vararg{Context,C}, ) where {F,C} check_prep(f, prep, backend, x, contexts...) - ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...) + ty = pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return only(ty) end @@ -210,7 +210,7 @@ function derivative!( contexts::Vararg{Context,C}, ) where {F,C} check_prep(f, prep, backend, x, contexts...) - pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) + pushforward!(f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return der end @@ -226,7 +226,7 @@ function value_and_derivative( ) where {F,C} check_prep(f!, y, prep, backend, x, contexts...) y, ty = value_and_pushforward( - f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts... + f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, only(ty) end @@ -242,7 +242,7 @@ function value_and_derivative!( ) where {F,C} check_prep(f!, y, prep, backend, x, contexts...) y, _ = value_and_pushforward!( - f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts... + f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... ) return y, der end @@ -256,7 +256,7 @@ function derivative( contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, prep, backend, x, contexts...) - ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...) + ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...) return only(ty) end @@ -270,7 +270,9 @@ function derivative!( contexts::Vararg{Context,C}, ) where {F,C} check_prep(f!, y, prep, backend, x, contexts...) - pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...) + pushforward!( + f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts... + ) return der end diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 448adba16..2720307c7 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -91,7 +91,7 @@ function prepare_gradient_nokwarg( _sig = signature(f, backend, x, contexts...; strict) y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference? pullback_prep = prepare_pullback_nokwarg( - strict, f, backend, x, (one(typeof(y)),), contexts... + strict, f, backend, x, (oneunit(typeof(y)),), contexts... ) return PullbackGradientPrep(_sig, y, pullback_prep) end @@ -106,7 +106,9 @@ function value_and_gradient( contexts::Vararg{Context,C}, ) where {F,SIG,Y,C} check_prep(f, prep, backend, x, contexts...) - y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) + y, tx = value_and_pullback( + f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts... + ) return y, only(tx) end @@ -120,7 +122,7 @@ function value_and_gradient!( ) where {F,SIG,Y,C} check_prep(f, prep, backend, x, contexts...) y, _ = value_and_pullback!( - f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts... + f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts... ) return y, grad end @@ -133,7 +135,7 @@ function gradient( contexts::Vararg{Context,C}, ) where {F,SIG,Y,C} check_prep(f, prep, backend, x, contexts...) - tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...) + tx = pullback(f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...) return only(tx) end @@ -146,7 +148,7 @@ function gradient!( contexts::Vararg{Context,C}, ) where {F,SIG,Y,C} check_prep(f, prep, backend, x, contexts...) - pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...) + pullback!(f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...) return grad end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 9d207d9e8..1606f01d6 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -280,7 +280,7 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) + dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x))) pushforward_prep = prepare_pushforward_nokwarg( strict, f, backend, x, (dx,), contexts... ) @@ -298,7 +298,7 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - dx = x isa Number ? one(x) : basis(x, first(CartesianIndices(x))) + dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x))) pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (dx,), contexts... ) @@ -315,7 +315,7 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)) + a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) return dx end @@ -328,8 +328,8 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)) - b = only(pushforward(f, pushforward_prep, backend, x, (im * one(x),), contexts...)) + a = only(pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) + b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)) dx = real(dot(a, dy)) + im * real(dot(b, dy)) return dx end @@ -436,7 +436,7 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)) + a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) dx = dot(a, dy) return dx end @@ -450,8 +450,10 @@ function _pullback_via_pushforward( dy, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)) - b = only(pushforward(f!, y, pushforward_prep, backend, x, (im * one(x),), contexts...)) + a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) + b = only( + pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) + ) dx = real(dot(a, dy)) + im * real(dot(b, dy)) return dx end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 11cfbf185..d338c292d 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -285,7 +285,7 @@ function _prepare_pushforward_aux( ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) - dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) + dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y))) pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -301,7 +301,7 @@ function _prepare_pushforward_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - dy = y isa Number ? one(y) : basis(y, first(CartesianIndices(y))) + dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y))) pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -317,7 +317,7 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...)) + a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) dy = dot(a, dx) return dy end @@ -331,8 +331,8 @@ function _pushforward_via_pullback( dx, contexts::Vararg{Context,C}, ) where {F,C} - a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...)) - b = only(pullback(f, pullback_prep, backend, x, (im * one(y),), contexts...)) + a = only(pullback(f, pullback_prep, backend, x, (oneunit(y),), contexts...)) + b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y),), contexts...)) dy = real(dot(a, dx)) + im * real(dot(b, dx)) return dy end diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index 6d1b5a55e..7fc38fcec 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -6,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`. function basis(a::AbstractArray{T}, i) where {T} b = similar(a) fill!(b, zero(T)) - b[i] = one(T) + b[i] = oneunit(T) if ismutable_array(a) return b else @@ -23,7 +23,7 @@ function multibasis(a::AbstractArray{T}, inds) where {T} b = similar(a) fill!(b, zero(T)) for i in inds - b[i] = one(T) + b[i] = oneunit(T) end return ismutable_array(a) ? b : map(+, zero(a), b) end diff --git a/DifferentiationInterface/test/Core/Internals/basis.jl b/DifferentiationInterface/test/Core/Internals/basis.jl index 681d71bab..e79829990 100644 --- a/DifferentiationInterface/test/Core/Internals/basis.jl +++ b/DifferentiationInterface/test/Core/Internals/basis.jl @@ -2,6 +2,7 @@ using DifferentiationInterface: basis, multibasis using LinearAlgebra using StaticArrays, JLArrays using Test +using Dates @testset "Basis" begin b_ref = [0, 1, 0] @@ -22,4 +23,7 @@ using Test @test all(basis(jl(rand(3, 3)), 4) .== b_ref) @test basis(@SMatrix(rand(3, 3)), 4) isa SMatrix @test basis(@SMatrix(rand(3, 3)), 4) == b_ref + + t = [Time(1) - Time(0)] + @test basis(t, 1) isa Vector{Nanosecond} end diff --git a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl index 2134faabc..643ebde95 100644 --- a/DifferentiationInterfaceTest/src/scenarios/allocfree.jl +++ b/DifferentiationInterfaceTest/src/scenarios/allocfree.jl @@ -2,7 +2,7 @@ function identity_scenarios(x::Number; dx::Number, dy::Number) f = identity dy_from_dx = dx dx_from_dy = dy - der = one(x) + der = oneunit(x) return [ Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), @@ -16,7 +16,7 @@ function sum_scenarios(x::AbstractArray; dx::AbstractArray, dy::Number) dy_from_dx = sum(dx) dx_from_dy = (similar(x) .= dy) grad = similar(x) - grad .= one(eltype(x)) + grad .= oneunit(eltype(x)) return [ Scenario{:pushforward,:out}(f, x, (dx,); res1=(dy_from_dx,)), diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index 06e1747f2..04a167871 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -1,4 +1,3 @@ -mysimilar(x::Number) = one(x) mysimilar(x::AbstractArray) = similar(x) mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x)