diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 54fdfeeb7..89223f597 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -51,8 +51,6 @@ jobs: - Back/Symbolics - Back/Tracker - Back/Zygote - - Down/Flux - - Down/Lux skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -79,14 +77,22 @@ jobs: arch: x64 - uses: julia-actions/cache@v2 - name: Install dependencies & run tests - # how to add the local DIT to the DI test env? - run: julia --project=./DifferentiationInterface --color=yes -e ' + run: julia --color=yes -e ' using Pkg; Pkg.Registry.update(); + Pkg.activate("./DifferentiationInterface/test"); + if VERSION < v"1.11"; + Pkg.rm("DifferentiationInterfaceTest"); + Pkg.resolve(); + else; + Pkg.develop(; path="./DifferentiationInterfaceTest"); + end; + Pkg.activate("./DifferentiationInterface"); + test_kwargs = (; allow_reresolve=false, coverage=true); if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterface"; coverage=true, julia_args=["-O1"]); + Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); else; - Pkg.test("DifferentiationInterface"; coverage=true); + Pkg.test("DifferentiationInterface"; test_kwargs...); end;' - uses: julia-actions/julia-processcoverage@v1 with: diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl deleted file mode 100644 index da9a76b3a..000000000 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ /dev/null @@ -1,26 +0,0 @@ -using Pkg -Pkg.add(["FiniteDifferences", "Enzyme", "Flux", "Zygote"]) - -using DifferentiationInterface, DifferentiationInterfaceTest -import DifferentiationInterfaceTest as DIT -using Enzyme: Enzyme -using FiniteDifferences: FiniteDifferences -using Flux: Flux -using Random -using Zygote: Zygote -using Test - -LOGGING = get(ENV, "CI", "false") == "false" - -test_differentiation( - [ - AutoZygote(), - # AutoEnzyme(), # TODO a few scenarios fail - ], - DIT.flux_scenarios(Random.MersenneTwister(0)); - isapprox = DIT.flux_isapprox, - rtol = 1.0e-2, - atol = 1.0e-4, - scenario_intact = false, # TODO: why? - logging = LOGGING, -) diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl deleted file mode 100644 index b1d0652ff..000000000 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ /dev/null @@ -1,22 +0,0 @@ -using Pkg -Pkg.add(["ForwardDiff", "Lux", "LuxTestUtils", "Zygote"]) - -using ComponentArrays: ComponentArrays -using DifferentiationInterface, DifferentiationInterfaceTest -import DifferentiationInterfaceTest as DIT -using ForwardDiff: ForwardDiff -using Lux: Lux -using LuxTestUtils: LuxTestUtils -using Random - -LOGGING = get(ENV, "CI", "false") == "false" - -test_differentiation( - AutoZygote(), - DIT.lux_scenarios(Random.Xoshiro(63)); - isapprox = DIT.lux_isapprox, - rtol = 1.0f-2, - atol = 1.0f-3, - scenario_intact = false, # TODO: why? - logging = LOGGING, -) diff --git a/DifferentiationInterface/test/Project.toml b/DifferentiationInterface/test/Project.toml index eee6a767e..350fcbb58 100644 --- a/DifferentiationInterface/test/Project.toml +++ b/DifferentiationInterface/test/Project.toml @@ -5,6 +5,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" @@ -26,6 +27,7 @@ ComponentArrays = "0.15.27" DataFrames = "1.7.0" Dates = "1" DifferentiationInterface = "0.7.10" +DifferentiationInterfaceTest = "0.10.3" ExplicitImports = "1.10.1" InteractiveUtils = "1" JET = "0.9,0.10,0.11" @@ -40,4 +42,4 @@ Test = "1" julia = "1.10.10" [sources] -DifferentiationInterface = { path = ".." } +DifferentiationInterface = { path = ".." } \ No newline at end of file diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index df7884a27..66159c3d9 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -2,11 +2,13 @@ using DifferentiationInterface using Pkg using Test -DIT_PATH = joinpath(@__DIR__, "..", "..", "DifferentiationInterfaceTest") -if isdir(DIT_PATH) - Pkg.develop(; path = DIT_PATH) -else - Pkg.add("DifferentiationInterfaceTest") +@static if VERSION < v"1.11" + DIT_PATH = joinpath(@__DIR__, "..", "..", "DifferentiationInterfaceTest") + if isdir(DIT_PATH) + Pkg.develop(; path = DIT_PATH) + else + Pkg.add("DifferentiationInterfaceTest") + end end include("testutils.jl") diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index f73e7d1ba..bca102eda 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -21,26 +21,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" -DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux", "Functors"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" -DifferentiationInterfaceTestLuxExt = [ - "ComponentArrays", - "ForwardDiff", - "Lux", - "LuxTestUtils", -] DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" [compat] @@ -51,15 +39,10 @@ ComponentArrays = "0.15" DataFrames = "1.6.1" DifferentiationInterface = "0.7.7" DocStringExtensions = "0.8,0.9" -FiniteDifferences = "0.12" -Flux = "0.16" ForwardDiff = "0.10.36,1" -Functors = "0.4, 0.5" JET = "0.9,0.10,0.11" JLArrays = "0.1,0.2,0.3" LinearAlgebra = "1" -Lux = "1.1.0" -LuxTestUtils = "1.3.1, 2" PrecompileTools = "1.2.1" ProgressMeter = "1" Random = "1" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl deleted file mode 100644 index 3296995ec..000000000 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ /dev/null @@ -1,214 +0,0 @@ -module DifferentiationInterfaceTestFluxExt - -import DifferentiationInterface as DI -import DifferentiationInterfaceTest as DIT -using FiniteDifferences: FiniteDifferences -using Flux: - Flux, - Bilinear, - Chain, - Conv, - ConvTranspose, - Dense, - GRU, - GRUCell, - LSTM, - LSTMCell, - Maxout, - MeanPool, - RNN, - RNNCell, - SamePad, - Scale, - SkipConnection, - destructure, - f64, - glorot_uniform, - relu -using Functors: @functor, fmapstructure_with_path, fleaves -using LinearAlgebra -using Statistics: mean -using Random: AbstractRNG, default_rng - -#= -Relevant discussions: - -- https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/105 -- https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/343 -- https://github.com/FluxML/Flux.jl/issues/2469 -=# - -function gradient_finite_differences(loss, model, x) - v, re = destructure(model) - fdm = FiniteDifferences.central_fdm(5, 1) - gs = FiniteDifferences.grad(fdm, model -> loss(re(model), x), f64(v)) - return re(only(gs)) -end - -function DIT.flux_isapprox(a, b; atol, rtol) - isapprox_results = fmapstructure_with_path(a, b) do kp, x, y - if x isa AbstractArray{<:Number} - return isapprox(x, y; atol, rtol) - else # ignore non-arrays - return true - end - end - return all(fleaves(isapprox_results)) -end - -square_loss(model, x) = mean(abs2, model(x)) - -function square_loss_iterated(cell, x) - y, st = cell(x) # uses default initial state - for _ in 1:2 - y, st = cell(x, st) - end - return mean(abs2, y) -end - -struct SimpleDense{W, B, F} - w::W - b::B - σ::F -end - -(m::SimpleDense)(x) = m.σ.(m.w * x .+ m.b) - -@functor SimpleDense - -function DIT.flux_scenarios(rng::AbstractRNG = default_rng()) - init = glorot_uniform(rng) - - scens = DIT.Scenario[] - - # Simple dense - - d_in, d_out = 4, 2 - w = randn(rng, d_out, d_in) - b = randn(rng, d_out) - model = SimpleDense(w, b, Flux.σ) - - x = randn(rng, d_in) - g = gradient_finite_differences(square_loss, model, x) - - scen = DIT.Scenario{:gradient, :out}( - square_loss, - model, - DI.Constant(x); - prep_args = (x = model, contexts = (DI.Constant(x),)), - res1 = g, - ) - push!(scens, scen) - - # Layers - - models_and_xs = [ - #! format: off - ( - Dense(2, 4; init), - randn(rng, Float32, 2) - ), - ( - Chain(Dense(2, 4, relu; init), Dense(4, 3; init)), - randn(rng, Float32, 2)), - ( - f64(Chain(Dense(2, 4; init), Dense(4, 2; init))), - randn(rng, Float64, 2, 1)), - ( - Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), - randn(rng, Float32, 2)), - ( - Conv((3, 3), 2 => 3; init), - randn(rng, Float32, 3, 3, 2, 1)), - ( - Chain(Conv((3, 3), 2 => 3, relu; init), Conv((3, 3), 3 => 1, relu; init)), - rand(rng, Float32, 5, 5, 2, 1), - ), - ( - Chain(Conv((4, 4), 2 => 2; pad=SamePad(), init), MeanPool((5, 5); pad=SamePad())), - rand(rng, Float32, 5, 5, 2, 2), - ), - ( - Maxout(() -> Dense(5 => 4, tanh; init), 3), - randn(rng, Float32, 5, 1) - ), - ( - SkipConnection(Dense(2 => 2; init), vcat), - randn(rng, Float32, 2, 3) - ), - ( - Bilinear((2, 2) => 3; init), - randn(rng, Float32, 2, 1) - ), - ( - ConvTranspose((3, 3), 3 => 2; stride=2, init), - rand(rng, Float32, 5, 5, 3, 1) - ), - ( - RNN(3 => 4; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2, 1) - ), - ( - LSTM(3 => 4; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2, 1) - ), - ( - GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2, 1) - ), - ( - Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), - randn(rng, Float32, 3, 2, 1) - ), - #! format: on - ] - - for (model, x) in models_and_xs - Flux.trainmode!(model) - g = gradient_finite_differences(square_loss, model, x) - scen = DIT.Scenario{:gradient, :out}( - square_loss, - model, - DI.Constant(x); - prep_args = (; x = model, contexts = (DI.Constant(x),)), - res1 = g, - ) - push!(scens, scen) - end - - # Recurrent Cells - - recurrent_models_and_xs = [ - #! format: off - ( - RNNCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2) - ), - ( - LSTMCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2) - ), - ( - GRUCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), - randn(rng, Float32, 3, 2) - ), - #! format: on - ] - - for (model, x) in recurrent_models_and_xs - Flux.trainmode!(model) - g = gradient_finite_differences(square_loss_iterated, model, x) - scen = DIT.Scenario{:gradient, :out}( - square_loss_iterated, - model, - DI.Constant(x); - prep_args = (; x = model, contexts = (DI.Constant(x),)), - res1 = g, - ) - push!(scens, scen) - end - - return scens -end - -end diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl deleted file mode 100644 index 4dddb944d..000000000 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ /dev/null @@ -1,218 +0,0 @@ -module DifferentiationInterfaceTestLuxExt - -using ComponentArrays: ComponentArray -import DifferentiationInterface as DI -import DifferentiationInterfaceTest as DIT -using ForwardDiff: ForwardDiff -using Lux: - Lux, - BatchNorm, - Bilinear, - Chain, - Conv, - ConvTranspose, - Dense, - GroupNorm, - GRUCell, - InstanceNorm, - LayerNorm, - LSTMCell, - Maxout, - MaxPool, - MeanPool, - RNNCell, - SamePad, - Scale, - SkipConnection, - StatefulRecurrentCell, - gelu, - relu -using LuxTestUtils: check_approx -using Random: AbstractRNG, default_rng - -#= -Relevant discussions: - -- https://github.com/LuxDL/Lux.jl/issues/769 -=# - -function DIT.lux_isapprox(a, b; atol, rtol) - return check_approx(a, b; atol, rtol) -end - -function square_loss(ps, model, x, st) - return sum(abs2, first(model(x, ps, st))) -end - -function DIT.lux_scenarios(rng::AbstractRNG = default_rng()) - models_and_xs = [ - #! format: off - ( - Dense(2, 4), - randn(rng, Float32, 2, 3) - ), - ( - Dense(2, 4, gelu), - randn(rng, Float32, 2, 3) - ), - ( - Dense(2, 4, gelu; use_bias=false), - randn(rng, Float32, 2, 3) - ), - ( - Chain(Dense(2, 4, relu), Dense(4, 3)), - randn(rng, Float32, 2, 3) - ), - ( - Scale(2), - randn(rng, Float32, 2, 3) - ), - ( - Conv((3, 3), 2 => 3), - randn(rng, Float32, 3, 3, 2, 2) - ), - ( - Conv((3, 3), 2 => 3, gelu; pad=SamePad()), - randn(rng, Float32, 3, 3, 2, 2) - ), - ( - Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), - randn(rng, Float32, 3, 3, 2, 2), - ), - ( - Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), - rand(rng, Float32, 5, 5, 2, 2), - ), - ( - Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), - rand(rng, Float32, 5, 5, 2, 2), - ), - ( - Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), - rand(rng, Float32, 5, 5, 2, 2), - ), - ( - Maxout(() -> Dense(5 => 4, tanh), 3), - randn(rng, Float32, 5, 2) - ), - ( - Bilinear((2, 2) => 3), - randn(rng, Float32, 2, 3) - ), - ( - SkipConnection(Dense(2 => 2), vcat), - randn(rng, Float32, 2, 3) - ), - ( - ConvTranspose((3, 3), 3 => 2; stride=2), - rand(rng, Float32, 5, 5, 3, 1) - ), - ( - StatefulRecurrentCell(RNNCell(3 => 5)), - rand(rng, Float32, 3, 2) - ), - ( - StatefulRecurrentCell(RNNCell(3 => 5, gelu)), - rand(rng, Float32, 3, 2) - ), - ( - StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), - rand(rng, Float32, 3, 2), - ), - ( - Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3)),), - rand(rng, Float32, 3, 2), - ), - ( - StatefulRecurrentCell(LSTMCell(3 => 5)), - rand(rng, Float32, 3, 2) - ), - ( - Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3)),), - rand(rng, Float32, 3, 2), - ), - ( - StatefulRecurrentCell(GRUCell(3 => 5)), - rand(rng, Float32, 3, 10) - ), - ( - Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3)),), - rand(rng, Float32, 3, 10), - ), - ( - Chain(Dense(2, 4), BatchNorm(4)), - randn(rng, Float32, 2, 3) - ), - ( - Chain(Dense(2, 4), BatchNorm(4, gelu)), - randn(rng, Float32, 2, 3) - ), - ( - Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), - randn(rng, Float32, 2, 3), - ), - ( - Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), - randn(rng, Float32, 6, 6, 2, 2) - ), - ( - Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), - randn(rng, Float32, 6, 6, 2, 2) - ), - ( - Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), - randn(rng, Float32, 2, 3) - ), - ( - Chain(Dense(2, 4), GroupNorm(4, 2)), - randn(rng, Float32, 2, 3) - ), - ( - Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), - randn(rng, Float32, 6, 6, 2, 2) - ), - ( - Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), - randn(rng, Float32, 6, 6, 2, 2), - ), - ( - Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), - randn(rng, Float32, 4, 4, 2, 2), - ), - ( - Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), - randn(rng, Float32, 6, 6, 2, 2) - ), - ( - Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), - randn(rng, Float32, 6, 6, 2, 2), - ), - #! format: on - ] - - scens = DIT.Scenario[] - - for (model, x) in models_and_xs - ps, st = Lux.setup(rng, model) - g = DI.gradient( - ps -> square_loss(ps, model, x, st), DI.AutoForwardDiff(), ComponentArray(ps) - ) - scen = DIT.Scenario{:gradient, :out}( - square_loss, - ComponentArray(ps), - DI.Constant(model), - DI.Constant(x), - DI.Constant(st); - prep_args = ( - x = ComponentArray(ps), - contexts = (DI.Constant(model), DI.Constant(x), DI.Constant(st)), - ), - res1 = g, - ) - push!(scens, scen) - end - - return scens -end - -end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 1a121ce36..f56831690 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -1,18 +1,12 @@ -using Pkg -Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils"]) - using ADTypes using ComponentArrays: ComponentArrays using DifferentiationInterface using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT -using FiniteDifferences: FiniteDifferences +using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff -using Flux: Flux using ForwardDiff: ForwardDiff using JLArrays: JLArrays -using Lux: Lux -using LuxTestUtils: LuxTestUtils using Random using SparseConnectivityTracer using SparseMatrixColorings @@ -64,25 +58,3 @@ test_differentiation( excluded = SECOND_ORDER, logging = LOGGING, ); - -## Neural nets - -test_differentiation( - AutoZygote(), - DIT.flux_scenarios(Random.MersenneTwister(0)); - isapprox = DIT.flux_isapprox, - rtol = 1.0e-2, - atol = 1.0e-4, - scenario_intact = false, - logging = LOGGING, -) - -test_differentiation( - AutoZygote(), - DIT.lux_scenarios(Random.Xoshiro(63)); - isapprox = DIT.lux_isapprox, - rtol = 1.0f-2, - atol = 1.0f-3, - scenario_intact = false, - logging = LOGGING, -)