summaryrefslogtreecommitdiff
path: root/lib/testtools/testtools/matchers/_higherorder.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/testtools/testtools/matchers/_higherorder.py')
-rw-r--r--lib/testtools/testtools/matchers/_higherorder.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/lib/testtools/testtools/matchers/_higherorder.py b/lib/testtools/testtools/matchers/_higherorder.py
new file mode 100644
index 0000000000..c31c525d6a
--- /dev/null
+++ b/lib/testtools/testtools/matchers/_higherorder.py
@@ -0,0 +1,269 @@
+# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
+
+__all__ = [
+ 'AfterPreprocessing',
+ 'AllMatch',
+ 'Annotate',
+ 'MatchesAny',
+ 'MatchesAll',
+ 'Not',
+ ]
+
+import types
+
+from ._impl import (
+ Matcher,
+ Mismatch,
+ MismatchDecorator,
+ )
+
+
+class MatchesAny(object):
+ """Matches if any of the matchers it is created with match."""
+
+ def __init__(self, *matchers):
+ self.matchers = matchers
+
+ def match(self, matchee):
+ results = []
+ for matcher in self.matchers:
+ mismatch = matcher.match(matchee)
+ if mismatch is None:
+ return None
+ results.append(mismatch)
+ return MismatchesAll(results)
+
+ def __str__(self):
+ return "MatchesAny(%s)" % ', '.join([
+ str(matcher) for matcher in self.matchers])
+
+
+class MatchesAll(object):
+ """Matches if all of the matchers it is created with match."""
+
+ def __init__(self, *matchers, **options):
+ """Construct a MatchesAll matcher.
+
+ Just list the component matchers as arguments in the ``*args``
+ style. If you want only the first mismatch to be reported, past in
+ first_only=True as a keyword argument. By default, all mismatches are
+ reported.
+ """
+ self.matchers = matchers
+ self.first_only = options.get('first_only', False)
+
+ def __str__(self):
+ return 'MatchesAll(%s)' % ', '.join(map(str, self.matchers))
+
+ def match(self, matchee):
+ results = []
+ for matcher in self.matchers:
+ mismatch = matcher.match(matchee)
+ if mismatch is not None:
+ if self.first_only:
+ return mismatch
+ results.append(mismatch)
+ if results:
+ return MismatchesAll(results)
+ else:
+ return None
+
+
+class MismatchesAll(Mismatch):
+ """A mismatch with many child mismatches."""
+
+ def __init__(self, mismatches, wrap=True):
+ self.mismatches = mismatches
+ self._wrap = wrap
+
+ def describe(self):
+ descriptions = []
+ if self._wrap:
+ descriptions = ["Differences: ["]
+ for mismatch in self.mismatches:
+ descriptions.append(mismatch.describe())
+ if self._wrap:
+ descriptions.append("]")
+ return '\n'.join(descriptions)
+
+
+class Not(object):
+ """Inverts a matcher."""
+
+ def __init__(self, matcher):
+ self.matcher = matcher
+
+ def __str__(self):
+ return 'Not(%s)' % (self.matcher,)
+
+ def match(self, other):
+ mismatch = self.matcher.match(other)
+ if mismatch is None:
+ return MatchedUnexpectedly(self.matcher, other)
+ else:
+ return None
+
+
+class MatchedUnexpectedly(Mismatch):
+ """A thing matched when it wasn't supposed to."""
+
+ def __init__(self, matcher, other):
+ self.matcher = matcher
+ self.other = other
+
+ def describe(self):
+ return "%r matches %s" % (self.other, self.matcher)
+
+
+class Annotate(object):
+ """Annotates a matcher with a descriptive string.
+
+ Mismatches are then described as '<mismatch>: <annotation>'.
+ """
+
+ def __init__(self, annotation, matcher):
+ self.annotation = annotation
+ self.matcher = matcher
+
+ @classmethod
+ def if_message(cls, annotation, matcher):
+ """Annotate ``matcher`` only if ``annotation`` is non-empty."""
+ if not annotation:
+ return matcher
+ return cls(annotation, matcher)
+
+ def __str__(self):
+ return 'Annotate(%r, %s)' % (self.annotation, self.matcher)
+
+ def match(self, other):
+ mismatch = self.matcher.match(other)
+ if mismatch is not None:
+ return AnnotatedMismatch(self.annotation, mismatch)
+
+
+class PostfixedMismatch(MismatchDecorator):
+ """A mismatch annotated with a descriptive string."""
+
+ def __init__(self, annotation, mismatch):
+ super(PostfixedMismatch, self).__init__(mismatch)
+ self.annotation = annotation
+ self.mismatch = mismatch
+
+ def describe(self):
+ return '%s: %s' % (self.original.describe(), self.annotation)
+
+
+AnnotatedMismatch = PostfixedMismatch
+
+
+class PrefixedMismatch(MismatchDecorator):
+
+ def __init__(self, prefix, mismatch):
+ super(PrefixedMismatch, self).__init__(mismatch)
+ self.prefix = prefix
+
+ def describe(self):
+ return '%s: %s' % (self.prefix, self.original.describe())
+
+
+class AfterPreprocessing(object):
+ """Matches if the value matches after passing through a function.
+
+ This can be used to aid in creating trivial matchers as functions, for
+ example::
+
+ def PathHasFileContent(content):
+ def _read(path):
+ return open(path).read()
+ return AfterPreprocessing(_read, Equals(content))
+ """
+
+ def __init__(self, preprocessor, matcher, annotate=True):
+ """Create an AfterPreprocessing matcher.
+
+ :param preprocessor: A function called with the matchee before
+ matching.
+ :param matcher: What to match the preprocessed matchee against.
+ :param annotate: Whether or not to annotate the matcher with
+ something explaining how we transformed the matchee. Defaults
+ to True.
+ """
+ self.preprocessor = preprocessor
+ self.matcher = matcher
+ self.annotate = annotate
+
+ def _str_preprocessor(self):
+ if isinstance(self.preprocessor, types.FunctionType):
+ return '<function %s>' % self.preprocessor.__name__
+ return str(self.preprocessor)
+
+ def __str__(self):
+ return "AfterPreprocessing(%s, %s)" % (
+ self._str_preprocessor(), self.matcher)
+
+ def match(self, value):
+ after = self.preprocessor(value)
+ if self.annotate:
+ matcher = Annotate(
+ "after %s on %r" % (self._str_preprocessor(), value),
+ self.matcher)
+ else:
+ matcher = self.matcher
+ return matcher.match(after)
+
+
+# This is the old, deprecated. spelling of the name, kept for backwards
+# compatibility.
+AfterPreproccessing = AfterPreprocessing
+
+
+class AllMatch(object):
+ """Matches if all provided values match the given matcher."""
+
+ def __init__(self, matcher):
+ self.matcher = matcher
+
+ def __str__(self):
+ return 'AllMatch(%s)' % (self.matcher,)
+
+ def match(self, values):
+ mismatches = []
+ for value in values:
+ mismatch = self.matcher.match(value)
+ if mismatch:
+ mismatches.append(mismatch)
+ if mismatches:
+ return MismatchesAll(mismatches)
+
+
+class MatchesPredicate(Matcher):
+ """Match if a given function returns True.
+
+ It is reasonably common to want to make a very simple matcher based on a
+ function that you already have that returns True or False given a single
+ argument (i.e. a predicate function). This matcher makes it very easy to
+ do so. e.g.::
+
+ IsEven = MatchesPredicate(lambda x: x % 2 == 0, '%s is not even')
+ self.assertThat(4, IsEven)
+ """
+
+ def __init__(self, predicate, message):
+ """Create a ``MatchesPredicate`` matcher.
+
+ :param predicate: A function that takes a single argument and returns
+ a value that will be interpreted as a boolean.
+ :param message: A message to describe a mismatch. It will be formatted
+ with '%' and be given whatever was passed to ``match()``. Thus, it
+ needs to contain exactly one thing like '%s', '%d' or '%f'.
+ """
+ self.predicate = predicate
+ self.message = message
+
+ def __str__(self):
+ return '%s(%r, %r)' % (
+ self.__class__.__name__, self.predicate, self.message)
+
+ def match(self, x):
+ if not self.predicate(x):
+ return Mismatch(self.message % x)