You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: DifferentiationInterface/docs/src/explanation/advanced.md
+15-8Lines changed: 15 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,22 +10,29 @@ However, the release v0.6 introduced the possibility of additional "context" arg
10
10
Contexts can be useful if you have a function `y = f(x, a, b, c, ...)` or `f!(y, x, a, b, c, ...)` and you want derivatives of `y` with respect to `x` only.
11
11
Another option would be creating a closure, but that is sometimes undesirable.
12
12
13
-
!!! warning
14
-
This feature is still experimental, and will likely not be supported by all backends.
15
-
At the moment, it only works with certain backends, among which ForwardDiff, Zygote and Enzyme.
16
-
17
13
### Types of contexts
18
14
19
15
Every context argument must be wrapped in a subtype of [`Context`](@ref) and come after the differentiated input `x`.
20
-
Right now, there is only one kind of context, namely [`Constant`](@ref), but we might add more.
21
-
Semantically, calling
16
+
Right now, there are two kinds of context: [`Constant`](@ref) and [`Cache`](@ref).
17
+
18
+
!!! warning
19
+
This feature is still experimental and will not be supported by all backends.
20
+
At the moment:
21
+
- `Constant` is supported by all backends except symbolic ones
22
+
- `Cache` is only supported by finite difference backends
23
+
24
+
Semantically, both of these calls compute the partial gradient of `f(x, c)` with respect to `x`, but they consider `c` differently:
22
25
23
26
```julia
24
27
gradient(f, backend, x, Constant(c))
28
+
gradient(f, backend, x, Cache(c))
25
29
```
26
30
27
-
computes the partial gradient of `f(x, c)` with respect to `x`, while keeping `c` constant.
28
-
Importantly, one can prepare an operator with an arbitrary value `c'` of the constant (subject to the usual restrictions on preparation).
31
+
In the first call, `c` is kept unchanged throughout the function evaluation.
32
+
In the second call, `c` can be mutated with values computed during the function.
33
+
34
+
Importantly, one can prepare an operator with an arbitrary value `c'` of the `Constant` (subject to the usual restrictions on preparation).
35
+
The values in a provided `Cache` never matter anyway.
Copy file name to clipboardExpand all lines: DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl
+2-2Lines changed: 2 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -21,13 +21,13 @@ function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T}
Copy file name to clipboardExpand all lines: DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl
+23-23Lines changed: 23 additions & 23 deletions
Original file line number
Diff line number
Diff line change
@@ -7,45 +7,45 @@ using Random: AbstractRNG, default_rng
7
7
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
8
8
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
9
9
10
-
mySArray(f::Function) = f
11
-
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
12
-
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
13
-
mySArray(f::DIT.MultiplyByConstant) = f
14
-
mySArray(f::DIT.WritableClosure) = f
10
+
mystatic(f::Function) = f
11
+
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
12
+
mystatic(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
13
+
mystatic(f::DIT.FunctionModifier) = f
15
14
16
-
mySArray(x::Number) = x
17
-
myMArray(x::Number) = x
15
+
mystatic(x::Number) = x
16
+
mymutablestatic(x::Number) = x
18
17
19
-
mySArray(x::AbstractVector{T}) where {T} =convert(SVector{length(x),T}, x)
20
-
myMArray(x::AbstractVector{T}) where {T} =convert(MVector{length(x),T}, x)
18
+
mystatic(x::AbstractVector{T}) where {T} =convert(SVector{length(x),T}, x)
19
+
mymutablestatic(x::AbstractVector{T}) where {T} =convert(MVector{length(x),T}, x)
0 commit comments