Skip to content

feat: test type consistency between preparation and execution#745

Merged
gdalle merged 22 commits intomainfrom
gd/Strict
Mar 17, 2025
Merged

feat: test type consistency between preparation and execution#745
gdalle merged 22 commits intomainfrom
gd/Strict

Conversation

@gdalle
Copy link
Copy Markdown
Member

@gdalle gdalle commented Mar 16, 2025

Core

  • Add a keyword argument strict to every prepare_operator function, with default value strict=Val(false). This kwarg is then passed as a first argument to a hidden method of prepare_operator, in order to reduce the depth of the call stack and help the compiler infer types properly.
  • When strict=Val(true), store the signature SIG = typeof((f, backend, x, contexts)) inside the type of the preparation result. Otherwise store SIG = Nothing.
  • Add a check inside each differentiation operator that the signature SIG !== Nothing recorded during preparation is consistent with the signature during execution.
  • Augment every preparation type to keep track of SIG.
  • Adjust preparation modification and same-point preparation.

Tests and docs

  • Add tests for the correct behavior inside DifferentiationInterfaceTest.
  • Add tests for the errors inside DI's tests.
  • Update docstrings to point to the strict keyword.

Other modifications

  • Logic changed:
    • No more direct fallbacks between backends, intermediate structs necessary everywhere (e.g. no jumping from PolyesterForwardDiff to ForwardDiff)
    • No more direct fallback from batched pushforward/pullback to individual (e.g. in Mooncake)
  • Bugs fixed:
    • Incorrect preparation for pullback-based gradient fallback
    • Incorrect preparation in GTPSA

Warning

Setting strict=Val(true) is recommended from now on, but can break existing code which happens to work even though it uses preparation incorrectly. That is why the default value will remain strict=Val(false) until the next breaking release.


Lessons learned:

  • Add type parameters to the back, not the front

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 16, 2025

Codecov Report

Attention: Patch coverage is 38.69694% with 621 lines in your changes missing coverage. Please review.

Project coverage is 41.27%. Comparing base (bac2d02) to head (594d720).

Files with missing lines Patch % Lines
...tiationInterfaceTest/src/tests/correctness_eval.jl 6.66% 56 Missing ⚠️
...ntiationInterfacePolyesterForwardDiffExt/onearg.jl 0.00% 54 Missing ⚠️
...entiationInterfaceFastDifferentiationExt/onearg.jl 0.00% 47 Missing ⚠️
...ext/DifferentiationInterfaceSymbolicsExt/onearg.jl 0.00% 45 Missing ⚠️
...t/DifferentiationInterfaceReverseDiffExt/onearg.jl 36.53% 33 Missing ⚠️
...ifferentiationInterface/src/misc/from_primitive.jl 0.00% 26 Missing ⚠️
...xt/DifferentiationInterfaceFiniteDiffExt/onearg.jl 21.87% 25 Missing ⚠️
...ntiationInterfacePolyesterForwardDiffExt/twoarg.jl 0.00% 25 Missing ⚠️
...entiationInterfaceFastDifferentiationExt/twoarg.jl 0.00% 24 Missing ⚠️
...xt/DifferentiationInterfaceFiniteDiffExt/twoarg.jl 0.00% 23 Missing ⚠️
... and 26 more

❗ There is a different number of reports uploaded between BASE (bac2d02) and HEAD (594d720). Click for more details.

HEAD has 78 uploads less than BASE
Flag BASE (bac2d02) HEAD (594d720)
DI 74 9
DIT 16 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #745       +/-   ##
===========================================
- Coverage   97.19%   41.27%   -55.93%     
===========================================
  Files         128      118       -10     
  Lines        6678     7208      +530     
