-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathDifferentiationInterfaceDiffractorExt.jl
More file actions
38 lines (31 loc) · 1.11 KB
/
DifferentiationInterfaceDiffractorExt.jl
File metadata and controls
38 lines (31 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
module DifferentiationInterfaceDiffractorExt
using ADTypes: ADTypes, AutoDiffractor
import DifferentiationInterface as DI
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆
DI.check_available(::AutoDiffractor) = true
DI.inplace_support(::AutoDiffractor) = DI.InPlaceNotSupported()
DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()
## Pushforward
function DI.prepare_pushforward(strict::Val, f, backend::AutoDiffractor, x, tx::NTuple)
_sig = DI.signature(f, backend, x, tx; strict)
return DI.NoPushforwardPrep(_sig)
end
function DI.pushforward(
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
DI.check_prep(f, prep, backend, x, tx)
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
dy = z[TaylorTangentIndex(1)]
dy
end
return ty
end
function DI.value_and_pushforward(
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
DI.check_prep(f, prep, backend, x, tx)
return f(x), DI.pushforward(f, prep, backend, x, tx)
end
end