-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathDifferentiationInterfaceSparseConnectivityTracerExt.jl
More file actions
63 lines (55 loc) · 1.91 KB
/
DifferentiationInterfaceSparseConnectivityTracerExt.jl
File metadata and controls
63 lines (55 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
module DifferentiationInterfaceSparseConnectivityTracerExt
using ADTypes: jacobian_sparsity, hessian_sparsity
import DifferentiationInterface as DI
using SparseConnectivityTracer:
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
return DI.recursive_similar(DI.unwrap(c), T)
end
function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
T = eltype(jacobian_buffer(x, detector))
new_contexts = map(contexts) do c
_translate(T, c)
end
return new_contexts
end
function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
T = eltype(hessian_buffer(x, detector))
new_contexts = map(contexts) do c
_translate(T, c)
end
return new_contexts
end
function DI.jacobian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, x, contexts...)
fc = DI.FixTail(f, contexts_tracer...)
return jacobian_sparsity(fc, x, detector)
end
function DI.jacobian_sparsity_with_contexts(
f!::F,
y,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, x, contexts...)
fc! = DI.FixTail(f!, contexts_tracer...)
return jacobian_sparsity(fc!, y, x, detector)
end
function DI.hessian_sparsity_with_contexts(
f::F,
detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector},
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = hessian_translate(detector, x, contexts...)
fc = DI.FixTail(f, contexts_tracer...)
return hessian_sparsity(fc, x, detector)
end
end