Skip to content

Commit 069554a

Browse files
authored
Tapir mutation, take 2 (#141)
* Tapir mutation, take 2 * Handle the case where `f!` is a closure * Define zero_tangent * Copy dy * Reactivate tests
1 parent 33475ed commit 069554a

3 files changed

Lines changed: 89 additions & 38 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceTapirExt/DifferentiationInterfaceTapirExt.jl

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,19 @@ using ADTypes: ADTypes
44
import DifferentiationInterface as DI
55
using DifferentiationInterface: AutoTapir
66
using DifferentiationInterface: PullbackExtras
7-
using Tapir: CoDual, build_rrule, value_and_pullback!!, zero_codual
8-
9-
DI.supports_mutation(::AutoTapir) = DI.MutationNotSupported()
10-
11-
function zero_sametype!!(x_target, x::Number)
12-
return zero(x)
13-
end
14-
15-
function zero_sametype!!(x_target, x::AbstractArray)
16-
x_sametype = convert(typeof(x), x_target)
17-
x_sametype .= zero(eltype(x))
18-
return x_sametype
19-
end
20-
21-
## Pullback
22-
23-
struct TapirPullbackExtras{R} <: PullbackExtras
24-
rrule::R
25-
end
26-
27-
DI.prepare_pullback(f, ::AutoTapir, x) = TapirPullbackExtras(build_rrule(f, x))
28-
29-
function DI.value_and_pullback(f, ::AutoTapir, x, dy, extras::TapirPullbackExtras)
30-
y = f(x)
31-
dy_righttype = convert(typeof(y), dy)
32-
_, (_, dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x)
33-
return y, dx
34-
end
35-
36-
function DI.value_and_pullback!!(f, dx, ::AutoTapir, x, dy, extras::TapirPullbackExtras)
37-
y = f(x)
38-
dy_righttype = convert(typeof(y), dy)
39-
dx_righttype = zero_sametype!!(dx, x)
40-
new_y, (_, new_dx) = value_and_pullback!!(
41-
extras.rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
42-
)
43-
return new_y, new_dx
44-
end
7+
using Tapir:
8+
CoDual,
9+
NoTangent,
10+
build_rrule,
11+
increment!!,
12+
set_to_zero!!,
13+
tangent,
14+
tangent_type,
15+
value_and_pullback!!,
16+
zero_codual,
17+
zero_tangent
18+
19+
include("allocating.jl")
20+
include("mutating.jl")
4521

4622
end
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
struct TapirAllocatingPullbackExtras{R} <: PullbackExtras
2+
rrule::R
3+
end
4+
5+
DI.prepare_pullback(f, ::AutoTapir, x) = TapirAllocatingPullbackExtras(build_rrule(f, x))
6+
7+
function DI.value_and_pullback(f, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras)
8+
y = f(x) # TODO: one call too many, just for the conversion
9+
dy_righttype = convert(tangent_type(typeof(y)), dy)
10+
new_y, (new_df, new_dx) = value_and_pullback!!(extras.rrule, dy_righttype, f, x)
11+
return new_y, new_dx
12+
end
13+
14+
function DI.value_and_pullback!!(
15+
f, dx, ::AutoTapir, x, dy, extras::TapirAllocatingPullbackExtras
16+
)
17+
y = f(x) # TODO: one call too many, just for the conversion
18+
dy_righttype = convert(tangent_type(typeof(y)), dy)
19+
dx_righttype = convert(tangent_type(typeof(x)), dx)
20+
dx_righttype = set_to_zero!!(dx_righttype)
21+
new_y, (new_df, new_dx) = value_and_pullback!!(
22+
extras.rrule, dy_righttype, zero_codual(f), CoDual(x, dx_righttype)
23+
)
24+
return new_y, new_dx
25+
end
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
struct TapirMutatingPullbackExtras{R} <: PullbackExtras
2+
rrule::R
3+
end
4+
5+
function DI.prepare_pullback(f!, ::AutoTapir, y, x)
6+
return TapirMutatingPullbackExtras(build_rrule(f!, y, x))
7+
end
8+
9+
# see https://github.com/withbayes/Tapir.jl/issues/113#issuecomment-2036718992
10+
11+
function DI.value_and_pullback!!(
12+
f!, y, dx, ::AutoTapir, x, dy, extras::TapirMutatingPullbackExtras
13+
)
14+
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
15+
dx_righttype = convert(tangent_type(typeof(x)), dx)
16+
17+
# We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
18+
# that it will also replace any immutable components of `dx` to zero.
19+
dx_righttype = set_to_zero!!(dx_righttype)
20+
21+
# We want `dy` to correspond to the cotangent of `y` _after_
22+
# running the forwards-pass, so I'm going to take a copy, and zero-out the original.
23+
dy_righttype_backup = copy(dy_righttype)
24+
dy_righttype = set_to_zero!!(dy_righttype)
25+
26+
# Mutate a copy of `y`, so that we can run the reverse-pass later on.
27+
y_copy = copy(y)
28+
29+
# In case `f!` is a closure
30+
df! = zero_tangent(f!)
31+
32+
# Run the forwards-pass.
33+
out, pb!! = extras.rrule(
34+
CoDual(f!, df!), CoDual(y_copy, dy_righttype), CoDual(x, dx_righttype)
35+
)
36+
37+
# Verify that the output is non-differentiable.
38+
@assert tangent(out) == NoTangent()
39+
40+
# Set the cotangent of `y` to be equal to the requested value.
41+
dy_righttype = increment!!(dy_righttype, dy_righttype_backup)
42+
43+
# Record the state of `y` before running the reverse-pass.
44+
y = copy!(y, y_copy)
45+
46+
# Run the reverse-pass.
47+
_, _, new_dx = pb!!(NoTangent(), df!, dy_righttype, dx_righttype)
48+
49+
return y, new_dx
50+
end

0 commit comments

Comments
 (0)