|
59 | 59 |
|
60 | 60 | ## Gradient |
61 | 61 |
|
62 | | -DI.prepare_gradient(f, ::AnyAutoPolyForwardDiff, x) = NoGradientExtras() |
| 62 | +function DI.prepare_gradient(f, backend::AnyAutoPolyForwardDiff, x) |
| 63 | + return DI.prepare_gradient(f, single_threaded(backend), x) |
| 64 | +end |
63 | 65 |
|
64 | 66 | function DI.value_and_gradient!!( |
65 | | - f, |
66 | | - grad::AbstractVector, |
67 | | - ::AnyAutoPolyForwardDiff{C}, |
68 | | - x::AbstractVector, |
69 | | - ::NoGradientExtras, |
| 67 | + f, grad, ::AnyAutoPolyForwardDiff{C}, x::AbstractVector, ::GradientExtras |
70 | 68 | ) where {C} |
71 | 69 | threaded_gradient!(f, grad, x, Chunk{C}()) |
72 | 70 | return f(x), grad |
73 | 71 | end |
74 | 72 |
|
75 | 73 | function DI.gradient!!( |
76 | | - f, |
77 | | - grad::AbstractVector, |
78 | | - ::AnyAutoPolyForwardDiff{C}, |
79 | | - x::AbstractVector, |
80 | | - ::NoGradientExtras, |
| 74 | + f, grad, ::AnyAutoPolyForwardDiff{C}, x::AbstractVector, ::GradientExtras |
81 | 75 | ) where {C} |
82 | 76 | threaded_gradient!(f, grad, x, Chunk{C}()) |
83 | 77 | return grad |
84 | 78 | end |
85 | 79 |
|
| 80 | +function DI.value_and_gradient!!( |
| 81 | + f, grad, backend::AnyAutoPolyForwardDiff{C}, x::AbstractArray, extras::GradientExtras |
| 82 | +) where {C} |
| 83 | + return DI.value_and_gradient!!(f, grad, single_threaded(backend), x, extras) |
| 84 | +end |
| 85 | + |
| 86 | +function DI.gradient!!( |
| 87 | + f, grad, backend::AnyAutoPolyForwardDiff{C}, x::AbstractArray, extras::GradientExtras |
| 88 | +) where {C} |
| 89 | + return DI.gradient!!(f, grad, single_threaded(backend), x, extras) |
| 90 | +end |
| 91 | + |
86 | 92 | function DI.value_and_gradient( |
87 | | - f, backend::AnyAutoPolyForwardDiff, x::AbstractVector, extras::NoGradientExtras |
| 93 | + f, backend::AnyAutoPolyForwardDiff, x::AbstractArray, extras::GradientExtras |
88 | 94 | ) |
89 | 95 | return DI.value_and_gradient!!(f, similar(x), backend, x, extras) |
90 | 96 | end |
91 | 97 |
|
92 | 98 | function DI.gradient( |
93 | | - f, backend::AnyAutoPolyForwardDiff, x::AbstractVector, extras::NoGradientExtras |
| 99 | + f, backend::AnyAutoPolyForwardDiff, x::AbstractArray, extras::GradientExtras |
94 | 100 | ) |
95 | 101 | return DI.gradient!!(f, similar(x), backend, x, extras) |
96 | 102 | end |
|
0 commit comments