Skip to content

Commit b3c9b0a

Browse files
authored
Get rid of functors for second order (#460)
* Get rid of functors for second order * Types * Fix
1 parent 2411653 commit b3c9b0a

5 files changed

Lines changed: 19 additions & 98 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ function tag_backend_hvp(f, backend::AutoForwardDiff, x)
1818
return backend
1919
end
2020

21-
struct ForwardDiffOverSomethingHVPExtras{
22-
B<:AutoForwardDiff,G<:DI.Gradient,E<:PushforwardExtras
23-
} <: HVPExtras
21+
struct ForwardDiffOverSomethingHVPExtras{B<:AutoForwardDiff,G,E<:PushforwardExtras} <:
22+
HVPExtras
2423
tagged_outer_backend::B
2524
inner_gradient::G
2625
outer_pushforward_extras::E
@@ -33,7 +32,7 @@ function DI.prepare_hvp(
3332
T = tag_type(f, tagged_outer_backend, x)
3433
xdual = make_dual(T, x, tx)
3534
gradient_extras = DI.prepare_gradient(f, inner(backend), xdual)
36-
inner_gradient = DI.Gradient(f, inner(backend), gradient_extras)
35+
inner_gradient(x) = DI.gradient(f, gradient_extras, inner(backend), x)
3736
outer_pushforward_extras = DI.prepare_pushforward(
3837
inner_gradient, tagged_outer_backend, x, tx
3938
)

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -88,53 +88,3 @@ function gradient!(
8888
pullback!(f, Tangents(grad), extras.pullback_extras, backend, x, Tangents(true))
8989
return grad
9090
end
91-
92-
## Functors
93-
94-
"""
95-
Gradient
96-
97-
Functor computing the gradient of `f` with a fixed `backend`.
98-
99-
!!! warning
100-
This type is not part of the public API.
101-
102-
# Constructor
103-
104-
Gradient(f, backend, extras=nothing)
105-
106-
If `extras` is provided, the gradient closure will skip preparation.
107-
108-
# Example
109-
110-
```jldoctest
111-
using DifferentiationInterface
112-
import Zygote
113-
114-
g = DifferentiationInterface.Gradient(x -> sum(abs2, x), AutoZygote())
115-
g([2.0, 3.0])
116-
117-
# output
118-
119-
2-element Vector{Float64}:
120-
4.0
121-
6.0
122-
```
123-
"""
124-
struct Gradient{F,B,E}
125-
f::F
126-
backend::B
127-
extras::E
128-
end
129-
130-
Gradient(f, backend::AbstractADType) = Gradient(f, backend, nothing)
131-
132-
function (g::Gradient{F,B,Nothing})(x) where {F,B}
133-
@compat (; f, backend) = g
134-
return gradient(f, backend, x)
135-
end
136-
137-
function (g::Gradient{F,B,<:GradientExtras})(x) where {F,B}
138-
@compat (; f, backend, extras) = g
139-
return gradient(f, extras, backend, x)
140-
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -242,22 +242,3 @@ function pushforward!(
242242
) where {F}
243243
return value_and_pushforward!(f!, y, ty, extras, backend, x, tx)[2]
244244
end
245-
246-
## Functors
247-
248-
struct PushforwardFixedSeed{F,B,TX,E}
249-
f::F
250-
backend::B
251-
tx::TX
252-
extras::E
253-
end
254-
255-
function PushforwardFixedSeed(f, backend::AbstractADType, tx)
256-
return PushforwardFixedSeed(f, backend, tx, nothing)
257-
end
258-
259-
# not covered but don't remove, Enzyme messes with code coverage
260-
function (pfs::PushforwardFixedSeed{F,B,TX,Nothing})(x) where {F,B,TX}
261-
@compat (; f, backend, tx) = pfs
262-
return pushforward(f, backend, x, tx)
263-
end

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ function hvp! end
4040

4141
## Preparation
4242

43-
struct ForwardOverForwardHVPExtras{G<:Gradient,E<:PushforwardExtras} <: HVPExtras
43+
struct ForwardOverForwardHVPExtras{G,E<:PushforwardExtras} <: HVPExtras
4444
inner_gradient::G
4545
outer_pushforward_extras::E
4646
end
4747

48-
struct ForwardOverReverseHVPExtras{G<:Gradient,E<:PushforwardExtras} <: HVPExtras
48+
struct ForwardOverReverseHVPExtras{G,E<:PushforwardExtras} <: HVPExtras
4949
inner_gradient::G
5050
outer_pushforward_extras::E
5151
end
5252

5353
struct ReverseOverForwardHVPExtras <: HVPExtras end
5454

55-
struct ReverseOverReverseHVPExtras{G<:Gradient,E<:PullbackExtras} <: HVPExtras
55+
struct ReverseOverReverseHVPExtras{G,E<:PullbackExtras} <: HVPExtras
5656
inner_gradient::G
5757
outer_pullback_extras::E
5858
end
@@ -65,7 +65,7 @@ function _prepare_hvp_aux(
6565
f::F, backend::AbstractADType, x, tx::Tangents, ::ForwardOverForward
6666
) where {F}
6767
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
68-
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
68+
inner_gradient(x) = gradient(f, nested(maybe_inner(backend)), x)
6969
outer_pushforward_extras = prepare_pushforward(
7070
inner_gradient, maybe_outer(backend), x, tx
7171
)
@@ -76,7 +76,7 @@ function _prepare_hvp_aux(
7676
f::F, backend::AbstractADType, x, tx::Tangents, ::ForwardOverReverse
7777
) where {F}
7878
# pushforward of gradient
79-
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
79+
inner_gradient(x) = gradient(f, nested(maybe_inner(backend)), x)
8080
outer_pushforward_extras = prepare_pushforward(
8181
inner_gradient, maybe_outer(backend), x, tx
8282
)
@@ -95,7 +95,7 @@ function _prepare_hvp_aux(
9595
f::F, backend::AbstractADType, x, tx::Tangents, ::ReverseOverReverse
9696
) where {F}
9797
# pullback of gradient
98-
inner_gradient = Gradient(f, nested(maybe_inner(backend)))
98+
inner_gradient(x) = gradient(f, nested(maybe_inner(backend)), x)
9999
outer_pullback_extras = prepare_pullback(inner_gradient, maybe_outer(backend), x, tx)
100100
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
101101
end
@@ -123,11 +123,13 @@ end
123123
function hvp(
124124
f::F, ::ReverseOverForwardHVPExtras, backend::AbstractADType, x, tx::Tangents
125125
) where {F}
126-
dgs = map(tx.d) do dx
127-
inner_pushforward = PushforwardFixedSeed(f, nested(maybe_inner(backend)), Tangents(dx))
126+
tg = map(tx) do dx
127+
function inner_pushforward(x)
128+
return only(pushforward(f, nested(maybe_inner(backend)), x, Tangents(dx)))
129+
end
128130
gradient(only inner_pushforward, maybe_outer(backend), x)
129131
end
130-
return Tangents(dgs...)
132+
return tg
131133
end
132134

133135
function hvp(
@@ -174,9 +176,9 @@ function hvp!(
174176
tx::Tangents,
175177
) where {F}
176178
for b in eachindex(tx.d, tg.d)
177-
inner_pushforward = PushforwardFixedSeed(
178-
f, nested(maybe_inner(backend)), Tangents(tx.d[b])
179-
)
179+
function inner_pushforward(x)
180+
return only(pushforward(f, nested(maybe_inner(backend)), x, Tangents(tx.d[b])))
181+
end
180182
gradient!(only inner_pushforward, tg.d[b], maybe_outer(backend), x)
181183
end
182184
return tg

DifferentiationInterface/src/second_order/second_derivative.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,13 @@ function value_derivative_and_second_derivative! end
4848

4949
## Preparation
5050

51-
struct InnerDerivative{F,B}
52-
f::F
53-
backend::B
54-
end
55-
56-
function (id::InnerDerivative)(x)
57-
@compat (; f, backend) = id
58-
return derivative(f, backend, x)
59-
end
60-
61-
struct ClosureSecondDerivativeExtras{ID<:InnerDerivative,E<:DerivativeExtras} <:
62-
SecondDerivativeExtras
51+
struct ClosureSecondDerivativeExtras{ID,E<:DerivativeExtras} <: SecondDerivativeExtras
6352
inner_derivative::ID
6453
outer_derivative_extras::E
6554
end
6655

6756
function prepare_second_derivative(f::F, backend::AbstractADType, x) where {F}
68-
inner_derivative = InnerDerivative(f, nested(maybe_inner(backend)))
57+
inner_derivative(x) = derivative(f, nested(maybe_inner(backend)), x)
6958
outer_derivative_extras = prepare_derivative(inner_derivative, maybe_outer(backend), x)
7059
return ClosureSecondDerivativeExtras(inner_derivative, outer_derivative_extras)
7160
end

0 commit comments

Comments
 (0)