forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdifferentiate_with.jl
More file actions
22 lines (17 loc) · 777 Bytes
/
differentiate_with.jl
File metadata and controls
22 lines (17 loc) · 777 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}}
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}}
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...)
primal_func = Mooncake.primal(dw)
primal_args = map(arg -> Mooncake.primal(arg), args)
(; f, backend) = primal_func
y = f(primal_args...)
prep_same = DI.prepare_pullback_same_point_nokwarg(
Val(true), f, backend, primal_args..., (y,)
)
function pullback!!(dy)
tx = DI.pullback(f, prep_same, backend, primal_args, (dy,))
args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx))
return NoRData(), args_rdata...
end
return zero_fcodual(y), pullback!!
end