summaryrefslogtreecommitdiff
path: root/lib/testtools/testtools/testsuite.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/testtools/testtools/testsuite.py')
-rw-r--r--lib/testtools/testtools/testsuite.py40
1 files changed, 39 insertions, 1 deletions
diff --git a/lib/testtools/testtools/testsuite.py b/lib/testtools/testtools/testsuite.py
index 41eb6f7d3a..67ace56110 100644
--- a/lib/testtools/testtools/testsuite.py
+++ b/lib/testtools/testtools/testsuite.py
@@ -6,9 +6,10 @@ __metaclass__ = type
__all__ = [
'ConcurrentTestSuite',
'iterate_tests',
+ 'sorted_tests',
]
-from testtools.helpers import try_imports
+from testtools.helpers import safe_hasattr, try_imports
Queue = try_imports(['Queue.Queue', 'queue.Queue'])
@@ -114,3 +115,40 @@ class FixtureSuite(unittest.TestSuite):
super(FixtureSuite, self).run(result)
finally:
self._fixture.cleanUp()
+
+ def sort_tests(self):
+ self._tests = sorted_tests(self, True)
+
+
+def _flatten_tests(suite_or_case, unpack_outer=False):
+ try:
+ tests = iter(suite_or_case)
+ except TypeError:
+ # Not iterable, assume it's a test case.
+ return [(suite_or_case.id(), suite_or_case)]
+ if (type(suite_or_case) in (unittest.TestSuite,) or
+ unpack_outer):
+ # Plain old test suite (or any others we may add).
+ result = []
+ for test in tests:
+ # Recurse to flatten.
+ result.extend(_flatten_tests(test))
+ return result
+ else:
+ # Find any old actual test and grab its id.
+ suite_id = None
+ tests = iterate_tests(suite_or_case)
+ for test in tests:
+ suite_id = test.id()
+ break
+ # If it has a sort_tests method, call that.
+ if safe_hasattr(suite_or_case, 'sort_tests'):
+ suite_or_case.sort_tests()
+ return [(suite_id, suite_or_case)]
+
+
+def sorted_tests(suite_or_case, unpack_outer=False):
+ """Sort suite_or_case while preserving non-vanilla TestSuites."""
+ tests = _flatten_tests(suite_or_case, unpack_outer=unpack_outer)
+ tests.sort()
+ return unittest.TestSuite([test for (sort_key, test) in tests])