Skip to content

Commit 9b17e3e

Browse files
authored
fix: unthunk ChainRules pullback outputs (#674)
1 parent 6604be2 commit 9b17e3e

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using ChainRulesCore:
88
NoTangent,
99
RuleConfig,
1010
frule_via_ad,
11-
rrule_via_ad
11+
rrule_via_ad,
12+
unthunk
1213
import DifferentiationInterface as DI
1314

1415
ruleconfig(backend::AutoChainRules) = backend.ruleconfig

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function DI.value_and_pullback(
3939
rc = ruleconfig(backend)
4040
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
4141
tx = map(ty) do dy
42-
pb(dy)[2]
42+
unthunk(pb(dy)[2])
4343
end
4444
return y, tx
4545
end
@@ -54,7 +54,7 @@ function DI.value_and_pullback(
5454
) where {C}
5555
(; y, pb) = prep
5656
tx = map(ty) do dy
57-
pb(dy)[2]
57+
unthunk(pb(dy)[2])
5858
end
5959
return copy(y), tx
6060
end
@@ -69,7 +69,7 @@ function DI.pullback(
6969
) where {C}
7070
(; pb) = prep
7171
tx = map(ty) do dy
72-
pb(dy)[2]
72+
unthunk(pb(dy)[2])
7373
end
7474
return tx
7575
end

0 commit comments

Comments
 (0)