|
21 | 21 |
|
22 | 22 | ## Derivative |
23 | 23 |
|
24 | | -DI.prepare_derivative(f, ::AnyAutoFiniteDiff, x) = NoDerivativeExtras() |
| 24 | +struct FiniteDiffAllocatingDerivativeExtras{C} |
| 25 | + cache::C |
| 26 | +end |
| 27 | + |
| 28 | +function DI.prepare_derivative(f, backend::AnyAutoFiniteDiff, x) |
| 29 | + y = f(x) |
| 30 | + cache = if y isa Number |
| 31 | + nothing |
| 32 | + elseif y isa AbstractArray |
| 33 | + df = similar(y) |
| 34 | + cache = GradientCache(df, x, fdtype(backend), eltype(y), FUNCTION_NOT_INPLACE) |
| 35 | + end |
| 36 | + return FiniteDiffAllocatingDerivativeExtras(cache) |
| 37 | +end |
| 38 | + |
| 39 | +### Scalar to scalar |
25 | 40 |
|
26 | | -function DI.derivative(f, backend::AnyAutoFiniteDiff, x, ::NoDerivativeExtras) |
| 41 | +function DI.derivative( |
| 42 | + f, backend::AnyAutoFiniteDiff, x, ::FiniteDiffAllocatingDerivativeExtras{Nothing} |
| 43 | +) |
27 | 44 | return finite_difference_derivative(f, x, fdtype(backend)) |
28 | 45 | end |
29 | 46 |
|
30 | | -function DI.value_and_derivative(f, backend::AnyAutoFiniteDiff, x, ::NoDerivativeExtras) |
| 47 | +function DI.derivative!!( |
| 48 | + f, |
| 49 | + _der, |
| 50 | + backend::AnyAutoFiniteDiff, |
| 51 | + x, |
| 52 | + extras::FiniteDiffAllocatingDerivativeExtras{Nothing}, |
| 53 | +) |
| 54 | + return DI.derivative(f, backend, x, extras) |
| 55 | +end |
| 56 | + |
| 57 | +function DI.value_and_derivative( |
| 58 | + f, backend::AnyAutoFiniteDiff, x, ::FiniteDiffAllocatingDerivativeExtras{Nothing} |
| 59 | +) |
31 | 60 | y = f(x) |
32 | 61 | return y, finite_difference_derivative(f, x, fdtype(backend), eltype(y), y) |
33 | 62 | end |
34 | 63 |
|
35 | | -function DI.derivative!!(f, der, backend::AnyAutoFiniteDiff, x, extras::NoDerivativeExtras) |
36 | | - return DI.derivative(f, backend, x, extras) |
| 64 | +function DI.value_and_derivative!!( |
| 65 | + f, |
| 66 | + _der, |
| 67 | + backend::AnyAutoFiniteDiff, |
| 68 | + x, |
| 69 | + extras::FiniteDiffAllocatingDerivativeExtras{Nothing}, |
| 70 | +) |
| 71 | + return DI.value_and_derivative(f, backend, x, extras) |
| 72 | +end |
| 73 | + |
| 74 | +### Scalar to array |
| 75 | + |
| 76 | +function DI.derivative( |
| 77 | + f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache} |
| 78 | +) |
| 79 | + return finite_difference_gradient(f, x, extras.cache) |
| 80 | +end |
| 81 | + |
| 82 | +function DI.derivative!!( |
| 83 | + f, |
| 84 | + der, |
| 85 | + ::AnyAutoFiniteDiff, |
| 86 | + x, |
| 87 | + extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache}, |
| 88 | +) |
| 89 | + return finite_difference_gradient!(der, f, x, extras.cache) |
| 90 | +end |
| 91 | + |
| 92 | +function DI.value_and_derivative( |
| 93 | + f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache} |
| 94 | +) |
| 95 | + y = f(x) |
| 96 | + return y, finite_difference_gradient(f, x, extras.cache) |
37 | 97 | end |
38 | 98 |
|
39 | 99 | function DI.value_and_derivative!!( |
40 | | - f, der, backend::AnyAutoFiniteDiff, x, extras::NoDerivativeExtras |
| 100 | + f, |
| 101 | + der, |
| 102 | + ::AnyAutoFiniteDiff, |
| 103 | + x, |
| 104 | + extras::FiniteDiffAllocatingDerivativeExtras{<:GradientCache}, |
41 | 105 | ) |
42 | | - return DI.value_and_derivative(f, backend, x, extras) |
| 106 | + return f(x), finite_difference_gradient!(der, f, x, extras.cache) |
43 | 107 | end |
44 | 108 |
|
45 | 109 | ## Gradient |
46 | 110 |
|
47 | | -DI.prepare_gradient(f, ::AnyAutoFiniteDiff, x) = NoGradientExtras() |
| 111 | +struct FiniteDiffGradientExtras{C} |
| 112 | + cache::C |
| 113 | +end |
48 | 114 |
|
49 | | -function DI.gradient(f, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras) |
50 | | - return finite_difference_gradient(f, x, fdtype(backend)) |
| 115 | +function DI.prepare_gradient(f, backend::AnyAutoFiniteDiff, x) |
| 116 | + y = f(x) |
| 117 | + df = zero(y) .* x |
| 118 | + cache = GradientCache(df, x, fdtype(backend)) |
| 119 | + return FiniteDiffGradientExtras(cache) |
| 120 | +end |
| 121 | + |
| 122 | +function DI.gradient( |
| 123 | + f, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras |
| 124 | +) |
| 125 | + return finite_difference_gradient(f, x, extras.cache) |
51 | 126 | end |
52 | 127 |
|
53 | 128 | function DI.value_and_gradient( |
54 | | - f, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras |
| 129 | + f, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras |
55 | 130 | ) |
56 | | - y = f(x) |
57 | | - return y, finite_difference_gradient(f, x, fdtype(backend), typeof(y), y) |
| 131 | + return f(x), finite_difference_gradient(f, x, extras.cache) |
58 | 132 | end |
59 | 133 |
|
60 | 134 | function DI.gradient!!( |
61 | | - f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras |
| 135 | + f, grad, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras |
62 | 136 | ) |
63 | | - return finite_difference_gradient!(grad, f, x, fdtype(backend)) |
| 137 | + return finite_difference_gradient!(grad, f, x, extras.cache) |
64 | 138 | end |
65 | 139 |
|
66 | 140 | function DI.value_and_gradient!!( |
67 | | - f, grad, backend::AnyAutoFiniteDiff, x::AbstractArray, ::NoGradientExtras |
| 141 | + f, grad, ::AnyAutoFiniteDiff, x::AbstractArray, extras::FiniteDiffGradientExtras |
68 | 142 | ) |
69 | | - y = f(x) |
70 | | - return y, finite_difference_gradient!(grad, f, x, fdtype(backend), typeof(y), y) |
| 143 | + return f(x), finite_difference_gradient!(grad, f, x, extras.cache) |
71 | 144 | end |
72 | 145 |
|
73 | 146 | ## Jacobian |
74 | 147 |
|
75 | | -DI.prepare_jacobian(f, ::AnyAutoFiniteDiff, x) = NoJacobianExtras() |
| 148 | +struct FiniteDiffAllocatingJacobianExtras{C} |
| 149 | + cache::C |
| 150 | +end |
| 151 | + |
| 152 | +function DI.prepare_jacobian(f, backend::AnyAutoFiniteDiff, x) |
| 153 | + y = f(x) |
| 154 | + x1 = similar(x) |
| 155 | + fx = similar(y) |
| 156 | + fx1 = similar(y) |
| 157 | + cache = JacobianCache(x1, fx, fx1, fdjtype(backend)) |
| 158 | + return FiniteDiffAllocatingJacobianExtras(cache) |
| 159 | +end |
76 | 160 |
|
77 | | -function DI.jacobian(f, backend::AnyAutoFiniteDiff, x, ::NoJacobianExtras) |
78 | | - return finite_difference_jacobian(f, x, fdjtype(backend)) |
| 161 | +function DI.jacobian(f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras) |
| 162 | + return finite_difference_jacobian(f, x, extras.cache) |
79 | 163 | end |
80 | 164 |
|
81 | | -function DI.value_and_jacobian(f, backend::AnyAutoFiniteDiff, x, ::NoJacobianExtras) |
| 165 | +function DI.value_and_jacobian( |
| 166 | + f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras |
| 167 | +) |
82 | 168 | y = f(x) |
83 | | - return y, finite_difference_jacobian(f, x, fdjtype(backend), eltype(y), y) |
| 169 | + return y, finite_difference_jacobian(f, x, extras.cache, y) |
84 | 170 | end |
85 | 171 |
|
86 | | -function DI.jacobian!!(f, jac, backend::AnyAutoFiniteDiff, x, extras::NoJacobianExtras) |
87 | | - return DI.jacobian(f, backend, x, extras) |
| 172 | +function DI.jacobian!!( |
| 173 | + f, jac, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras |
| 174 | +) |
| 175 | + return finite_difference_jacobian(f, x, extras.cache; jac_prototype=jac) |
88 | 176 | end |
89 | 177 |
|
90 | 178 | function DI.value_and_jacobian!!( |
91 | | - f, jac, backend::AnyAutoFiniteDiff, x, extras::NoJacobianExtras |
| 179 | + f, jac, ::AnyAutoFiniteDiff, x, extras::FiniteDiffAllocatingJacobianExtras |
92 | 180 | ) |
93 | | - return DI.value_and_jacobian(f, backend, x, extras) |
| 181 | + y = f(x) |
| 182 | + return y, finite_difference_jacobian(f, x, extras.cache, y; jac_prototype=jac) |
94 | 183 | end |
95 | 184 |
|
96 | 185 | ## Hessian |
97 | 186 |
|
98 | | -DI.prepare_hessian(f, ::AnyAutoFiniteDiff, x) = NoHessianExtras() |
| 187 | +struct FiniteDiffHessianExtras{C} |
| 188 | + cache::C |
| 189 | +end |
| 190 | + |
| 191 | +function DI.prepare_hessian(f, backend::AnyAutoFiniteDiff, x) |
| 192 | + cache = HessianCache(x, fdhtype(backend)) |
| 193 | + return FiniteDiffHessianExtras(cache) |
| 194 | +end |
99 | 195 |
|
100 | | -function DI.hessian(f, backend::AnyAutoFiniteDiff, x, ::NoHessianExtras) |
101 | | - return finite_difference_hessian(f, x, fdhtype(backend)) |
| 196 | +function DI.hessian(f, ::AnyAutoFiniteDiff, x, extras::FiniteDiffHessianExtras) |
| 197 | + return finite_difference_hessian(f, x, extras.cache) |
102 | 198 | end |
103 | 199 |
|
104 | | -function DI.hessian!!(f, hess, backend::AnyAutoFiniteDiff, x, ::NoHessianExtras) |
105 | | - return finite_difference_hessian!(hess, f, x, fdhtype(backend)) |
| 200 | +function DI.hessian!!(f, hess, ::AnyAutoFiniteDiff, x, extras::FiniteDiffHessianExtras) |
| 201 | + return finite_difference_hessian!(hess, f, x, extras.cache) |
106 | 202 | end |
0 commit comments