-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathforward_twoarg.jl
More file actions
104 lines (98 loc) · 2.95 KB
/
forward_twoarg.jl
File metadata and controls
104 lines (98 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
## Pushforward
function DI.prepare_pushforward_nokwarg(
strict::Val,
f!::F,
y,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C};
) where {F,C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
return DI.NoPushforwardPrep(_sig)
end
function DI.value_and_pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode)
dx = only(tx)
dy = make_zero(y)
x_and_dx = Duplicated(x, dx)
y_and_dy = Duplicated(y, dy)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
return y, (dy,)
end
function DI.value_and_pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
ty = ntuple(_ -> make_zero(y), Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end
function DI.pushforward(
f!::F,
y,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
return ty
end
function DI.value_and_pushforward!(
f!::F,
y,
ty::NTuple{B},
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
mode = forward_noprimal(backend)
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
x_and_tx = BatchDuplicated(x, tx)
y_and_ty = BatchDuplicated(y, ty)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
return y, ty
end
function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
DI.value_and_pushforward!(f!, y, ty, prep, backend, x, tx, contexts...)
return ty
end