Skip to content

Commit e490e32

Browse files
authored
Tag depends on ForwardDiff backend object (#229)
* Tag depends on ForwardDiff backend object * Remove map in make_dual
1 parent 2fe6c41 commit e490e32

4 files changed

Lines changed: 14 additions & 32 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DifferentiationInterfaceForwardDiffExt
22

33
using ADTypes: AbstractADType, AutoForwardDiff
4+
using Base: Fix1
45
import DifferentiationInterface as DI
56
using DifferentiationInterface:
67
DerivativeExtras,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
44
xdual_tmp::X
55
end
66

7-
function DI.prepare_pushforward(f, ::AutoForwardDiff, x, dx)
8-
T = tag_type(f, x)
7+
function DI.prepare_pushforward(f, backend::AutoForwardDiff, x, dx)
8+
T = tag_type(f, backend, x)
99
xdual_tmp = make_dual(T, x, dx)
1010
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
1111
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ struct ForwardDiffTwoArgPushforwardExtras{T,X,Y} <: PushforwardExtras
55
ydual_tmp::Y
66
end
77

8-
function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx)
9-
T = tag_type(f!, x)
8+
function DI.prepare_pushforward(f!, y, backend::AutoForwardDiff, x, dx)
9+
T = tag_type(f!, backend, x)
1010
xdual_tmp = make_dual(T, x, dx)
1111
ydual_tmp = make_dual(T, y, similar(y))
1212
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,22 @@
11
choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x)
22
choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}()
33

4-
tag_type(::F, x::Number) where {F} = Tag{F,typeof(x)}
5-
tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)}
4+
tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T
5+
tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)}
66

77
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
8-
make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx)
8+
make_dual(::Type{T}, x, dx) where {T} = Dual{T}.(x, dx) # TODO: map causes Enzyme to fail
99

10-
function make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T}
11-
for i in eachindex(xdual, x, dx)
12-
xdual[i] = Dual{T}(x[i], dx[i])
13-
end
14-
return nothing
15-
end
10+
make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx)
1611

1712
myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
18-
myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual)
13+
myvalue(::Type{T}, ydual) where {T} = map(Fix1(value, T), ydual)
1914

20-
function myvalue!(::Type{T}, y::AbstractArray, ydual) where {T}
21-
for i in eachindex(y, ydual)
22-
y[i] = value(T, ydual[i])
23-
end
24-
return nothing
25-
end
15+
myvalue!(::Type{T}, y, ydual) where {T} = map!(Fix1(value, T), y, ydual)
2616

2717
myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual)
28-
myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual)
29-
30-
function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T}
31-
extract_derivative!(T, dy, ydual)
32-
return nothing
33-
end
18+
myderivative(::Type{T}, ydual) where {T} = map(Fix1(extract_derivative, T), ydual)
3419

35-
function myvalueandderivative!(::Type{T}, y, dy, ydual::AbstractArray) where {T}
36-
for i in eachindex(y, dy, ydual)
37-
y[i] = value(T, ydual[i])
38-
dy[i] = extract_derivative(T, ydual[i])
39-
end
40-
return nothing
20+
function myderivative!(::Type{T}, dy, ydual) where {T}
21+
return map!(Fix1(extract_derivative, T), dy, ydual)
4122
end

0 commit comments

Comments
 (0)