Skip to content

Commit 8c85a44

Browse files
committed
unittest: Add more tests, including test output.
Make it possible to run the unittest from within itself. This allows us to write tests that assert: - That the TestSuite had the correct outcomes (failed, errored, skipped, etc) - That the test output the correct information to the user (test names, test outcomes, etc) With this change, we no longer need to rely on raising assertions that are not subclasses of `AssertionError` to show unexpected behaviour in tests. This commit makes a (small) change to how tests for the unittest framework are written. It is now possible to run the entire test framework within itself, allowing us to have assertions against the TestResults and the generated text output. To write assertions against the output TestSuite, we need a way to redirect the output of `print`. This commit adds the `_stdout` variable to the unittest module to allow us to redirect the output. While this is hacky, this variable will go away in a later commit in the series. NOTE: Some of the added tests either had to be skipped or have the incorrect result encoded in them as the existing framework does not behave as expected. These will be fixed in latter commits in the series. Signed-off-by: Greg Darke <micropython@me.tsukasa.au>
1 parent 54e79af commit 8c85a44

5 files changed

Lines changed: 628 additions & 48 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import io
2+
import unittest
3+
import collections
4+
5+
6+
def _run_tests(suite: unittest.TestSuite) -> tuple[unittest.TestResult, str]:
7+
"""Runs a TestSuite, capturing its output.
8+
9+
Args:
10+
suite: The TestSuite to run
11+
12+
Returns:
13+
A tuple of (test_result, text_output)
14+
"""
15+
stdout = io.StringIO()
16+
tmp_stdout = unittest._stdout
17+
tmp_current_test = unittest.__current_test__
18+
tmp_test_result = unittest.__test_result__
19+
try:
20+
unittest._stdout = stdout
21+
result = unittest.TestResult()
22+
suite.run(result)
23+
return result, stdout.getvalue()
24+
finally:
25+
unittest._stdout = tmp_stdout
26+
unittest.__current_test__ = tmp_current_test
27+
unittest.__test_result__ = tmp_test_result
28+
29+
30+
def run_tests_in_module(parent_test: unittest.TestCase, module) -> tuple[unittest.TestResult, str]:
31+
test_name, parent_suite_name = unittest.__current_test__
32+
parent_suite_name = f"{parent_suite_name[1:-1]}.{test_name}"
33+
suite = unittest.TestSuite(name=parent_suite_name)
34+
suite._load_module(module)
35+
return _run_tests(suite)
36+
37+
38+
def run_tests_in_testcase(
39+
parent_test: unittest.TestCase, *testcase_classes: type[unittest.TestCase]
40+
) -> tuple[unittest.TestResult, str]:
41+
"""Runs tests in the given TestCase classes."""
42+
43+
class _FakeModule: ...
44+
45+
for tc in testcase_classes:
46+
setattr(_FakeModule, tc.__name__, tc)
47+
return run_tests_in_module(parent_test, _FakeModule)
48+
49+
50+
class _TestResultSummary(
51+
collections.namedtuple(
52+
"TestResultSummary", ("testsRun", "numFailures", "numErrors", "numSkipped")
53+
)
54+
):
55+
@classmethod
56+
def convert(cls, test_result: unittest.TestResult):
57+
return cls(
58+
test_result.testsRun,
59+
len(test_result.failures),
60+
len(test_result.errors),
61+
len(test_result.skipped),
62+
)
63+
64+
65+
class BaseTestCase(unittest.TestCase):
66+
def full_test_name(self):
67+
my_name, cls_name = unittest.__current_test__
68+
return f"{cls_name[1:-1]}.{my_name}"
69+
70+
def assertTestResult(
71+
self,
72+
result: unittest.TestResult,
73+
*,
74+
testsRun: int,
75+
numFailures: int,
76+
numErrors: int,
77+
numSkipped: int,
78+
):
79+
expected = _TestResultSummary(testsRun, numFailures, numErrors, numSkipped)
80+
actual = _TestResultSummary.convert(result)
81+
if expected == actual:
82+
return
83+
err_parts = [f"{actual} != (expected) {expected}"]
84+
if actual.numFailures != expected.numFailures:
85+
err_parts.append(f"failures={result.failures!r}")
86+
if actual.numErrors != expected.numErrors:
87+
err_parts.append(f"errors={result.errors!r}")
88+
if actual.numSkipped != expected.numSkipped:
89+
err_parts.append(f"skipped={result.skipped!r}")
90+
raise AssertionError("\n".join(err_parts))

python-stdlib/unittest/tests/test_basics.py

Lines changed: 202 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,215 @@
11
import unittest
2+
import helpers
23

34

