Skip to content

Commit 2411653

Browse files
authored
Remove fallback backend -> SecondOrder(backend, backend) (#459)
* Remove fallback `backend -> SecondOrder(backend, backend)` * Add hvp mode
1 parent 1155218 commit 2411653

3 files changed

Lines changed: 57 additions & 82 deletions

File tree

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -58,123 +58,126 @@ struct ReverseOverReverseHVPExtras{G<:Gradient,E<:PullbackExtras} <: HVPExtras
5858
end
5959

6060
function prepare_hvp(f::F, backend::AbstractADType, x, tx::Tangents) where {F}
61-
return prepare_hvp(f, SecondOrder(backend, backend), x, tx)
62-
end
63-
64-
function prepare_hvp(f::F, backend::SecondOrder, x, tx::Tangents) where {F}
6561
return _prepare_hvp_aux(f, backend, x, tx, hvp_mode(backend))
6662
end
6763

6864
function _prepare_hvp_aux(
69-
f::F, backend::SecondOrder, x, tx::Tangents, ::ForwardOverForward
65+
f::F, backend::AbstractADType, x, tx::Tangents, ::ForwardOverForward
7066
) where {F}
7167
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
72-
inner_gradient = Gradient(f, nested(inner(backend)))
73-
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, tx)
68+
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
69+
outer_pushforward_extras = prepare_pushforward(
70+
inner_gradient, maybe_outer(backend), x, tx
71+
)
7472
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
7573
end
7674

7775
function _prepare_hvp_aux(
78-
f::F, backend::SecondOrder, x, tx::Tangents, ::ForwardOverReverse
76+
f::F, backend::AbstractADType, x, tx::Tangents, ::ForwardOverReverse
7977
) where {F}
8078
# pushforward of gradient
81-
inner_gradient = Gradient(f, nested(inner(backend)))
82-
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, tx)
79+
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
80+
outer_pushforward_extras = prepare_pushforward(
81+
inner_gradient, maybe_outer(backend), x, tx
82+
)
8383
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
8484
end
8585

8686
function _prepare_hvp_aux(
87-
f::F, backend::SecondOrder, x, tx::Tangents, ::ReverseOverForward
87+
f::F, backend::AbstractADType, x, tx::Tangents, ::ReverseOverForward
8888
) where {F}
8989
# gradient of pushforward
9090
# uses dx in the closure so it can't be prepared
9191
return ReverseOverForwardHVPExtras()
9292
end
9393

9494
function _prepare_hvp_aux(
95-
f::F, backend::SecondOrder, x, tx::Tangents, ::ReverseOverReverse
95+
f::F, backend::AbstractADType, x, tx::Tangents, ::ReverseOverReverse
9696
) where {F}
9797
# pullback of gradient
98-
inner_gradient = Gradient(f, nested(inner(backend)))
99-
outer_pullback_extras = prepare_pullback(inner_gradient, outer(backend), x, tx)
98+
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
99+
outer_pullback_extras = prepare_pullback(inner_gradient, maybe_outer(backend), x, tx)
100100
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
101101
end
102102

103103
## One argument
104104

105-
function hvp(f::F, extras::HVPExtras, backend::AbstractADType, x, tx::Tangents) where {F}
106-
return hvp(f, extras, SecondOrder(backend, backend), x, tx)
107-
end
108-
109105
function hvp(
110-
f::F, extras::ForwardOverForwardHVPExtras, backend::SecondOrder, x, tx::Tangents
106+
f::F, extras::ForwardOverForwardHVPExtras, backend::AbstractADType, x, tx::Tangents
111107
) where {F}
112108
@compat (; inner_gradient, outer_pushforward_extras) = extras
113-
return pushforward(inner_gradient, outer_pushforward_extras, outer(backend), x, tx)
109+
return pushforward(
110+
inner_gradient, outer_pushforward_extras, maybe_outer(backend), x, tx
111+
)
114112
end
115113

116114
function hvp(
117-
f::F, extras::ForwardOverReverseHVPExtras, backend::SecondOrder, x, tx::Tangents
115+
f::F, extras::ForwardOverReverseHVPExtras, backend::AbstractADType, x, tx::Tangents
118116
) where {F}
119117
@compat (; inner_gradient, outer_pushforward_extras) = extras
120-
return pushforward(inner_gradient, outer_pushforward_extras, outer(backend), x, tx)
118+
return pushforward(
119+
inner_gradient, outer_pushforward_extras, maybe_outer(backend), x, tx
120+
)
121121
end
122122

