Skip to content

Commit b19bcde

Browse files
committed
Fix errors
1 parent bad6a01 commit b19bcde

2 files changed

Lines changed: 9 additions & 9 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
const NumberOrArray = Union{Number, AbstractArray{<:Number}}
2-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray}
3-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray}
4-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray}
5-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray}
2+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any}
3+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any}
4+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any}
5+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any}
66
# TODO: generate more cases programmatically
77

88
struct MooncakeDifferentiateWithError <: Exception
@@ -17,14 +17,14 @@ end
1717
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
1818
return print(
1919
io,
20-
"MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.",
20+
"MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.",
2121
)
2222
end
2323

2424
function Mooncake.rrule!!(
2525
dw::CoDual{<:DI.DifferentiateWith{C}},
2626
x::CoDual{<:Number},
27-
contexts::Vararg{CoDual, C}
27+
contexts::Vararg{CoDual{<:NumberOrArray}, C}
2828
) where {C}
2929
@assert tangent_type(typeof(dw)) == NoTangent
3030
primal_func = primal(dw)
@@ -64,7 +64,7 @@ end
6464
function Mooncake.rrule!!(
6565
dw::CoDual{<:DI.DifferentiateWith{C}},
6666
x::CoDual{<:AbstractArray{<:Number}},
67-
contexts::Vararg{CoDual, C}
67+
contexts::Vararg{CoDual{<:NumberOrArray}, C}
6868
) where {C}
6969
@assert tangent_type(typeof(dw)) == NoTangent
7070
primal_func = primal(dw)

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ end;
9090
MooncakeDifferentiateWithError =
9191
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError
9292

93-
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
93+
e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0)
9494
@test sprint(showerror, e) ==
95-
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input types (Float64,), the output type Float64 is currently not supported."
95+
"MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported."
9696

9797
f_num2tup(x::Number) = (x,)
9898
f_vec2tup(x::Vector) = (first(x),)

0 commit comments

Comments
 (0)