4-
class TestWithRunTest(unittest.TestCase):
5-
run = False
5+
class Basics(helpers.BaseTestCase):
6+
def test_bare_function__passes(self):
7+
class FakeModule:
8+
@staticmethod
9+
def test_func1(): ...
10+
@staticmethod
11+
def test_func2(): ...
612

7-
def runTest(self):
8-
TestWithRunTest.run = True
13+
result, output = helpers.run_tests_in_module(self, FakeModule)
14+
self.assertTestResult(result, testsRun=2, numFailures=0, numErrors=0, numSkipped=0)
15+
self.assertTrue(result.wasSuccessful())
16+
self.assertEqual(
17+
output,
18+
f"test_func1 ({self.full_test_name()}) ... ok\ntest_func2 ({self.full_test_name()}) ... ok\n",
19+
)
920

10-
def testRunTest(self):
11-
self.fail()
21+
def test_bare_function__fail(self):
22+
class FakeModule:
23+
@staticmethod
24+
def test_func_fail():
25+
assert False
1226

13-
@staticmethod
14-
def tearDownClass():
15-
if not TestWithRunTest.run:
16-
raise ValueError()
27+
result, output = helpers.run_tests_in_module(self, FakeModule)
28+
self.assertTestResult(result, testsRun=1, numFailures=1, numErrors=0, numSkipped=0)
29+
self.assertFalse(result.wasSuccessful())
30+
self.assertEqual(output, f"test_func_fail ({self.full_test_name()}) ... FAIL\n")
1731

32+
def test_bare_function__error(self):
33+
class FakeModule:
34+
@staticmethod
35+
def test_func_error():
36+
raise ValueError
1837

19-
def test_func():
20-
pass
38+
result, output = helpers.run_tests_in_module(self, FakeModule)
39+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=1, numSkipped=0)
40+
self.assertFalse(result.wasSuccessful())
41+
self.assertEqual(output, f"test_func_error ({self.full_test_name()}) ... ERROR\n")
2142

43+
def test_bare_function__expect_failure__fail(self):
44+
class FakeModule:
45+
@unittest.expectedFailure
46+
def test_func_fail():
47+
assert False
2248