123123
function hvp(
124-
f::F, ::ReverseOverForwardHVPExtras, backend::SecondOrder, x, tx::Tangents
124+
f::F, ::ReverseOverForwardHVPExtras, backend::AbstractADType, x, tx::Tangents
125125
) where {F}
126126
dgs = map(tx.d) do dx
127-
inner_pushforward = PushforwardFixedSeed(f, nested(inner(backend)), Tangents(dx))
128-
gradient(only inner_pushforward, outer(backend), x)
127+
inner_pushforward = PushforwardFixedSeed(f, nested(maybe_inner(backend)), Tangents(dx))
128+
gradient(only inner_pushforward, maybe_outer(backend), x)
129129
end
130130
return Tangents(dgs...)
131131
end
132132

133133
function hvp(
134-
f::F, extras::ReverseOverReverseHVPExtras, backend::SecondOrder, x, tx::Tangents
134+
f::F, extras::ReverseOverReverseHVPExtras, backend::AbstractADType, x, tx::Tangents
135135
) where {F}
136136
@compat (; inner_gradient, outer_pullback_extras) = extras
137-
return pullback(inner_gradient, outer_pullback_extras, outer(backend), x, tx)
138-
end
139-
140-
function hvp!(
141-
f::F, tg::Tangents, extras::HVPExtras, backend::AbstractADType, x, tx::Tangents
142-
) where {F}
143-
return hvp!(f, tg, extras, SecondOrder(backend, backend), x, tx)
137+
return pullback(inner_gradient, outer_pullback_extras, maybe_outer(backend), x, tx)
144138
end
145139

146140
function hvp!(
147141
f::F,
148142
tg::Tangents,
149143
extras::ForwardOverForwardHVPExtras,
150-
backend::SecondOrder,
144+
backend::AbstractADType,
151145
x,
152146
tx::Tangents,
153147
) where {F}
154148
@compat (; inner_gradient, outer_pushforward_extras) = extras
155-
return pushforward!(inner_gradient, tg, outer_pushforward_extras, outer(backend), x, tx)
149+
return pushforward!(
150+
inner_gradient, tg, outer_pushforward_extras, maybe_outer(backend), x, tx
151+
)
156152
end
157153

158154
function hvp!(
159155
f::F,
160156
tg::Tangents,
161157
extras::ForwardOverReverseHVPExtras,
162-
backend::SecondOrder,
158+
backend::AbstractADType,
163159
x,
164160
tx::Tangents,
165161
) where {F}
166162
@compat (; inner_gradient, outer_pushforward_extras) = extras
167-
return pushforward!(inner_gradient, tg, outer_pushforward_extras, outer(backend), x, tx)
163+
return pushforward!(
164+
inner_gradient, tg, outer_pushforward_extras, maybe_outer(backend), x, tx
165+
)
168166
end
169167

170168
function hvp!(
171-
f::F, tg::Tangents, ::ReverseOverForwardHVPExtras, backend::SecondOrder, x, tx::Tangents
169+
f::F,
170+
tg::Tangents,
171+
::ReverseOverForwardHVPExtras,
172+
backend::AbstractADType,
173+
x,
174+
tx::Tangents,
172175
) where {F}
173176
for b in eachindex(tx.d, tg.d)
174177
inner_pushforward = PushforwardFixedSeed(
175-
f, nested(inner(backend)), Tangents(tx.d[b])
178+
f, nested(maybe_inner(backend)), Tangents(tx.d[b])
176179
)
177-
gradient!(only inner_pushforward, tg.d[b], outer(backend), x)
180+
gradient!(only inner_pushforward, tg.d[b], maybe_outer(backend), x)
178181
end
179182
return tg
180183
end
@@ -183,10 +186,10 @@ function hvp!(
183186
f::F,
184187
tg::Tangents,
185188
extras::ReverseOverReverseHVPExtras,
186-
backend::SecondOrder,
189+
backend::AbstractADType,
187190
x,
188191
tx::Tangents,
189192
) where {F}
190193
@compat (; inner_gradient, outer_pullback_extras) = extras
191-
return pullback!(inner_gradient, tg, outer_pullback_extras, outer(backend), x, tx)
194+
return pullback!(inner_gradient, tg, outer_pullback_extras, maybe_outer(backend), x, tx)
192195
end

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -65,77 +65,47 @@ struct ClosureSecondDerivativeExtras{ID<:InnerDerivative,E<:DerivativeExtras} <:
6565
end
6666

