|
6 | 6 |
|
7 | 7 | function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F} |
8 | 8 | T = tag_type(f, backend, x) |
9 | | - xdual_tmp = make_dual_similar(T, x) |
| 9 | + xdual_tmp = make_dual_similar(T, x, dx) |
| 10 | + return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp) |
| 11 | +end |
| 12 | + |
| 13 | +function DI.prepare_pushforward_batched( |
| 14 | + f::F, backend::AutoForwardDiff, x, dx::Batch |
| 15 | +) where {F} |
| 16 | + T = tag_type(f, backend, x) |
| 17 | + xdual_tmp = make_dual_similar(T, x, dx) |
10 | 18 | return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp) |
11 | 19 | end |
12 | 20 |
|
@@ -61,56 +69,25 @@ function DI.pushforward!( |
61 | 69 | return dy |
62 | 70 | end |
63 | 71 |
|
64 | | -## Second derivative |
65 | | - |
66 | | -function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F} |
67 | | - return NoSecondDerivativeExtras() |
68 | | -end |
69 | | - |
70 | | -function DI.second_derivative( |
71 | | - f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
72 | | -) where {F} |
73 | | - T = tag_type(f, backend, x) |
74 | | - xdual = make_dual(T, x, one(x)) |
75 | | - T2 = tag_type(f, backend, xdual) |
76 | | - ydual = f(make_dual(T2, xdual, one(xdual))) |
77 | | - return myderivative(T, myderivative(T2, ydual)) |
78 | | -end |
79 | | - |
80 | | -function DI.second_derivative!( |
81 | | - f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
82 | | -) where {F} |
83 | | - T = tag_type(f, backend, x) |
84 | | - xdual = make_dual(T, x, one(x)) |
85 | | - T2 = tag_type(f, backend, xdual) |
86 | | - ydual = f(make_dual(T2, xdual, one(xdual))) |
87 | | - return myderivative!(T, der2, myderivative(T2, ydual)) |
88 | | -end |
89 | | - |
90 | | -function DI.value_derivative_and_second_derivative( |
91 | | - f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
92 | | -) where {F} |
93 | | - T = tag_type(f, backend, x) |
94 | | - xdual = make_dual(T, x, one(x)) |
95 | | - T2 = tag_type(f, backend, xdual) |
96 | | - ydual = f(make_dual(T2, xdual, one(xdual))) |
97 | | - y = myvalue(T, myvalue(T2, ydual)) |
98 | | - der = myderivative(T, myvalue(T2, ydual)) |
99 | | - der2 = myderivative(T, myderivative(T2, ydual)) |
100 | | - return y, der, der2 |
| 72 | +function DI.pushforward_batched( |
| 73 | + f::F, ::AutoForwardDiff, x, dx::Batch{B}, extras::ForwardDiffOneArgPushforwardExtras{T} |
| 74 | +) where {F,T,B} |
| 75 | + ydual = compute_ydual_onearg(f, x, dx, extras) |
| 76 | + new_dy = mypartials(T, Val(B), ydual) |
| 77 | + return new_dy |
101 | 78 | end |
102 | 79 |
|
103 | | -function DI.value_derivative_and_second_derivative!( |
104 | | - f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
105 | | -) where {F} |
106 | | - T = tag_type(f, backend, x) |
107 | | - xdual = make_dual(T, x, one(x)) |
108 | | - T2 = tag_type(f, backend, xdual) |
109 | | - ydual = f(make_dual(T2, xdual, one(xdual))) |
110 | | - y = myvalue(T, myvalue(T2, ydual)) |
111 | | - myderivative!(T, der, myvalue(T2, ydual)) |
112 | | - myderivative!(T, der2, myderivative(T2, ydual)) |
113 | | - return y, der, der2 |
| 80 | +function DI.pushforward_batched!( |
| 81 | + f::F, |
| 82 | + dy::Batch{B}, |
| 83 | + ::AutoForwardDiff, |
| 84 | + x, |
| 85 | + dx::Batch{B}, |
| 86 | + extras::ForwardDiffOneArgPushforwardExtras{T}, |
| 87 | +) where {F,T,B} |
| 88 | + ydual = compute_ydual_onearg(f, x, dx, extras) |
| 89 | + mypartials!(T, dy, ydual) |
| 90 | + return dy |
114 | 91 | end |
115 | 92 |
|
116 | 93 | ## Gradient |
@@ -188,6 +165,58 @@ function DI.jacobian( |
188 | 165 | return jacobian(f, x, extras.config) |
189 | 166 | end |
190 | 167 |
|
| 168 | +## Second derivative |
| 169 | + |
| 170 | +function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F} |
| 171 | + return NoSecondDerivativeExtras() |
| 172 | +end |
| 173 | + |
| 174 | +function DI.second_derivative( |
| 175 | + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
| 176 | +) where {F} |
| 177 | + T = tag_type(f, backend, x) |
| 178 | + xdual = make_dual(T, x, one(x)) |
| 179 | + T2 = tag_type(f, backend, xdual) |
| 180 | + ydual = f(make_dual(T2, xdual, one(xdual))) |
| 181 | + return myderivative(T, myderivative(T2, ydual)) |
| 182 | +end |
| 183 | + |
| 184 | +function DI.second_derivative!( |
| 185 | + f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
| 186 | +) where {F} |
| 187 | + T = tag_type(f, backend, x) |
| 188 | + xdual = make_dual(T, x, one(x)) |
| 189 | + T2 = tag_type(f, backend, xdual) |
| 190 | + ydual = f(make_dual(T2, xdual, one(xdual))) |
| 191 | + return myderivative!(T, der2, myderivative(T2, ydual)) |
| 192 | +end |
| 193 | + |
| 194 | +function DI.value_derivative_and_second_derivative( |
| 195 | + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
| 196 | +) where {F} |
| 197 | + T = tag_type(f, backend, x) |
| 198 | + xdual = make_dual(T, x, one(x)) |
| 199 | + T2 = tag_type(f, backend, xdual) |
| 200 | + ydual = f(make_dual(T2, xdual, one(xdual))) |
| 201 | + y = myvalue(T, myvalue(T2, ydual)) |
| 202 | + der = myderivative(T, myvalue(T2, ydual)) |
| 203 | + der2 = myderivative(T, myderivative(T2, ydual)) |
| 204 | + return y, der, der2 |
| 205 | +end |
| 206 | + |
| 207 | +function DI.value_derivative_and_second_derivative!( |
| 208 | + f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras |
| 209 | +) where {F} |
| 210 | + T = tag_type(f, backend, x) |
| 211 | + xdual = make_dual(T, x, one(x)) |
| 212 | + T2 = tag_type(f, backend, xdual) |
| 213 | + ydual = f(make_dual(T2, xdual, one(xdual))) |
| 214 | + y = myvalue(T, myvalue(T2, ydual)) |
| 215 | + myderivative!(T, der, myvalue(T2, ydual)) |
| 216 | + myderivative!(T, der2, myderivative(T2, ydual)) |
| 217 | + return y, der, der2 |
| 218 | +end |
| 219 | + |
191 | 220 | ## Hessian |
192 | 221 |
|
193 | 222 | struct ForwardDiffHessianExtras{C1,C2,C3} <: HessianExtras |
|
0 commit comments