Skip to content

Commit d6ead1b

Browse files
committed
Add finer tests and comments
1 parent 2a8a179 commit d6ead1b

2 files changed

Lines changed: 48 additions & 18 deletions

File tree

  • DifferentiationInterfaceTest/src/scenarios
  • DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ end
2828
function get_f_and_df_prepared!(
2929
df, f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}
3030
) where {F,M,B}
31+
#=
32+
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`.
33+
- In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`.
34+
- In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`.
35+
=#
3136
if B == 1
3237
return Duplicated(f, df)
3338
else
@@ -117,6 +122,11 @@ end
117122
function _translate_prepared!(
118123
dc, c_wrapped::Union{DI.ConstantOrCache,DI.FunctionContext}, ::Val{B}
119124
) where {B}
125+
#=
126+
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts.
127+
- In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`.
128+
- In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`.
129+
=#
120130
c = DI.unwrap(c_wrapped)
121131
if isnothing(dc)
122132
return Const(c)

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,26 +104,31 @@ struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier
104104
f::F
105105
x_buffer::Vector{X}
106106
y_buffer::Vector{Y}
107+
a::Float64
108+
b::Vector{Float64}
107109
end
108110

109111
function WritableClosure{pl_fun}(
110-
f::F, x_buffer::Vector{X}, y_buffer::Vector{Y}
112+
f::F, x_buffer::Vector{X}, y_buffer::Vector{Y}, a, b
111113
) where {pl_fun,F,X,Y}
112-
return WritableClosure{pl_fun,F,X,Y}(f, x_buffer, y_buffer)
114+
return WritableClosure{pl_fun,F,X,Y}(f, x_buffer, y_buffer, a, b)
113115
end
114116

115117
Base.show(io::IO, f::WritableClosure) = print(io, "WritableClosure($(f.f))")
116118

117119
function (mc::WritableClosure{:out})(x)
118-
mc.x_buffer[1] = copy(x)
119-
mc.y_buffer[1] = mc.f(x)
120-
return copy(mc.y_buffer[1])
120+
(; f, x_buffer, y_buffer, a, b) = mc
121+
x_buffer[1] = copy(x)
122+
y_buffer[1] = (a + only(b)) * f(x)
123+
return copy(y_buffer[1])
121124
end
122125

123126
function (mc::WritableClosure{:in})(y, x)
124-
mc.x_buffer[1] = copy(x)
125-
mc.f(mc.y_buffer[1], mc.x_buffer[1])
126-
copyto!(y, mc.y_buffer[1])
127+
(; f, x_buffer, y_buffer, a, b) = mc
128+
x_buffer[1] = copy(x)
129+
f(y_buffer[1], x_buffer[1])
130+
y_buffer[1] .*= (a + only(b))
131+
copyto!(y, y_buffer[1])
127132
return nothing
128133
end
129134

@@ -132,13 +137,25 @@ end
132137
133138
Return a new `Scenario` identical to `scen` except for the function `f` which is made to close over differentiable data.
134139
"""
135-
function closurify(scen::Scenario)
140+
function closurify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
136141
(; f, x, y) = scen
137142
@assert isempty(scen.contexts)
138143
x_buffer = [zero(x)]
139144
y_buffer = [zero(y)]
140-
closure_f = WritableClosure{function_place(scen)}(f, x_buffer, y_buffer)
141-
return change_function(scen, closure_f; keep_smaller=false)
145+
a = 3.0
146+
b = [4.0]
147+
closure_f = WritableClosure{pl_fun}(f, x_buffer, y_buffer, a, b)
148+
return Scenario{op,pl_op,pl_fun}(
149+
closure_f;
150+
x = scen.x,
151+
y = mymultiply(scen.y, a + only(b)),
152+
tang = scen.tang,
153+
contexts = scen.contexts,
154+
res1 = mymultiply(scen.res1, a + only(b)),
155+
res2 = mymultiply(scen.res2, a + only(b)),
156+
smaller = nothing,
157+
name = isnothing(scen.name) ? nothing : scen.name * " [closurified]",
158+
)
142159
end
143160

144161
struct MultiplyByConstant{pl_fun,F} <: FunctionModifier
@@ -267,7 +284,8 @@ end
267284

268285
function (sc::MultiplyByConstantAndStoreInCache{:out})(x, constantorcache)
269286
(; constant, cache) = constantorcache
270-
y = constant * sc.f(x)
287+
(; a, b) = constant
288+
y = (a + only(b)) * sc.f(x)
271289
if eltype(y) == eltype(cache)
272290
newcache = cache
273291
else
@@ -285,14 +303,15 @@ end
285303

286304
function (sc::MultiplyByConstantAndStoreInCache{:in})(y, x, constantorcache)
287305
(; constant, cache) = constantorcache
306+
(; a, b) = constant
288307
if eltype(y) == eltype(cache)
289308
newcache = cache
290309
else
291310
# poor man's PreallocationTools
292311
newcache = similar(cache, eltype(y))
293312
end
294313
sc.f(newcache, x)
295-
newcache .*= constant
314+
newcache .*= (a + only(b))
296315
copyto!(y, newcache)
297316
return nothing
298317
end
@@ -307,19 +326,20 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f
307326
@assert isempty(scen.contexts)
308327
constantorcache_f = MultiplyByConstantAndStoreInCache{pl_fun}(f)
309328
a = 3.0
329+
b = [4.0]
310330
constantorcache = if scen.y isa Number
311-
(; cache=[myzero(scen.y)], constant=a)
331+
(; cache=[myzero(scen.y)], constant=(; a, b))
312332
else
313-
(; cache=mysimilar(scen.y), constant=a)
333+
(; cache=mysimilar(scen.y), constant=(; a, b))
314334
end
315335
return Scenario{op,pl_op,pl_fun}(
316336
constantorcache_f;
317337
x=scen.x,
318-
y=mymultiply(scen.y, a),
338+
y=mymultiply(scen.y, a + only(b)),
319339
tang=scen.tang,
320340
contexts=(ConstantOrCache(constantorcache),),
321-
res1=mymultiply(scen.res1, a),
322-
res2=mymultiply(scen.res2, a),
341+
res1=mymultiply(scen.res1, a + only(b)),
342+
res2=mymultiply(scen.res2, a + only(b)),
323343
smaller=isnothing(scen.smaller) ? nothing : constantorcachify(scen.smaller),
324344
name=isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]",
325345
)

0 commit comments

Comments
 (0)