Skip to content

Commit c63c956

Browse files
typecheck for array rule.
1 parent 2472ecc commit c63c956

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
3333
# output is a vector, so we need to use the vector pullback
3434
function pullback_array!!(dy::NoRData)
3535
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
36+
@assert only(tx) isa rdata_type(typeof(primal_x))
3637
fdata_arg .+= only(tx)
3738
return NoRData(), dy
3839
end
3940

4041
# output is a scalar, so we can use the scalar pullback
4142
function pullback_scalar!!(dy::Number)
4243
tx = DI.pullback(f, backend, primal_x, (dy,))
44+
@assert only(tx) isa rdata_type(typeof(primal_x))
4345
fdata_arg .+= only(tx)
4446
return NoRData(), NoRData()
4547
end

0 commit comments

Comments
 (0)