|
1 | 1 | choose_chunk(::AutoForwardDiff{nothing}, x) = Chunk(x) |
2 | 2 | choose_chunk(::AutoForwardDiff{C}, x) where {C} = Chunk{C}() |
3 | 3 |
|
4 | | -tag_type(::F, x::Number) where {F} = Tag{F,typeof(x)} |
5 | | -tag_type(::F, x::AbstractArray) where {F} = Tag{F,eltype(x)} |
| 4 | +tag_type(f, ::AutoForwardDiff{C,T}, x) where {C,T} = T |
| 5 | +tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)} |
6 | 6 |
|
7 | 7 | make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx) |
8 | | -make_dual(::Type{T}, x::AbstractArray, dx) where {T} = Dual{T}.(x, dx) |
| 8 | +make_dual(::Type{T}, x, dx) where {T} = Dual{T}.(x, dx) # TODO: map causes Enzyme to fail |
9 | 9 |
|
10 | | -function make_dual!(::Type{T}, xdual, x::AbstractArray, dx) where {T} |
11 | | - for i in eachindex(xdual, x, dx) |
12 | | - xdual[i] = Dual{T}(x[i], dx[i]) |
13 | | - end |
14 | | - return nothing |
15 | | -end |
| 10 | +make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx) |
16 | 11 |
|
17 | 12 | myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual) |
18 | | -myvalue(::Type{T}, ydual::AbstractArray) where {T} = value.(T, ydual) |
| 13 | +myvalue(::Type{T}, ydual) where {T} = map(Fix1(value, T), ydual) |
19 | 14 |
|
20 | | -function myvalue!(::Type{T}, y::AbstractArray, ydual) where {T} |
21 | | - for i in eachindex(y, ydual) |
22 | | - y[i] = value(T, ydual[i]) |
23 | | - end |
24 | | - return nothing |
25 | | -end |
| 15 | +myvalue!(::Type{T}, y, ydual) where {T} = map!(Fix1(value, T), y, ydual) |
26 | 16 |
|
27 | 17 | myderivative(::Type{T}, ydual::Number) where {T} = extract_derivative(T, ydual) |
28 | | -myderivative(::Type{T}, ydual::AbstractArray) where {T} = extract_derivative(T, ydual) |
29 | | - |
30 | | -function myderivative!(::Type{T}, dy, ydual::AbstractArray) where {T} |
31 | | - extract_derivative!(T, dy, ydual) |
32 | | - return nothing |
33 | | -end |
| 18 | +myderivative(::Type{T}, ydual) where {T} = map(Fix1(extract_derivative, T), ydual) |
34 | 19 |
|
35 | | -function myvalueandderivative!(::Type{T}, y, dy, ydual::AbstractArray) where {T} |
36 | | - for i in eachindex(y, dy, ydual) |
37 | | - y[i] = value(T, ydual[i]) |
38 | | - dy[i] = extract_derivative(T, ydual[i]) |
39 | | - end |
40 | | - return nothing |
| 20 | +function myderivative!(::Type{T}, dy, ydual) where {T} |
| 21 | + return map!(Fix1(extract_derivative, T), dy, ydual) |
41 | 22 | end |
0 commit comments