Skip to content

Commit d5d5b31

Browse files
gdalleCopilot
andauthored
Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b19bcde commit d5d5b31

3 files changed

Lines changed: 34 additions & 8 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
const NumberOrArray = Union{Number, AbstractArray{<:Number}}
2-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any}
3-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any}
4-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any}
5-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any}
6-
# TODO: generate more cases programmatically
72

3+
# Mark DifferentiateWith with a range of context arities as primitives.
4+
# For C contexts, the corresponding call tuple type is
5+
# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}:
6+
# one slot for the primal input x and C slots for contexts.
7+
for C in 0:16
8+
@eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}}
9+
end
810
struct MooncakeDifferentiateWithError <: Exception
911
F::Type
1012
X::Type
@@ -37,7 +39,7 @@ function Mooncake.rrule!!(
3739
# output is a vector, so we need to use the vector pullback
3840
function pullback_array!!(dy::NoRData)
3941
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
40-
@assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x)))
42+
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
4143
rc = nanify_fdata_and_rdata!!(contexts...)
4244
return (NoRData(), rdata(dx), rc...)
4345
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
1919
end
2020

2121
nanify(x::AbstractFloat) = convert(typeof(x), NaN)
22-
nanify(x::AbstractArray) = map(nan_tangent, x)
23-
nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x)
22+
nanify(x::AbstractArray) = map(nanify, x)
23+
nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x)))
24+
nanify(x::Tuple) = map(nanify, x)
2425
nanify(::NoFData) = NoFData()
2526
nanify(::NoRData) = NoRData()
2627

DifferentiationInterface/src/misc/differentiate_with.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,29 @@ struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}}
7676
context_wrappers::N
7777
end
7878

79+
function DifferentiateWith(
80+
f::F,
81+
backend::B,
82+
context_wrappers::NTuple{C, Any},
83+
) where {F, B <: AbstractADType, C}
84+
for (i, wrapper) in pairs(context_wrappers)
85+
# Accept typical constructor-like values: functions or types.
86+
if !(wrapper isa Function || wrapper isa Type)
87+
throw(
88+
ArgumentError(
89+
"Each context wrapper must be a callable object or type " *
90+
"(e.g., a wrapper constructor like `Constant` or `Cache`), " *
91+
"but element $i has type $(typeof(wrapper)).",
92+
),
93+
)
94+
end
95+
end
96+
return DifferentiateWith{C, F, B, typeof(context_wrappers)}(
97+
f,
98+
backend,
99+
context_wrappers,
100+
)
101+
end
79102
DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ())
80103

81104
function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}

0 commit comments

Comments
 (0)