Skip to content

Commit eff8ff5

Browse files
committed
Fix FromPrimitive tests
1 parent f863413 commit eff8ff5

4 files changed

Lines changed: 152 additions & 229 deletions

File tree

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,135 @@
1-
abstract type FromPrimitive <: AbstractADType end
1+
abstract type FromPrimitive{inplace} <: AbstractADType end
22

33
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
4-
inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend)
4+
inplace_support(::FromPrimitive{true}) = InPlaceSupported()
5+
inplace_support(::FromPrimitive{false}) = InPlaceNotSupported()
56

67
function pick_batchsize(fromprim::FromPrimitive, N::Integer)
78
return pick_batchsize(fromprim.backend, N)
89
end
910

11+
"""
12+
AutoForwardFromPrimitive
13+
14+
Wrapper which forces a given backend to act as a reverse-mode backend.
15+
16+
Used in internal testing.
17+
"""
18+
struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
19+
backend::B
20+
end
21+
22+
function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true)
23+
return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend)
24+
end
25+
26+
ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode()
27+
28+
function threshold_batchsize(
29+
fromprim::AutoForwardFromPrimitive{inplace}, dimension::Integer
30+
) where {inplace}
31+
return AutoForwardFromPrimitive(
32+
threshold_batchsize(fromprim.backend, dimension); inplace
33+
)
34+
end
35+
36+
struct FromPrimitivePushforwardPrep{E<:PushforwardPrep} <: PushforwardPrep
37+
pushforward_prep::E
38+
end
39+
40+
function prepare_pushforward(
41+
f::F, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}
42+
) where {F,C}
43+
primitive_prep = prepare_pushforward(f, fromprim.backend, x, tx, contexts...)
44+
return FromPrimitivePushforwardPrep(primitive_prep)
45+
end
46+
47+
function prepare_pushforward(
48+
f!::F, y, fromprim::AutoForwardFromPrimitive, x, tx::NTuple, contexts::Vararg{Context,C}
49+
) where {F,C}
50+
primitive_prep = prepare_pushforward(f!, y, fromprim.backend, x, tx, contexts...)
51+
return FromPrimitivePushforwardPrep(primitive_prep)
52+
end
53+
54+
function value_and_pushforward(
55+
f::F,
56+
prep::FromPrimitivePushforwardPrep,
57+
fromprim::AutoForwardFromPrimitive,
58+
x,
59+
tx::NTuple,
60+
contexts::Vararg{Context,C},
61+
) where {F,C}
62+
return value_and_pushforward(
63+
f, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
64+
)
65+
end
66+
67+
function value_and_pushforward(
68+
f!::F,
69+
y,
70+
prep::FromPrimitivePushforwardPrep,
71+
fromprim::AutoForwardFromPrimitive,
72+
x,
73+
tx::NTuple,
74+
contexts::Vararg{Context,C},
75+
) where {F,C}
76+
return value_and_pushforward(
77+
f!, y, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
78+
)
79+
end
80+
81+
function value_and_pushforward!(
82+
f::F,
83+
ty::NTuple,
84+
prep::FromPrimitivePushforwardPrep,
85+
fromprim::AutoForwardFromPrimitive,
86+
x,
87+
tx::NTuple,
88+
contexts::Vararg{Context,C},
89+
) where {F,C}
90+
return value_and_pushforward!(
91+
f, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
92+
)
93+
end
94+
95+
function value_and_pushforward!(
96+
f!::F,
97+
y,
98+
ty::NTuple,
99+
prep::FromPrimitivePushforwardPrep,
100+
fromprim::AutoForwardFromPrimitive,
101+
x,
102+
tx::NTuple,
103+
contexts::Vararg{Context,C},
104+
) where {F,C}
105+
return value_and_pushforward!(
106+
f!, y, ty, prep.pushforward_prep, fromprim.backend, x, tx, contexts...
107+
)
108+
end
109+
10110
"""
11111
AutoReverseFromPrimitive
12112
13113
Wrapper which forces a given backend to act as a reverse-mode backend.
14114
15115
Used in internal testing.
16116
"""
17-
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive
117+
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
18118
backend::B
19119
end
20120

21121
function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true)
22122
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
23123
end
24124

25-
inplace_support(::AutoReverseFromPrimitive{true}) = InPlaceSupported()
26-
inplace_support(::AutoReverseFromPrimitive{false}) = InPlaceNotSupported()
27125
ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode()
28126

29-
function threshold_batchsize(fromprim::AutoReverseFromPrimitive, dimension::Integer)
30-
return AutoReverseFromPrimitive(threshold_batchsize(fromprim.backend, dimension))
127+
function threshold_batchsize(
128+
fromprim::AutoReverseFromPrimitive{inplace}, dimension::Integer
129+
) where {inplace}
130+
return AutoReverseFromPrimitive(
131+
threshold_batchsize(fromprim.backend, dimension); inplace
132+
)
31133
end
32134

33135
struct FromPrimitivePullbackPrep{E<:PullbackPrep} <: PullbackPrep

0 commit comments

Comments
 (0)