@@ -15,22 +15,27 @@ function DI.prepare_pushforward(f!, y, ::AutoForwardDiff, x, dx)
1515end
1616
1717function compute_ydual_twoarg (
18- f!, y, x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
18+ :: Type{T} , f!, y, x:: Number , dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
19+ ) where {T}
20+ (; ydual_tmp) = extras
21+ xdual_tmp = make_dual (T, x, dx)
22+ f! (ydual_tmp, xdual_tmp)
23+ return ydual_tmp
24+ end
25+
26+ function compute_ydual_twoarg (
27+ :: Type{T} , f!, y, x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
1928) where {T}
2029 (; xdual_tmp, ydual_tmp) = extras
21- xdual_tmp = if x isa Number
22- make_dual (T, x, dx)
23- else
24- make_dual! (T, xdual_tmp, x, dx)
25- end
30+ make_dual! (T, xdual_tmp, x, dx)
2631 f! (ydual_tmp, xdual_tmp)
2732 return ydual_tmp
2833end
2934
3035function DI. value_and_pushforward (
3136 f!, y, :: AutoForwardDiff , x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
3237) where {T}
33- ydual_tmp = compute_ydual_twoarg (f!, y, x, dx, extras)
38+ ydual_tmp = compute_ydual_twoarg (T, f!, y, x, dx, extras)
3439 myvalue! (T, y, ydual_tmp)
3540 dy = myderivative (T, ydual_tmp)
3641 return y, dy
3944function DI. pushforward (
4045 f!, y, :: AutoForwardDiff , x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
4146) where {T}
42- ydual_tmp = compute_ydual_twoarg (f!, y, x, dx, extras)
47+ ydual_tmp = compute_ydual_twoarg (T, f!, y, x, dx, extras)
4348 dy = myderivative (T, ydual_tmp)
4449 return dy
4550end
4651
4752function DI. value_and_pushforward! (
4853 f!, y, dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
4954) where {T}
50- ydual_tmp = compute_ydual_twoarg (f!, y, x, dx, extras)
55+ ydual_tmp = compute_ydual_twoarg (T, f!, y, x, dx, extras)
5156 myvalue! (T, y, ydual_tmp)
5257 myderivative! (T, dy, ydual_tmp)
5358 return y, dy
5661function DI. pushforward! (
5762 f!, y, dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffTwoArgPushforwardExtras{T}
5863) where {T}
59- ydual_tmp = compute_ydual_twoarg (f!, y, x, dx, extras)
64+ ydual_tmp = compute_ydual_twoarg (T, f!, y, x, dx, extras)
6065 myderivative! (T, dy, ydual_tmp)
6166 return dy
6267end
0 commit comments