|
| 1 | +struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep |
| 2 | + y_prototype::Y |
| 3 | + rrule::R |
| 4 | +end |
| 5 | + |
| 6 | +function DI.prepare_pullback( |
| 7 | + f, backend::AutoMooncake, x, ty::DI.Tangents, contexts::Vararg{Context,C} |
| 8 | +) where {C} |
| 9 | + y = f(x, map(unwrap, contexts)...) |
| 10 | + config = get_config(backend) |
| 11 | + rrule = build_rrule( |
| 12 | + get_interpreter(), |
| 13 | + Tuple{typeof(f),typeof(x),typeof.(map(unwrap, contexts))...}; |
| 14 | + debug_mode=config.debug_mode, |
| 15 | + silence_debug_messages=config.silence_debug_messages, |
| 16 | + ) |
| 17 | + prep = MooncakeOneArgPullbackPrep(y, rrule) |
| 18 | + DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up |
| 19 | + return prep |
| 20 | +end |
| 21 | + |
| 22 | +function DI.value_and_pullback( |
| 23 | + f, |
| 24 | + prep::MooncakeOneArgPullbackPrep{Y}, |
| 25 | + ::AutoMooncake, |
| 26 | + x, |
| 27 | + ty::DI.Tangents{1}, |
| 28 | + contexts::Vararg{Context,C}, |
| 29 | +) where {Y,C} |
| 30 | + dy = only(ty) |
| 31 | + dy_righttype = convert(tangent_type(Y), dy) |
| 32 | + new_y, (_, new_dx) = value_and_pullback!!( |
| 33 | + prep.rrule, dy_righttype, f, x, map(unwrap, contexts)... |
| 34 | + ) |
| 35 | + return new_y, DI.Tangents(new_dx) |
| 36 | +end |
| 37 | + |
| 38 | +function DI.value_and_pullback!( |
| 39 | + f, |
| 40 | + prep::MooncakeOneArgPullbackPrep{Y}, |
| 41 | + tx::DI.Tangents, |
| 42 | + ::AutoMooncake, |
| 43 | + x, |
| 44 | + ty::DI.Tangents{1}, |
| 45 | + contexts::Vararg{Context,C}, |
| 46 | +) where {Y,C} |
| 47 | + dx, dy = only(tx), only(ty) |
| 48 | + dy_righttype = convert(tangent_type(Y), dy) |
| 49 | + dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx)) |
| 50 | + contexts_coduals = map(zero_fcodual ∘ unwrap, contexts) |
| 51 | + y, (_, new_dx) = __value_and_pullback!!( |
| 52 | + prep.rrule, |
| 53 | + dy_righttype, |
| 54 | + zero_codual(f), |
| 55 | + CoDual(x, dx_righttype), |
| 56 | + contexts_coduals..., |
| 57 | + ) |
| 58 | + copyto!(dx, new_dx) |
| 59 | + return y, tx |
| 60 | +end |
| 61 | + |
| 62 | +function DI.value_and_pullback( |
| 63 | + f, |
| 64 | + prep::MooncakeOneArgPullbackPrep, |
| 65 | + backend::AutoMooncake, |
| 66 | + x, |
| 67 | + ty::DI.Tangents, |
| 68 | + contexts::Vararg{Context,C}, |
| 69 | +) where {C} |
| 70 | + ys_and_dxs = map(ty.d) do dy |
| 71 | + y, tx = DI.value_and_pullback(f, prep, backend, x, DI.Tangents(dy), contexts...) |
| 72 | + y, only(tx) |
| 73 | + end |
| 74 | + y = first(ys_and_dxs[1]) |
| 75 | + dxs = last.(ys_and_dxs) |
| 76 | + return y, DI.Tangents(dxs...) |
| 77 | +end |
| 78 | + |
| 79 | +function DI.pullback( |
| 80 | + f, |
| 81 | + prep::MooncakeOneArgPullbackPrep, |
| 82 | + backend::AutoMooncake, |
| 83 | + x, |
| 84 | + ty::DI.Tangents, |
| 85 | + contexts::Vararg{Context,C}, |
| 86 | +) where {C} |
| 87 | + return DI.value_and_pullback(f, prep, backend, x, ty, contexts...)[2] |
| 88 | +end |
| 89 | + |
| 90 | +function DI.pullback!( |
| 91 | + f, |
| 92 | + tx::DI.Tangents, |
| 93 | + prep::MooncakeOneArgPullbackPrep, |
| 94 | + backend::AutoMooncake, |
| 95 | + x, |
| 96 | + ty::DI.Tangents, |
| 97 | + contexts::Vararg{Context,C}, |
| 98 | +) where {C} |
| 99 | + return DI.value_and_pullback!(f, tx, prep, backend, x, ty, contexts...)[2] |
| 100 | +end |
0 commit comments