Skip to content

Commit 401096d

Browse files
committed
Fix tangents for prep same point
1 parent 92f337b commit 401096d

4 files changed

Lines changed: 34 additions & 35 deletions

File tree

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,26 @@ abstract type FunctionModifier end
66
Return a new `Scenario` identical to `scen` except for the first- and second-order results which are set to zero.
77
"""
88
function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
9+
zero_res1 = if op in (:pushforward, :pullback)
10+
map(zero, scen.res1)
11+
else
12+
zero(scen.res1)
13+
end
14+
zero_res2 = if isnothing(scen.res2)
15+
nothing
16+
elseif op == :hvp
17+
map(zero, scen.res2)
18+
else
19+
zero(scen.res2)
20+
end
921
return Scenario{op,pl_op,pl_fun}(;
1022
f=scen.f,
1123
x=scen.x,
1224
y=scen.y,
1325
t=scen.t,
1426
contexts=scen.contexts,
15-
res1=myzero(scen.res1),
16-
res2=myzero(scen.res2),
27+
res1=zero_res1,
28+
res2=zero_res2,
1729
prep_args=scen.prep_args,
1830
name=isnothing(scen.name) ? nothing : scen.name * " [zero]",
1931
)
@@ -239,15 +251,15 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl
239251
cache_f = StoreInCache{pl_fun}(f)
240252
if use_tuples
241253
y_cache = if scen.y isa Number
242-
(; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)])
254+
(; useful_cache=([zero(scen.y)],), useless_cache=[zero(scen.y)])
243255
else
244-
(; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y))
256+
(; useful_cache=(similar(scen.y),), useless_cache=similar(scen.y))
245257
end
246258
else
247259
y_cache = if scen.y isa Number
248-
[myzero(scen.y)]
260+
[zero(scen.y)]
249261
else
250-
mysimilar(scen.y)
262+
similar(scen.y)
251263
end
252264
end
253265
return Scenario{op,pl_op,pl_fun}(;
@@ -321,14 +333,14 @@ function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_f
321333
a = 3.0
322334
b = [4.0]
323335
constantorcache = if scen.y isa Number
324-
(; cache=[myzero(scen.y)], constant=(; a, b))
336+
(; cache=[zero(scen.y)], constant=(; a, b))
325337
else
326-
(; cache=mysimilar(scen.y), constant=(; a, b))
338+
(; cache=similar(scen.y), constant=(; a, b))
327339
end
328340
prep_constantorcache = if scen.y isa Number
329-
(; cache=[myzero(scen.y)], constant=(; a=2a, b=3b))
341+
(; cache=[zero(scen.y)], constant=(; a=2a, b=3b))
330342
else
331-
(; cache=mysimilar(scen.y), constant=(; a=2a, b=3b))
343+
(; cache=similar(scen.y), constant=(; a=2a, b=3b))
332344
end
333345
return Scenario{op,pl_op,pl_fun}(;
334346
f=constantorcache_f,

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,P<
6060
end
6161
end
6262

63-
function myzero_contexts(contexts...)
63+
function zero_contexts(contexts...)
6464
rewrap = Rewrap(contexts...)
65-
return rewrap(map(myzero unwrap, contexts)...)
65+
return rewrap(map(zero unwrap, contexts)...)
6666
end
6767

6868
function Scenario{op,pl_op}(
@@ -71,7 +71,7 @@ function Scenario{op,pl_op}(
7171
contexts::Vararg{Context};
7272
res1=nothing,
7373
res2=nothing,
74-
prep_args=(; x=myzero(x), contexts=myzero_contexts(contexts...)),
74+
prep_args=(; x=zero(x), contexts=zero_contexts(contexts...)),
7575
name=nothing,
7676
) where {op,pl_op}
7777
y = f(x, map(unwrap, contexts)...)
@@ -87,7 +87,7 @@ function Scenario{op,pl_op}(
8787
contexts::Vararg{Context};
8888
res1=nothing,
8989
res2=nothing,
90-
prep_args=(; y=myzero(y), x=myzero(x), contexts=myzero_contexts(contexts...)),
90+
prep_args=(; y=zero(y), x=zero(x), contexts=zero_contexts(contexts...)),
9191
name=nothing,
9292
) where {op,pl_op}
9393
f(y, x, map(unwrap, contexts)...)
@@ -103,7 +103,7 @@ function Scenario{op,pl_op}(
103103
contexts::Vararg{Context};
104104
res1=nothing,
105105
res2=nothing,
106-
prep_args=(; x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...)),
106+
prep_args=(; x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)),
107107
name=nothing,
108108
) where {op,pl_op}
109109
y = f(x, map(unwrap, contexts)...)
@@ -118,9 +118,7 @@ function Scenario{op,pl_op}(
118118
contexts::Vararg{Context};
119119
res1=nothing,
120120
res2=nothing,
121-
prep_args=(;
122-
y=myzero(y), x=myzero(x), t=map(myzero, t), contexts=myzero_contexts(contexts...)
123-
),
121+
prep_args=(; y=zero(y), x=zero(x), t=map(zero, t), contexts=zero_contexts(contexts...)),
124122
name=nothing,
125123
) where {op,pl_op}
126124
f(y, x, map(unwrap, contexts)...)

DifferentiationInterfaceTest/src/tests/correctness_eval.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,10 @@ for op in ALL_OPS
470470
prep_args.contexts...;
471471
strict=Val(true),
472472
)
473-
prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...)
473+
prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...)
474474
if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x))
475475
prep = $prep_op!(f, prep, ba, x, t, contexts...)
476476
prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...)
477-
prep_same = $prep_op_same(f, ba, x, t, contexts...)
478477
end
479478
[(), (prep,), (prepstrict,), (prep_same,)]
480479
end
@@ -527,11 +526,10 @@ for op in ALL_OPS
527526
prep_args.contexts...;
528527
strict=Val(true),
529528
)
530-
prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...)
529+
prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...)
531530
if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x))
532531
prep = $prep_op!(f, prep, ba, x, t, contexts...)
533532
prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...)
534-
prep_same = $prep_op_same(f, ba, x, t, contexts...)
535533
end
536534
[(), (prep,), (prepstrict,), (prep_same,)]
537535
end
@@ -603,14 +601,13 @@ for op in ALL_OPS
603601
prep_args.contexts...;
604602
strict=Val(true),
605603
)
606-
prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...)
604+
prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...)
607605
if reprepare &&
608606
has_size(x) &&
609607
has_size(y) &&
610608
(size(x) != size(prep_args.x) || size(y) != prep_args.y)
611609
prep = $prep_op!(f, y, prep, ba, x, t, contexts...)
612610
prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...)
613-
prep_same = $prep_op_same(f, y, ba, x, t, contexts...)
614611
end
615612
[(), (prep,), (prepstrict,), (prep_same,)]
616613
end
@@ -678,14 +675,13 @@ for op in ALL_OPS
678675
prep_args.contexts...;
679676
strict=Val(true),
680677
)
681-
prep_same = $prep_op_same(f, prep_args.y, ba, x, prep_args.t, contexts...)
678+
prep_same = $prep_op_same(f, y, ba, x, map(zero, t), contexts...)
682679
if reprepare &&
683680
has_size(x) &&
684681
has_size(y) &&
685682
(size(x) != size(prep_args.x) || size(y) != prep_args.y)
686683
prep = $prep_op!(f, y, prep, ba, x, t, contexts...)
687684
prepstrict = $prep_op!(f, y, prepstrict, ba, x, t, contexts...)
688-
prep_same = $prep_op_same(f, y, ba, x, t, contexts...)
689685
end
690686
[(), (prep,), (prepstrict,), (prep_same,)]
691687
end
@@ -757,11 +753,10 @@ for op in ALL_OPS
757753
prep_args.contexts...;
758754
strict=Val(true),
759755
)
760-
prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...)
756+
prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...)
761757
if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x))
762758
prep = $prep_op!(f, prep, ba, x, t, contexts...)
763759
prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...)
764-
prep_same = $prep_op_same(f, ba, x, t, contexts...)
765760
end
766761
[(), (prep,), (prepstrict,), (prep_same,)]
767762
end
@@ -814,11 +809,10 @@ for op in ALL_OPS
814809
prep_args.contexts...;
815810
strict=Val(true),
816811
)
817-
prep_same = $prep_op_same(f, ba, x, prep_args.t, contexts...)
812+
prep_same = $prep_op_same(f, ba, x, map(zero, t), contexts...)
818813
if reprepare && has_size(x) && has_size(y) && (size(x) != size(prep_args.x))
819814
prep = $prep_op!(f, prep, ba, x, t, contexts...)
820815
prepstrict = $prep_op!(f, prepstrict, ba, x, t, contexts...)
821-
prep_same = $prep_op_same(f, ba, x, t, contexts...)
822816
end
823817
[(), (prep,), (prepstrict,), (prep_same,)]
824818
end

DifferentiationInterfaceTest/src/utils.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
myzero(x::Number) = zero(x)
2-
myzero(x::AbstractArray) = zero(x)
3-
myzero(x::Union{Tuple,NamedTuple}) = map(myzero, x)
4-
myzero(::Nothing) = nothing
5-
61
mysimilar(x::Number) = one(x)
72
mysimilar(x::AbstractArray) = similar(x)
83
mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x)

0 commit comments

Comments
 (0)