@@ -104,26 +104,31 @@ struct WritableClosure{pl_fun,F,X,Y} <: FunctionModifier
104104 f:: F
105105 x_buffer:: Vector{X}
106106 y_buffer:: Vector{Y}
107+ a:: Float64
108+ b:: Vector{Float64}
107109end
108110
109111function WritableClosure {pl_fun} (
110- f:: F , x_buffer:: Vector{X} , y_buffer:: Vector{Y}
112+ f:: F , x_buffer:: Vector{X} , y_buffer:: Vector{Y} , a, b
111113) where {pl_fun,F,X,Y}
112- return WritableClosure {pl_fun,F,X,Y} (f, x_buffer, y_buffer)
114+ return WritableClosure {pl_fun,F,X,Y} (f, x_buffer, y_buffer, a, b )
113115end
114116
115117Base. show (io:: IO , f:: WritableClosure ) = print (io, " WritableClosure($(f. f) )" )
116118
117119function (mc:: WritableClosure{:out} )(x)
118- mc. x_buffer[1 ] = copy (x)
119- mc. y_buffer[1 ] = mc. f (x)
120- return copy (mc. y_buffer[1 ])
120+ (; f, x_buffer, y_buffer, a, b) = mc
121+ x_buffer[1 ] = copy (x)
122+ y_buffer[1 ] = (a + only (b)) * f (x)
123+ return copy (y_buffer[1 ])
121124end
122125
123126function (mc:: WritableClosure{:in} )(y, x)
124- mc. x_buffer[1 ] = copy (x)
125- mc. f (mc. y_buffer[1 ], mc. x_buffer[1 ])
126- copyto! (y, mc. y_buffer[1 ])
127+ (; f, x_buffer, y_buffer, a, b) = mc
128+ x_buffer[1 ] = copy (x)
129+ f (y_buffer[1 ], x_buffer[1 ])
130+ y_buffer[1 ] .*= (a + only (b))
131+ copyto! (y, y_buffer[1 ])
127132 return nothing
128133end
129134
@@ -132,13 +137,25 @@ end
132137
133138Return a new `Scenario` identical to `scen` except for the function `f` which is made to close over differentiable data.
134139"""
135- function closurify (scen:: Scenario )
140+ function closurify (scen:: Scenario{op,pl_op,pl_fun} ) where {op,pl_op,pl_fun}
136141 (; f, x, y) = scen
137142 @assert isempty (scen. contexts)
138143 x_buffer = [zero (x)]
139144 y_buffer = [zero (y)]
140- closure_f = WritableClosure {function_place(scen)} (f, x_buffer, y_buffer)
141- return change_function (scen, closure_f; keep_smaller= false )
145+ a = 3.0
146+ b = [4.0 ]
147+ closure_f = WritableClosure {pl_fun} (f, x_buffer, y_buffer, a, b)
148+ return Scenario {op,pl_op,pl_fun} (
149+ closure_f;
150+ x = scen. x,
151+ y = mymultiply (scen. y, a + only (b)),
152+ tang = scen. tang,
153+ contexts = scen. contexts,
154+ res1 = mymultiply (scen. res1, a + only (b)),
155+ res2 = mymultiply (scen. res2, a + only (b)),
156+ smaller = nothing ,
157+ name = isnothing (scen. name) ? nothing : scen. name * " [closurified]" ,
158+ )
142159end
143160
144161struct MultiplyByConstant{pl_fun,F} <: FunctionModifier
267284
268285function (sc:: MultiplyByConstantAndStoreInCache{:out} )(x, constantorcache)
269286 (; constant, cache) = constantorcache
270- y = constant * sc. f (x)
287+ (; a, b) = constant
288+ y = (a + only (b)) * sc. f (x)
271289 if eltype (y) == eltype (cache)
272290 newcache = cache
273291 else
@@ -285,14 +303,15 @@ end
285303
286304function (sc:: MultiplyByConstantAndStoreInCache{:in} )(y, x, constantorcache)
287305 (; constant, cache) = constantorcache
306+ (; a, b) = constant
288307 if eltype (y) == eltype (cache)
289308 newcache = cache
290309 else
291310 # poor man's PreallocationTools
292311 newcache = similar (cache, eltype (y))
293312 end
294313 sc. f (newcache, x)
295- newcache .*= constant
314+ newcache .*= (a + only (b))
296315 copyto! (y, newcache)
297316 return nothing
298317end
@@ -307,19 +326,20 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f
307326 @assert isempty (scen. contexts)
308327 constantorcache_f = MultiplyByConstantAndStoreInCache {pl_fun} (f)
309328 a = 3.0
329+ b = [4.0 ]
310330 constantorcache = if scen. y isa Number
311- (; cache= [myzero (scen. y)], constant= a )
331+ (; cache= [myzero (scen. y)], constant= (; a, b) )
312332 else
313- (; cache= mysimilar (scen. y), constant= a )
333+ (; cache= mysimilar (scen. y), constant= (; a, b) )
314334 end
315335 return Scenario {op,pl_op,pl_fun} (
316336 constantorcache_f;
317337 x= scen. x,
318- y= mymultiply (scen. y, a),
338+ y= mymultiply (scen. y, a + only (b) ),
319339 tang= scen. tang,
320340 contexts= (ConstantOrCache (constantorcache),),
321- res1= mymultiply (scen. res1, a),
322- res2= mymultiply (scen. res2, a),
341+ res1= mymultiply (scen. res1, a + only (b) ),
342+ res2= mymultiply (scen. res2, a + only (b) ),
323343 smaller= isnothing (scen. smaller) ? nothing : constantorcachify (scen. smaller),
324344 name= isnothing (scen. name) ? nothing : scen. name * " [constantorcachified]" ,
325345 )
0 commit comments