|
2 | 2 |
|
3 | 3 | DI.prepare_pullback(f, ::AnyAutoReverseDiff, x) = NoPullbackExtras() |
4 | 4 |
|
5 | | -function DI.value_and_pullback!!( |
6 | | - f, |
7 | | - dx::AbstractArray, |
8 | | - backend::AnyAutoReverseDiff, |
9 | | - x::AbstractArray, |
10 | | - dy, |
11 | | - extras::NoPullbackExtras, |
12 | | -) |
13 | | - return f(x), DI.pullback!!(f, dx, backend, x, dy, extras) |
14 | | -end |
15 | | - |
16 | | -function DI.value_and_pullback( |
17 | | - f, backend::AnyAutoReverseDiff, x::AbstractArray, dy, extras::NoPullbackExtras |
18 | | -) |
19 | | - return f(x), DI.pullback(f, backend, x, dy, extras) |
20 | | -end |
21 | | - |
22 | | -### Number out |
23 | | - |
24 | | -function DI.pullback!!( |
25 | | - f, |
26 | | - dx::AbstractArray, |
27 | | - ::AnyAutoReverseDiff, |
28 | | - x::AbstractArray, |
29 | | - dy::Number, |
30 | | - ::NoPullbackExtras, |
| 5 | +function DI.value_and_pullback_split( |
| 6 | + f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras |
31 | 7 | ) |
32 | | - dx = gradient!(dx, f, x) |
33 | | - dx .*= dy |
34 | | - return dx |
35 | | -end |
36 | | - |
37 | | -function DI.pullback( |
38 | | - f, ::AnyAutoReverseDiff, x::AbstractArray, dy::Number, ::NoPullbackExtras |
39 | | -) |
40 | | - dx = gradient(f, x) |
41 | | - dx .*= dy |
42 | | - return dx |
43 | | -end |
44 | | - |
45 | | -### Array out |
46 | | - |
47 | | -function DI.pullback!!( |
48 | | - f, |
49 | | - dx::AbstractArray, |
50 | | - ::AnyAutoReverseDiff, |
51 | | - x::AbstractArray, |
52 | | - dy::AbstractArray, |
53 | | - ::NoPullbackExtras, |
54 | | -) |
55 | | - dotproduct_closure(x) = dot(f(x), dy) |
56 | | - dx = gradient!(dx, dotproduct_closure, x) |
57 | | - return dx |
| 8 | + y = f(x) |
| 9 | + pullbackfunc = if y isa Number |
| 10 | + dy -> dy .* gradient(f, x) |
| 11 | + elseif y isa AbstractArray |
| 12 | + dy -> gradient(z -> dot(f(z), dy), x) |
| 13 | + end |
| 14 | + return y, pullbackfunc |
58 | 15 | end |
59 | 16 |
|
60 | | -function DI.pullback( |
61 | | - f, ::AnyAutoReverseDiff, x::AbstractArray, dy::AbstractArray, extras::NoPullbackExtras |
| 17 | +function DI.value_and_pullback!!_split( |
| 18 | + f, ::AnyAutoReverseDiff, x::AbstractArray, ::NoPullbackExtras |
62 | 19 | ) |
63 | | - dotproduct_closure(x) = dot(f(x), dy) |
64 | | - dx = gradient(dotproduct_closure, x) |
65 | | - return dx |
| 20 | + y = f(x) |
| 21 | + pullbackfunc!! = if y isa Number |
| 22 | + (dx, dy) -> begin |
| 23 | + dx = gradient!(dx, f, x) |
| 24 | + dx .*= dy |
| 25 | + end |
| 26 | + elseif y isa AbstractArray |
| 27 | + (dx, dy) -> gradient!(dx, z -> dot(f(z), dy), x) |
| 28 | + end |
| 29 | + return y, pullbackfunc!! |
66 | 30 | end |
67 | 31 |
|
68 | | -### Number in, not supported |
69 | | - |
70 | | -function DI.value_and_pullback( |
71 | | - f, backend::AnyAutoReverseDiff, x::Number, dy, ::NoPullbackExtras |
| 32 | +function DI.value_and_pullback_split( |
| 33 | + f, backend::AnyAutoReverseDiff, x::Number, ::NoPullbackExtras |
72 | 34 | ) |
73 | 35 | x_array = [x] |
74 | 36 | f_array = f ∘ only |
75 | | - new_extras = DI.prepare_pullback(f_array, backend, x_array) |
76 | | - y, dx_array = DI.value_and_pullback(f_array, backend, x_array, dy, new_extras) |
77 | | - return y, only(dx_array) |
| 37 | + y, pullbackfunc = DI.value_and_pullback_split(f_array, backend, x_array) |
| 38 | + return y, only ∘ pullbackfunc |
78 | 39 | end |
79 | 40 |
|
80 | 41 | ## Gradient |
|
0 commit comments