-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathmisc.jl
More file actions
32 lines (27 loc) · 1.13 KB
/
misc.jl
File metadata and controls
32 lines (27 loc) · 1.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
## Pushforward
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
function DI.overloaded_input(
::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F,B}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
function DI.overloaded_input(
::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F,B}
T = tag_type(f!, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
## Derivative
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
end
DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals)
## Gradient
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
## Jacobian
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])