11module DifferentiationInterfaceDiffractorExt
22
3- import AbstractDifferentiation as AD # public API for Diffractor
43using ADTypes: ADTypes, AutoChainRules, AutoDiffractor
54import DifferentiationInterface as DI
65using DifferentiationInterface: NoPushforwardExtras
7- using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig
6+ using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
87
98DI. supports_mutation (:: AutoDiffractor ) = DI. MutationNotSupported ()
109DI. mode (:: AutoDiffractor ) = ADTypes. AbstractForwardMode
@@ -14,10 +13,17 @@ DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode
1413
1514DI. prepare_pushforward (f, :: AutoDiffractor , x) = NoPushforwardExtras ()
1615
17- function DI. value_and_pushforward (f, :: AutoDiffractor , x, dx, :: NoPushforwardExtras )
18- vpff = AD. value_and_pushforward_function (DiffractorForwardBackend (), f, x)
19- y, dy = vpff ((dx,))
20- return y, dy
16+ function DI. pushforward (f, :: AutoDiffractor , x, dx, :: NoPushforwardExtras )
17+ # code copied from Diffractor.jl
18+ z = ∂☆ {1} ()(ZeroBundle {1} (f), bundle (x, dx))
19+ dy = z[TaylorTangentIndex (1 )]
20+ return dy
21+ end
22+
23+ function DI. value_and_pushforward (
24+ f, backend:: AutoDiffractor , x, dx, extras:: NoPushforwardExtras
25+ )
26+ return f (x), DI. pushforward (f, backend, x, dx, extras)
2127end
2228
2329end
0 commit comments