From d53d943e15494550145045e909fc9cda5937ab62 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 9 May 2025 19:46:38 +0200 Subject: [PATCH 1/6] feat: error hints for Enzyme --- DifferentiationInterface/CHANGELOG.md | 11 ++++- DifferentiationInterface/Project.toml | 4 +- .../DifferentiationInterfaceEnzymeExt.jl | 3 ++ .../DifferentiationInterfaceEnzymeExt/init.jl | 44 +++++++++++++++++++ 4 files changed, 58 insertions(+), 4 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 7682ef9b4..9ec766c29 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,13 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -### Changed +## [0.6.54] - 2025-05-07 -- Allocate Enzyme shadow memory during preparation ([#782]) +### Added + +- Error hints for Enzyme ([#788]) ## [0.6.53] - 2025-05-07 +### Changed + +- Allocate Enzyme shadow memory during preparation ([#782]) + [unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...main [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 +[#788]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/788 [#782]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/782 \ No newline at end of file diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 852718e42..ab28e08c0 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.53" +version = "0.6.54" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -56,7 +56,7 @@ ADTypes = "1.13.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" -Enzyme = "0.13.17" +Enzyme = "0.13.39" EnzymeCore = "0.8.8" ExplicitImports = "1.10.1" FastDifferentiation = "0.4.3" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 328bffaf3..dbf1e4f5c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -30,6 +30,7 @@ using EnzymeCore: Split, WithPrimal using Enzyme: + Enzyme, autodiff, autodiff_thunk, create_shadows, @@ -53,4 +54,6 @@ include("forward_twoarg.jl") include("reverse_onearg.jl") include("reverse_twoarg.jl") +include("init.jl") + end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl new file mode 100644 index 000000000..ebf59c14a --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl @@ -0,0 +1,44 @@ +function __init__() + # robust against internal changes + condition = ( + isdefined(Enzyme, :Compiler) && + Enzyme.Compiler isa Module && + isdefined(Enzyme.Compiler, :EnzymeError) && + Enzyme.Compiler.EnzymeError isa DataType + ) + condition || return nothing + # see https://github.com/JuliaLang/julia/issues/58367 for why this isn't easier + for n in names(Enzyme.Compiler; all=true) + T = getfield(Enzyme.Compiler, n) + if T isa DataType && T <: Enzyme.Compiler.EnzymeError + # robust against internal changes + Base.Experimental.register_error_hint(T) do io, exc + if occursin("EnzymeMutabilityException", string(nameof(T))) + printstyled( + io, + "\nIf you are using Enzyme through DifferentiationInterface, you may want to try modifying the ADTypes backend object as follows:"; + bold=true, + ) + printstyled( + io, + "\n\n\tAutoEnzyme(; function_annotation=Enzyme.Duplicated)\n\n"; + color=:cyan, + bold=true, + ) + elseif occursin("EnzymeRuntimeActivityError", string(nameof(T))) + printstyled( + io, + "\nIf you are using Enzyme through DifferentiationInterface, you may want to try modifying the ADTypes backend object as follows:"; + bold=true, + ) + printstyled( + io, + "\n\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))\n\n"; + color=:cyan, + bold=true, + ) + end + end + end + end +end From 52a91fe9188a4e1853463555560133309d50a853 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 07:10:25 +0200 Subject: [PATCH 2/6] Clearer message --- .../DifferentiationInterfaceEnzymeExt/init.jl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl index ebf59c14a..7681e49c0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/init.jl @@ -1,3 +1,9 @@ +const HINT_END = "\n\nThis hint appears because DifferentiationInterface and Enzyme are both loaded. It does not necessarily imply that Enzyme is being called through DifferentiationInterface.\n\n" + +function HINT_START(option) + return "\nIf you are using Enzyme by selecting the `AutoEnzyme` object from ADTypes, you may want to try setting the `$option` option as follows:" +end + function __init__() # robust against internal changes condition = ( @@ -14,29 +20,23 @@ function __init__() # robust against internal changes Base.Experimental.register_error_hint(T) do io, exc if occursin("EnzymeMutabilityException", string(nameof(T))) + printstyled(io, HINT_START("function_annotation"); bold=true) printstyled( io, - "\nIf you are using Enzyme through DifferentiationInterface, you may want to try modifying the ADTypes backend object as follows:"; - bold=true, - ) - printstyled( - io, - "\n\n\tAutoEnzyme(; function_annotation=Enzyme.Duplicated)\n\n"; + "\n\n\tAutoEnzyme(; function_annotation=Enzyme.Duplicated)"; color=:cyan, bold=true, ) + printstyled(io, HINT_END; italic=true) elseif occursin("EnzymeRuntimeActivityError", string(nameof(T))) + printstyled(io, HINT_START("mode"); bold=true) printstyled( io, - "\nIf you are using Enzyme through DifferentiationInterface, you may want to try modifying the ADTypes backend object as follows:"; - bold=true, - ) - printstyled( - io, - "\n\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))\n\n"; + "\n\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))\n\tAutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))"; color=:cyan, bold=true, ) + printstyled(io, HINT_END; italic=true) end end end From 9ceeb97a97caafee41a3a0bb6fbfbe2c91276013 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 07:30:48 +0200 Subject: [PATCH 3/6] Add tests --- .../test/Back/Enzyme/hints.jl | 56 +++++++++++++++++++ .../test/Core/Internals/hints.jl | 15 +++++ 2 files changed, 71 insertions(+) create mode 100644 DifferentiationInterface/test/Back/Enzyme/hints.jl create mode 100644 DifferentiationInterface/test/Core/Internals/hints.jl diff --git a/DifferentiationInterface/test/Back/Enzyme/hints.jl b/DifferentiationInterface/test/Back/Enzyme/hints.jl new file mode 100644 index 000000000..a6d5a3742 --- /dev/null +++ b/DifferentiationInterface/test/Back/Enzyme/hints.jl @@ -0,0 +1,56 @@ +using DifferentiationInterface +using Enzyme: Enzyme +using Test + +@testset "MutabilityError" begin + f = let + cache = [0.0] + x -> sum(copyto!(cache, x)) + end + + msg = try + gradient(f, AutoEnzyme(), [1.0]) + catch e + buf = IOBuffer() + showerror(buf, e) + String(take!(buf)) + end + @test occursin("AutoEnzyme", msg) + @test occursin("function_annotation", msg) + @test occursin("ADTypes", msg) + @test occursin("DifferentiationInterface", msg) +end + +@testset "RuntimeActivityError" begin + function g(active_var, constant_var, cond) + if cond + return active_var + else + return constant_var + end + end + + function h(active_var, constant_var, cond) + return [g(active_var, constant_var, cond), g(active_var, constant_var, cond)] + end + + msg = try + pushforward( + h, + AutoEnzyme(; mode=Enzyme.Forward), + [1.0], + ([1.0],), + Constant([1.0]), + Constant(true), + ) + catch e + buf = IOBuffer() + showerror(buf, e) + String(take!(buf)) + end + @test occursin("AutoEnzyme", msg) + @test occursin("mode", msg) + @test occursin("set_runtime_activity", msg) + @test occursin("ADTypes", msg) + @test occursin("DifferentiationInterface", msg) +end diff --git a/DifferentiationInterface/test/Core/Internals/hints.jl b/DifferentiationInterface/test/Core/Internals/hints.jl new file mode 100644 index 000000000..754782730 --- /dev/null +++ b/DifferentiationInterface/test/Core/Internals/hints.jl @@ -0,0 +1,15 @@ +using ADTypes +using DifferentiationInterface +import DifferentiationInterface as DI +using Test + +@testset "Missing backend" begin + msg = try + gradient(sum, AutoZygote(), [1.0]) + catch e + buf = IOBuffer() + showerror(buf, e) + String(take!(buf)) + end + @test occursin("import Zygote", msg) +end From 8f226610e1e21943f7335dc6696351d8a8218b51 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 07:34:00 +0200 Subject: [PATCH 4/6] Simpler show --- DifferentiationInterface/test/Back/Enzyme/hints.jl | 14 ++++++-------- .../test/Core/Internals/hints.jl | 7 +++---- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/test/Back/Enzyme/hints.jl b/DifferentiationInterface/test/Back/Enzyme/hints.jl index a6d5a3742..5a11c6b07 100644 --- a/DifferentiationInterface/test/Back/Enzyme/hints.jl +++ b/DifferentiationInterface/test/Back/Enzyme/hints.jl @@ -8,13 +8,12 @@ using Test x -> sum(copyto!(cache, x)) end - msg = try + e = nothing + try gradient(f, AutoEnzyme(), [1.0]) catch e - buf = IOBuffer() - showerror(buf, e) - String(take!(buf)) end + msg = sprint(showerror, e) @test occursin("AutoEnzyme", msg) @test occursin("function_annotation", msg) @test occursin("ADTypes", msg) @@ -34,7 +33,8 @@ end return [g(active_var, constant_var, cond), g(active_var, constant_var, cond)] end - msg = try + e = nothing + try pushforward( h, AutoEnzyme(; mode=Enzyme.Forward), @@ -44,10 +44,8 @@ end Constant(true), ) catch e - buf = IOBuffer() - showerror(buf, e) - String(take!(buf)) end + msg = sprint(showerror, e) @test occursin("AutoEnzyme", msg) @test occursin("mode", msg) @test occursin("set_runtime_activity", msg) diff --git a/DifferentiationInterface/test/Core/Internals/hints.jl b/DifferentiationInterface/test/Core/Internals/hints.jl index 754782730..83f43edaa 100644 --- a/DifferentiationInterface/test/Core/Internals/hints.jl +++ b/DifferentiationInterface/test/Core/Internals/hints.jl @@ -4,12 +4,11 @@ import DifferentiationInterface as DI using Test @testset "Missing backend" begin - msg = try + e = nothing + try gradient(sum, AutoZygote(), [1.0]) catch e - buf = IOBuffer() - showerror(buf, e) - String(take!(buf)) end + msg = sprint(showerror, e) @test occursin("import Zygote", msg) end From d10ab87b85a0ec789efb23e902cb865fe33b37aa Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 07:35:04 +0200 Subject: [PATCH 5/6] Unreleased --- DifferentiationInterface/CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 9ec766c29..248a1ff8a 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -7,8 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -## [0.6.54] - 2025-05-07 - ### Added - Error hints for Enzyme ([#788]) From 20e8be795101b6f036725c2f4163fe4d70f575af Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 10 May 2025 08:17:50 +0200 Subject: [PATCH 6/6] Move --- .../test/Back/Enzyme/hints.jl | 54 ------------------- .../test/Back/Enzyme/test.jl | 53 ++++++++++++++++++ 2 files changed, 53 insertions(+), 54 deletions(-) delete mode 100644 DifferentiationInterface/test/Back/Enzyme/hints.jl diff --git a/DifferentiationInterface/test/Back/Enzyme/hints.jl b/DifferentiationInterface/test/Back/Enzyme/hints.jl deleted file mode 100644 index 5a11c6b07..000000000 --- a/DifferentiationInterface/test/Back/Enzyme/hints.jl +++ /dev/null @@ -1,54 +0,0 @@ -using DifferentiationInterface -using Enzyme: Enzyme -using Test - -@testset "MutabilityError" begin - f = let - cache = [0.0] - x -> sum(copyto!(cache, x)) - end - - e = nothing - try - gradient(f, AutoEnzyme(), [1.0]) - catch e - end - msg = sprint(showerror, e) - @test occursin("AutoEnzyme", msg) - @test occursin("function_annotation", msg) - @test occursin("ADTypes", msg) - @test occursin("DifferentiationInterface", msg) -end - -@testset "RuntimeActivityError" begin - function g(active_var, constant_var, cond) - if cond - return active_var - else - return constant_var - end - end - - function h(active_var, constant_var, cond) - return [g(active_var, constant_var, cond), g(active_var, constant_var, cond)] - end - - e = nothing - try - pushforward( - h, - AutoEnzyme(; mode=Enzyme.Forward), - [1.0], - ([1.0],), - Constant([1.0]), - Constant(true), - ) - catch e - end - msg = sprint(showerror, e) - @test occursin("AutoEnzyme", msg) - @test occursin("mode", msg) - @test occursin("set_runtime_activity", msg) - @test occursin("ADTypes", msg) - @test occursin("DifferentiationInterface", msg) -end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index a772e48ba..7e1b02451 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -148,3 +148,56 @@ end f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing) ) end + +@testset "Hints" begin + @testset "MutabilityError" begin + f = let + cache = [0.0] + x -> sum(copyto!(cache, x)) + end + + e = nothing + try + gradient(f, AutoEnzyme(), [1.0]) + catch e + end + msg = sprint(showerror, e) + @test occursin("AutoEnzyme", msg) + @test occursin("function_annotation", msg) + @test occursin("ADTypes", msg) + @test occursin("DifferentiationInterface", msg) + end + + @testset "RuntimeActivityError" begin + function g(active_var, constant_var, cond) + if cond + return active_var + else + return constant_var + end + end + + function h(active_var, constant_var, cond) + return [g(active_var, constant_var, cond), g(active_var, constant_var, cond)] + end + + e = nothing + try + pushforward( + h, + AutoEnzyme(; mode=Enzyme.Forward), + [1.0], + ([1.0],), + Constant([1.0]), + Constant(true), + ) + catch e + end + msg = sprint(showerror, e) + @test occursin("AutoEnzyme", msg) + @test occursin("mode", msg) + @test occursin("set_runtime_activity", msg) + @test occursin("ADTypes", msg) + @test occursin("DifferentiationInterface", msg) + end +end