@@ -14,7 +14,7 @@ By order of preference:
1414 hvp(f, backend, x, v, [extras]) -> p
1515"""
1616function hvp (f, backend:: AbstractADType , x, v, extras= prepare_hvp (f, backend, x))
17- new_backend = SecondOrder (backend, backend )
17+ new_backend = SecondOrder (backend)
1818 new_extras = prepare_hvp (f, new_backend, x)
1919 return hvp (f, new_backend, x, v, new_extras)
2020end
3232
3333function hvp_aux (f, backend, x, v, extras, :: ForwardOverReverse )
3434 # JVP of the gradient
35- inner_extras = prepare_gradient (extras, f, inner (backend), x)
36- gradient_closure (z) = gradient (f, inner (backend), z, inner_extras)
35+ function gradient_closure (z)
36+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
37+ return gradient (f, inner (backend), z, inner_extras)
38+ end
3739 outer_extras = prepare_pushforward (extras, gradient_closure, outer (backend), x)
3840 p = pushforward (gradient_closure, outer (backend), x, v, outer_extras)
3941 return p
4042end
4143
4244function hvp_aux (f, backend, x, v, extras, :: ReverseOverForward )
4345 # gradient of the JVP
44- inner_extras = prepare_pushforward (extras, f, inner (backend), x)
45- jvp_closure (z) = pushforward (f, inner (backend), z, v, inner_extras)
46+ function jvp_closure (z)
47+ inner_extras = prepare_pushforward (extras, f, inner (backend), z)
48+ return pushforward (f, inner (backend), z, v, inner_extras)
49+ end
4650 outer_extras = prepare_gradient (extras, jvp_closure, outer (backend), x)
4751 p = gradient (jvp_closure, outer (backend), x, outer_extras)
4852 return p
4953end
5054
5155function hvp_aux (f, backend, x, v, extras, :: ReverseOverReverse )
5256 # VJP of the gradient
53- inner_extras = prepare_gradient (extras, f, inner (backend), x)
54- gradient_closure (z) = gradient (f, inner (backend), z, inner_extras)
57+ function gradient_closure (z)
58+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
59+ return gradient (f, inner (backend), z, inner_extras)
60+ end
5561 outer_extras = prepare_pullback (extras, gradient_closure, outer (backend), x)
5662 p = pullback (gradient_closure, outer (backend), x, v, outer_extras)
5763 return p
6066function hvp_aux (f, backend, x, v, extras, :: ForwardOverForward )
6167 # JVPs of JVPs in theory
6268 # also pushforward of gradient in practice
63- inner_extras = prepare_gradient (extras, f, inner (backend), x)
64- gradient_closure (z) = gradient (f, inner (backend), z, nothing ) # TODO : fix
69+ function gradient_closure (z)
70+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
71+ return gradient (f, inner (backend), z, inner_extras)
72+ end
6573 outer_extras = prepare_pushforward (extras, gradient_closure, outer (backend), x)
6674 p = pushforward (gradient_closure, outer (backend), x, v, outer_extras)
6775 return p
7179 hvp!!(f, p, backend, x, v, [extras]) -> p
7280"""
7381function hvp!! (f, p, backend:: AbstractADType , x, v, extras= prepare_hvp (f, backend, x))
74- new_backend = SecondOrder (backend, backend )
82+ new_backend = SecondOrder (backend)
7583 new_extras = prepare_hvp (f, new_backend, x)
7684 return hvp!! (f, p, new_backend, x, v, new_extras)
7785end
@@ -87,32 +95,40 @@ function hvp!!(f, p, backend::SecondOrder, x, v, extras=prepare_hvp(f, backend,
8795end
8896
8997function hvp_aux!! (f, p, backend, x, v, extras, :: ForwardOverReverse )
90- inner_extras = prepare_gradient (extras, f, inner (backend), x)
91- gradient_closure (z) = gradient (f, inner (backend), z, inner_extras)
98+ function gradient_closure (z)
99+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
100+ return gradient (f, inner (backend), z, inner_extras)
101+ end
92102 outer_extras = prepare_pushforward (extras, gradient_closure, outer (backend), x)
93103 p = pushforward!! (gradient_closure, p, outer (backend), x, v, outer_extras)
94104 return p
95105end
96106
97107function hvp_aux!! (f, p, backend, x, v, extras, :: ReverseOverForward )
98- inner_extras = prepare_pushforward (extras, f, inner (backend), x)
99- jvp_closure (z) = pushforward (f, inner (backend), z, v, inner_extras)
108+ function jvp_closure (z)
109+ inner_extras = prepare_pushforward (extras, f, inner (backend), z)
110+ return pushforward (f, inner (backend), z, v, inner_extras)
111+ end
100112 outer_extras = prepare_gradient (extras, jvp_closure, outer (backend), x)
101113 p = gradient!! (jvp_closure, p, outer (backend), x, outer_extras)
102114 return p
103115end
104116
105117function hvp_aux!! (f, p, backend, x, v, extras, :: ReverseOverReverse )
106- inner_extras = prepare_gradient (extras, f, inner (backend), x)
107- gradient_closure (z) = gradient (f, inner (backend), z, inner_extras)
118+ function gradient_closure (z)
119+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
120+ return gradient (f, inner (backend), z, inner_extras)
121+ end
108122 outer_extras = prepare_pullback (extras, gradient_closure, outer (backend), x)
109123 p = pullback!! (gradient_closure, p, outer (backend), x, v, outer_extras)
110124 return p
111125end
112126
113127function hvp_aux!! (f, p, backend, x, v, extras, :: ForwardOverForward )
114- inner_extras = prepare_gradient (extras, f, inner (backend), x)
115- gradient_closure (z) = gradient (f, inner (backend), z, nothing ) # TODO : fix
128+ function gradient_closure (z)
129+ inner_extras = prepare_gradient (extras, f, inner (backend), z)
130+ return gradient (f, inner (backend), z, inner_extras)
131+ end
116132 outer_extras = prepare_pushforward (extras, gradient_closure, outer (backend), x)
117133 p = pushforward!! (gradient_closure, p, outer (backend), x, v, outer_extras)
118134 return p
0 commit comments