Skip to content

Commit f27415c

Browse files
Use new ReverseDiff compile type parameter (#351)
1 parent e79bc20 commit f27415c

3 files changed

Lines changed: 14 additions & 10 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
4444
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
4545

4646
[compat]
47-
ADTypes = "1.2.0"
47+
ADTypes = "1.5.0"
4848
ChainRulesCore = "1.23.0"
4949
Compat = "3,4"
5050
Diffractor = "=0.2.6"

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ struct ReverseDiffGradientExtras{T} <: GradientExtras
4242
tape::T
4343
end
4444

45-
function DI.prepare_gradient(f, backend::AutoReverseDiff, x::AbstractArray)
45+
function DI.prepare_gradient(
46+
f, ::AutoReverseDiff{Compile}, x::AbstractArray
47+
) where {Compile}
4648
tape = GradientTape(f, x)
47-
if backend.compile
49+
if Compile
4850
tape = compile(tape)
4951
end
5052
return ReverseDiffGradientExtras(tape)
@@ -91,9 +93,11 @@ struct ReverseDiffOneArgJacobianExtras{T} <: JacobianExtras
9193
tape::T
9294
end
9395

94-
function DI.prepare_jacobian(f, backend::AutoReverseDiff, x::AbstractArray)
96+
function DI.prepare_jacobian(
97+
f, ::AutoReverseDiff{Compile}, x::AbstractArray
98+
) where {Compile}
9599
tape = JacobianTape(f, x)
96-
if backend.compile
100+
if Compile
97101
tape = compile(tape)
98102
end
99103
return ReverseDiffOneArgJacobianExtras(tape)
@@ -140,9 +144,9 @@ struct ReverseDiffHessianExtras{T} <: HessianExtras
140144
tape::T
141145
end
142146

143-
function DI.prepare_hessian(f, backend::AutoReverseDiff, x::AbstractArray)
147+
function DI.prepare_hessian(f, ::AutoReverseDiff{Compile}, x::AbstractArray) where {Compile}
144148
tape = HessianTape(f, x)
145-
if backend.compile
149+
if Compile
146150
tape = compile(tape)
147151
end
148152
return ReverseDiffHessianExtras(tape)

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ struct ReverseDiffTwoArgJacobianExtras{T} <: JacobianExtras
7272
end
7373

7474
function DI.prepare_jacobian(
75-
f!, y::AbstractArray, backend::AutoReverseDiff, x::AbstractArray
76-
)
75+
f!, y::AbstractArray, ::AutoReverseDiff{Compile}, x::AbstractArray
76+
) where {Compile}
7777
tape = JacobianTape(f!, y, x)
78-
if backend.compile
78+
if Compile
7979
tape = compile(tape)
8080
end
8181
return ReverseDiffTwoArgJacobianExtras(tape)

0 commit comments

Comments
 (0)