Skip to content

Commit 65c693a

Browse files
authored
Call Diffractor's native pushforward (#153)
1 parent 99a3f37 commit 65c693a

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111

1212
[weakdeps]
13-
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
1413
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1514
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
1615
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -28,10 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2827

2928
[extensions]
3029
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
31-
DifferentiationInterfaceDiffractorExt = [
32-
"AbstractDifferentiation",
33-
"Diffractor",
34-
]
30+
DifferentiationInterfaceDiffractorExt = "Diffractor"
3531
DifferentiationInterfaceEnzymeExt = "Enzyme"
3632
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3733
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
@@ -46,7 +42,6 @@ DifferentiationInterfaceZygoteExt = "Zygote"
4642

4743
[compat]
4844
ADTypes = "0.2.7"
49-
AbstractDifferentiation = "0.6.2"
5045
ChainRulesCore = "1.23.0"
5146
Diffractor = "0.2.6"
5247
DocStringExtensions = "0.9.3"
Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module DifferentiationInterfaceDiffractorExt
22

3-
import AbstractDifferentiation as AD # public API for Diffractor
43
using ADTypes: ADTypes, AutoChainRules, AutoDiffractor
54
import DifferentiationInterface as DI
65
using DifferentiationInterface: NoPushforwardExtras
7-
using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig
6+
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
87

98
DI.supports_mutation(::AutoDiffractor) = DI.MutationNotSupported()
109
DI.mode(::AutoDiffractor) = ADTypes.AbstractForwardMode
@@ -14,10 +13,17 @@ DI.mode(::AutoChainRules{<:DiffractorRuleConfig}) = ADTypes.AbstractForwardMode
1413

1514
DI.prepare_pushforward(f, ::AutoDiffractor, x) = NoPushforwardExtras()
1615

17-
function DI.value_and_pushforward(f, ::AutoDiffractor, x, dx, ::NoPushforwardExtras)
18-
vpff = AD.value_and_pushforward_function(DiffractorForwardBackend(), f, x)
19-
y, dy = vpff((dx,))
20-
return y, dy
16+
function DI.pushforward(f, ::AutoDiffractor, x, dx, ::NoPushforwardExtras)
17+
# code copied from Diffractor.jl
18+
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
19+
dy = z[TaylorTangentIndex(1)]
20+
return dy
21+
end
22+
23+
function DI.value_and_pushforward(
24+
f, backend::AutoDiffractor, x, dx, extras::NoPushforwardExtras
25+
)
26+
return f(x), DI.pushforward(f, backend, x, dx, extras)
2127
end
2228

2329
end

0 commit comments

Comments
 (0)