|
1 | | -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any} |
| 1 | +const NumberOrArray = Union{Number, AbstractArray{<:Number}} |
| 2 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray} |
| 3 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray} |
| 4 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} |
| 5 | +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} |
| 6 | +# TODO: generate more cases programmatically |
2 | 7 |
|
3 | 8 | struct MooncakeDifferentiateWithError <: Exception |
4 | 9 | F::Type |
|
12 | 17 | function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) |
13 | 18 | return print( |
14 | 19 | io, |
15 | | - "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", |
| 20 | + "MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.", |
16 | 21 | ) |
17 | 22 | end |
18 | 23 |
|
19 | | -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) |
| 24 | +function Mooncake.rrule!!( |
| 25 | + dw::CoDual{<:DI.DifferentiateWith{C}}, |
| 26 | + x::CoDual{<:Number}, |
| 27 | + contexts::Vararg{CoDual, C} |
| 28 | + ) where {C} |
| 29 | + @assert tangent_type(typeof(dw)) == NoTangent |
20 | 30 | primal_func = primal(dw) |
21 | 31 | primal_x = primal(x) |
22 | | - (; f, backend) = primal_func |
23 | | - y = zero_fcodual(f(primal_x)) |
| 32 | + primal_contexts = map(primal, contexts) |
| 33 | + (; f, backend, context_wrappers) = primal_func |
| 34 | + y = zero_fcodual(f(primal_x, primal_contexts...)) |
| 35 | + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) |
24 | 36 |
|
25 | 37 | # output is a vector, so we need to use the vector pullback |
26 | 38 | function pullback_array!!(dy::NoRData) |
27 | | - tx = DI.pullback(f, backend, primal_x, (y.dx,)) |
28 | | - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) |
29 | | - return NoRData(), rdata(only(tx)) |
| 39 | + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only |
| 40 | + @assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x))) |
| 41 | + rc = nanify_fdata_and_rdata!!(contexts...) |
| 42 | + return (NoRData(), rdata(dx), rc...) |
30 | 43 | end |
31 | 44 |
|
32 | 45 | # output is a scalar, so we can use the scalar pullback |
33 | 46 | function pullback_scalar!!(dy::Number) |
34 | | - tx = DI.pullback(f, backend, primal_x, (dy,)) |
35 | | - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) |
36 | | - return NoRData(), rdata(only(tx)) |
| 47 | + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only |
| 48 | + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) |
| 49 | + rc = nanify_fdata_and_rdata!!(contexts...) |
| 50 | + return (NoRData(), rdata(dx), rc...) |
37 | 51 | end |
38 | 52 |
|
39 | 53 | pullback = if primal(y) isa Number |
40 | 54 | pullback_scalar!! |
41 | 55 | elseif primal(y) isa AbstractArray |
42 | 56 | pullback_array!! |
43 | 57 | else |
44 | | - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) |
| 58 | + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) |
45 | 59 | end |
46 | 60 |
|
47 | 61 | return y, pullback |
48 | 62 | end |
49 | 63 |
|
50 | 64 | function Mooncake.rrule!!( |
51 | | - dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} |
52 | | - ) |
| 65 | + dw::CoDual{<:DI.DifferentiateWith{C}}, |
| 66 | + x::CoDual{<:AbstractArray{<:Number}}, |
| 67 | + contexts::Vararg{CoDual, C} |
| 68 | + ) where {C} |
| 69 | + @assert tangent_type(typeof(dw)) == NoTangent |
53 | 70 | primal_func = primal(dw) |
54 | 71 | primal_x = primal(x) |
55 | | - fdata_arg = x.dx |
56 | | - (; f, backend) = primal_func |
57 | | - y = zero_fcodual(f(primal_x)) |
| 72 | + primal_contexts = map(primal, contexts) |
| 73 | + (; f, backend, context_wrappers) = primal_func |
| 74 | + y = zero_fcodual(f(primal_x, primal_contexts...)) |
| 75 | + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) |
58 | 76 |
|
59 | 77 | # output is a vector, so we need to use the vector pullback |
60 | 78 | function pullback_array!!(dy::NoRData) |
61 | | - tx = DI.pullback(f, backend, primal_x, (y.dx,)) |
62 | | - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) |
63 | | - fdata_arg .+= only(tx) |
64 | | - return NoRData(), dy |
| 79 | + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only |
| 80 | + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) |
| 81 | + x.dx .+= dx |
| 82 | + rc = nanify_fdata_and_rdata!!(contexts...) |
| 83 | + return (NoRData(), dy, rc...) |
65 | 84 | end |
66 | 85 |
|
67 | 86 | # output is a scalar, so we can use the scalar pullback |
68 | 87 | function pullback_scalar!!(dy::Number) |
69 | | - tx = DI.pullback(f, backend, primal_x, (dy,)) |
70 | | - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) |
71 | | - fdata_arg .+= only(tx) |
72 | | - return NoRData(), NoRData() |
| 88 | + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only |
| 89 | + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) |
| 90 | + x.dx .+= dx |
| 91 | + rc = nanify_fdata_and_rdata!!(contexts...) |
| 92 | + return (NoRData(), NoRData(), rc...) |
73 | 93 | end |
74 | 94 |
|
75 | 95 | pullback = if primal(y) isa Number |
76 | 96 | pullback_scalar!! |
77 | 97 | elseif primal(y) isa AbstractArray |
78 | 98 | pullback_array!! |
79 | 99 | else |
80 | | - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) |
| 100 | + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) |
81 | 101 | end |
82 | 102 |
|
83 | 103 | return y, pullback |
|
0 commit comments