Skip to content

Commit dfd6022

Browse files
authored
Pullback for FastDifferentiation (#188)
1 parent 06841e0 commit dfd6022

5 files changed

Lines changed: 137 additions & 7 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ChainRulesCore = "1.23.0"
4747
Diffractor = "=0.2.6"
4848
DocStringExtensions = "0.9.3"
4949
Enzyme = "0.11.20"
50-
FastDifferentiation = "0.3.7"
50+
FastDifferentiation = "0.3.9"
5151
FillArrays = "1.9.3"
5252
FiniteDiff = "2.23.1"
5353
FiniteDifferences = "0.12.31"

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ const AnyAutoFastDifferentiation = Union{
3232
}
3333

3434
DI.check_available(::AutoFastDifferentiation) = true
35-
DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
36-
DI.pushforward_performance(::AutoFastDifferentiation) = DI.PushforwardFast()
37-
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()
3835

3936
monovec(x::Number) = Fill(x, 1)
4037

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,66 @@ end
7070

7171
## Pullback
7272

73-
# TODO: fix https://github.com/gdalle/DifferentiationInterface.jl/issues/131
73+
struct FastDifferentiationOneArgPullbackExtras{E1,E2} <: PullbackExtras
74+
vjp_exe::E1
75+
vjp_exe!::E2
76+
end
77+
78+
function DI.prepare_pullback(f, ::AutoFastDifferentiation, x, dy)
79+
x_var = if x isa Number
80+
only(make_variables(:x))
81+
else
82+
make_variables(:x, size(x)...)
83+
end
84+
y_var = f(x_var)
85+
86+
x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var)
87+
y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var)
88+
vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var)
89+
vjp_exe = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
90+
vjp_exe! = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
91+
return FastDifferentiationOneArgPullbackExtras(vjp_exe, vjp_exe!)
92+
end
93+
94+
function DI.pullback(
95+
f, ::AutoFastDifferentiation, x, dy, extras::FastDifferentiationOneArgPullbackExtras
96+
)
97+
v_vec = vcat(myvec(x), myvec(dy))
98+
if x isa Number
99+
return only(extras.vjp_exe(v_vec))
100+
else
101+
return reshape(extras.vjp_exe(v_vec), size(x))
102+
end
103+
end
104+
105+
function DI.pullback!(
106+
f, dx, ::AutoFastDifferentiation, x, dy, extras::FastDifferentiationOneArgPullbackExtras
107+
)
108+
v_vec = vcat(myvec(x), myvec(dy))
109+
extras.vjp_exe!(vec(dx), v_vec)
110+
return dx
111+
end
112+
113+
function DI.value_and_pullback(
114+
f,
115+
backend::AutoFastDifferentiation,
116+
x,
117+
dy,
118+
extras::FastDifferentiationOneArgPullbackExtras,
119+
)
120+
return f(x), DI.pullback(f, backend, x, dy, extras)
121+
end
122+
123+
function DI.value_and_pullback!(
124+
f,
125+
dx,
126+
backend::AutoFastDifferentiation,
127+
x,
128+
dy,
129+
extras::FastDifferentiationOneArgPullbackExtras,
130+
)
131+
return f(x), DI.pullback!(f, dx, backend, x, dy, extras)
132+
end
74133

75134
## Derivative
76135

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,82 @@ function DI.pushforward!(
7878
return dy
7979
end
8080

81+
## Pullback
82+
83+
struct FastDifferentiationTwoArgPullbackExtras{E1,E2} <: PullbackExtras
84+
vjp_exe::E1
85+
vjp_exe!::E2
86+
end
87+
88+
function DI.prepare_pullback(f!, y, ::AutoFastDifferentiation, x, dy)
89+
x_var = if x isa Number
90+
only(make_variables(:x))
91+
else
92+
make_variables(:x, size(x)...)
93+
end
94+
y_var = make_variables(:y, size(y)...)
95+
f!(y_var, x_var)
96+
97+
x_vec_var = x_var isa Number ? monovec(x_var) : vec(x_var)
98+
y_vec_var = y_var isa Number ? monovec(y_var) : vec(y_var)
99+
vj_vec_var, v_vec_var = jacobian_transpose_v(y_vec_var, x_vec_var)
100+
vjp_exe = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
101+
vjp_exe! = make_function(vj_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
102+
return FastDifferentiationTwoArgPullbackExtras(vjp_exe, vjp_exe!)
103+
end
104+
105+
function DI.pullback(
106+
f!, y, ::AutoFastDifferentiation, x, dy, extras::FastDifferentiationTwoArgPullbackExtras
107+
)
108+
v_vec = vcat(myvec(x), myvec(dy))
109+
if x isa Number
110+
return only(extras.vjp_exe(v_vec))
111+
else
112+
return reshape(extras.vjp_exe(v_vec), size(x))
113+
end
114+
end
115+
116+
function DI.pullback!(
117+
f!,
118+
y,
119+
dx,
120+
::AutoFastDifferentiation,
121+
x,
122+
dy,
123+
extras::FastDifferentiationTwoArgPullbackExtras,
124+
)
125+
v_vec = vcat(myvec(x), myvec(dy))
126+
extras.vjp_exe!(vec(dx), v_vec)
127+
return dx
128+
end
129+
130+
function DI.value_and_pullback(
131+
f!,
132+
y,
133+
backend::AutoFastDifferentiation,
134+
x,
135+
dy,
136+
extras::FastDifferentiationTwoArgPullbackExtras,
137+
)
138+
dx = DI.pullback(f!, y, backend, x, dy, extras)
139+
f!(y, x)
140+
return y, dx
141+
end
142+
143+
function DI.value_and_pullback!(
144+
f!,
145+
y,
146+
dx,
147+
backend::AutoFastDifferentiation,
148+
x,
149+
dy,
150+
extras::FastDifferentiationTwoArgPullbackExtras,
151+
)
152+
DI.pullback!(f!, y, dx, backend, x, dy, extras)
153+
f!(y, x)
154+
return y, dx
155+
end
156+
81157
## Derivative
82158

83159
struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
3030
const AnyAutoSymbolics = Union{AutoSymbolics,AutoSparseSymbolics}
3131

3232
DI.check_available(::AutoSymbolics) = true
33-
DI.mode(::AutoSymbolics) = ADTypes.AbstractSymbolicDifferentiationMode
34-
DI.pushforward_performance(::AutoSymbolics) = DI.PushforwardFast()
3533
DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow()
3634

3735
monovec(x::Number) = Fill(x, 1)

0 commit comments

Comments
 (0)