Skip to content

Commit 495d988

Browse files
authored
Start implementing Cache contexts (#587)
* Start implementing caches * First Cache implementation and tests * Function modifiers * Scenario conversion
1 parent d4b17c1 commit 495d988

14 files changed

Lines changed: 148 additions & 46 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.14"
4+
version = "0.6.15"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DifferentiationInterface
1313
```@docs
1414
Context
1515
Constant
16+
Cache
1617
```
1718

1819
## First order

DifferentiationInterface/docs/src/explanation/advanced.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,29 @@ However, the release v0.6 introduced the possibility of additional "context" arg
1010
Contexts can be useful if you have a function `y = f(x, a, b, c, ...)` or `f!(y, x, a, b, c, ...)` and you want derivatives of `y` with respect to `x` only.
1111
Another option would be creating a closure, but that is sometimes undesirable.
1212

13-
!!! warning
14-
This feature is still experimental, and will likely not be supported by all backends.
15-
At the moment, it only works with certain backends, among which ForwardDiff, Zygote and Enzyme.
16-
1713
### Types of contexts
1814

1915
Every context argument must be wrapped in a subtype of [`Context`](@ref) and come after the differentiated input `x`.
20-
Right now, there is only one kind of context, namely [`Constant`](@ref), but we might add more.
21-
Semantically, calling
16+
Right now, there are two kinds of context: [`Constant`](@ref) and [`Cache`](@ref).
17+
18+
!!! warning
19+
This feature is still experimental and will not be supported by all backends.
20+
At the moment:
21+
- `Constant` is supported by all backends except symbolic ones
22+
- `Cache` is only supported by finite difference backends
23+
24+
Semantically, both of these calls compute the partial gradient of `f(x, c)` with respect to `x`, but they consider `c` differently:
2225

2326
```julia
2427
gradient(f, backend, x, Constant(c))
28+
gradient(f, backend, x, Cache(c))
2529
```
2630

27-
computes the partial gradient of `f(x, c)` with respect to `x`, while keeping `c` constant.
28-
Importantly, one can prepare an operator with an arbitrary value `c'` of the constant (subject to the usual restrictions on preparation).
31+
In the first call, `c` is kept unchanged throughout the function evaluation.
32+
In the second call, `c` can be mutated with values computed during the function.
33+
34+
Importantly, one can prepare an operator with an arbitrary value `c'` of the `Constant` (subject to the usual restrictions on preparation).
35+
The values in a provided `Cache` never matter anyway.
2936

3037
## Sparsity
3138

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ include("misc/zero_backends.jl")
6565

6666
## Exported
6767

68-
export Context, Constant
68+
export Context, Constant, Cache
6969
export SecondOrder
7070

7171
export value_and_pushforward!, value_and_pushforward

DifferentiationInterface/src/utils/context.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Abstract supertype for additional context arguments, which can be passed to diff
1515
# See also
1616
1717
- [`Constant`](@ref)
18+
- [`Cache`](@ref)
1819
"""
1920
abstract type Context end
2021

@@ -58,6 +59,25 @@ end
5859

5960
Base.convert(::Type{Constant{T}}, x) where {T} = Constant(convert(T, x))
6061

62+
"""
63+
Cache
64+
65+
Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation.
66+
67+
The initial values present inside the cache do not matter.
68+
"""
69+
struct Cache{T} <: Context
70+
data::T
71+
end
72+
73+
unwrap(c::Cache) = c.data
74+
75+
function Base.convert(::Type{Cache{T}}, x::Cache) where {T}
76+
return Cache(convert(T, x.data))
77+
end
78+
79+
Base.convert(::Type{Cache{T}}, x) where {T} = Cache(convert(T, x))
80+
6181
struct Rewrap{C,T}
6282
function Rewrap(contexts::Vararg{Context,C}) where {C}
6383
T = typeof(contexts)

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
test_differentiation(
1616
AutoFiniteDiff(),
17-
default_scenarios(; include_constantified=true);
17+
default_scenarios(; include_constantified=true, include_cachified=true);
1818
excluded=[:second_derivative, :hvp],
1919
logging=LOGGING,
2020
);

DifferentiationInterface/test/Back/FiniteDifferences/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
test_differentiation(
1616
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
17-
default_scenarios(; include_constantified=true);
17+
default_scenarios(; include_constantified=true, include_cachified=true);
1818
excluded=SECOND_ORDER,
1919
logging=LOGGING,
2020
);

DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.8.3"
4+
version = "0.8.4"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T}
2121
return sin.(x .* a)
2222
end
2323

24-
myjl(f::DIT.MultiplyByConstant) = f
25-
myjl(f::DIT.WritableClosure) = f
24+
myjl(f::DIT.FunctionModifier) = f
2625

2726
myjl(x::Number) = x
2827
myjl(x::AbstractArray) = jl(x)
2928
myjl(x::Tuple) = map(myjl, x)
3029
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
30+
myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x)))
3131
myjl(::Nothing) = nothing
3232

3333
function myjl(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,45 +7,45 @@ using Random: AbstractRNG, default_rng
77
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
88
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
99

10-
mySArray(f::Function) = f
11-
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
12-
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
13-
mySArray(f::DIT.MultiplyByConstant) = f
14-
mySArray(f::DIT.WritableClosure) = f
10+
mystatic(f::Function) = f
11+
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
12+
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
13+
mystatic(f::DIT.FunctionModifier) = f
1514

16-
mySArray(x::Number) = x
17-
myMArray(x::Number) = x
15+
mystatic(x::Number) = x
16+
mymutablestatic(x::Number) = x
1817

19-
mySArray(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x)
20-
myMArray(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x)
18+
mystatic(x::AbstractVector{T}) where {T} = convert(SVector{length(x),T}, x)
19+
mymutablestatic(x::AbstractVector{T}) where {T} = convert(MVector{length(x),T}, x)
2120

22-
function mySArray(x::AbstractMatrix{T}) where {T}
21+
function mystatic(x::AbstractMatrix{T}) where {T}
2322
return convert(SMatrix{size(x, 1),size(x, 2),T,length(x)}, x)
2423
end
25-
function myMArray(x::AbstractMatrix{T}) where {T}
24+
function mymutablestatic(x::AbstractMatrix{T}) where {T}
2625
return convert(MMatrix{size(x, 1),size(x, 2),T,length(x)}, x)
2726
end
2827

29-
mySArray(x::Tuple) = map(mySArray, x)
30-
mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x)))
31-
mySArray(::Nothing) = nothing
28+
mystatic(x::Tuple) = map(mystatic, x)
29+
mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x)))
30+
mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x)))
31+
mystatic(::Nothing) = nothing
3232

33-
function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
33+
function mystatic(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
3434
(; f, x, y, tang, contexts, res1, res2) = scen
3535
return Scenario{op,pl_op,pl_fun}(
36-
mySArray(f);
37-
x=mySArray(x),
38-
y=pl_fun == :in ? myMArray(y) : mySArray(y),
39-
tang=mySArray(tang),
40-
contexts=mySArray(contexts),
41-
res1=mySArray(res1),
42-
res2=mySArray(res2),
36+
mystatic(f);
37+
x=mystatic(x),
38+
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
39+
tang=mystatic(tang),
40+
contexts=mystatic(contexts),
41+
res1=mystatic(res1),
42+
res2=mystatic(res2),
4343
)
4444
end
4545

4646
function DIT.static_scenarios(args...; kwargs...)
4747
scens = DIT.default_scenarios(args...; kwargs...)
48-
return mySArray.(scens)
48+
return mystatic.(scens)
4949
end
5050

5151
end

0 commit comments

Comments
 (0)