-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathcontext.jl
More file actions
146 lines (104 loc) · 3.39 KB
/
context.jl
File metadata and controls
146 lines (104 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
struct FixTail{F,A<:Tuple}
f::F
tail_args::A
function FixTail(f::F, tail_args::Vararg{Any,N}) where {F,N}
return new{F,typeof(tail_args)}(f, tail_args)
end
end
function (ft::FixTail)(args::Vararg{Any,N}) where {N}
return ft.f(args..., ft.tail_args...)
end
"""
Context
Abstract supertype for additional context arguments, which can be passed to differentiation operators after the active input `x` but are not differentiated.
# Subtypes
- [`Constant`](@ref)
- [`Cache`](@ref)
"""
abstract type Context end
abstract type GeneralizedConstant <: Context end
unwrap(c::Context) = c.data
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
## Public contexts
"""
Constant
Concrete type of [`Context`](@ref) argument which is kept constant during differentiation.
Note that an operator can be prepared with an arbitrary value of the constant.
However, same-point preparation must occur with the exact value that will be reused later.
!!! warning
Some backends require any `Constant` context to be a `Number` or an `AbstractArray`.
# Example
```jldoctest
julia> using DifferentiationInterface
julia> import ForwardDiff
julia> f(x, c) = c * sum(abs2, x);
julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(10))
2-element Vector{Float64}:
20.0
40.0
julia> gradient(f, AutoForwardDiff(), [1.0, 2.0], Constant(100))
2-element Vector{Float64}:
200.0
400.0
```
"""
struct Constant{T} <: GeneralizedConstant
data::T
end
constant_maker(c) = Constant(c)
maker(::Constant) = constant_maker
"""
Cache
Concrete type of [`Context`](@ref) argument which can be mutated with active values during differentiation.
The initial values present inside the cache do not matter.
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).
!!! warning
Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s.
# Example
```jldoctest
julia> using DifferentiationInterface
julia> import ForwardDiff
julia> f(x, c) = sum(copyto!(c, x));
julia> prep = prepare_gradient(f, AutoForwardDiff(), [1.0, 2.0], Cache(zeros(2)));
julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
2-element Vector{Float64}:
1.0
1.0
````
"""
struct Cache{T} <: Context
data::T
end
cache_maker(c) = Cache(c)
maker(::Cache) = cache_maker
## Internal contexts for passing stuff around
struct FunctionContext{T} <: GeneralizedConstant
data::T
end
struct BackendContext{T} <: GeneralizedConstant
data::T
end
struct PrepContext{T} <: Context
data::T
end
## Context manipulation
struct Rewrap{C,T}
context_makers::T
function Rewrap(contexts::Vararg{Context,C}) where {C}
context_makers = map(maker, contexts)
return new{C,typeof(context_makers)}(context_makers)
end
end
(::Rewrap{0})() = ()
function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T}
return map(r.context_makers, unannotated_contexts) do maker, c
maker(c)
end
end
with_contexts(f) = f
function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
tail_args = map(unwrap, contexts)
return FixTail(f, tail_args...)
end
adapt_eltype(c::Constant, ::Type) = c
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))