|
1 | | -struct TapirAllocatingPullbackExtras{R} <: PullbackExtras |
| 1 | +struct TapirAllocatingPullbackExtras{Y,R} <: PullbackExtras |
| 2 | + y_prototype::Y |
2 | 3 | rrule::R |
3 | 4 | end |
4 | 5 |
|
5 | | -DI.prepare_pullback(f, ::AutoTapir, x) = TapirAllocatingPullbackExtras(build_rrule(f, x)) |
| 6 | +function DI.prepare_pullback(f, ::AutoTapir, x) |
| 7 | + y = f(x) |
| 8 | + rrule = build_rrule(f, x) |
| 9 | + return TapirAllocatingPullbackExtras(y, rrule) |
| 10 | +end |
6 | 11 |
|
7 | | -function DI.value_and_pullback(f, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras) |
8 | | - y = f(x) # TODO: one call too many, just for the conversion |
9 | | - dy_righttype = convert(tangent_type(typeof(y)), dy) |
| 12 | +function DI.value_and_pullback( |
| 13 | + f, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras{Y} |
| 14 | +) where {Y} |
| 15 | + dy_righttype = convert(tangent_type(Y), dy) |
10 | 16 | new_y, (new_df, new_dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x) |
11 | 17 | return new_y, new_dx |
12 | 18 | end |
13 | 19 |
|
14 | 20 | function DI.value_and_pullback!!( |
15 | | - f, dx, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras |
16 | | -) |
17 | | - y = f(x) # TODO: one call too many, just for the conversion |
18 | | - dy_righttype = convert(tangent_type(typeof(y)), dy) |
| 21 | + f, dx, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras{Y} |
| 22 | +) where {Y} |
| 23 | + dy_righttype = convert(tangent_type(Y), dy) |
19 | 24 | dx_righttype = convert(tangent_type(typeof(x)), dx) |
20 | 25 | dx_righttype = set_to_zero!!(dx_righttype) |
21 | 26 | new_y, (new_df, new_dx) = value_and_pullback!!( |
|
0 commit comments