23-
@unittest.expectedFailure
24-
def test_foo():
25-
raise ValueError()
49+
result, output = helpers.run_tests_in_module(self, FakeModule)
50+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=0, numSkipped=0)
51+
self.assertTrue(result.wasSuccessful())
52+
# FIXME: This should be "test_func_fail", but the existing
53+
# implementation pulls the wrong name
54+
self.assertEqual(output, f"test_exp_fail ({self.full_test_name()}) ... ok\n")
55+
56+
@unittest.skip("expectedFailure incorrectly consumes all failure types")
57+
def test_bare_function__expect_failure__error(self):
58+
class FakeModule:
59+
@unittest.expectedFailure
60+
def test_func_error():
61+
raise ValueError
62+
63+
result, output = helpers.run_tests_in_module(self, FakeModule)
64+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=1, numSkipped=0)
65+
self.assertFalse(result.wasSuccessful())
66+
self.assertEqual(output, f"test_func_error ({self.full_test_name()}) ... ERROR\n")
67+
68+
@unittest.skip("expectedFailure incorrectly consumes the SkipTest exception")
69+
def test_bare_function__expect_failure__skip(self):
70+
class FakeModule:
71+
@unittest.expectedFailure
72+
@unittest.skip("reason1")
73+
def test_func_error():
74+
raise ValueError
75+
76+
result, output = helpers.run_tests_in_module(self, FakeModule)
77+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=0, numSkipped=1)
78+
self.assertTrue(result.wasSuccessful())
79+
self.assertEqual(
80+
output, f"test_func_error ({self.full_test_name()}) ... skipped: reason1\n"
81+
)
82+
83+
@unittest.skip("expectedFailure incorrectly consumes the SkipTest exception")
84+
def test_testcase__expect_failure__skip(self):
85+
class FakeModule:
86+
class Test(unittest.TestCase):
87+
@unittest.expectedFailure
88+
def test_func_skip_in_test(self):
89+
self.skipTest("reason1")
90+
91+
@unittest.expectedFailure
92+
@unittest.skip("reason2")
93+
def test_func_skip_wrap_test(self):
94+
pass
95+
96+
result, output = helpers.run_tests_in_module(self, FakeModule)
97+
self.assertTestResult(result, testsRun=2, numFailures=0, numErrors=0, numSkipped=2)
98+
self.assertTrue(result.wasSuccessful())
99+
self.assertEqual(
100+
output,
101+
f"test_func_skip_in_test ({self.full_test_name()}.Test) ... skipped: reason1\n"
102+
f"test_func_skip_wrap_test ({self.full_test_name()}.Test) ... skipped: reason2\n",
103+
)
104+
105+
106+
class TestTestCase(helpers.BaseTestCase):
107+
def test_method_called__passes(self):
108+
class FakeModule:
109+
class Test(unittest.TestCase):
110+
def test1(self): ...
111+
def test2(self): ...
112+
113+
result, output = helpers.run_tests_in_module(self, FakeModule)
114+
self.assertTestResult(result, testsRun=2, numFailures=0, numErrors=0, numSkipped=0)
115+
self.assertEqual(
116+
output,
117+
f"test1 ({self.full_test_name()}.Test) ... ok\n"
118+
f"test2 ({self.full_test_name()}.Test) ... ok\n",
119+
)
120+
121+
def test_method_called__fail(self):
122+
class FakeModule:
123+
class Test(unittest.TestCase):
124+
def test1(self):
125+
self.fail("reason1")
126+
127+
result, output = helpers.run_tests_in_module(self, FakeModule)
128+
self.assertTestResult(result, testsRun=1, numFailures=1, numErrors=0, numSkipped=0)
129+
self.assertEqual(output, f"test1 ({self.full_test_name()}.Test) ... FAIL\n")
130+
131+
def test_method_called__error(self):
132+
class FakeModule:
133+
class Test(unittest.TestCase):
134+
def test1(self):
135+
raise ValueError("reason1")
136+
137+
result, output = helpers.run_tests_in_module(self, FakeModule)
138+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=1, numSkipped=0)
139+
self.assertEqual(output, f"test1 ({self.full_test_name()}.Test) ... ERROR\n")
140+
141+
@unittest.skip("unittest framework incorrectly calls `Test.test3`")
142+
def test_only_calls_methods(self):
143+
class FakeModule:
144+
class Test(unittest.TestCase):
145+
def test1(self): ...
146+
147+
test_copy = test1
148+
149+
test2 = None
150+
151+
@property
152+
def test3(self):
153+
raise ValueError
154+
155+
result, output = helpers.run_tests_in_module(self, FakeModule)
156+
self.assertTestResult(result, testsRun=2, numFailures=0, numErrors=0, numSkipped=0)
157+
self.assertEqual(
158+
output,
159+
f"test1 ({self.full_test_name()}.Test) ... ok\n"
160+
f"test_copy ({self.full_test_name()}.Test) ... ok\n",
161+
)
162+
163+
def test_prefers_runTest(self):
164+
class FakeModule:
165+
class Test(unittest.TestCase):
166+
def test1(self):
167+
self.fail("wrong method called")
168+
169+
def runTest(self):
170+
pass
171+
172+
def __repr__(self):
173+
# FIXME: Remove this method, it should not be needed
174+
return "runTest"
175+
176+
result, output = helpers.run_tests_in_module(self, FakeModule)
177+
self.assertTestResult(result, testsRun=1, numFailures=0, numErrors=0, numSkipped=0)
178+
self.assertEqual(output, f"runTest ({self.full_test_name()}.Test) ... ok\n")
179+
180+
def test_keyboard_interrupt_not_captured(self):
181+
class FakeModule:
182+
class Test(unittest.TestCase):
183+
def test(self):
184+
raise KeyboardInterrupt
185+
186+
with self.assertRaises(KeyboardInterrupt):
187+
helpers.run_tests_in_module(self, FakeModule)
188+
189+
@unittest.skip("unittest framework does not call `TestCase.run` method")
190+
def test_run_method_overridable(self):
191+
class FakeModule:
192+
class Test(unittest.TestCase):
193+
def run(self, result: unittest.TestResult | None = None):
194+
if result is None:
195+
result = self.defaultTestResult()
196+
197+
tmp_result = unittest.TestResult(stream=result._stream)
198+
super().run(tmp_result)
199+
if tmp_result.failures:
200+
test, err_msg = tmp_result.failures[-1]
201+
err_msg += "\nSome extra debugging info"
202+
tmp_result.failures[-1] = test, err_msg
203+
result += tmp_result
204+
return result
205+
206+
def test1(self):
207+
self.fail("reason1")
208+
209+
result, output = helpers.run_tests_in_module(self, FakeModule)
210+
self.assertTestResult(result, testsRun=1, numFailures=1, numErrors=0, numSkipped=0)
211+
self.assertEqual(output, f"test1 ({self.full_test_name()}.Test) ... FAIL\n")
212+
self.assertIn("\nSome extra debugging info", result.failures[-1][1])
26213

27214

28215
if __name__ == "__main__":

0 commit comments

Comments
 (0)