6767
function prepare_second_derivative(f::F, backend::AbstractADType, x) where {F}
68-
return prepare_second_derivative(f, SecondOrder(backend, backend), x)
69-
end
70-
71-
function prepare_second_derivative(f::F, backend::SecondOrder, x) where {F}
72-
inner_derivative = InnerDerivative(f, nested(inner(backend)))
73-
outer_derivative_extras = prepare_derivative(inner_derivative, outer(backend), x)
68+
inner_derivative = InnerDerivative(f, nested(maybe_inner(backend)))
69+
outer_derivative_extras = prepare_derivative(inner_derivative, maybe_outer(backend), x)
7470
return ClosureSecondDerivativeExtras(inner_derivative, outer_derivative_extras)
7571
end
7672

7773
## One argument
7874

7975
function second_derivative(
80-
f::F, extras::SecondDerivativeExtras, backend::AbstractADType, x
81-
) where {F}
82-
return second_derivative(f, extras, SecondOrder(backend, backend), x)
83-
end
84-
85-
function second_derivative(
86-
f::F, extras::ClosureSecondDerivativeExtras, backend::SecondOrder, x
76+
f::F, extras::ClosureSecondDerivativeExtras, backend::AbstractADType, x
8777
) where {F}
8878
@compat (; inner_derivative, outer_derivative_extras) = extras
89-
return derivative(inner_derivative, outer_derivative_extras, outer(backend), x)
90-
end
91-
92-
function value_derivative_and_second_derivative(
93-
f::F, extras::SecondDerivativeExtras, backend::AbstractADType, x
94-
) where {F}
95-
return value_derivative_and_second_derivative(
96-
f, extras, SecondOrder(backend, backend), x
97-
)
79+
return derivative(inner_derivative, outer_derivative_extras, maybe_outer(backend), x)
9880
end
9981

10082
function value_derivative_and_second_derivative(
101-
f::F, extras::ClosureSecondDerivativeExtras, backend::SecondOrder, x
83+
f::F, extras::ClosureSecondDerivativeExtras, backend::AbstractADType, x
10284
) where {F}
10385
@compat (; inner_derivative, outer_derivative_extras) = extras
10486
y = f(x)
10587
der, der2 = value_and_derivative(
106-
inner_derivative, outer_derivative_extras, outer(backend), x
88+
inner_derivative, outer_derivative_extras, maybe_outer(backend), x
10789
)
10890
return y, der, der2
10991
end
11092

11193
function second_derivative!(
11294
f::F, der2, extras::SecondDerivativeExtras, backend::AbstractADType, x
113-
) where {F}
114-
return second_derivative!(f, der2, extras, SecondOrder(backend, backend), x)
115-
end
116-
117-
function second_derivative!(
118-
f::F, der2, extras::SecondDerivativeExtras, backend::SecondOrder, x
11995
) where {F}
12096
@compat (; inner_derivative, outer_derivative_extras) = extras
121-
return derivative!(inner_derivative, der2, outer_derivative_extras, outer(backend), x)
122-
end
123-
124-
function value_derivative_and_second_derivative!(
125-
f::F, der, der2, extras::SecondDerivativeExtras, backend::AbstractADType, x
126-
) where {F}
127-
return value_derivative_and_second_derivative!(
128-
f, der, der2, extras, SecondOrder(backend, backend), x
97+
return derivative!(
98+
inner_derivative, der2, outer_derivative_extras, maybe_outer(backend), x
12999
)
130100
end
131101

132102
function value_derivative_and_second_derivative!(
133-
f::F, der, der2, extras::SecondDerivativeExtras, backend::SecondOrder, x
103+
f::F, der, der2, extras::SecondDerivativeExtras, backend::AbstractADType, x
134104
) where {F}
135105
@compat (; inner_derivative, outer_derivative_extras) = extras
136106
y = f(x)
137107
new_der, _ = value_and_derivative!(
138-
inner_derivative, der2, outer_derivative_extras, outer(backend), x
108+
inner_derivative, der2, outer_derivative_extras, maybe_outer(backend), x
139109
)
140110
return y, copyto!(der, new_der), der2
141111
end

DifferentiationInterface/src/utils/traits.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ Traits identifying second-order backends that compute HVPs in forward over forwa
125125
"""
126126
struct ForwardOverForward <: HVPMode end
127127

128+
hvp_mode(backend::AbstractADType) = hvp_mode(SecondOrder(backend, backend))
129+
128130
function hvp_mode(ba::SecondOrder)
129131
if Bool(pushforward_performance(outer(ba))) && Bool(pullback_performance(inner(ba)))
130132
return ForwardOverReverse()

0 commit comments

Comments
 (0)