diff options
Diffstat (limited to 'lib/testtools/testtools/tests/test_spinner.py')
-rw-r--r-- | lib/testtools/testtools/tests/test_spinner.py | 325 |
1 files changed, 325 insertions, 0 deletions
diff --git a/lib/testtools/testtools/tests/test_spinner.py b/lib/testtools/testtools/tests/test_spinner.py new file mode 100644 index 0000000000..f89895653b --- /dev/null +++ b/lib/testtools/testtools/tests/test_spinner.py @@ -0,0 +1,325 @@ +# Copyright (c) 2010 Jonathan M. Lange. See LICENSE for details. + +"""Tests for the evil Twisted reactor-spinning we do.""" + +import os +import signal + +from testtools import ( + skipIf, + TestCase, + ) +from testtools.helpers import try_import +from testtools.matchers import ( + Equals, + Is, + MatchesException, + Raises, + ) + +_spinner = try_import('testtools._spinner') + +defer = try_import('twisted.internet.defer') +Failure = try_import('twisted.python.failure.Failure') + + +class NeedsTwistedTestCase(TestCase): + + def setUp(self): + super(NeedsTwistedTestCase, self).setUp() + if defer is None or Failure is None: + self.skipTest("Need Twisted to run") + + +class TestNotReentrant(NeedsTwistedTestCase): + + def test_not_reentrant(self): + # A function decorated as not being re-entrant will raise a + # _spinner.ReentryError if it is called while it is running. + calls = [] + @_spinner.not_reentrant + def log_something(): + calls.append(None) + if len(calls) < 5: + log_something() + self.assertThat( + log_something, Raises(MatchesException(_spinner.ReentryError))) + self.assertEqual(1, len(calls)) + + def test_deeper_stack(self): + calls = [] + @_spinner.not_reentrant + def g(): + calls.append(None) + if len(calls) < 5: + f() + @_spinner.not_reentrant + def f(): + calls.append(None) + if len(calls) < 5: + g() + self.assertThat(f, Raises(MatchesException(_spinner.ReentryError))) + self.assertEqual(2, len(calls)) + + +class TestExtractResult(NeedsTwistedTestCase): + + def test_not_fired(self): + # _spinner.extract_result raises _spinner.DeferredNotFired if it's + # given a Deferred that has not fired. + self.assertThat(lambda:_spinner.extract_result(defer.Deferred()), + Raises(MatchesException(_spinner.DeferredNotFired))) + + def test_success(self): + # _spinner.extract_result returns the value of the Deferred if it has + # fired successfully. + marker = object() + d = defer.succeed(marker) + self.assertThat(_spinner.extract_result(d), Equals(marker)) + + def test_failure(self): + # _spinner.extract_result raises the failure's exception if it's given + # a Deferred that is failing. + try: + 1/0 + except ZeroDivisionError: + f = Failure() + d = defer.fail(f) + self.assertThat(lambda:_spinner.extract_result(d), + Raises(MatchesException(ZeroDivisionError))) + + +class TestTrapUnhandledErrors(NeedsTwistedTestCase): + + def test_no_deferreds(self): + marker = object() + result, errors = _spinner.trap_unhandled_errors(lambda: marker) + self.assertEqual([], errors) + self.assertIs(marker, result) + + def test_unhandled_error(self): + failures = [] + def make_deferred_but_dont_handle(): + try: + 1/0 + except ZeroDivisionError: + f = Failure() + failures.append(f) + defer.fail(f) + result, errors = _spinner.trap_unhandled_errors( + make_deferred_but_dont_handle) + self.assertIs(None, result) + self.assertEqual(failures, [error.failResult for error in errors]) + + +class TestRunInReactor(NeedsTwistedTestCase): + + def make_reactor(self): + from twisted.internet import reactor + return reactor + + def make_spinner(self, reactor=None): + if reactor is None: + reactor = self.make_reactor() + return _spinner.Spinner(reactor) + + def make_timeout(self): + return 0.01 + + def test_function_called(self): + # run_in_reactor actually calls the function given to it. + calls = [] + marker = object() + self.make_spinner().run(self.make_timeout(), calls.append, marker) + self.assertThat(calls, Equals([marker])) + + def test_return_value_returned(self): + # run_in_reactor returns the value returned by the function given to + # it. + marker = object() + result = self.make_spinner().run(self.make_timeout(), lambda: marker) + self.assertThat(result, Is(marker)) + + def test_exception_reraised(self): + # If the given function raises an error, run_in_reactor re-raises that + # error. + self.assertThat( + lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0), + Raises(MatchesException(ZeroDivisionError))) + + def test_keyword_arguments(self): + # run_in_reactor passes keyword arguments on. + calls = [] + function = lambda *a, **kw: calls.extend([a, kw]) + self.make_spinner().run(self.make_timeout(), function, foo=42) + self.assertThat(calls, Equals([(), {'foo': 42}])) + + def test_not_reentrant(self): + # run_in_reactor raises an error if it is called inside another call + # to run_in_reactor. + spinner = self.make_spinner() + self.assertThat(lambda: spinner.run( + self.make_timeout(), spinner.run, self.make_timeout(), + lambda: None), Raises(MatchesException(_spinner.ReentryError))) + + def test_deferred_value_returned(self): + # If the given function returns a Deferred, run_in_reactor returns the + # value in the Deferred at the end of the callback chain. + marker = object() + result = self.make_spinner().run( + self.make_timeout(), lambda: defer.succeed(marker)) + self.assertThat(result, Is(marker)) + + def test_preserve_signal_handler(self): + signals = ['SIGINT', 'SIGTERM', 'SIGCHLD'] + signals = filter( + None, (getattr(signal, name, None) for name in signals)) + for sig in signals: + self.addCleanup(signal.signal, sig, signal.getsignal(sig)) + new_hdlrs = list(lambda *a: None for _ in signals) + for sig, hdlr in zip(signals, new_hdlrs): + signal.signal(sig, hdlr) + spinner = self.make_spinner() + spinner.run(self.make_timeout(), lambda: None) + self.assertEqual(new_hdlrs, map(signal.getsignal, signals)) + + def test_timeout(self): + # If the function takes too long to run, we raise a + # _spinner.TimeoutError. + timeout = self.make_timeout() + self.assertThat( + lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()), + Raises(MatchesException(_spinner.TimeoutError))) + + def test_no_junk_by_default(self): + # If the reactor hasn't spun yet, then there cannot be any junk. + spinner = self.make_spinner() + self.assertThat(spinner.get_junk(), Equals([])) + + def test_clean_do_nothing(self): + # If there's nothing going on in the reactor, then clean does nothing + # and returns an empty list. + spinner = self.make_spinner() + result = spinner._clean() + self.assertThat(result, Equals([])) + + def test_clean_delayed_call(self): + # If there's a delayed call in the reactor, then clean cancels it and + # returns an empty list. + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + call = reactor.callLater(10, lambda: None) + results = spinner._clean() + self.assertThat(results, Equals([call])) + self.assertThat(call.active(), Equals(False)) + + def test_clean_delayed_call_cancelled(self): + # If there's a delayed call that's just been cancelled, then it's no + # longer there. + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + call = reactor.callLater(10, lambda: None) + call.cancel() + results = spinner._clean() + self.assertThat(results, Equals([])) + + def test_clean_selectables(self): + # If there's still a selectable (e.g. a listening socket), then + # clean() removes it from the reactor's registry. + # + # Note that the socket is left open. This emulates a bug in trial. + from twisted.internet.protocol import ServerFactory + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + port = reactor.listenTCP(0, ServerFactory()) + spinner.run(self.make_timeout(), lambda: None) + results = spinner.get_junk() + self.assertThat(results, Equals([port])) + + def test_clean_running_threads(self): + import threading + import time + current_threads = list(threading.enumerate()) + reactor = self.make_reactor() + timeout = self.make_timeout() + spinner = self.make_spinner(reactor) + spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0) + self.assertThat(list(threading.enumerate()), Equals(current_threads)) + + def test_leftover_junk_available(self): + # If 'run' is given a function that leaves the reactor dirty in some + # way, 'run' will clean up the reactor and then store information + # about the junk. This information can be got using get_junk. + from twisted.internet.protocol import ServerFactory + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + port = spinner.run( + self.make_timeout(), reactor.listenTCP, 0, ServerFactory()) + self.assertThat(spinner.get_junk(), Equals([port])) + + def test_will_not_run_with_previous_junk(self): + # If 'run' is called and there's still junk in the spinner's junk + # list, then the spinner will refuse to run. + from twisted.internet.protocol import ServerFactory + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + timeout = self.make_timeout() + spinner.run(timeout, reactor.listenTCP, 0, ServerFactory()) + self.assertThat(lambda: spinner.run(timeout, lambda: None), + Raises(MatchesException(_spinner.StaleJunkError))) + + def test_clear_junk_clears_previous_junk(self): + # If 'run' is called and there's still junk in the spinner's junk + # list, then the spinner will refuse to run. + from twisted.internet.protocol import ServerFactory + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + timeout = self.make_timeout() + port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory()) + junk = spinner.clear_junk() + self.assertThat(junk, Equals([port])) + self.assertThat(spinner.get_junk(), Equals([])) + + @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only") + def test_sigint_raises_no_result_error(self): + # If we get a SIGINT during a run, we raise _spinner.NoResultError. + SIGINT = getattr(signal, 'SIGINT', None) + if not SIGINT: + self.skipTest("SIGINT not available") + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + timeout = self.make_timeout() + reactor.callLater(timeout, os.kill, os.getpid(), SIGINT) + self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred), + Raises(MatchesException(_spinner.NoResultError))) + self.assertEqual([], spinner._clean()) + + @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only") + def test_sigint_raises_no_result_error_second_time(self): + # If we get a SIGINT during a run, we raise _spinner.NoResultError. + # This test is exactly the same as test_sigint_raises_no_result_error, + # and exists to make sure we haven't futzed with state. + self.test_sigint_raises_no_result_error() + + @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only") + def test_fast_sigint_raises_no_result_error(self): + # If we get a SIGINT during a run, we raise _spinner.NoResultError. + SIGINT = getattr(signal, 'SIGINT', None) + if not SIGINT: + self.skipTest("SIGINT not available") + reactor = self.make_reactor() + spinner = self.make_spinner(reactor) + timeout = self.make_timeout() + reactor.callWhenRunning(os.kill, os.getpid(), SIGINT) + self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred), + Raises(MatchesException(_spinner.NoResultError))) + self.assertEqual([], spinner._clean()) + + @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only") + def test_fast_sigint_raises_no_result_error_second_time(self): + self.test_fast_sigint_raises_no_result_error() + + +def test_suite(): + from unittest import TestLoader + return TestLoader().loadTestsFromName(__name__) |