summaryrefslogtreecommitdiff
path: root/lib/testtools/testtools/matchers.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/testtools/testtools/matchers.py')
-rw-r--r--lib/testtools/testtools/matchers.py521
1 files changed, 498 insertions, 23 deletions
diff --git a/lib/testtools/testtools/matchers.py b/lib/testtools/testtools/matchers.py
index 06b348c6d9..6ee33f0fd8 100644
--- a/lib/testtools/testtools/matchers.py
+++ b/lib/testtools/testtools/matchers.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2009-2010 Jonathan M. Lange. See LICENSE for details.
+# Copyright (c) 2009-2011 testtools developers. See LICENSE for details.
"""Matchers, a way to express complex assertions outside the testcase.
@@ -12,14 +12,25 @@ $ python -c 'import testtools.matchers; print testtools.matchers.__all__'
__metaclass__ = type
__all__ = [
+ 'AfterPreprocessing',
+ 'AllMatch',
'Annotate',
+ 'Contains',
'DocTestMatches',
+ 'EndsWith',
'Equals',
+ 'GreaterThan',
'Is',
+ 'IsInstance',
+ 'KeysEqual',
'LessThan',
'MatchesAll',
'MatchesAny',
'MatchesException',
+ 'MatchesListwise',
+ 'MatchesRegex',
+ 'MatchesSetwise',
+ 'MatchesStructure',
'NotEquals',
'Not',
'Raises',
@@ -30,9 +41,16 @@ __all__ = [
import doctest
import operator
from pprint import pformat
+import re
import sys
+import types
-from testtools.compat import classtypes, _error_repr, isbaseexception
+from testtools.compat import (
+ classtypes,
+ _error_repr,
+ isbaseexception,
+ istext,
+ )
class Matcher(object):
@@ -113,6 +131,69 @@ class Mismatch(object):
id(self), self.__dict__)
+class MismatchDecorator(object):
+ """Decorate a ``Mismatch``.
+
+ Forwards all messages to the original mismatch object. Probably the best
+ way to use this is inherit from this class and then provide your own
+ custom decoration logic.
+ """
+
+ def __init__(self, original):
+ """Construct a `MismatchDecorator`.
+
+ :param original: A `Mismatch` object to decorate.
+ """
+ self.original = original
+
+ def __repr__(self):
+ return '<testtools.matchers.MismatchDecorator(%r)>' % (self.original,)
+
+ def describe(self):
+ return self.original.describe()
+
+ def get_details(self):
+ return self.original.get_details()
+
+
+class _NonManglingOutputChecker(doctest.OutputChecker):
+ """Doctest checker that works with unicode rather than mangling strings
+
+ This is needed because current Python versions have tried to fix string
+ encoding related problems, but regressed the default behaviour with unicode
+ inputs in the process.
+
+ In Python 2.6 and 2.7 `OutputChecker.output_difference` is was changed to
+ return a bytestring encoded as per `sys.stdout.encoding`, or utf-8 if that
+ can't be determined. Worse, that encoding process happens in the innocent
+ looking `_indent` global function. Because the `DocTestMismatch.describe`
+ result may well not be destined for printing to stdout, this is no good
+ for us. To get a unicode return as before, the method is monkey patched if
+ `doctest._encoding` exists.
+
+ Python 3 has a different problem. For some reason both inputs are encoded
+ to ascii with 'backslashreplace', making an escaped string matches its
+ unescaped form. Overriding the offending `OutputChecker._toAscii` method
+ is sufficient to revert this.
+ """
+
+ def _toAscii(self, s):
+ """Return `s` unchanged rather than mangling it to ascii"""
+ return s
+
+ # Only do this overriding hackery if doctest has a broken _input function
+ if getattr(doctest, "_encoding", None) is not None:
+ from types import FunctionType as __F
+ __f = doctest.OutputChecker.output_difference.im_func
+ __g = dict(__f.func_globals)
+ def _indent(s, indent=4, _pattern=re.compile("^(?!$)", re.MULTILINE)):
+ """Prepend non-empty lines in `s` with `indent` number of spaces"""
+ return _pattern.sub(indent*" ", s)
+ __g["_indent"] = _indent
+ output_difference = __F(__f.func_code, __g, "output_difference")
+ del __F, __f, __g, _indent
+
+
class DocTestMatches(object):
"""See if a string matches a doctest example."""
@@ -127,7 +208,7 @@ class DocTestMatches(object):
example += '\n'
self.want = example # required variable name by doctest.
self.flags = flags
- self._checker = doctest.OutputChecker()
+ self._checker = _NonManglingOutputChecker()
def __str__(self):
if self.flags:
@@ -137,7 +218,7 @@ class DocTestMatches(object):
return 'DocTestMatches(%r%s)' % (self.want, flagstr)
def _with_nl(self, actual):
- result = str(actual)
+ result = self.want.__class__(actual)
if not result.endswith('\n'):
result += '\n'
return result
@@ -163,14 +244,28 @@ class DocTestMismatch(Mismatch):
return self.matcher._describe_difference(self.with_nl)
+class DoesNotContain(Mismatch):
+
+ def __init__(self, matchee, needle):
+ """Create a DoesNotContain Mismatch.
+
+ :param matchee: the object that did not contain needle.
+ :param needle: the needle that 'matchee' was expected to contain.
+ """
+ self.matchee = matchee
+ self.needle = needle
+
+ def describe(self):
+ return "%r not in %r" % (self.needle, self.matchee)
+
+
class DoesNotStartWith(Mismatch):
def __init__(self, matchee, expected):
"""Create a DoesNotStartWith Mismatch.
:param matchee: the string that did not match.
- :param expected: the string that `matchee` was expected to start
- with.
+ :param expected: the string that 'matchee' was expected to start with.
"""
self.matchee = matchee
self.expected = expected
@@ -186,7 +281,7 @@ class DoesNotEndWith(Mismatch):
"""Create a DoesNotEndWith Mismatch.
:param matchee: the string that did not match.
- :param expected: the string that `matchee` was expected to end with.
+ :param expected: the string that 'matchee' was expected to end with.
"""
self.matchee = matchee
self.expected = expected
@@ -222,13 +317,20 @@ class _BinaryMismatch(Mismatch):
self._mismatch_string = mismatch_string
self.other = other
+ def _format(self, thing):
+ # Blocks of text with newlines are formatted as triple-quote
+ # strings. Everything else is pretty-printed.
+ if istext(thing) and '\n' in thing:
+ return '"""\\\n%s"""' % (thing,)
+ return pformat(thing)
+
def describe(self):
left = repr(self.expected)
right = repr(self.other)
if len(left) + len(right) > 70:
return "%s:\nreference = %s\nactual = %s\n" % (
- self._mismatch_string, pformat(self.expected),
- pformat(self.other))
+ self._mismatch_string, self._format(self.expected),
+ self._format(self.other))
else:
return "%s %s %s" % (left, self._mismatch_string,right)
@@ -243,8 +345,8 @@ class Equals(_BinaryComparison):
class NotEquals(_BinaryComparison):
"""Matches if the items are not equal.
- In most cases, this is equivalent to `Not(Equals(foo))`. The difference
- only matters when testing `__ne__` implementations.
+ In most cases, this is equivalent to ``Not(Equals(foo))``. The difference
+ only matters when testing ``__ne__`` implementations.
"""
comparator = operator.ne
@@ -258,11 +360,54 @@ class Is(_BinaryComparison):
mismatch_string = 'is not'
+class IsInstance(object):
+ """Matcher that wraps isinstance."""
+
+ def __init__(self, *types):
+ self.types = tuple(types)
+
+ def __str__(self):
+ return "%s(%s)" % (self.__class__.__name__,
+ ', '.join(type.__name__ for type in self.types))
+
+ def match(self, other):
+ if isinstance(other, self.types):
+ return None
+ return NotAnInstance(other, self.types)
+
+
+class NotAnInstance(Mismatch):
+
+ def __init__(self, matchee, types):
+ """Create a NotAnInstance Mismatch.
+
+ :param matchee: the thing which is not an instance of any of types.
+ :param types: A tuple of the types which were expected.
+ """
+ self.matchee = matchee
+ self.types = types
+
+ def describe(self):
+ if len(self.types) == 1:
+ typestr = self.types[0].__name__
+ else:
+ typestr = 'any of (%s)' % ', '.join(type.__name__ for type in
+ self.types)
+ return "'%s' is not an instance of %s" % (self.matchee, typestr)
+
+
class LessThan(_BinaryComparison):
"""Matches if the item is less than the matchers reference object."""
comparator = operator.__lt__
- mismatch_string = 'is >='
+ mismatch_string = 'is not >'
+
+
+class GreaterThan(_BinaryComparison):
+ """Matches if the item is greater than the matchers reference object."""
+
+ comparator = operator.__gt__
+ mismatch_string = 'is not <'
class MatchesAny(object):
@@ -351,17 +496,26 @@ class MatchedUnexpectedly(Mismatch):
class MatchesException(Matcher):
"""Match an exc_info tuple against an exception instance or type."""
- def __init__(self, exception):
+ def __init__(self, exception, value_re=None):
"""Create a MatchesException that will match exc_info's for exception.
:param exception: Either an exception instance or type.
If an instance is given, the type and arguments of the exception
are checked. If a type is given only the type of the exception is
- checked.
+ checked. If a tuple is given, then as with isinstance, any of the
+ types in the tuple matching is sufficient to match.
+ :param value_re: If 'exception' is a type, and the matchee exception
+ is of the right type, then match against this. If value_re is a
+ string, then assume value_re is a regular expression and match
+ the str() of the exception against it. Otherwise, assume value_re
+ is a matcher, and match the exception against it.
"""
Matcher.__init__(self)
self.expected = exception
- self._is_instance = type(self.expected) not in classtypes()
+ if istext(value_re):
+ value_re = AfterPreproccessing(str, MatchesRegex(value_re), False)
+ self.value_re = value_re
+ self._is_instance = type(self.expected) not in classtypes() + (tuple,)
def match(self, other):
if type(other) != tuple:
@@ -371,9 +525,12 @@ class MatchesException(Matcher):
expected_class = expected_class.__class__
if not issubclass(other[0], expected_class):
return Mismatch('%r is not a %r' % (other[0], expected_class))
- if self._is_instance and other[1].args != self.expected.args:
- return Mismatch('%s has different arguments to %s.' % (
- _error_repr(other[1]), _error_repr(self.expected)))
+ if self._is_instance:
+ if other[1].args != self.expected.args:
+ return Mismatch('%s has different arguments to %s.' % (
+ _error_repr(other[1]), _error_repr(self.expected)))
+ elif self.value_re is not None:
+ return self.value_re.match(other[1])
def __str__(self):
if self._is_instance:
@@ -381,6 +538,29 @@ class MatchesException(Matcher):
return "MatchesException(%s)" % repr(self.expected)
+class Contains(Matcher):
+ """Checks whether something is container in another thing."""
+
+ def __init__(self, needle):
+ """Create a Contains Matcher.
+
+ :param needle: the thing that needs to be contained by matchees.
+ """
+ self.needle = needle
+
+ def __str__(self):
+ return "Contains(%r)" % (self.needle,)
+
+ def match(self, matchee):
+ try:
+ if self.needle not in matchee:
+ return DoesNotContain(matchee, self.needle)
+ except TypeError:
+ # e.g. 1 in 2 will raise TypeError
+ return DoesNotContain(matchee, self.needle)
+ return None
+
+
class StartsWith(Matcher):
"""Checks whether one string starts with another."""
@@ -425,7 +605,7 @@ class KeysEqual(Matcher):
def __init__(self, *expected):
"""Create a `KeysEqual` Matcher.
- :param *expected: The keys the dict is expected to have. If a dict,
+ :param expected: The keys the dict is expected to have. If a dict,
then we use the keys of that dict, if a collection, we assume it
is a collection of expected keys.
"""
@@ -457,6 +637,13 @@ class Annotate(object):
self.annotation = annotation
self.matcher = matcher
+ @classmethod
+ def if_message(cls, annotation, matcher):
+ """Annotate ``matcher`` only if ``annotation`` is non-empty."""
+ if not annotation:
+ return matcher
+ return cls(annotation, matcher)
+
def __str__(self):
return 'Annotate(%r, %s)' % (self.annotation, self.matcher)
@@ -466,15 +653,16 @@ class Annotate(object):
return AnnotatedMismatch(self.annotation, mismatch)
-class AnnotatedMismatch(Mismatch):
+class AnnotatedMismatch(MismatchDecorator):
"""A mismatch annotated with a descriptive string."""
def __init__(self, annotation, mismatch):
+ super(AnnotatedMismatch, self).__init__(mismatch)
self.annotation = annotation
self.mismatch = mismatch
def describe(self):
- return '%s: %s' % (self.mismatch.describe(), self.annotation)
+ return '%s: %s' % (self.original.describe(), self.annotation)
class Raises(Matcher):
@@ -502,16 +690,19 @@ class Raises(Matcher):
# Catch all exceptions: Raises() should be able to match a
# KeyboardInterrupt or SystemExit.
except:
+ exc_info = sys.exc_info()
if self.exception_matcher:
- mismatch = self.exception_matcher.match(sys.exc_info())
+ mismatch = self.exception_matcher.match(exc_info)
if not mismatch:
+ del exc_info
return
else:
mismatch = None
# The exception did not match, or no explicit matching logic was
# performed. If the exception is a non-user exception (that is, not
# a subclass of Exception on Python 2.5+) then propogate it.
- if isbaseexception(sys.exc_info()[1]):
+ if isbaseexception(exc_info[1]):
+ del exc_info
raise
return mismatch
@@ -523,8 +714,292 @@ def raises(exception):
"""Make a matcher that checks that a callable raises an exception.
This is a convenience function, exactly equivalent to::
+
return Raises(MatchesException(exception))
See `Raises` and `MatchesException` for more information.
"""
return Raises(MatchesException(exception))
+
+
+class MatchesListwise(object):
+ """Matches if each matcher matches the corresponding value.
+
+ More easily explained by example than in words:
+
+ >>> MatchesListwise([Equals(1)]).match([1])
+ >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
+ >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
+ Differences: [
+ 1 != 2
+ 2 != 1
+ ]
+ """
+
+ def __init__(self, matchers):
+ self.matchers = matchers
+
+ def match(self, values):
+ mismatches = []
+ length_mismatch = Annotate(
+ "Length mismatch", Equals(len(self.matchers))).match(len(values))
+ if length_mismatch:
+ mismatches.append(length_mismatch)
+ for matcher, value in zip(self.matchers, values):
+ mismatch = matcher.match(value)
+ if mismatch:
+ mismatches.append(mismatch)
+ if mismatches:
+ return MismatchesAll(mismatches)
+
+
+class MatchesStructure(object):
+ """Matcher that matches an object structurally.
+
+ 'Structurally' here means that attributes of the object being matched are
+ compared against given matchers.
+
+ `fromExample` allows the creation of a matcher from a prototype object and
+ then modified versions can be created with `update`.
+
+ `byEquality` creates a matcher in much the same way as the constructor,
+ except that the matcher for each of the attributes is assumed to be
+ `Equals`.
+
+ `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
+ the matcher, rather than just using `Equals`.
+ """
+
+ def __init__(self, **kwargs):
+ """Construct a `MatchesStructure`.
+
+ :param kwargs: A mapping of attributes to matchers.
+ """
+ self.kws = kwargs
+
+ @classmethod
+ def byEquality(cls, **kwargs):
+ """Matches an object where the attributes equal the keyword values.
+
+ Similar to the constructor, except that the matcher is assumed to be
+ Equals.
+ """
+ return cls.byMatcher(Equals, **kwargs)
+
+ @classmethod
+ def byMatcher(cls, matcher, **kwargs):
+ """Matches an object where the attributes match the keyword values.
+
+ Similar to the constructor, except that the provided matcher is used
+ to match all of the values.
+ """
+ return cls(
+ **dict((name, matcher(value)) for name, value in kwargs.items()))
+
+ @classmethod
+ def fromExample(cls, example, *attributes):
+ kwargs = {}
+ for attr in attributes:
+ kwargs[attr] = Equals(getattr(example, attr))
+ return cls(**kwargs)
+
+ def update(self, **kws):
+ new_kws = self.kws.copy()
+ for attr, matcher in kws.items():
+ if matcher is None:
+ new_kws.pop(attr, None)
+ else:
+ new_kws[attr] = matcher
+ return type(self)(**new_kws)
+
+ def __str__(self):
+ kws = []
+ for attr, matcher in sorted(self.kws.items()):
+ kws.append("%s=%s" % (attr, matcher))
+ return "%s(%s)" % (self.__class__.__name__, ', '.join(kws))
+
+ def match(self, value):
+ matchers = []
+ values = []
+ for attr, matcher in sorted(self.kws.items()):
+ matchers.append(Annotate(attr, matcher))
+ values.append(getattr(value, attr))
+ return MatchesListwise(matchers).match(values)
+
+
+class MatchesRegex(object):
+ """Matches if the matchee is matched by a regular expression."""
+
+ def __init__(self, pattern, flags=0):
+ self.pattern = pattern
+ self.flags = flags
+
+ def __str__(self):
+ args = ['%r' % self.pattern]
+ flag_arg = []
+ # dir() sorts the attributes for us, so we don't need to do it again.
+ for flag in dir(re):
+ if len(flag) == 1:
+ if self.flags & getattr(re, flag):
+ flag_arg.append('re.%s' % flag)
+ if flag_arg:
+ args.append('|'.join(flag_arg))
+ return '%s(%s)' % (self.__class__.__name__, ', '.join(args))
+
+ def match(self, value):
+ if not re.match(self.pattern, value, self.flags):
+ return Mismatch("%r does not match /%s/" % (
+ value, self.pattern))
+
+
+class MatchesSetwise(object):
+ """Matches if all the matchers match elements of the value being matched.
+
+ That is, each element in the 'observed' set must match exactly one matcher
+ from the set of matchers, with no matchers left over.
+
+ The difference compared to `MatchesListwise` is that the order of the
+ matchings does not matter.
+ """
+
+ def __init__(self, *matchers):
+ self.matchers = matchers
+
+ def match(self, observed):
+ remaining_matchers = set(self.matchers)
+ not_matched = []
+ for value in observed:
+ for matcher in remaining_matchers:
+ if matcher.match(value) is None:
+ remaining_matchers.remove(matcher)
+ break
+ else:
+ not_matched.append(value)
+ if not_matched or remaining_matchers:
+ remaining_matchers = list(remaining_matchers)
+ # There are various cases that all should be reported somewhat
+ # differently.
+
+ # There are two trivial cases:
+ # 1) There are just some matchers left over.
+ # 2) There are just some values left over.
+
+ # Then there are three more interesting cases:
+ # 3) There are the same number of matchers and values left over.
+ # 4) There are more matchers left over than values.
+ # 5) There are more values left over than matchers.
+
+ if len(not_matched) == 0:
+ if len(remaining_matchers) > 1:
+ msg = "There were %s matchers left over: " % (
+ len(remaining_matchers),)
+ else:
+ msg = "There was 1 matcher left over: "
+ msg += ', '.join(map(str, remaining_matchers))
+ return Mismatch(msg)
+ elif len(remaining_matchers) == 0:
+ if len(not_matched) > 1:
+ return Mismatch(
+ "There were %s values left over: %s" % (
+ len(not_matched), not_matched))
+ else:
+ return Mismatch(
+ "There was 1 value left over: %s" % (
+ not_matched, ))
+ else:
+ common_length = min(len(remaining_matchers), len(not_matched))
+ if common_length == 0:
+ raise AssertionError("common_length can't be 0 here")
+ if common_length > 1:
+ msg = "There were %s mismatches" % (common_length,)
+ else:
+ msg = "There was 1 mismatch"
+ if len(remaining_matchers) > len(not_matched):
+ extra_matchers = remaining_matchers[common_length:]
+ msg += " and %s extra matcher" % (len(extra_matchers), )
+ if len(extra_matchers) > 1:
+ msg += "s"
+ msg += ': ' + ', '.join(map(str, extra_matchers))
+ elif len(not_matched) > len(remaining_matchers):
+ extra_values = not_matched[common_length:]
+ msg += " and %s extra value" % (len(extra_values), )
+ if len(extra_values) > 1:
+ msg += "s"
+ msg += ': ' + str(extra_values)
+ return Annotate(
+ msg, MatchesListwise(remaining_matchers[:common_length])
+ ).match(not_matched[:common_length])
+
+
+class AfterPreprocessing(object):
+ """Matches if the value matches after passing through a function.
+
+ This can be used to aid in creating trivial matchers as functions, for
+ example::
+
+ def PathHasFileContent(content):
+ def _read(path):
+ return open(path).read()
+ return AfterPreprocessing(_read, Equals(content))
+ """
+
+ def __init__(self, preprocessor, matcher, annotate=True):
+ """Create an AfterPreprocessing matcher.
+
+ :param preprocessor: A function called with the matchee before
+ matching.
+ :param matcher: What to match the preprocessed matchee against.
+ :param annotate: Whether or not to annotate the matcher with
+ something explaining how we transformed the matchee. Defaults
+ to True.
+ """
+ self.preprocessor = preprocessor
+ self.matcher = matcher
+ self.annotate = annotate
+
+ def _str_preprocessor(self):
+ if isinstance(self.preprocessor, types.FunctionType):
+ return '<function %s>' % self.preprocessor.__name__
+ return str(self.preprocessor)
+
+ def __str__(self):
+ return "AfterPreprocessing(%s, %s)" % (
+ self._str_preprocessor(), self.matcher)
+
+ def match(self, value):
+ after = self.preprocessor(value)
+ if self.annotate:
+ matcher = Annotate(
+ "after %s on %r" % (self._str_preprocessor(), value),
+ self.matcher)
+ else:
+ matcher = self.matcher
+ return matcher.match(after)
+
+# This is the old, deprecated. spelling of the name, kept for backwards
+# compatibility.
+AfterPreproccessing = AfterPreprocessing
+
+
+class AllMatch(object):
+ """Matches if all provided values match the given matcher."""
+
+ def __init__(self, matcher):
+ self.matcher = matcher
+
+ def __str__(self):
+ return 'AllMatch(%s)' % (self.matcher,)
+
+ def match(self, values):
+ mismatches = []
+ for value in values:
+ mismatch = self.matcher.match(value)
+ if mismatch:
+ mismatches.append(mismatch)
+ if mismatches:
+ return MismatchesAll(mismatches)
+
+
+# Signal that this is part of the testing framework, and that code from this
+# should not normally appear in tracebacks.
+__unittest = True