@@ -87,56 +87,68 @@ for op in (:pushforward, :pullback, :hvp)
8787 HVPExtras
8888 end
8989 # 1-arg
90- @eval function $prep_op_same_point (f:: F , backend:: AbstractADType , x, seed) where {F}
90+ @eval function $prep_op_same_point (
91+ f:: F , backend:: AbstractADType , x, seed:: Tangents
92+ ) where {F}
9193 ex = $ prep_op (f, backend, x, seed)
9294 return $ prep_op_same_point (f, ex, backend, x, seed)
9395 end
9496 @eval function $prep_op_same_point (
95- f:: F , ex:: $E , backend:: AbstractADType , x, seed
97+ f:: F , ex:: $E , backend:: AbstractADType , x, seed:: Tangents
9698 ) where {F}
9799 return ex
98100 end
99- @eval function $op (f:: F , backend:: AbstractADType , x, seed) where {F}
101+ @eval function $op (f:: F , backend:: AbstractADType , x, seed:: Tangents ) where {F}
100102 ex = $ prep_op (f, backend, x, seed)
101103 return $ op (f, ex, backend, x, seed)
102104 end
103- @eval function $op! (f:: F , result, backend:: AbstractADType , x, seed) where {F}
105+ @eval function $op! (
106+ f:: F , result:: Tangents , backend:: AbstractADType , x, seed:: Tangents
107+ ) where {F}
104108 ex = $ prep_op (f, backend, x, seed)
105109 return $ op! (f, result, ex, backend, x, seed)
106110 end
107111 op == :hvp && continue
108- @eval function $val_and_op (f:: F , backend:: AbstractADType , x, seed) where {F}
112+ @eval function $val_and_op (f:: F , backend:: AbstractADType , x, seed:: Tangents ) where {F}
109113 ex = $ prep_op (f, backend, x, seed)
110114 return $ val_and_op (f, ex, backend, x, seed)
111115 end
112- @eval function $val_and_op! (f:: F , result, backend:: AbstractADType , x, seed) where {F}
116+ @eval function $val_and_op! (
117+ f:: F , result:: Tangents , backend:: AbstractADType , x, seed:: Tangents
118+ ) where {F}
113119 ex = $ prep_op (f, backend, x, seed)
114120 return $ val_and_op! (f, result, ex, backend, x, seed)
115121 end
116122 # 2-arg
117- @eval function $prep_op_same_point (f!:: F , y, backend:: AbstractADType , x, seed) where {F}
123+ @eval function $prep_op_same_point (
124+ f!:: F , y, backend:: AbstractADType , x, seed:: Tangents
125+ ) where {F}
118126 ex = $ prep_op (f!, y, backend, x, seed)
119127 return $ prep_op_same_point (f!, y, ex, backend, x, seed)
120128 end
121129 @eval function $prep_op_same_point (
122- f!:: F , y, ex:: $E , backend:: AbstractADType , x, seed
130+ f!:: F , y, ex:: $E , backend:: AbstractADType , x, seed:: Tangents
123131 ) where {F}
124132 return ex
125133 end
126- @eval function $op (f!:: F , y, backend:: AbstractADType , x, seed) where {F}
134+ @eval function $op (f!:: F , y, backend:: AbstractADType , x, seed:: Tangents ) where {F}
127135 ex = $ prep_op (f!, y, backend, x, seed)
128136 return $ op (f!, y, ex, backend, x, seed)
129137 end
130- @eval function $op! (f!:: F , y, result, backend:: AbstractADType , x, seed) where {F}
138+ @eval function $op! (
139+ f!:: F , y, result:: Tangents , backend:: AbstractADType , x, seed:: Tangents
140+ ) where {F}
131141 ex = $ prep_op (f!, y, backend, x, seed)
132142 return $ op! (f!, y, result, ex, backend, x, seed)
133143 end
134- @eval function $val_and_op (f!:: F , y, backend:: AbstractADType , x, seed) where {F}
144+ @eval function $val_and_op (
145+ f!:: F , y, backend:: AbstractADType , x, seed:: Tangents
146+ ) where {F}
135147 ex = $ prep_op (f!, y, backend, x, seed)
136148 return $ val_and_op (f!, y, ex, backend, x, seed)
137149 end
138150 @eval function $val_and_op! (
139- f!:: F , y, result, backend:: AbstractADType , x, seed
151+ f!:: F , y, result:: Tangents , backend:: AbstractADType , x, seed:: Tangents
140152 ) where {F}
141153 ex = $ prep_op (f!, y, backend, x, seed)
142154 return $ val_and_op! (f!, y, result, ex, backend, x, seed)
0 commit comments