forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmisc.jl
More file actions
36 lines (31 loc) · 1.23 KB
/
misc.jl
File metadata and controls
36 lines (31 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
## Pushforward
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
function DI.overloaded_input(
::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F, B}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
function DI.overloaded_input(
::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F, B}
T = tag_type(f!, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
## Derivative
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
end
function DI.overloaded_input_type(
prep::ForwardDiffTwoArgDerivativePrep{SIG, X, <:DerivativeConfig{T}}
) where {SIG, X, T}
return typeof(Dual{T}(one(X), one(X)))
end
## Gradient
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
## Jacobian
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals)
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])