Skip to content

Commit 8cfc501

Browse files
committed
Recursive caches
1 parent 6c8ea28 commit 8cfc501

8 files changed

Lines changed: 42 additions & 32 deletions

File tree

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true # TODO: toggle
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ force_annotation(f::F) where {F} = Const(f)
5353
return Const(DI.unwrap(c))
5454
end
5555

56-
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
58-
) where {B}
56+
@inline function _translate(backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.Cache) where {B}
5957
if B == 1
6058
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
6159
else

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ myvec(x::AbstractArray) = vec(x)
2323

2424
variablize(::Number, name::Symbol) = only(make_variables(name))
2525
variablize(x::AbstractArray, name::Symbol) = make_variables(name, size(x)...)
26+
function variablize(x::Union{Tuple,NamedTuple}, name::Symbol)
27+
return map(x) do xi
28+
variablize(xi, gensym()) # TODO: fix symbol?
29+
end
30+
end
2631

2732
function variablize(contexts::NTuple{C,DI.Context}) where {C}
2833
map(enumerate(contexts)) do (k, c)

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,23 @@ import DifferentiationInterface as DI
55
using SparseConnectivityTracer:
66
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
77

8-
@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
9-
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
10-
return jacobian_buffer(DI.unwrap(c), detector)
8+
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
9+
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
10+
return DI.recursive_similar(DI.unwrap(c), T)
1111
end
1212

13-
function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
13+
function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
14+
T = eltype(jacobian_buffer(x, detector))
1415
new_contexts = map(contexts) do c
15-
_jacobian_translate(detector, c)
16+
_translate(T, c)
1617
end
1718
return new_contexts
1819
end
1920

20-
@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
21-
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
22-
return hessian_buffer(DI.unwrap(c), detector)
23-
end
24-
25-
function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
21+
function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
22+
T = eltype(hessian_buffer(x, detector))
2623
new_contexts = map(contexts) do c
27-
_hessian_translate(detector, c)
24+
_translate(T, c)
2825
end
2926
return new_contexts
3027
end
@@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts(
3532
x,
3633
contexts::Vararg{DI.Context,C},
3734
) where {F,C}
38-
contexts_tracer = jacobian_translate(detector, contexts...)
35+
contexts_tracer = jacobian_translate(detector, x, contexts...)
3936
fc = DI.FixTail(f, contexts_tracer...)
4037
return jacobian_sparsity(fc, x, detector)
4138
end
@@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts(
4744
x,
4845
contexts::Vararg{DI.Context,C},
4946
) where {F,C}
50-
contexts_tracer = jacobian_translate(detector, contexts...)
47+
contexts_tracer = jacobian_translate(detector, x, contexts...)
5148
fc! = DI.FixTail(f!, contexts_tracer...)
5249
return jacobian_sparsity(fc!, y, x, detector)
5350
end

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true
1717
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
1818

1919
translate(c::DI.Context) = DI.unwrap(c)
20-
translate(c::DI.Cache) = Buffer(DI.unwrap(c))
20+
translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c))
21+
function translate(c::DI.Cache{<:Union{NTuple,NamedTuple}})
22+
return map(translate, map(DI.Cache, DI.unwrap(c)))
23+
end
2124

2225
## Pullback
2326

DifferentiationInterface/src/utils/context.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
2323
abstract type Context end
2424

2525
abstract type GeneralizedConstant <: Context end
26-
abstract type GeneralizedCache <: Context end
2726

2827
unwrap(c::Context) = c.data
2928
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
@@ -78,7 +77,7 @@ The initial values present inside the cache do not matter.
7877
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).
7978
8079
!!! warning
81-
Most backends require any `Cache` context to be an `AbstractArray`.
80+
Some backends require any `Cache` context to be an `AbstractArray or a (named) tuple of `AbstractArray`s.
8281
8382
# Example
8483
@@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
9796
1.0
9897
````
9998
"""
100-
struct Cache{T} <: GeneralizedCache
99+
struct Cache{T} <: Context
101100
data::T
102101
end
103102

@@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant
114113
data::T
115114
end
116115

117-
struct PrepContext{T} <: GeneralizedCache
116+
struct PrepContext{T} <: Context
118117
data::T
119118
end
120119

121-
struct UnknownContext <: Context end
122-
123120
## Context manipulation
124121

125122
struct Rewrap{C,T}
@@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
146143
end
147144

148145
adapt_eltype(c::Constant, ::Type) = c
149-
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T))
146+
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ using DifferentiationInterface:
77
using SparseMatrixColorings
88
using Test
99

10+
test_differentiation(
11+
AutoSimpleFiniteDiff(),
12+
default_scenarios(; include_normal=false, include_cachified=true);
13+
logging=true,
14+
)
15+
1016
LOGGING = get(ENV, "CI", "false") == "false"
1117

1218
backends = [ #

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
"""
148148
constantify(scen::Scenario)
149149
150-
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument `a` by which the output is multiplied.
150+
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied.
151151
The output and result fields are updated accordingly.
152152
"""
153153
function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
@@ -178,7 +178,8 @@ end
178178

179179
Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")
180180

181-
function (sc::StoreInCache{:out})(x, y_cache)
181+
function (sc::StoreInCache{:out})(x, y_cache_tup)
182+
y_cache = y_cache_tup.cache
182183
y = sc.f(x)
183184
if y isa Number
184185
y_cache[1] = y
@@ -189,7 +190,8 @@ function (sc::StoreInCache{:out})(x, y_cache)
189190
end
190191
end
191192

192-
function (sc::StoreInCache{:in})(y, x, y_cache)
193+
function (sc::StoreInCache{:in})(y, x, y_cache_tup)
194+
y_cache = y_cache_tup.cache
193195
sc.f(y_cache, x)
194196
copyto!(y, y_cache)
195197
return nothing
@@ -198,16 +200,18 @@ end
198200
"""
199201
cachify(scen::Scenario)
200202
201-
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned.
203+
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument to store the result before it is returned.
204+
205+
If `tup=true` the cache is a tuple of arrays, otherwise just an array.
202206
"""
203207
function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
204208
(; f,) = scen
205209
@assert isempty(scen.contexts)
206210
cache_f = StoreInCache{pl_fun}(f)
207211
y_cache = if scen.y isa Number
208-
[myzero(scen.y)]
212+
(; cache=[myzero(scen.y)])
209213
else
210-
mysimilar(scen.y)
214+
(; cache=mysimilar(scen.y))
211215
end
212216
return Scenario{op,pl_op,pl_fun}(
213217
cache_f;

0 commit comments

Comments
 (0)