===========================================
- Hits         6491     2975     -3516     
- Misses        187     4233     +4046     
Flag Coverage Δ
DI 36.14% <40.71%> (-61.70%) ⬇️
DIT 54.00% <6.66%> (-41.74%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle gdalle changed the title feat!: enforce consistency between preparation result and signature feat: enforce consistency between preparation result and signature Mar 16, 2025
@gdalle gdalle changed the title feat: enforce consistency between preparation result and signature feat: test consistency between preparation result and signature Mar 16, 2025
@gdalle gdalle changed the title feat: test consistency between preparation result and signature feat: test type consistency between preparation and execution Mar 17, 2025
@gdalle gdalle marked this pull request as ready for review March 17, 2025 10:18
@gdalle
Copy link
Copy Markdown
Member Author

gdalle commented Mar 17, 2025

MWE for type instability, to run on Julia 1.10 (doesn't appear on 1.11) with the DI version from this branch:

using DifferentiationInterface, JET
f(x, c) = exp(x) * c
backend = DifferentiationInterface.AutoZeroReverse()
@report_opt prepare_second_derivative(f, backend, 1.0, Constant(1))
Details

JET report:

═════ 9 possible errors found ═════
┌ second_derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/fallbacks/no_prep.jl:101
│┌ kwcall(::@NamedTuple{}, ::typeof(prepare_second_derivative), f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/second_order/second_derivative.jl:60
││┌ prepare_second_derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/second_order/second_derivative.jl:68
│││┌ kwcall(::@NamedTuple{}, ::typeof(prepare_derivative), ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:66
││││┌ prepare_derivative(::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:70
│││││┌ kwcall(::@NamedTuple{}, ::typeof(prepare_pushforward), ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:93
││││││┌ prepare_pushforward(::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:101
│││││││┌ kwcall(::@NamedTuple{}, ::typeof(DifferentiationInterface._prepare_pushforward_aux), ::DifferentiationInterface.PushforwardSlow, ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:120
││││││││┌ _prepare_pushforward_aux(::DifferentiationInterface.PushforwardSlow, ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:130
│││││││││┌ shuffled_derivative(x::Float64, f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, rewrap::DifferentiationInterface.Rewrap{…}, unannotated_contexts::Int64) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:210
││││││││││┌ derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{Int64}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/fallbacks/no_prep.jl:48
│││││││││││┌ kwcall(::@NamedTuple{}, ::typeof(prepare_derivative), f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:66
││││││││││││┌ prepare_derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:70
│││││││││││││┌ kwcall(::NamedTuple, ::typeof(prepare_pushforward), f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, tx::Tuple{…}, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:93
││││││││││││││┌ pairs(nt::NamedTuple) @ Base.Iterators ./iterators.jl:279
│││││││││││││││┌ (Base.Pairs{Symbol})(data::NamedTuple, itr::Tuple{Vararg{Symbol}}) @ Base ./essentials.jl:343
││││││││││││││││┌ eltype(::Type{A} where A<:NamedTuple) @ Base ./namedtuple.jl:233
│││││││││││││││││┌ nteltype(::Type{NamedTuple{names, T}} where names) where T<:Tuple @ Base ./namedtuple.jl:235
││││││││││││││││││┌ eltype(t::Type{<:Tuple{Vararg{E}}}) where E @ Base ./tuple.jl:208
│││││││││││││││││││┌ _compute_eltype(t::Type{<:Tuple{Vararg{E}}} where E) @ Base ./tuple.jl:231
││││││││││││││││││││┌ afoldl(op::Base.var"#54#55", a::Any, bs::Vararg{Any}) @ Base ./operators.jl:544
│││││││││││││││││││││┌ (::Base.var"#54#55")(a::Any, b::Any) @ Base ./tuple.jl:235
││││││││││││││││││││││┌ promote_typejoin(a::Any, b::Any) @ Base ./promotion.jl:172
│││││││││││││││││││││││┌ typejoin(a::Any, b::Any) @ Base ./promotion.jl:127
││││││││││││││││││││││││ runtime dispatch detected: Base.UnionAll(%403::Any, %405::Any)::Any
│││││││││││││││││││││││└────────────────────
││││││││││││││││││││┌ afoldl(op::Base.var"#54#55", a::Any, bs::Vararg{Any}) @ Base ./operators.jl:545
│││││││││││││││││││││┌ (::Base.var"#54#55")(a::Type, b::Any) @ Base ./tuple.jl:235
││││││││││││││││││││││┌ promote_typejoin(a::Type, b::Any) @ Base ./promotion.jl:172
│││││││││││││││││││││││┌ typejoin(a::Type, b::Any) @ Base ./promotion.jl:127
││││││││││││││││││││││││ runtime dispatch detected: Base.UnionAll(%398::Any, %400::Any)::Any
│││││││││││││││││││││││└────────────────────
││││││││││││┌ prepare_derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:66
│││││││││││││ failed to optimize due to recursion: DifferentiationInterface.var"#prepare_derivative#42"(::Val{…}, ::typeof(prepare_derivative), ::typeof(f), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Constant{…})
││││││││││││└────────────────────
│││││││││││┌ kwcall(::@NamedTuple{}, ::typeof(prepare_derivative), f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:66
││││││││││││ failed to optimize due to recursion: Core.kwcall(::@NamedTuple{}, ::typeof(prepare_derivative), ::typeof(f), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Constant{…})
│││││││││││└────────────────────
││││││││││┌ derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{Int64}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/fallbacks/no_prep.jl:49
│││││││││││┌ derivative(f::typeof(f), prep::DifferentiationInterface.PushforwardDerivativePrep{…} where E<:DifferentiationInterface.PullbackPushforwardPrep, backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:128
││││││││││││┌ pushforward(f::typeof(f), prep::DifferentiationInterface.PullbackPushforwardPrep, backend::DifferentiationInterface.AutoZeroReverse, x::Float64, tx::Tuple{…}, contexts::Constant{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:258
│││││││││││││┌ ntuple(f::DifferentiationInterface.var"#34#35"{}, ::Val{…}) @ Base ./ntuple.jl:48
││││││││││││││ runtime dispatch detected: f::DifferentiationInterface.var"#34#35"{typeof(f), DifferentiationInterface.AutoZeroReverse, Float64, Tuple{}, Tuple{}}(1)::Float64
│││││││││││││└────────────────────
│││││││││││││┌ ntuple(f::DifferentiationInterface.var"#18#19"{}, ::Val{…}) @ Base ./ntuple.jl:48
││││││││││││││ runtime dispatch detected: f::DifferentiationInterface.var"#18#19"{}(1)::Float64
│││││││││││││└────────────────────
││││││││││┌ derivative(f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, x::Float64, contexts::Constant{Int64}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/fallbacks/no_prep.jl:45
│││││││││││ failed to optimize due to recursion: DifferentiationInterface.derivative(::typeof(f), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Constant{…})
││││││││││└────────────────────
│││││││││┌ shuffled_derivative(x::Float64, f::typeof(f), backend::DifferentiationInterface.AutoZeroReverse, rewrap::DifferentiationInterface.Rewrap{…}, unannotated_contexts::Int64) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/derivative.jl:207
││││││││││ failed to optimize due to recursion: DifferentiationInterface.shuffled_derivative(::Float64, ::typeof(f), ::DifferentiationInterface.AutoZeroReverse, ::DifferentiationInterface.Rewrap{…}, ::Int64)
│││││││││└────────────────────
││││││││┌ _prepare_pushforward_aux(::DifferentiationInterface.PushforwardSlow, ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…}; strict::Val{…}) @ DifferentiationInterface /Users/guillaumedalle/Documents/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/src/first_order/pushforward.jl:120
│││││││││ failed to optimize due to recursion: DifferentiationInterface.var"#_prepare_pushforward_aux#12"(::Val{…}, ::typeof(DifferentiationInterface._prepare_pushforward_aux), ::DifferentiationInterface.PushforwardSlow, ::typeof(DifferentiationInterface.shuffled_derivative), ::DifferentiationInterface.AutoZeroReverse, ::Float64, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::Constant{…}, ::Constant{…})

@gdalle gdalle marked this pull request as draft March 17, 2025 14:17
@gdalle gdalle marked this pull request as ready for review March 17, 2025 14:50
@gdalle gdalle merged commit 435dea6 into main Mar 17, 2025
47 of 48 checks passed
@franckgaga
Copy link
Copy Markdown
Contributor

franckgaga commented Mar 18, 2025

@gdalle, in your opinion, should i use strict=Val(true) in MPC.jl ?

It would be non-breaking for me since it would be the first version that uses DI.jl, when I will release it.

@gdalle
Copy link
Copy Markdown
Member Author

gdalle commented Mar 18, 2025

Definitely, go for it. I just made it optional because some criminals over in SciML misuse my API, but technically it wouldn't even have been breaking to make it the default right away

@gdalle gdalle deleted the gd/Strict branch May 12, 2025 07:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants