summaryrefslogtreecommitdiff
path: root/lib/testtools/testtools/matchers/_datastructures.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/testtools/testtools/matchers/_datastructures.py')
-rw-r--r--lib/testtools/testtools/matchers/_datastructures.py228
1 files changed, 228 insertions, 0 deletions
diff --git a/lib/testtools/testtools/matchers/_datastructures.py b/lib/testtools/testtools/matchers/_datastructures.py
new file mode 100644
index 0000000000..70de790738
--- /dev/null
+++ b/lib/testtools/testtools/matchers/_datastructures.py
@@ -0,0 +1,228 @@
+# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
+
+__all__ = [
+ 'ContainsAll',
+ 'MatchesListwise',
+ 'MatchesSetwise',
+ 'MatchesStructure',
+ ]
+
+"""Matchers that operate with knowledge of Python data structures."""
+
+from ..helpers import map_values
+from ._higherorder import (
+ Annotate,
+ MatchesAll,
+ MismatchesAll,
+ )
+from ._impl import Mismatch
+
+
+def ContainsAll(items):
+ """Make a matcher that checks whether a list of things is contained
+ in another thing.
+
+ The matcher effectively checks that the provided sequence is a subset of
+ the matchee.
+ """
+ from ._basic import Contains
+ return MatchesAll(*map(Contains, items), first_only=False)
+
+
+class MatchesListwise(object):
+ """Matches if each matcher matches the corresponding value.
+
+ More easily explained by example than in words:
+
+ >>> from ._basic import Equals
+ >>> MatchesListwise([Equals(1)]).match([1])
+ >>> MatchesListwise([Equals(1), Equals(2)]).match([1, 2])
+ >>> print (MatchesListwise([Equals(1), Equals(2)]).match([2, 1]).describe())
+ Differences: [
+ 1 != 2
+ 2 != 1
+ ]
+ >>> matcher = MatchesListwise([Equals(1), Equals(2)], first_only=True)
+ >>> print (matcher.match([3, 4]).describe())
+ 1 != 3
+ """
+
+ def __init__(self, matchers, first_only=False):
+ """Construct a MatchesListwise matcher.
+
+ :param matchers: A list of matcher that the matched values must match.
+ :param first_only: If True, then only report the first mismatch,
+ otherwise report all of them. Defaults to False.
+ """
+ self.matchers = matchers
+ self.first_only = first_only
+
+ def match(self, values):
+ from ._basic import Equals
+ mismatches = []
+ length_mismatch = Annotate(
+ "Length mismatch", Equals(len(self.matchers))).match(len(values))
+ if length_mismatch:
+ mismatches.append(length_mismatch)
+ for matcher, value in zip(self.matchers, values):
+ mismatch = matcher.match(value)
+ if mismatch:
+ if self.first_only:
+ return mismatch
+ mismatches.append(mismatch)
+ if mismatches:
+ return MismatchesAll(mismatches)
+
+
+class MatchesStructure(object):
+ """Matcher that matches an object structurally.
+
+ 'Structurally' here means that attributes of the object being matched are
+ compared against given matchers.
+
+ `fromExample` allows the creation of a matcher from a prototype object and
+ then modified versions can be created with `update`.
+
+ `byEquality` creates a matcher in much the same way as the constructor,
+ except that the matcher for each of the attributes is assumed to be
+ `Equals`.
+
+ `byMatcher` creates a similar matcher to `byEquality`, but you get to pick
+ the matcher, rather than just using `Equals`.
+ """
+
+ def __init__(self, **kwargs):
+ """Construct a `MatchesStructure`.
+
+ :param kwargs: A mapping of attributes to matchers.
+ """
+ self.kws = kwargs
+
+ @classmethod
+ def byEquality(cls, **kwargs):
+ """Matches an object where the attributes equal the keyword values.
+
+ Similar to the constructor, except that the matcher is assumed to be
+ Equals.
+ """
+ from ._basic import Equals
+ return cls.byMatcher(Equals, **kwargs)
+
+ @classmethod
+ def byMatcher(cls, matcher, **kwargs):
+ """Matches an object where the attributes match the keyword values.
+
+ Similar to the constructor, except that the provided matcher is used
+ to match all of the values.
+ """
+ return cls(**map_values(matcher, kwargs))
+
+ @classmethod
+ def fromExample(cls, example, *attributes):
+ from ._basic import Equals
+ kwargs = {}
+ for attr in attributes:
+ kwargs[attr] = Equals(getattr(example, attr))
+ return cls(**kwargs)
+
+ def update(self, **kws):
+ new_kws = self.kws.copy()
+ for attr, matcher in kws.items():
+ if matcher is None:
+ new_kws.pop(attr, None)
+ else:
+ new_kws[attr] = matcher
+ return type(self)(**new_kws)
+
+ def __str__(self):
+ kws = []
+ for attr, matcher in sorted(self.kws.items()):
+ kws.append("%s=%s" % (attr, matcher))
+ return "%s(%s)" % (self.__class__.__name__, ', '.join(kws))
+
+ def match(self, value):
+ matchers = []
+ values = []
+ for attr, matcher in sorted(self.kws.items()):
+ matchers.append(Annotate(attr, matcher))
+ values.append(getattr(value, attr))
+ return MatchesListwise(matchers).match(values)
+
+
+class MatchesSetwise(object):
+ """Matches if all the matchers match elements of the value being matched.
+
+ That is, each element in the 'observed' set must match exactly one matcher
+ from the set of matchers, with no matchers left over.
+
+ The difference compared to `MatchesListwise` is that the order of the
+ matchings does not matter.
+ """
+
+ def __init__(self, *matchers):
+ self.matchers = matchers
+
+ def match(self, observed):
+ remaining_matchers = set(self.matchers)
+ not_matched = []
+ for value in observed:
+ for matcher in remaining_matchers:
+ if matcher.match(value) is None:
+ remaining_matchers.remove(matcher)
+ break
+ else:
+ not_matched.append(value)
+ if not_matched or remaining_matchers:
+ remaining_matchers = list(remaining_matchers)
+ # There are various cases that all should be reported somewhat
+ # differently.
+
+ # There are two trivial cases:
+ # 1) There are just some matchers left over.
+ # 2) There are just some values left over.
+
+ # Then there are three more interesting cases:
+ # 3) There are the same number of matchers and values left over.
+ # 4) There are more matchers left over than values.
+ # 5) There are more values left over than matchers.
+
+ if len(not_matched) == 0:
+ if len(remaining_matchers) > 1:
+ msg = "There were %s matchers left over: " % (
+ len(remaining_matchers),)
+ else:
+ msg = "There was 1 matcher left over: "
+ msg += ', '.join(map(str, remaining_matchers))
+ return Mismatch(msg)
+ elif len(remaining_matchers) == 0:
+ if len(not_matched) > 1:
+ return Mismatch(
+ "There were %s values left over: %s" % (
+ len(not_matched), not_matched))
+ else:
+ return Mismatch(
+ "There was 1 value left over: %s" % (
+ not_matched, ))
+ else:
+ common_length = min(len(remaining_matchers), len(not_matched))
+ if common_length == 0:
+ raise AssertionError("common_length can't be 0 here")
+ if common_length > 1:
+ msg = "There were %s mismatches" % (common_length,)
+ else:
+ msg = "There was 1 mismatch"
+ if len(remaining_matchers) > len(not_matched):
+ extra_matchers = remaining_matchers[common_length:]
+ msg += " and %s extra matcher" % (len(extra_matchers), )
+ if len(extra_matchers) > 1:
+ msg += "s"
+ msg += ': ' + ', '.join(map(str, extra_matchers))
+ elif len(not_matched) > len(remaining_matchers):
+ extra_values = not_matched[common_length:]
+ msg += " and %s extra value" % (len(extra_values), )
+ if len(extra_values) > 1:
+ msg += "s"
+ msg += ': ' + str(extra_values)
+ return Annotate(
+ msg, MatchesListwise(remaining_matchers[:common_length])
+ ).match(not_matched[:common_length])