|
| 1 | +module DifferentiationInterfaceSparseConnectivityTracerExt |
| 2 | + |
| 3 | +using ADTypes: jacobian_sparsity, hessian_sparsity |
| 4 | +import DifferentiationInterface as DI |
| 5 | +using SparseConnectivityTracer: |
| 6 | + TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer |
| 7 | + |
| 8 | +@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c) |
| 9 | +@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray}) |
| 10 | + return jacobian_buffer(DI.unwrap(c), detector) |
| 11 | +end |
| 12 | + |
| 13 | +function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} |
| 14 | + new_contexts = map(contexts) do c |
| 15 | + _jacobian_translate(detector, c) |
| 16 | + end |
| 17 | + return new_contexts |
| 18 | +end |
| 19 | + |
| 20 | +@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c) |
| 21 | +@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray}) |
| 22 | + return hessian_buffer(DI.unwrap(c), detector) |
| 23 | +end |
| 24 | + |
| 25 | +function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} |
| 26 | + new_contexts = map(contexts) do c |
| 27 | + _hessian_translate(detector, c) |
| 28 | + end |
| 29 | + return new_contexts |
| 30 | +end |
| 31 | + |
| 32 | +function DI.jacobian_sparsity_with_contexts( |
| 33 | + f::F, |
| 34 | + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, |
| 35 | + x, |
| 36 | + contexts::Vararg{DI.Context,C}, |
| 37 | +) where {F,C} |
| 38 | + contexts_tracer = jacobian_translate(detector, contexts...) |
| 39 | + fc = DI.FixTail(f, contexts_tracer) |
| 40 | + return jacobian_sparsity(fc, x, detector) |
| 41 | +end |
| 42 | + |
| 43 | +function DI.jacobian_sparsity_with_contexts( |
| 44 | + f!::F, |
| 45 | + y, |
| 46 | + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, |
| 47 | + x, |
| 48 | + contexts::Vararg{DI.Context,C}, |
| 49 | +) where {F,C} |
| 50 | + contexts_tracer = jacobian_translate(detector, contexts...) |
| 51 | + fc! = DI.FixTail(f!, contexts_tracer) |
| 52 | + return jacobian_sparsity(fc!, y, x, detector) |
| 53 | +end |
| 54 | + |
| 55 | +function DI.hessian_sparsity_with_contexts( |
| 56 | + f::F, |
| 57 | + detector::Union{TracerSparsityDetector,TracerLocalSparsityDetector}, |
| 58 | + x, |
| 59 | + contexts::Vararg{DI.Context,C}, |
| 60 | +) where {F,C} |
| 61 | + contexts_tracer = hessian_translate(detector, contexts...) |
| 62 | + fc = DI.FixTail(f, contexts_tracer) |
| 63 | + return hessian_sparsity(fc, x, detector) |
| 64 | +end |
| 65 | + |
| 66 | +end |
0 commit comments