Skip to content

Commit 1155218

Browse files
authored
Better type annotations in fallbacks (#458)
1 parent ccf5247 commit 1155218

1 file changed

Lines changed: 24 additions & 12 deletions

File tree

DifferentiationInterface/src/fallbacks/no_extras.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,56 +87,68 @@ for op in (:pushforward, :pullback, :hvp)
8787
HVPExtras
8888
end
8989
# 1-arg
90-
@eval function $prep_op_same_point(f::F, backend::AbstractADType, x, seed) where {F}
90+
@eval function $prep_op_same_point(
91+
f::F, backend::AbstractADType, x, seed::Tangents
92+
) where {F}
9193
ex = $prep_op(f, backend, x, seed)
9294
return $prep_op_same_point(f, ex, backend, x, seed)
9395
end
9496
@eval function $prep_op_same_point(
95-
f::F, ex::$E, backend::AbstractADType, x, seed
97+
f::F, ex::$E, backend::AbstractADType, x, seed::Tangents
9698
) where {F}
9799
return ex
98100
end
99-
@eval function $op(f::F, backend::AbstractADType, x, seed) where {F}
101+
@eval function $op(f::F, backend::AbstractADType, x, seed::Tangents) where {F}
100102
ex = $prep_op(f, backend, x, seed)
101103
return $op(f, ex, backend, x, seed)
102104
end
103-
@eval function $op!(f::F, result, backend::AbstractADType, x, seed) where {F}
105+
@eval function $op!(
106+
f::F, result::Tangents, backend::AbstractADType, x, seed::Tangents
107+
) where {F}
104108
ex = $prep_op(f, backend, x, seed)
105109
return $op!(f, result, ex, backend, x, seed)
106110
end
107111
op == :hvp && continue
108-
@eval function $val_and_op(f::F, backend::AbstractADType, x, seed) where {F}
112+
@eval function $val_and_op(f::F, backend::AbstractADType, x, seed::Tangents) where {F}
109113
ex = $prep_op(f, backend, x, seed)
110114
return $val_and_op(f, ex, backend, x, seed)
111115
end
112-
@eval function $val_and_op!(f::F, result, backend::AbstractADType, x, seed) where {F}
116+
@eval function $val_and_op!(
117+
f::F, result::Tangents, backend::AbstractADType, x, seed::Tangents
118+
) where {F}
113119
ex = $prep_op(f, backend, x, seed)
114120
return $val_and_op!(f, result, ex, backend, x, seed)
115121
end
116122
# 2-arg
117-
@eval function $prep_op_same_point(f!::F, y, backend::AbstractADType, x, seed) where {F}
123+
@eval function $prep_op_same_point(
124+
f!::F, y, backend::AbstractADType, x, seed::Tangents
125+
) where {F}
118126
ex = $prep_op(f!, y, backend, x, seed)
119127
return $prep_op_same_point(f!, y, ex, backend, x, seed)
120128
end
121129
@eval function $prep_op_same_point(
122-
f!::F, y, ex::$E, backend::AbstractADType, x, seed
130+
f!::F, y, ex::$E, backend::AbstractADType, x, seed::Tangents
123131
) where {F}
124132
return ex
125133
end
126-
@eval function $op(f!::F, y, backend::AbstractADType, x, seed) where {F}
134+
@eval function $op(f!::F, y, backend::AbstractADType, x, seed::Tangents) where {F}
127135
ex = $prep_op(f!, y, backend, x, seed)
128136
return $op(f!, y, ex, backend, x, seed)
129137
end
130-
@eval function $op!(f!::F, y, result, backend::AbstractADType, x, seed) where {F}
138+
@eval function $op!(
139+
f!::F, y, result::Tangents, backend::AbstractADType, x, seed::Tangents
140+
) where {F}
131141
ex = $prep_op(f!, y, backend, x, seed)
132142
return $op!(f!, y, result, ex, backend, x, seed)
133143
end
134-
@eval function $val_and_op(f!::F, y, backend::AbstractADType, x, seed) where {F}
144+
@eval function $val_and_op(
145+
f!::F, y, backend::AbstractADType, x, seed::Tangents
146+
) where {F}
135147
ex = $prep_op(f!, y, backend, x, seed)
136148
return $val_and_op(f!, y, ex, backend, x, seed)
137149
end
138150
@eval function $val_and_op!(
139-
f!::F, y, result, backend::AbstractADType, x, seed
151+
f!::F, y, result::Tangents, backend::AbstractADType, x, seed::Tangents
140152
) where {F}
141153
ex = $prep_op(f!, y, backend, x, seed)
142154
return $val_and_op!(f!, y, result, ex, backend, x, seed)

0 commit comments

Comments
 (0)