|
1 | 1 | const NumberOrArray = Union{Number, AbstractArray{<:Number}} |
2 | | -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray} |
3 | | -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray} |
4 | | -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} |
5 | | -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} |
| 2 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any} |
| 3 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any} |
| 4 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any} |
| 5 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any} |
6 | 6 | # TODO: generate more cases programmatically |
7 | 7 |
|
8 | 8 | struct MooncakeDifferentiateWithError <: Exception |
|
17 | 17 | function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) |
18 | 18 | return print( |
19 | 19 | io, |
20 | | - "MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.", |
| 20 | + "MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.", |
21 | 21 | ) |
22 | 22 | end |
23 | 23 |
|
24 | 24 | function Mooncake.rrule!!( |
25 | 25 | dw::CoDual{<:DI.DifferentiateWith{C}}, |
26 | 26 | x::CoDual{<:Number}, |
27 | | - contexts::Vararg{CoDual, C} |
| 27 | + contexts::Vararg{CoDual{<:NumberOrArray}, C} |
28 | 28 | ) where {C} |
29 | 29 | @assert tangent_type(typeof(dw)) == NoTangent |
30 | 30 | primal_func = primal(dw) |
|
64 | 64 | function Mooncake.rrule!!( |
65 | 65 | dw::CoDual{<:DI.DifferentiateWith{C}}, |
66 | 66 | x::CoDual{<:AbstractArray{<:Number}}, |
67 | | - contexts::Vararg{CoDual, C} |
| 67 | + contexts::Vararg{CoDual{<:NumberOrArray}, C} |
68 | 68 | ) where {C} |
69 | 69 | @assert tangent_type(typeof(dw)) == NoTangent |
70 | 70 | primal_func = primal(dw) |
|
0 commit comments