diff options
Diffstat (limited to 'lib/testtools/testtools/testsuite.py')
-rw-r--r-- | lib/testtools/testtools/testsuite.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/lib/testtools/testtools/testsuite.py b/lib/testtools/testtools/testsuite.py index 18de8b89e1..41eb6f7d3a 100644 --- a/lib/testtools/testtools/testsuite.py +++ b/lib/testtools/testtools/testsuite.py @@ -33,7 +33,7 @@ def iterate_tests(test_suite_or_case): class ConcurrentTestSuite(unittest.TestSuite): """A TestSuite whose run() calls out to a concurrency strategy.""" - def __init__(self, suite, make_tests): + def __init__(self, suite, make_tests, wrap_result=None): """Create a ConcurrentTestSuite to execute suite. :param suite: A suite to run concurrently. @@ -42,9 +42,24 @@ class ConcurrentTestSuite(unittest.TestSuite): sub-suites. make_tests must take a suite, and return an iterable of TestCase-like object, each of which must have a run(result) method. + :param wrap_result: An optional function that takes a thread-safe + result and a thread number and must return a ``TestResult`` + object. If not provided, then ``ConcurrentTestSuite`` will just + use a ``ThreadsafeForwardingResult`` wrapped around the result + passed to ``run()``. """ super(ConcurrentTestSuite, self).__init__([suite]) self.make_tests = make_tests + if wrap_result: + self._wrap_result = wrap_result + + def _wrap_result(self, thread_safe_result, thread_number): + """Wrap a thread-safe result before sending it test results. + + You can either override this in a subclass or pass your own + ``wrap_result`` in to the constructor. The latter is preferred. + """ + return thread_safe_result def run(self, result): """Run the tests concurrently. @@ -63,10 +78,10 @@ class ConcurrentTestSuite(unittest.TestSuite): try: threads = {} queue = Queue() - result_semaphore = threading.Semaphore(1) - for test in tests: - process_result = testtools.ThreadsafeForwardingResult(result, - result_semaphore) + semaphore = threading.Semaphore(1) + for i, test in enumerate(tests): + process_result = self._wrap_result( + testtools.ThreadsafeForwardingResult(result, semaphore), i) reader_thread = threading.Thread( target=self._run_test, args=(test, process_result, queue)) threads[test] = reader_thread, process_result |