-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathmisc.jl
More file actions
58 lines (50 loc) · 1.39 KB
/
misc.jl
File metadata and controls
58 lines (50 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
## Pushforward
function DI.overloaded_input(
::typeof(DI.pushforward),
f::F,
backend::AutoForwardDiff,
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end
#=
function DI.overloaded_input(
::typeof(DI.pushforward),
f!::F,
y,
backend::AutoForwardDiff,
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
T = tag_type(f, backend, x)
xdual = if x isa Number
make_dual(T, x, tx)
else
make_dual_similar(T, x, tx)
end
return xdual
end
=#
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
## Derivative
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
end
function DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep)
return typeof(prep.config.duals)
end
## Gradient
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
## Jacobian
function DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep)
return typeof(prep.config.duals[2])
end
function DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep)
return typeof(prep.config.duals[2])
end