11# # Pullback
22
3- struct ChainRulesPullbackPrepSamePoint{SIG,Y, PB} <: DI.PullbackPrep{SIG}
3+ struct ChainRulesPullbackPrepSamePoint{SIG, Y, PB} <: DI.PullbackPrep{SIG}
44 _sig:: Val{SIG}
55 y:: Y
66 pb:: PB
77end
88
99function DI. prepare_pullback_nokwarg (
10- strict:: Val ,
11- f,
12- backend:: AutoReverseChainRules ,
13- x,
14- ty:: NTuple ,
15- contexts:: Vararg{DI.GeneralizedConstant,C} ;
16- ) where {C}
10+ strict:: Val ,
11+ f,
12+ backend:: AutoReverseChainRules ,
13+ x,
14+ ty:: NTuple ,
15+ contexts:: Vararg{DI.GeneralizedConstant, C}
16+ ) where {C}
1717 _sig = DI. signature (f, backend, x, ty, contexts... ; strict)
1818 return DI. NoPullbackPrep (_sig)
1919end
2020
2121function DI. prepare_pullback_same_point (
22- f,
23- prep:: DI.NoPullbackPrep ,
24- backend:: AutoReverseChainRules ,
25- x,
26- ty:: NTuple ,
27- contexts:: Vararg{DI.GeneralizedConstant,C} ;
28- ) where {C}
22+ f,
23+ prep:: DI.NoPullbackPrep ,
24+ backend:: AutoReverseChainRules ,
25+ x,
26+ ty:: NTuple ,
27+ contexts:: Vararg{DI.GeneralizedConstant, C}
28+ ) where {C}
2929 DI. check_prep (f, prep, backend, x, ty, contexts... )
30- _sig = DI. signature (f, backend, x, ty, contexts... ; strict= DI. is_strict (prep))
30+ _sig = DI. signature (f, backend, x, ty, contexts... ; strict = DI. is_strict (prep))
3131 rc = ruleconfig (backend)
3232 y, pb = rrule_via_ad (rc, f, x, map (DI. unwrap, contexts)... )
3333 return ChainRulesPullbackPrepSamePoint (_sig, y, pb)
3434end
3535
3636function DI. value_and_pullback (
37- f,
38- prep:: DI.NoPullbackPrep ,
39- backend:: AutoReverseChainRules ,
40- x,
41- ty:: NTuple ,
42- contexts:: Vararg{DI.GeneralizedConstant,C} ,
43- ) where {C}
37+ f,
38+ prep:: DI.NoPullbackPrep ,
39+ backend:: AutoReverseChainRules ,
40+ x,
41+ ty:: NTuple ,
42+ contexts:: Vararg{DI.GeneralizedConstant, C} ,
43+ ) where {C}
4444 DI. check_prep (f, prep, backend, x, ty, contexts... )
4545 rc = ruleconfig (backend)
4646 y, pb = rrule_via_ad (rc, f, x, map (DI. unwrap, contexts)... )
@@ -51,13 +51,13 @@ function DI.value_and_pullback(
5151end
5252
5353function DI. value_and_pullback (
54- f,
55- prep:: ChainRulesPullbackPrepSamePoint ,
56- backend:: AutoReverseChainRules ,
57- x,
58- ty:: NTuple ,
59- contexts:: Vararg{DI.GeneralizedConstant,C} ,
60- ) where {C}
54+ f,
55+ prep:: ChainRulesPullbackPrepSamePoint ,
56+ backend:: AutoReverseChainRules ,
57+ x,
58+ ty:: NTuple ,
59+ contexts:: Vararg{DI.GeneralizedConstant, C} ,
60+ ) where {C}
6161 DI. check_prep (f, prep, backend, x, ty, contexts... )
6262 (; y, pb) = prep
6363 tx = map (ty) do dy
@@ -67,13 +67,13 @@ function DI.value_and_pullback(
6767end
6868
6969function DI. pullback (
70- f,
71- prep:: ChainRulesPullbackPrepSamePoint ,
72- backend:: AutoReverseChainRules ,
73- x,
74- ty:: NTuple ,
75- contexts:: Vararg{DI.GeneralizedConstant,C} ,
76- ) where {C}
70+ f,
71+ prep:: ChainRulesPullbackPrepSamePoint ,
72+ backend:: AutoReverseChainRules ,
73+ x,
74+ ty:: NTuple ,
75+ contexts:: Vararg{DI.GeneralizedConstant, C} ,
76+ ) where {C}
7777 DI. check_prep (f, prep, backend, x, ty, contexts... )
7878 (; pb) = prep
7979 tx = map (ty) do dy
0 commit comments