-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathgradient.jl
More file actions
134 lines (107 loc) · 3.16 KB
/
gradient.jl
File metadata and controls
134 lines (107 loc) · 3.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
## Docstrings
"""
prepare_gradient(f, backend, x, [contexts...]) -> prep
$(docstring_prepare("gradient"))
"""
function prepare_gradient end
"""
prepare!_gradient(f, prep, backend, x, [contexts...]) -> new_prep
$(docstring_prepare!("gradient"))
"""
function prepare!_gradient end
"""
value_and_gradient(f, [prep,] backend, x, [contexts...]) -> (y, grad)
Compute the value and the gradient of the function `f` at point `x`.
$(docstring_preparation_hint("gradient"))
"""
function value_and_gradient end
"""
value_and_gradient!(f, grad, [prep,] backend, x, [contexts...]) -> (y, grad)
Compute the value and the gradient of the function `f` at point `x`, overwriting `grad`.
$(docstring_preparation_hint("gradient"))
"""
function value_and_gradient! end
"""
gradient(f, [prep,] backend, x, [contexts...]) -> grad
Compute the gradient of the function `f` at point `x`.
$(docstring_preparation_hint("gradient"))
"""
function gradient end
"""
gradient!(f, grad, [prep,] backend, x, [contexts...]) -> grad
Compute the gradient of the function `f` at point `x`, overwriting `grad`.
$(docstring_preparation_hint("gradient"))
"""
function gradient! end
## Preparation
struct PullbackGradientPrep{Y,E<:PullbackPrep} <: GradientPrep
pullback_prep::E
end
function prepare_gradient(
f::F, backend::AbstractADType, x, contexts::Vararg{Context,C}
) where {F,C}
y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference?
pullback_prep = prepare_pullback(f, backend, x, (true,), contexts...)
return PullbackGradientPrep{typeof(y),typeof(pullback_prep)}(pullback_prep)
end
## One argument
function value_and_gradient(
f::F,
prep::PullbackGradientPrep{Y},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {F,Y,C}
y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
return y, only(tx)
end
function value_and_gradient!(
f::F,
grad,
prep::PullbackGradientPrep{Y},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {F,Y,C}
y, _ = value_and_pullback!(
f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...
)
return y, grad
end
function gradient(
f::F,
prep::PullbackGradientPrep{Y},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {F,Y,C}
tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
return only(tx)
end
function gradient!(
f::F,
grad,
prep::PullbackGradientPrep{Y},
backend::AbstractADType,
x,
contexts::Vararg{Context,C},
) where {F,Y,C}
pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...)
return grad
end
## Shuffled
function shuffled_gradient(
x, f::F, backend::AbstractADType, rewrap::Rewrap{C}, unannotated_contexts::Vararg{Any,C}
) where {F,C}
return gradient(f, backend, x, rewrap(unannotated_contexts...)...)
end
function shuffled_gradient(
x,
f::F,
prep::GradientPrep,
backend::AbstractADType,
rewrap::Rewrap{C},
unannotated_contexts::Vararg{Any,C},
) where {F,C}
return gradient(f, prep, backend, x, rewrap(unannotated_contexts...)...)
end