-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathforward_twoarg.jl
More file actions
116 lines (109 loc) · 3.69 KB
/
forward_twoarg.jl
File metadata and controls
116 lines (109 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
## Pushforward
struct EnzymeTwoArgPushforwardPrep{SIG, DF, DC} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
df!::DF
context_shadows::DC
end
function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C}
) where {F, B, C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
mode = forward_noprimal(backend)
df! = function_shadow(f!, backend, mode, Val(B))
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)
end
function DI.value_and_pushforward(
f!::F,
y,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1))
dx = only(tx)
dy = make_zero(y)
x_and_dx = Duplicated(x, dx)
y_and_dy = Duplicated(y, dy)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
return y, (dy,)
end
function DI.value_and_pushforward(
f!::F,
y,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))
ty = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end
function DI.pushforward(
f!::F,
y,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
return ty
end
function DI.value_and_pushforward!(
f!::F,
y,
ty::NTuple{B},
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context, C},
) where {F, B, C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
(; df!, context_shadows) = prep
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end
function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::EnzymeTwoArgPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode, Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
return ty
end