-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathDifferentiationInterfaceSparseConnectivityTracerExt.jl
More file actions
66 lines (57 loc) · 2.05 KB
/
DifferentiationInterfaceSparseConnectivityTracerExt.jl
File metadata and controls
66 lines (57 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
module DifferentiationInterfaceSparseConnectivityTracerExt
using ADTypes: jacobian_sparsity, hessian_sparsity
import DifferentiationInterface as DI
using SparseConnectivityTracer:
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
return jacobian_buffer(DI.unwrap(c), detector)
end
function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
new_contexts = map(contexts) do c
_jacobian_translate(detector, c)
end
return new_contexts
end
@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
return hessian_buffer(DI.unwrap(c), detector)
end
function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
new_contexts = map(contexts) do c
_hessian_translate(detector, c)
end
return new_contexts
end
function DI.jacobian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
fc = DI.FixTail(f, contexts_tracer)
return jacobian_sparsity(fc, x, detector)
end
function DI.jacobian_sparsity_with_contexts(
f!::F,
y,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
fc! = DI.FixTail(f!, contexts_tracer)
return jacobian_sparsity(fc!, y, x, detector)
end
function DI.hessian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = hessian_translate(detector, contexts...)
fc = DI.FixTail(f, contexts_tracer)
return hessian_sparsity(fc, x, detector)
end
end