Skip to content

Commit b09faab

Browse files
committed
Fix PolyesterForwardDiff
1 parent dfd5379 commit b09faab

16 files changed

Lines changed: 212 additions & 121 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3939
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4040
DifferentiationInterfaceGTPSAExt = "GTPSA"
4141
DifferentiationInterfaceMooncakeExt = "Mooncake"
42-
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
42+
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
4343
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4444
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4545
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -115,7 +115,6 @@ test = [
115115
"ComponentArrays",
116116
"DataFrames",
117117
"ExplicitImports",
118-
"ForwardDiff",
119118
"JET",
120119
"JLArrays",
121120
"JuliaFormatter",

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f
4848
force_annotation(f::F) where {F} = Const(f)
4949

5050
@inline function _translate(
51-
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
51+
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant
5252
) where {B}
5353
return Const(DI.unwrap(c))
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
57+
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
5858
) where {B}
5959
if B == 1
6060
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,16 @@ DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.x
33
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
44

55
function DI.overloaded_input(
6-
::typeof(DI.pushforward),
7-
f::F,
8-
backend::AutoForwardDiff,
9-
x,
10-
tx::NTuple{B},
11-
contexts::Vararg{DI.Context,C},
12-
) where {F,B,C}
6+
::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B}
7+
) where {F,B}
138
T = tag_type(f, backend, x)
149
xdual = make_dual(T, x, tx)
1510
return xdual
1611
end
1712

1813
function DI.overloaded_input(
19-
::typeof(DI.pushforward),
20-
f!::F,
21-
y,
22-
backend::AutoForwardDiff,
23-
x,
24-
tx::NTuple{B},
25-
contexts::Vararg{DI.Context,C},
26-
) where {F,B,C}
14+
::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}
15+
) where {F,B}
2716
T = tag_type(f!, backend, x)
2817
xdual = make_dual(T, x, tx)
2918
return xdual

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
8282
return ty
8383
end
8484

85-
_translate(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = DI.unwrap(c)
85+
function _translate(
86+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
87+
) where {D<:Dual}
88+
return DI.unwrap(c)
89+
end
8690
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
8791
c0 = DI.unwrap(c)
8892
return similar(c0, D)
@@ -95,7 +99,11 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
9599
return new_contexts
96100
end
97101

98-
_translate_toprep(::Type{D}, c::DI.GeneralizedConstant) where {D<:Dual} = nothing
102+
function _translate_toprep(
103+
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
104+
) where {D<:Dual}
105+
return nothing
106+
end
99107
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
100108
c0 = DI.unwrap(c)
101109
return similar(c0, D)
@@ -108,7 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D
108116
return new_contexts
109117
end
110118

111-
_translate_prepared(c::DI.GeneralizedConstant, _pc) = DI.unwrap(c)
119+
_translate_prepared(c::Union{DI.GeneralizedConstant,DI.PrepContext}, _pc) = DI.unwrap(c)
112120
_translate_prepared(_c::DI.Cache, pc) = pc
113121

114122
function translate_prepared(

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,22 @@ using ADTypes: AutoForwardDiff, AutoPolyesterForwardDiff
44
import DifferentiationInterface as DI
55
using LinearAlgebra: mul!
66
using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian!
7-
using PolyesterForwardDiff.ForwardDiff: Chunk
8-
using PolyesterForwardDiff.ForwardDiff.DiffResults: DiffResults
7+
using ForwardDiff: Chunk
8+
using DiffResults: DiffResults
9+
10+
const FDExt = Base.get_extension(DI, :DifferentiationInterfaceForwardDiffExt)
11+
@assert !isnothing(FDExt)
912

1013
function single_threaded(backend::AutoPolyesterForwardDiff{chunksize,T}) where {chunksize,T}
1114
return AutoForwardDiff(; chunksize, tag=backend.tag)
1215
end
1316

1417
DI.check_available(::AutoPolyesterForwardDiff) = true
18+
DI.inner_preparation_behavior(::AutoPolyesterForwardDiff) = DI.PrepareInnerOverload()
1519

16-
function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray)
17-
return DI.pick_batchsize(single_threaded(backend), x)
18-
end
19-
20-
function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, N::Integer)
21-
return DI.pick_batchsize(single_threaded(backend), N)
22-
end
23-
24-
function DI.threshold_batchsize(
25-
backend::AutoPolyesterForwardDiff{chunksize1}, chunksize2::Integer
26-
) where {chunksize1}
27-
chunksize = isnothing(chunksize1) ? nothing : min(chunksize1, chunksize2)
28-
return AutoPolyesterForwardDiff(; chunksize, tag=backend.tag)
29-
end
30-
20+
include("utils.jl")
3121
include("onearg.jl")
3222
include("twoarg.jl")
23+
include("misc.jl")
3324

