-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathutils.jl
More file actions
137 lines (113 loc) · 3.98 KB
/
utils.jl
File metadata and controls
137 lines (113 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
function DI.pick_batchsize(::AutoEnzyme, N::Integer)
B = DI.reasonable_batchsize(N, 16)
return DI.BatchSizeSettings{B}(N)
end
to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
## Annotations
@inline function get_f_and_df(
f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return f
end
@inline function get_f_and_df(
f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
) where {F,M,B}
return Const(f)
end
@inline function get_f_and_df(
f::F,
backend::AutoEnzyme{
M,
<:Union{
Duplicated,
MixedDuplicated,
BatchDuplicated,
BatchMixedDuplicated,
DuplicatedNoNeed,
BatchDuplicatedNoNeed,
},
},
mode::Mode,
::Val{B}=Val(1),
) where {F,M,B}
# TODO: needs more sophistication for mixed activities
if B == 1
return Duplicated(f, make_zero(f))
else
return BatchDuplicated(f, ntuple(_ -> make_zero(f), Val(B)))
end
end
force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)
@inline function _translate(
backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant
) where {B}
return Const(DI.unwrap(c))
end
@inline function _translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
) where {B}
# important to keep make_zero here for ConstantOrCache instead of similar
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
else
return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)))
end
end
@inline function _translate(
backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.ConstantOrCache
) where {B}
IA = guess_activity(typeof(DI.unwrap(c)), mode)
if IA <: Const
return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c)))
else
return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c)))
end
end
@inline function _translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B)))
end
@inline function translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C}
) where {B,C}
new_contexts = map(contexts) do c
_translate(backend, mode, Val(B), c)
end
return new_contexts
end
## Modes
forward_noprimal(backend::AutoEnzyme{<:ForwardMode}) = NoPrimal(backend.mode)
forward_noprimal(::AutoEnzyme{Nothing}) = Forward
forward_withprimal(backend::AutoEnzyme{<:ForwardMode}) = WithPrimal(backend.mode)
forward_withprimal(::AutoEnzyme{Nothing}) = ForwardWithPrimal
reverse_noprimal(backend::AutoEnzyme{<:ReverseMode}) = NoPrimal(backend.mode)
reverse_noprimal(::AutoEnzyme{Nothing}) = Reverse
reverse_withprimal(backend::AutoEnzyme{<:ReverseMode}) = WithPrimal(backend.mode)
reverse_withprimal(::AutoEnzyme{Nothing}) = ReverseWithPrimal
function reverse_split_withprimal(backend::AutoEnzyme{<:ReverseMode})
return set_err(WithPrimal(Split(backend.mode)), backend)
end
function reverse_split_withprimal(backend::AutoEnzyme{Nothing})
return set_err(ReverseSplitWithPrimal, backend)
end
function set_err(mode::Mode, backend::AutoEnzyme{<:Any,Nothing})
return EnzymeCore.set_err_if_func_written(mode)
end
set_err(mode::Mode, backend::AutoEnzyme{<:Any,<:Annotation}) = mode
function maybe_reshape(A::AbstractMatrix, m, n)
@assert size(A) == (m, n)
return A
end
function maybe_reshape(A::AbstractArray, m, n)
return reshape(A, m, n)
end
annotate(::Type{Active{T}}, x, dx) where {T} = Active(x)
annotate(::Type{Duplicated{T}}, x, dx) where {T} = Duplicated(x, dx)
function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B}
return BatchDuplicated(x, tx)
end
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}