1+ const AnyDuplicated = Union{
2+ Duplicated,
3+ MixedDuplicated,
4+ BatchDuplicated,
5+ BatchMixedDuplicated,
6+ DuplicatedNoNeed,
7+ BatchDuplicatedNoNeed,
8+ }
9+
110# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
211function DI. pick_batchsize (:: AutoEnzyme , N:: Integer )
312 B = DI. reasonable_batchsize (N, 16 )
@@ -8,26 +17,17 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
817
918# # Annotations
1019
11- const AnyDuplicated = Union{
12- Duplicated,
13- MixedDuplicated,
14- BatchDuplicated,
15- BatchMixedDuplicated,
16- DuplicatedNoNeed,
17- BatchDuplicatedNoNeed,
18- }
19-
20- function get_f_and_df (f:: F , :: AutoEnzyme{M,Nothing} , :: Val{B} ) where {F,M,B}
20+ function get_f_and_df_prepared! (_df, f:: F , :: AutoEnzyme{M,Nothing} , :: Val{B} ) where {F,M,B}
2121 return f
2222end
2323
24- function get_f_and_df ( f:: F , :: AutoEnzyme{M,<:Const} , :: Val{B} ) where {F,M,B}
24+ function get_f_and_df_prepared! (_df, f:: F , :: AutoEnzyme{M,<:Const} , :: Val{B} ) where {F,M,B}
2525 return Const (f)
2626end
2727
28- function get_f_and_df (f :: F , backend :: AutoEnzyme{M,<:AnyDuplicated} , :: Val{B} ) where {F,M,B}
29- # TODO : needs more sophistication for mixed activities
30- df = function_shadow (f, backend, Val (B))
28+ function get_f_and_df_prepared! (
29+ df, f :: F , :: AutoEnzyme{M,<:AnyDuplicated} , :: Val{B}
30+ ) where {F,M,B}
3131 if B == 1
3232 return Duplicated (f, df)
3333 else
@@ -49,71 +49,9 @@ function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where
4949 end
5050end
5151
52- function get_f_and_df_prepared! (_df, f:: F , :: AutoEnzyme{M,Nothing} , :: Val{B} ) where {F,M,B}
53- return f
54- end
55-
56- function get_f_and_df_prepared! (_df, f:: F , :: AutoEnzyme{M,<:Const} , :: Val{B} ) where {F,M,B}
57- return Const (f)
58- end
59-
60- function get_f_and_df_prepared! (
61- df, f:: F , :: AutoEnzyme{M,<:AnyDuplicated} , :: Val{B}
62- ) where {F,M,B}
63- if B == 1
64- return Duplicated (f, df)
65- else
66- return BatchDuplicated (f, df)
67- end
68- end
69-
7052force_annotation (f:: F ) where {F<: Annotation } = f
7153force_annotation (f:: F ) where {F} = Const (f)
7254
73- function _translate (:: AutoEnzyme , :: Mode , :: Val{B} , c_wrapped:: DI.Constant ) where {B}
74- c = DI. unwrap (c_wrapped)
75- return Const (c)
76- end
77-
78- function _translate (:: AutoEnzyme , :: Mode , :: Val{B} , c_wrapped:: DI.Cache ) where {B}
79- c = DI. unwrap (c_wrapped)
80- if B == 1
81- dc = make_zero (c)
82- return Duplicated (c, dc)
83- else
84- dc = ntuple (_ -> make_zero (c), Val (B))
85- return BatchDuplicated (c, dc)
86- end
87- end
88-
89- function _translate (
90- backend:: AutoEnzyme , mode:: Mode , :: Val{B} , c_wrapped:: DI.ConstantOrCache
91- ) where {B}
92- c = DI. unwrap (c_wrapped)
93- IA = guess_activity (typeof (c), mode)
94- if IA <: Const
95- return _translate (backend, mode, Val (B), DI. Constant (c))
96- else
97- return _translate (backend, mode, Val (B), DI. Cache (c))
98- end
99- end
100-
101- function _translate (
102- backend:: AutoEnzyme , :: Mode , :: Val{B} , c_wrapped:: DI.FunctionContext
103- ) where {B}
104- f = DI. unwrap (c_wrapped)
105- return force_annotation (get_f_and_df (f, backend, Val (B)))
106- end
107-
108- function translate (
109- backend:: AutoEnzyme , mode:: Mode , :: Val{B} , contexts:: Vararg{DI.Context,C}
110- ) where {B,C}
111- new_contexts = map (contexts) do c_wrapped
112- _translate (backend, mode, Val (B), c_wrapped)
113- end
114- return new_contexts
115- end
116-
11755function _shadow (:: AutoEnzyme , :: Mode , :: Val{B} , c_wrapped:: DI.Constant ) where {B}
11856 return nothing
11957end
@@ -128,14 +66,18 @@ function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Cache) where {B}
12866end
12967
13068function _shadow (
131- backend :: AutoEnzyme , mode:: Mode , valB:: Val{B} , c_wrapped:: DI.ConstantOrCache
69+ :: AutoEnzyme , mode:: Mode , valB:: Val{B} , c_wrapped:: DI.ConstantOrCache
13270) where {B}
13371 c = DI. unwrap (c_wrapped)
13472 IA = guess_activity (typeof (c), mode)
13573 if IA <: Const
136- return _shadow (backend, mode, valB, DI . Constant (c))
74+ nothing
13775 else
138- return _shadow (backend, mode, valB, DI. Cache (c))
76+ if B == 1
77+ return make_zero (c)
78+ else
79+ return ntuple (_ -> make_zero (c), Val (B))
80+ end
13981 end
14082end
14183
@@ -149,7 +91,7 @@ function _shadow(
14991 return function_shadow (f, backend, Val (B))
15092end
15193
152- function shadows (
94+ function make_context_shadows (
15395 backend:: AutoEnzyme , mode:: Mode , :: Val{B} , contexts:: Vararg{DI.Context,C}
15496) where {B,C}
15597 context_shadows = map (contexts) do c_wrapped
0 commit comments