3425
end # module
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function DI.overloaded_input(
2+
::typeof(DI.pushforward), f::F, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B}
3+
) where {F,B}
4+
return DI.overloaded_input(DI.pushforward, f, single_threaded(backend), x, tx)
5+
end
6+
7+
function DI.overloaded_input(
8+
::typeof(DI.pushforward), f!::F, y, backend::AutoPolyesterForwardDiff, x, tx::NTuple{B}
9+
) where {F,B}
10+
return DI.overloaded_input(DI.pushforward, f!, y, single_threaded(backend), x, tx)
11+
end

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -107,45 +107,61 @@ end
107107

108108
## Gradient
109109

110-
struct PolyesterForwardDiffGradientPrep{chunksize} <: DI.GradientPrep
110+
struct PolyesterForwardDiffGradientPrep{chunksize,P} <: DI.GradientPrep
111111
chunk::Chunk{chunksize}
112+
single_threaded_prep::P
112113
end
113114

114115
function DI.prepare_gradient(
115-
f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}
116+
f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}
116117
) where {chunksize,C}
117118
if isnothing(chunksize)
118119
chunk = Chunk(x)
119120
else
120121
chunk = Chunk{chunksize}()
121122
end
122-
return PolyesterForwardDiffGradientPrep(chunk)
123+
single_threaded_prep = DI.prepare_gradient(f, single_threaded(backend), x, contexts...)
124+
return PolyesterForwardDiffGradientPrep(chunk, single_threaded_prep)
123125
end
124126

125127
function DI.value_and_gradient!(
126128
f,
127129
grad,
128130
prep::PolyesterForwardDiffGradientPrep,
129-
::AutoPolyesterForwardDiff,
131+
backend::AutoPolyesterForwardDiff,
130132
x,
131133
contexts::Vararg{DI.Context,C},
132134
) where {C}
133-
fc = DI.with_contexts(f, contexts...)
134-
threaded_gradient!(fc, grad, x, prep.chunk)
135-
return fc(x), grad
135+
if contexts isa NTuple{C,DI.GeneralizedConstant}
136+
fc = DI.with_contexts(f, contexts...)
137+
threaded_gradient!(fc, grad, x, prep.chunk)
138+
return fc(x), grad
139+
else
140+
# TODO: optimize
141+
return DI.value_and_gradient!(
142+
f, grad, prep.single_threaded_prep, single_threaded(backend), x, contexts...
143+
)
144+
end
136145
end
137146

138147
function DI.gradient!(
139148
f,
140149
grad,
141150
prep::PolyesterForwardDiffGradientPrep,
142-
::AutoPolyesterForwardDiff,
151+
backend::AutoPolyesterForwardDiff,
143152
x,
144153
contexts::Vararg{DI.Context,C},
145154
) where {C}
146-
fc = DI.with_contexts(f, contexts...)
147-
threaded_gradient!(fc, grad, x, prep.chunk)
148-
return grad
155+
if contexts isa NTuple{C,DI.GeneralizedConstant}
156+
fc = DI.with_contexts(f, contexts...)
157+
threaded_gradient!(fc, grad, x, prep.chunk)
158+
return grad
159+
else
160+
# TODO: optimize
161+
return DI.gradient!(
162+
f, grad, prep.single_threaded_prep, single_threaded(backend), x, contexts...
163+
)
164+
end
149165
end
150166

151167
function DI.value_and_gradient(
@@ -170,43 +186,57 @@ end
170186

171187
## Jacobian
172188

173-
struct PolyesterForwardDiffOneArgJacobianPrep{chunksize} <: DI.JacobianPrep
189+
struct PolyesterForwardDiffOneArgJacobianPrep{chunksize,P} <: DI.JacobianPrep
174190
chunk::Chunk{chunksize}
191+
single_threaded_prep::P
175192
end
176193

177194
function DI.prepare_jacobian(
178-
f, ::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}
195+
f, backend::AutoPolyesterForwardDiff{chunksize}, x, contexts::Vararg{DI.Context,C}
179196
) where {chunksize,C}
180197
if isnothing(chunksize)
181198
chunk = Chunk(x)
182199
else
183200
chunk = Chunk{chunksize}()
184201
end
185-
return PolyesterForwardDiffOneArgJacobianPrep(chunk)
202+
single_threaded_prep = DI.prepare_jacobian(f, single_threaded(backend), x, contexts...)
203+
return PolyesterForwardDiffOneArgJacobianPrep(chunk, single_threaded_prep)
186204
end
187205

