summaryrefslogtreecommitdiff
path: root/source4/scripting/python/subunit
diff options
context:
space:
mode:
Diffstat (limited to 'source4/scripting/python/subunit')
-rw-r--r--source4/scripting/python/subunit/__init__.py386
1 files changed, 386 insertions, 0 deletions
diff --git a/source4/scripting/python/subunit/__init__.py b/source4/scripting/python/subunit/__init__.py
new file mode 100644
index 0000000000..e44dd766cc
--- /dev/null
+++ b/source4/scripting/python/subunit/__init__.py
@@ -0,0 +1,386 @@
+#
+# subunit: extensions to python unittest to get test results from subprocesses.
+# Copyright (C) 2005 Robert Collins <robertc@robertcollins.net>
+# Copyright (C) 2007 Jelmer Vernooij <jelmer@samba.org>
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+#
+
+import os
+from StringIO import StringIO
+import subprocess
+import sys
+import unittest
+
+def test_suite():
+ import subunit.tests
+ return subunit.tests.test_suite()
+
+
+def join_dir(base_path, path):
+ """
+ Returns an absolute path to C{path}, calculated relative to the parent
+ of C{base_path}.
+
+ @param base_path: A path to a file or directory.
+ @param path: An absolute path, or a path relative to the containing
+ directory of C{base_path}.
+
+ @return: An absolute path to C{path}.
+ """
+ return os.path.join(os.path.dirname(os.path.abspath(base_path)), path)
+
+
+class TestProtocolServer(object):
+ """A class for receiving results from a TestProtocol client."""
+
+ OUTSIDE_TEST = 0
+ TEST_STARTED = 1
+ READING_FAILURE = 2
+ READING_ERROR = 3
+
+ def __init__(self, client, stream=sys.stdout):
+ """Create a TestProtocol server instance.
+
+ client should be an object that provides
+ - startTest
+ - addSuccess
+ - addFailure
+ - addError
+ - stopTest
+ methods, i.e. a TestResult.
+ """
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ self.client = client
+ self._stream = stream
+
+ def _addError(self, offset, line):
+ if (self.state == TestProtocolServer.TEST_STARTED and
+ self.current_test_description == line[offset:-1]):
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ self.current_test_description = None
+ self.client.addError(self._current_test, RemoteError(""))
+ self.client.stopTest(self._current_test)
+ self._current_test = None
+ elif (self.state == TestProtocolServer.TEST_STARTED and
+ self.current_test_description + " [" == line[offset:-1]):
+ self.state = TestProtocolServer.READING_ERROR
+ self._message = ""
+ else:
+ self.stdOutLineReceived(line)
+
+ def _addFailure(self, offset, line):
+ if (self.state == TestProtocolServer.TEST_STARTED and
+ self.current_test_description == line[offset:-1]):
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ self.current_test_description = None
+ self.client.addFailure(self._current_test, RemoteError())
+ self.client.stopTest(self._current_test)
+ elif (self.state == TestProtocolServer.TEST_STARTED and
+ self.current_test_description + " [" == line[offset:-1]):
+ self.state = TestProtocolServer.READING_FAILURE
+ self._message = ""
+ else:
+ self.stdOutLineReceived(line)
+
+ def _addSuccess(self, offset, line):
+ if (self.state == TestProtocolServer.TEST_STARTED and
+ self.current_test_description == line[offset:-1]):
+ self.client.addSuccess(self._current_test)
+ self.client.stopTest(self._current_test)
+ self.current_test_description = None
+ self._current_test = None
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ else:
+ self.stdOutLineReceived(line)
+
+ def _appendMessage(self, line):
+ if line[0:2] == " ]":
+ # quoted ] start
+ self._message += line[1:]
+ else:
+ self._message += line
+
+ def endQuote(self, line):
+ if self.state == TestProtocolServer.READING_FAILURE:
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ self.current_test_description = None
+ self.client.addFailure(self._current_test,
+ RemoteError(self._message))
+ self.client.stopTest(self._current_test)
+ elif self.state == TestProtocolServer.READING_ERROR:
+ self.state = TestProtocolServer.OUTSIDE_TEST
+ self.current_test_description = None
+ self.client.addError(self._current_test,
+ RemoteError(self._message))
+ self.client.stopTest(self._current_test)
+ else:
+ self.stdOutLineReceived(line)
+
+ def lineReceived(self, line):
+ """Call the appropriate local method for the received line."""
+ if line == "]\n":
+ self.endQuote(line)
+ elif (self.state == TestProtocolServer.READING_FAILURE or
+ self.state == TestProtocolServer.READING_ERROR):
+ self._appendMessage(line)
+ else:
+ parts = line.split(None, 1)
+ if len(parts) == 2:
+ cmd, rest = parts
+ offset = len(cmd) + 1
+ cmd = cmd.strip(':')
+ if cmd in ('test', 'testing'):
+ self._startTest(offset, line)
+ elif cmd == 'error':
+ self._addError(offset, line)
+ elif cmd == 'failure':
+ self._addFailure(offset, line)
+ elif cmd in ('success', 'successful'):
+ self._addSuccess(offset, line)
+ else:
+ self.stdOutLineReceived(line)
+ else:
+ self.stdOutLineReceived(line)
+
+ def lostConnection(self):
+ """The input connection has finished."""
+ if self.state == TestProtocolServer.TEST_STARTED:
+ self.client.addError(self._current_test,
+ RemoteError("lost connection during test '%s'"
+ % self.current_test_description))
+ self.client.stopTest(self._current_test)
+ elif self.state == TestProtocolServer.READING_ERROR:
+ self.client.addError(self._current_test,
+ RemoteError("lost connection during "
+ "error report of test "
+ "'%s'" %
+ self.current_test_description))
+ self.client.stopTest(self._current_test)
+ elif self.state == TestProtocolServer.READING_FAILURE:
+ self.client.addError(self._current_test,
+ RemoteError("lost connection during "
+ "failure report of test "
+ "'%s'" %
+ self.current_test_description))
+ self.client.stopTest(self._current_test)
+
+ def readFrom(self, pipe):
+ for line in pipe.readlines():
+ self.lineReceived(line)
+ self.lostConnection()
+
+ def _startTest(self, offset, line):
+ """Internal call to change state machine. Override startTest()."""
+ if self.state == TestProtocolServer.OUTSIDE_TEST:
+ self.state = TestProtocolServer.TEST_STARTED
+ self._current_test = RemotedTestCase(line[offset:-1])
+ self.current_test_description = line[offset:-1]
+ self.client.startTest(self._current_test)
+ else:
+ self.stdOutLineReceived(line)
+
+ def stdOutLineReceived(self, line):
+ self._stream.write(line)
+
+
+class RemoteException(Exception):
+ """An exception that occured remotely to python."""
+
+ def __eq__(self, other):
+ try:
+ return self.args == other.args
+ except AttributeError:
+ return False
+
+
+class TestProtocolClient(unittest.TestResult):
+ """A class that looks like a TestResult and informs a TestProtocolServer."""
+
+ def __init__(self, stream):
+ unittest.TestResult.__init__(self)
+ self._stream = stream
+
+ def addError(self, test, error):
+ """Report an error in test test."""
+ self._stream.write("error: %s [\n" % (test.shortDescription() or str(test)))
+ for line in self._exc_info_to_string(error, test).split():
+ self._stream.write("%s\n" % line)
+ self._stream.write("]\n")
+
+ def addFailure(self, test, error):
+ """Report a failure in test test."""
+ self._stream.write("failure: %s [\n" % (test.shortDescription() or str(test)))
+ for line in self._exc_info_to_string(error, test).split():
+ self._stream.write("%s\n" % line)
+ self._stream.write("]\n")
+
+ def addSuccess(self, test):
+ """Report a success in a test."""
+ self._stream.write("successful: %s\n" % (test.shortDescription() or str(test)))
+
+ def startTest(self, test):
+ """Mark a test as starting its test run."""
+ self._stream.write("test: %s\n" % (test.shortDescription() or str(test)))
+
+
+def RemoteError(description=""):
+ if description == "":
+ description = "\n"
+ return (RemoteException, RemoteException(description), None)
+
+
+class RemotedTestCase(unittest.TestCase):
+ """A class to represent test cases run in child processes."""
+
+ def __eq__ (self, other):
+ try:
+ return self.__description == other.__description
+ except AttributeError:
+ return False
+
+ def __init__(self, description):
+ """Create a psuedo test case with description description."""
+ self.__description = description
+
+ def error(self, label):
+ raise NotImplementedError("%s on RemotedTestCases is not permitted." %
+ label)
+
+ def setUp(self):
+ self.error("setUp")
+
+ def tearDown(self):
+ self.error("tearDown")
+
+ def shortDescription(self):
+ return self.__description
+
+ def id(self):
+ return "%s.%s" % (self._strclass(), self.__description)
+
+ def __str__(self):
+ return "%s (%s)" % (self.__description, self._strclass())
+
+ def __repr__(self):
+ return "<%s description='%s'>" % \
+ (self._strclass(), self.__description)
+
+ def run(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ result.startTest(self)
+ result.addError(self, RemoteError("Cannot run RemotedTestCases.\n"))
+ result.stopTest(self)
+
+ def _strclass(self):
+ cls = self.__class__
+ return "%s.%s" % (cls.__module__, cls.__name__)
+
+
+class ExecTestCase(unittest.TestCase):
+ """A test case which runs external scripts for test fixtures."""
+
+ def __init__(self, methodName='runTest'):
+ """Create an instance of the class that will use the named test
+ method when executed. Raises a ValueError if the instance does
+ not have a method with the specified name.
+ """
+ unittest.TestCase.__init__(self, methodName)
+ testMethod = getattr(self, methodName)
+ self.script = join_dir(sys.modules[self.__class__.__module__].__file__,
+ testMethod.__doc__)
+
+ def countTestCases(self):
+ return 1
+
+ def run(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ self._run(result)
+
+ def debug(self):
+ """Run the test without collecting errors in a TestResult"""
+ self._run(unittest.TestResult())
+
+ def _run(self, result):
+ protocol = TestProtocolServer(result)
+ output = subprocess.Popen([self.script],
+ stdout=subprocess.PIPE).communicate()[0]
+ protocol.readFrom(StringIO(output))
+
+
+class IsolatedTestCase(unittest.TestCase):
+ """A TestCase which runs its tests in a forked process."""
+
+ def run(self, result=None):
+ if result is None: result = self.defaultTestResult()
+ run_isolated(unittest.TestCase, self, result)
+
+
+class IsolatedTestSuite(unittest.TestSuite):
+ """A TestCase which runs its tests in a forked process."""
+
+ def run(self, result=None):
+ if result is None: result = unittest.TestResult()
+ run_isolated(unittest.TestSuite, self, result)
+
+
+def run_isolated(klass, self, result):
+ """Run a test suite or case in a subprocess, using the run method on klass.
+ """
+ c2pread, c2pwrite = os.pipe()
+ # fixme - error -> result
+ # now fork
+ pid = os.fork()
+ if pid == 0:
+ # Child
+ # Close parent's pipe ends
+ os.close(c2pread)
+ # Dup fds for child
+ os.dup2(c2pwrite, 1)
+ # Close pipe fds.
+ os.close(c2pwrite)
+
+ # at this point, sys.stdin is redirected, now we want
+ # to filter it to escape ]'s.
+ ### XXX: test and write that bit.
+
+ result = TestProtocolClient(sys.stdout)
+ klass.run(self, result)
+ sys.stdout.flush()
+ sys.stderr.flush()
+ # exit HARD, exit NOW.
+ os._exit(0)
+ else:
+ # Parent
+ # Close child pipe ends
+ os.close(c2pwrite)
+ # hookup a protocol engine
+ protocol = TestProtocolServer(result)
+ protocol.readFrom(os.fdopen(c2pread, 'rU'))
+ os.waitpid(pid, 0)
+ # TODO return code evaluation.
+ return result
+
+
+class SubunitTestRunner:
+ def __init__(self, stream=sys.stdout):
+ self.stream = stream
+
+ def run(self, test):
+ "Run the given test case or test suite."
+ result = TestProtocolClient(self.stream)
+ test(result)
+ return result
+