Skip to content

Commit 93c9a39

Browse files
willtebbuttgdalle
andauthored
Updates for Tapir (#241)
* Bump Tapir version * Update Tapir usage * Bump patch * Update DifferentiationInterface/Project.toml Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 0b04ee8 commit 93c9a39

4 files changed

Lines changed: 17 additions & 13 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.3.1"
4+
version = "0.3.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -57,7 +57,7 @@ PolyesterForwardDiff = "0.1.1"
5757
ReverseDiff = "1.15.1"
5858
SparseArrays = "1"
5959
Symbolics = "5.27.1"
60-
Tapir = "0.1.2"
60+
Tapir = "0.2.4"
6161
Test = "1"
6262
Tracker = "0.2.33"
6363
Zygote = "0.6.69"

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/DifferentiationInterfaceTapirExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ using Tapir:
1414
tangent_type,
1515
value_and_pullback!!,
1616
zero_codual,
17-
zero_tangent
17+
zero_tangent,
18+
NoRData,
19+
fdata,
20+
rdata,
21+
__value_and_pullback!!
1822

1923
DI.check_available(::AutoTapir) = true
2024

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/onearg.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,23 @@ end
55

66
function DI.prepare_pullback(f, ::AutoTapir, x, dy)
77
y = f(x)
8-
rrule = build_rrule(f, x)
9-
return TapirOneArgPullbackExtras(y, rrule)
8+
return TapirOneArgPullbackExtras(y, build_rrule(f, x))
109
end
1110

1211
function DI.value_and_pullback(
1312
f, ::AutoTapir, x, dy, extras::TapirOneArgPullbackExtras{Y}
1413
) where {Y}
1514
dy_righttype = convert(tangent_type(Y), dy)
16-
new_y, (new_df, new_dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x)
15+
new_y, (_, new_dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x)
1716
return new_y, new_dx
1817
end
1918

2019
function DI.value_and_pullback!(
2120
f, dx, ::AutoTapir, x, dy, extras::TapirOneArgPullbackExtras{Y}
2221
) where {Y}
2322
dy_righttype = convert(tangent_type(Y), dy)
24-
dx_righttype = convert(tangent_type(typeof(x)), dx)
25-
dx_righttype = set_to_zero!!(dx_righttype)
26-
y, (new_df, new_dx) = value_and_pullback!!(
23+
dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx))
24+
y, (_, new_dx) = __value_and_pullback!!(
2725
extras.rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
2826
)
2927
return y, copyto!(dx, new_dx)

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/twoarg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ function DI.value_and_pullback(f!, y, ::AutoTapir, x, dy, extras::TapirTwoArgPul
2929

3030
# Run the forwards-pass.
3131
out, pb!! = extras.rrule(
32-
CoDual(f!, df!), CoDual(y_copy, dy_righttype), CoDual(x, dx_righttype)
32+
CoDual(f!, fdata(df!)),
33+
CoDual(y_copy, fdata(dy_righttype)),
34+
CoDual(x, fdata(dx_righttype)),
3335
)
3436

3537
# Verify that the output is non-differentiable.
36-
@assert tangent(out) == NoTangent()
38+
@assert primal(out) === nothing
3739

3840
# Set the cotangent of `y` to be equal to the requested value.
3941
dy_righttype = increment!!(dy_righttype, dy_righttype_backup)
@@ -42,7 +44,7 @@ function DI.value_and_pullback(f!, y, ::AutoTapir, x, dy, extras::TapirTwoArgPul
4244
y = copyto!(y, y_copy)
4345

4446
# Run the reverse-pass.
45-
_, _, new_dx = pb!!(NoTangent(), df!, dy_righttype, dx_righttype)
47+
_, _, new_dx = pb!!(NoRData())
4648

47-
return y, new_dx
49+
return y, tangent(fdata(dx_righttype), new_dx)
4850
end

0 commit comments

Comments
 (0)