forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraits.jl
More file actions
217 lines (158 loc) · 5.91 KB
/
traits.jl
File metadata and controls
217 lines (158 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
## Mutation
abstract type InPlaceBehavior end
"""
InPlaceSupported
Trait identifying backends that support in-place functions `f!(y, x)`.
"""
struct InPlaceSupported <: InPlaceBehavior end
"""
InPlaceNotSupported
Trait identifying backends that do not support in-place functions `f!(y, x)`.
"""
struct InPlaceNotSupported <: InPlaceBehavior end
"""
inplace_support(backend)
Return [`InPlaceSupported`](@ref) or [`InPlaceNotSupported`](@ref) in a statically predictable way.
"""
inplace_support(::AbstractADType) = InPlaceSupported()
inplace_support(::ADTypes.NoAutoDiff) = throw(ADTypes.NoAutoDiffSelectedError())
function inplace_support(backend::SecondOrder)
if inplace_support(inner(backend)) isa InPlaceSupported &&
inplace_support(outer(backend)) isa InPlaceSupported
return InPlaceSupported()
else
return InPlaceNotSupported()
end
end
inplace_support(backend::AutoSparse) = inplace_support(dense_ad(backend))
function inplace_support(backend::MixedMode)
if Bool(inplace_support(forward_backend(backend))) &&
Bool(inplace_support(reverse_backend(backend)))
return InPlaceSupported()
else
return InPlaceNotSupported()
end
end
## Pushforward
abstract type PushforwardPerformance end
"""
PushforwardFast
Trait identifying backends that support efficient pushforwards.
"""
struct PushforwardFast <: PushforwardPerformance end
"""
PushforwardSlow
Trait identifying backends that do not support efficient pushforwards.
"""
struct PushforwardSlow <: PushforwardPerformance end
"""
pushforward_performance(backend)
Return [`PushforwardFast`](@ref) or [`PushforwardSlow`](@ref) in a statically predictable way.
"""
pushforward_performance(backend::AbstractADType) = pushforward_performance(mode(backend))
pushforward_performance(::ForwardMode) = PushforwardFast()
pushforward_performance(::ForwardOrReverseMode) = PushforwardFast()
pushforward_performance(::ReverseMode) = PushforwardSlow()
pushforward_performance(::SymbolicMode) = PushforwardFast()
function pushforward_performance(backend::Union{AutoSparse,SecondOrder})
throw(ArgumentError("Pushforward performance not defined for $backend`."))
end
## Pullback
abstract type PullbackPerformance end
"""
PullbackFast
Trait identifying backends that support efficient pullbacks.
"""
struct PullbackFast <: PullbackPerformance end
"""
PullbackSlow
Trait identifying backends that do not support efficient pullbacks.
"""
struct PullbackSlow <: PullbackPerformance end
"""
pullback_performance(backend)
Return [`PullbackFast`](@ref) or [`PullbackSlow`](@ref) in a statically predictable way.
"""
pullback_performance(backend::AbstractADType) = pullback_performance(mode(backend))
pullback_performance(::ForwardMode) = PullbackSlow()
pullback_performance(::ForwardOrReverseMode) = PullbackFast()
pullback_performance(::ReverseMode) = PullbackFast()
pullback_performance(::SymbolicMode) = PullbackFast()
function pullback_performance(backend::Union{AutoSparse,SecondOrder})
throw(ArgumentError("Pullback performance not defined for $backend`."))
end
## HVP
abstract type HVPMode end
"""
ForwardOverReverse
Traits identifying second-order backends that compute HVPs in forward over reverse mode.
"""
struct ForwardOverReverse <: HVPMode end
"""
ReverseOverForward
Traits identifying second-order backends that compute HVPs in reverse over forward mode.
"""
struct ReverseOverForward <: HVPMode end
"""
ReverseOverReverse
Traits identifying second-order backends that compute HVPs in reverse over reverse mode.
"""
struct ReverseOverReverse <: HVPMode end
"""
ForwardOverForward
Traits identifying second-order backends that compute HVPs in forward over forward mode (inefficient).
"""
struct ForwardOverForward <: HVPMode end
const ForwardOverAnything = Union{ForwardOverForward,ForwardOverReverse}
"""
hvp_mode(backend)
Return the best combination of modes for [`hvp`](@ref) and its variants, among the following options:
- [`ForwardOverForward`](@ref)
- [`ForwardOverReverse`](@ref)
- [`ReverseOverForward`](@ref)
- [`ReverseOverReverse`](@ref)
"""
hvp_mode(backend::AbstractADType) = hvp_mode(SecondOrder(backend, backend))
function hvp_mode(ba::SecondOrder)
if Bool(pushforward_performance(outer(ba))) && Bool(pullback_performance(inner(ba)))
return ForwardOverReverse()
elseif Bool(pullback_performance(outer(ba))) && Bool(pushforward_performance(inner(ba)))
return ReverseOverForward()
elseif Bool(pullback_performance(outer(ba))) && Bool(pullback_performance(inner(ba)))
return ReverseOverReverse()
else
return ForwardOverForward()
end
end
function hvp_mode(backend::AutoSparse)
throw(ArgumentError("HVP mode not defined for $backend`."))
end
## Inner prep
abstract type InnerPreparationBehavior end
"""
PrepareInnerSimple
Trait identifying outer backends for which the inner backend in second-order autodiff should be prepared with the same input type.
"""
struct PrepareInnerSimple <: InnerPreparationBehavior end
"""
PrepareInnerOverload
Trait identifying outer backends for which the inner backend in second-order autodiff should be prepared with an overloaded input type.
"""
struct PrepareInnerOverload <: InnerPreparationBehavior end
"""
DontPrepareInner
Trait identifying outer backends for which the inner backend in second-order autodiff should not be prepared at all.
"""
struct DontPrepareInner <: InnerPreparationBehavior end
"""
inner_preparation_behavior(backend::AbstractADType)
Return [`PrepareInnerSimple`](@ref), [`PrepareInnerOverload`](@ref) or [`DontPrepareInner`](@ref) in a statically predictable way.
"""
inner_preparation_behavior(::AbstractADType) = DontPrepareInner()
## Conversions
Base.Bool(::InPlaceSupported) = true
Base.Bool(::InPlaceNotSupported) = false
Base.Bool(::PushforwardFast) = true
Base.Bool(::PushforwardSlow) = false
Base.Bool(::PullbackFast) = true
Base.Bool(::PullbackSlow) = false