Skip to content

Commit 95fe3bd

Browse files
authored
Add y prototype in Tapir (#156)
1 parent be5fc79 commit 95fe3bd

2 files changed

Lines changed: 14 additions & 10 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/allocating.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1-
struct TapirAllocatingPullbackExtras{R} <: PullbackExtras
1+
struct TapirAllocatingPullbackExtras{Y,R} <: PullbackExtras
2+
y_prototype::Y
23
rrule::R
34
end
45

5-
DI.prepare_pullback(f, ::AutoTapir, x) = TapirAllocatingPullbackExtras(build_rrule(f, x))
6+
function DI.prepare_pullback(f, ::AutoTapir, x)
7+
y = f(x)
8+
rrule = build_rrule(f, x)
9+
return TapirAllocatingPullbackExtras(y, rrule)
10+
end
611

7-
function DI.value_and_pullback(f, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras)
8-
y = f(x) # TODO: one call too many, just for the conversion
9-
dy_righttype = convert(tangent_type(typeof(y)), dy)
12+
function DI.value_and_pullback(
13+
f, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras{Y}
14+
) where {Y}
15+
dy_righttype = convert(tangent_type(Y), dy)
1016
new_y, (new_df, new_dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x)
1117
return new_y, new_dx
1218
end
1319

1420
function DI.value_and_pullback!!(
15-
f, dx, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras
16-
)
17-
y = f(x) # TODO: one call too many, just for the conversion
18-
dy_righttype = convert(tangent_type(typeof(y)), dy)
21+
f, dx, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras{Y}
22+
) where {Y}
23+
dy_righttype = convert(tangent_type(Y), dy)
1924
dx_righttype = convert(tangent_type(typeof(x)), dx)
2025
dx_righttype = set_to_zero!!(dx_righttype)
2126
new_y, (new_df, new_dx) = value_and_pullback!!(

DifferentiationInterface/test/test_imports.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ Pkg.develop(
88

99
using DifferentiationInterface
1010
using DifferentiationInterfaceTest
11-
using DifferentiationInterfaceTest: AutoZeroForward, AutoZeroReverse
1211

1312
using Aqua: Aqua
1413
using Documenter: Documenter

0 commit comments

Comments
 (0)