-
Notifications
You must be signed in to change notification settings - Fork 32
feat: backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
1a389a6
08b176a
ba0c9e6
1340d92
2ce1ee2
08de6df
84f27c9
2e95299
13233e5
1e8df98
afdddd4
233c312
7a07127
f3e436d
6a0d937
e543958
2472ecc
c63c956
36da036
d2b5a8c
c389a80
b4fe0f8
ec4b75d
0f0b9fc
3c5f99e
d94f146
c982f46
749fea5
9e5ecfd
1e85f17
ff5c4e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray} | ||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number} | ||
|
|
||
| function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||
| primal_func = primal(dw) | ||
| primal_x = primal(x) | ||
| (; f, backend) = primal_func | ||
| y = zero_fcodual(f(primal_x)) | ||
|
Check warning on line 8 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
gdalle marked this conversation as resolved.
|
||
|
|
||
| # output is a vector, so we need to use the vector pullback | ||
|
AstitvaAggarwal marked this conversation as resolved.
|
||
| function pullback!!(dy::NoRData) | ||
| tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
| return NoRData(), only(tx) | ||
|
Check warning on line 13 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
AstitvaAggarwal marked this conversation as resolved.
Outdated
AstitvaAggarwal marked this conversation as resolved.
Outdated
|
||
| end | ||
|
|
||
| # output is a scalar, so we can use the scalar pullback | ||
| function pullback!!(dy) | ||
|
AstitvaAggarwal marked this conversation as resolved.
Outdated
|
||
| tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
| return NoRData(), only(tx) | ||
|
Check warning on line 19 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
| end | ||
|
|
||
| return y, pullback!! | ||
|
Check warning on line 22 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
| end | ||
|
|
||
| function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) | ||
| primal_func = primal(dw) | ||
| primal_x = primal(x) | ||
| fdata_arg = fdata(x.dx) | ||
| (; f, backend) = primal_func | ||
| y = zero_fcodual(f(primal_x)) | ||
|
Check warning on line 30 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
|
|
||
| # output is a vector, so we need to use the vector pullback | ||
| function pullback!!(dy::NoRData) | ||
| tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),)) | ||
| fdata_arg .+= only(tx) | ||
| return NoRData(), dy | ||
|
Check warning on line 36 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
| end | ||
|
|
||
| # output is a scalar, so we can use the scalar pullback | ||
| function pullback!!(dy) | ||
| tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
| fdata_arg .+= only(tx) | ||
| return NoRData(), NoRData() | ||
|
Check warning on line 43 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
| end | ||
|
|
||
| # in case x is mutated when passed into f | ||
| x = CoDual(primal_x, x.dx) | ||
| return y, pullback!! | ||
|
Check warning on line 48 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.