188206
function DI.value_and_jacobian!(
189207
f,
190208
jac,
191209
prep::PolyesterForwardDiffOneArgJacobianPrep,
192-
::AutoPolyesterForwardDiff,
210+
backend::AutoPolyesterForwardDiff,
193211
x,
194212
contexts::Vararg{DI.Context,C},
195213
) where {C}
196-
fc = DI.with_contexts(f, contexts...)
197-
return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk)
214+
if contexts isa NTuple{C,DI.GeneralizedConstant}
215+
fc = DI.with_contexts(f, contexts...)
216+
return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk)
217+
else
218+
return DI.value_and_jacobian!(
219+
f, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts...
220+
)
221+
end
198222
end
199223

200224
function DI.jacobian!(
201225
f,
202226
jac,
203227
prep::PolyesterForwardDiffOneArgJacobianPrep,
204-
::AutoPolyesterForwardDiff,
228+
backend::AutoPolyesterForwardDiff,
205229
x,
206230
contexts::Vararg{DI.Context,C},
207231
) where {C}
208-
fc = DI.with_contexts(f, contexts...)
209-
return threaded_jacobian!(fc, jac, x, prep.chunk)
232+
if contexts isa NTuple{C,DI.GeneralizedConstant}
233+
fc = DI.with_contexts(f, contexts...)
234+
return threaded_jacobian!(fc, jac, x, prep.chunk)
235+
else
236+
return DI.jacobian!(
237+
f, jac, prep.single_threaded_prep, single_threaded(backend), x, contexts...
238+
)
239+
end
210240
end
211241

212242
function DI.value_and_jacobian(
@@ -217,9 +247,8 @@ function DI.value_and_jacobian(
217247
contexts::Vararg{DI.Context,C},
218248
) where {C}
219249
y = f(x, map(DI.unwrap, contexts)...)
220-
return DI.value_and_jacobian!(
221-
f, similar(y, length(y), length(x)), prep, backend, x, contexts...
222-
)
250+
jac = similar(y, length(y), length(x))
251+
return DI.value_and_jacobian!(f, jac, prep, backend, x, contexts...)
223252
end
224253

225254
function DI.jacobian(
@@ -230,7 +259,8 @@ function DI.jacobian(
230259
contexts::Vararg{DI.Context,C},
231260
) where {C}
232261
y = f(x, map(DI.unwrap, contexts)...)
233-
return DI.jacobian!(f, similar(y, length(y), length(x)), prep, backend, x, contexts...)
262+
jac = similar(y, length(y), length(x))
263+
return DI.jacobian!(f, jac, prep, backend, x, contexts...)
234264
end
235265

236266
## Hessian
@@ -299,7 +329,7 @@ end
299329
300330
function DI.hvp(
301331
f,
302-
prep::DI.HVPPrep,
332+
prep::DI.ForwardOverAnythingHVPPrep,
303333
backend::AutoPolyesterForwardDiff,
304334
x,
305335
tx::NTuple,
@@ -313,7 +343,7 @@ end
313343
function DI.hvp!(
314344
f,
315345
tg::NTuple,
316-
prep::DI.HVPPrep,
346+
prep::DI.ForwardOverAnythingHVPPrep,
317347
backend::AutoPolyesterForwardDiff,
318348
x,
319349
tx::NTuple,
@@ -326,7 +356,7 @@ end
326356
327357
function DI.gradient_and_hvp(
328358
f,
329-
prep::DI.HVPPrep,
359+
prep::DI.ForwardOverAnythingHVPPrep,
330360
backend::AutoPolyesterForwardDiff,
331361
x,
332362
tx::NTuple,
@@ -341,7 +371,7 @@ function DI.gradient_and_hvp!(
341371
f,
342372
grad,
343373
tg::NTuple,
344-
prep::DI.HVPPrep,
374+
prep::DI.ForwardOverAnythingHVPPrep,
345375
backend::AutoPolyesterForwardDiff,
346376
x,
347377
tx::NTuple,

0 commit comments

Comments
 (0)