|
1 | 1 | ## Pushforward |
2 | 2 |
|
3 | | -function DI.value_and_pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing) |
| 3 | +DI.prepare_pushforward(f, ::AutoForwardEnzyme, x) = NoPushforwardExtras() |
| 4 | + |
| 5 | +function DI.value_and_pushforward( |
| 6 | + f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras |
| 7 | +) |
4 | 8 | dx_sametype = convert(typeof(x), dx) |
5 | 9 | y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype)) |
6 | 10 | return y, new_dy |
7 | 11 | end |
8 | 12 |
|
9 | | -function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, extras::Nothing) |
| 13 | +function DI.pushforward(f, backend::AutoForwardEnzyme, x, dx, ::NoPushforwardExtras) |
10 | 14 | dx_sametype = convert(typeof(x), dx) |
11 | 15 | new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype))) |
12 | 16 | return new_dy |
13 | 17 | end |
14 | 18 |
|
15 | 19 | function DI.value_and_pushforward!!( |
16 | | - f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing |
| 20 | + f, _dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras |
17 | 21 | ) |
18 | 22 | # dy cannot be passed anyway |
19 | 23 | return DI.value_and_pushforward(f, backend, x, dx, extras) |
20 | 24 | end |
21 | 25 |
|
22 | | -function DI.pushforward!!(f, _dy, backend::AutoForwardEnzyme, x, dx, extras::Nothing) |
| 26 | +function DI.pushforward!!( |
| 27 | + f, _dy, backend::AutoForwardEnzyme, x, dx, extras::NoPushforwardExtras |
| 28 | +) |
23 | 29 | # dy cannot be passed anyway |
24 | 30 | return DI.pushforward(f, backend, x, dx, extras) |
25 | 31 | end |
26 | 32 |
|
27 | 33 | ## Gradient |
28 | 34 |
|
29 | | -function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing) |
| 35 | +DI.prepare_gradient(f, ::AutoForwardEnzyme, x) = NoGradientExtras() |
| 36 | + |
| 37 | +function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoGradientExtras) |
30 | 38 | return reshape(collect(gradient(backend.mode, f, x)), size(x)) |
31 | 39 | end |
32 | 40 |
|
33 | 41 | function DI.value_and_gradient( |
34 | | - f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 42 | + f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
35 | 43 | ) |
36 | 44 | return f(x), DI.gradient(f, backend, x, extras) |
37 | 45 | end |
38 | 46 |
|
39 | 47 | function DI.gradient!!( |
40 | | - f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 48 | + f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
41 | 49 | ) |
42 | 50 | return DI.gradient(f, backend, x, extras) |
43 | 51 | end |
44 | 52 |
|
45 | 53 | function DI.value_and_gradient!!( |
46 | | - f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 54 | + f, _grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
47 | 55 | ) |
48 | 56 | return DI.value_and_gradient(f, backend, x, extras) |
49 | 57 | end |
50 | 58 |
|
51 | 59 | ## Jacobian |
52 | 60 |
|
53 | | -function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing) |
| 61 | +DI.prepare_jacobian(f, ::AutoForwardEnzyme, x) = NoJacobianExtras() |
| 62 | + |
| 63 | +function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoJacobianExtras) |
54 | 64 | jac_wrongshape = jacobian(backend.mode, f, x) |
55 | 65 | nx = length(x) |
56 | 66 | ny = length(jac_wrongshape) ÷ length(x) |
57 | 67 | return reshape(jac_wrongshape, ny, nx) |
58 | 68 | end |
59 | 69 |
|
60 | 70 | function DI.value_and_jacobian( |
61 | | - f, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 71 | + f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
62 | 72 | ) |
63 | 73 | return f(x), DI.jacobian(f, backend, x, extras) |
64 | 74 | end |
65 | 75 |
|
66 | 76 | function DI.jacobian!!( |
67 | | - f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 77 | + f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
68 | 78 | ) |
69 | 79 | return DI.jacobian(f, backend, x, extras) |
70 | 80 | end |
71 | 81 |
|
72 | 82 | function DI.value_and_jacobian!!( |
73 | | - f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::Nothing |
| 83 | + f, _jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
74 | 84 | ) |
75 | 85 | return DI.value_and_jacobian(f, backend, x, extras) |
76 | 86 | end |
0 commit comments