|
33 | 33 |
|
34 | 34 | ## Gradient |
35 | 35 |
|
36 | | -DI.prepare_gradient(f, ::AutoForwardEnzyme, x) = NoGradientExtras() |
| 36 | +struct EnzymeForwardGradientExtras{C,O} |
| 37 | + shadow::O |
| 38 | +end |
| 39 | + |
| 40 | +function DI.prepare_gradient(f, ::AutoForwardEnzyme, x) |
| 41 | + C = pick_chunksize(length(x)) |
| 42 | + shadow = chunkedonehot(x, Val(C)) |
| 43 | + return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow) |
| 44 | +end |
37 | 45 |
|
38 | | -function DI.gradient(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoGradientExtras) |
39 | | - return reshape(collect(gradient(backend.mode, f, x)), size(x)) |
| 46 | +function DI.gradient( |
| 47 | + f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} |
| 48 | +) where {C} |
| 49 | + grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow) |
| 50 | + return reshape(collect(grad_tup), size(x)) |
40 | 51 | end |
41 | 52 |
|
42 | 53 | function DI.value_and_gradient( |
43 | | - f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
| 54 | + f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras |
44 | 55 | ) |
45 | 56 | return f(x), DI.gradient(f, backend, x, extras) |
46 | 57 | end |
47 | 58 |
|
48 | 59 | function DI.gradient!( |
49 | | - f, grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
50 | | -) |
51 | | - return copyto!(grad, DI.gradient(f, backend, x, extras)) |
| 60 | + f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} |
| 61 | +) where {C} |
| 62 | + grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow) |
| 63 | + return copyto!(grad, grad_tup) |
52 | 64 | end |
53 | 65 |
|
54 | 66 | function DI.value_and_gradient!( |
55 | | - f, grad, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoGradientExtras |
56 | | -) |
57 | | - y, new_grad = DI.value_and_gradient(f, backend, x, extras) |
58 | | - return y, copyto!(grad, new_grad) |
| 67 | + f, grad, backend::AutoForwardEnzyme, x, extras::EnzymeForwardGradientExtras{C} |
| 68 | +) where {C} |
| 69 | + grad_tup = gradient(backend.mode, f, x, Val{C}(); shadow=extras.shadow) |
| 70 | + return f(x), copyto!(grad, grad_tup) |
59 | 71 | end |
60 | 72 |
|
61 | 73 | ## Jacobian |
62 | 74 |
|
63 | | -DI.prepare_jacobian(f, ::AutoForwardEnzyme, x) = NoJacobianExtras() |
| 75 | +struct EnzymeForwardOneArgJacobianExtras{C,O} |
| 76 | + shadow::O |
| 77 | +end |
| 78 | + |
| 79 | +function DI.prepare_jacobian(f, ::AutoForwardEnzyme, x) |
| 80 | + C = pick_chunksize(length(x)) |
| 81 | + shadow = chunkedonehot(x, Val(C)) |
| 82 | + return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow) |
| 83 | +end |
64 | 84 |
|
65 | | -function DI.jacobian(f, backend::AutoForwardEnzyme, x::AbstractArray, ::NoJacobianExtras) |
66 | | - jac_wrongshape = jacobian(backend.mode, f, x) |
| 85 | +function DI.jacobian( |
| 86 | + f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras{C} |
| 87 | +) where {C} |
| 88 | + jac_wrongshape = jacobian(backend.mode, f, x, Val{C}(); shadow=extras.shadow) |
67 | 89 | nx = length(x) |
68 | 90 | ny = length(jac_wrongshape) ÷ length(x) |
69 | 91 | return reshape(jac_wrongshape, ny, nx) |
70 | 92 | end |
71 | 93 |
|
72 | 94 | function DI.value_and_jacobian( |
73 | | - f, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
| 95 | + f, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras |
74 | 96 | ) |
75 | 97 | return f(x), DI.jacobian(f, backend, x, extras) |
76 | 98 | end |
77 | 99 |
|
78 | 100 | function DI.jacobian!( |
79 | | - f, jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
| 101 | + f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras |
80 | 102 | ) |
81 | 103 | return copyto!(jac, DI.jacobian(f, backend, x, extras)) |
82 | 104 | end |
83 | 105 |
|
84 | 106 | function DI.value_and_jacobian!( |
85 | | - f, jac, backend::AutoForwardEnzyme, x::AbstractArray, extras::NoJacobianExtras |
| 107 | + f, jac, backend::AutoForwardEnzyme, x, extras::EnzymeForwardOneArgJacobianExtras |
86 | 108 | ) |
87 | 109 | y, new_jac = DI.value_and_jacobian(f, backend, x, extras) |
88 | 110 | return y, copyto!(jac, new_jac) |
|
0 commit comments