summaryrefslogtreecommitdiff
path: root/lib/testtools/testtools/matchers/_dict.py
blob: ff05199e6c14b4a33422f43ed9c7a852988af991 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# Copyright (c) 2009-2012 testtools developers. See LICENSE for details.

__all__ = [
    'KeysEqual',
    ]

from ..helpers import (
    dict_subtract,
    filter_values,
    map_values,
    )
from ._higherorder import (
    AnnotatedMismatch,
    PrefixedMismatch,
    MismatchesAll,
    )
from ._impl import Matcher, Mismatch


def LabelledMismatches(mismatches, details=None):
    """A collection of mismatches, each labelled."""
    return MismatchesAll(
        (PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())),
        wrap=False)


class MatchesAllDict(Matcher):
    """Matches if all of the matchers it is created with match.

    A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
    mismatches with the key of the dictionary.
    """

    def __init__(self, matchers):
        super(MatchesAllDict, self).__init__()
        self.matchers = matchers

    def __str__(self):
        return 'MatchesAllDict(%s)' % (_format_matcher_dict(self.matchers),)

    def match(self, observed):
        mismatches = {}
        for label in self.matchers:
            mismatches[label] = self.matchers[label].match(observed)
        return _dict_to_mismatch(
            mismatches, result_mismatch=LabelledMismatches)


class DictMismatches(Mismatch):
    """A mismatch with a dict of child mismatches."""

    def __init__(self, mismatches, details=None):
        super(DictMismatches, self).__init__(None, details=details)
        self.mismatches = mismatches

    def describe(self):
        lines = ['{']
        lines.extend(
            ['  %r: %s,' % (key, mismatch.describe())
             for (key, mismatch) in sorted(self.mismatches.items())])
        lines.append('}')
        return '\n'.join(lines)


def _dict_to_mismatch(data, to_mismatch=None,
                      result_mismatch=DictMismatches):
    if to_mismatch:
        data = map_values(to_mismatch, data)
    mismatches = filter_values(bool, data)
    if mismatches:
        return result_mismatch(mismatches)


class _MatchCommonKeys(Matcher):
    """Match on keys in a dictionary.

    Given a dictionary where the values are matchers, this will look for
    common keys in the matched dictionary and match if and only if all common
    keys match the given matchers.

    Thus::

      >>> structure = {'a': Equals('x'), 'b': Equals('y')}
      >>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
      None
    """

    def __init__(self, dict_of_matchers):
        super(_MatchCommonKeys, self).__init__()
        self._matchers = dict_of_matchers

    def _compare_dicts(self, expected, observed):
        common_keys = set(expected.keys()) & set(observed.keys())
        mismatches = {}
        for key in common_keys:
            mismatch = expected[key].match(observed[key])
            if mismatch:
                mismatches[key] = mismatch
        return mismatches

    def match(self, observed):
        mismatches = self._compare_dicts(self._matchers, observed)
        if mismatches:
            return DictMismatches(mismatches)


class _SubDictOf(Matcher):
    """Matches if the matched dict only has keys that are in given dict."""

    def __init__(self, super_dict, format_value=repr):
        super(_SubDictOf, self).__init__()
        self.super_dict = super_dict
        self.format_value = format_value

    def match(self, observed):
        excess = dict_subtract(observed, self.super_dict)
        return _dict_to_mismatch(
            excess, lambda v: Mismatch(self.format_value(v)))


class _SuperDictOf(Matcher):
    """Matches if all of the keys in the given dict are in the matched dict.
    """

    def __init__(self, sub_dict, format_value=repr):
        super(_SuperDictOf, self).__init__()
        self.sub_dict = sub_dict
        self.format_value = format_value

    def match(self, super_dict):
        return _SubDictOf(super_dict, self.format_value).match(self.sub_dict)


def _format_matcher_dict(matchers):
    return '{%s}' % (
        ', '.join(sorted('%r: %s' % (k, v) for k, v in matchers.items())))


class _CombinedMatcher(Matcher):
    """Many matchers labelled and combined into one uber-matcher.

    Subclass this and then specify a dict of matcher factories that take a
    single 'expected' value and return a matcher.  The subclass will match
    only if all of the matchers made from factories match.

    Not **entirely** dissimilar from ``MatchesAll``.
    """

    matcher_factories = {}

    def __init__(self, expected):
        super(_CombinedMatcher, self).__init__()
        self._expected = expected

    def format_expected(self, expected):
        return repr(expected)

    def __str__(self):
        return '%s(%s)' % (
            self.__class__.__name__, self.format_expected(self._expected))

    def match(self, observed):
        matchers = dict(
            (k, v(self._expected)) for k, v in self.matcher_factories.items())
        return MatchesAllDict(matchers).match(observed)


class MatchesDict(_CombinedMatcher):
    """Match a dictionary exactly, by its keys.

    Specify a dictionary mapping keys (often strings) to matchers.  This is
    the 'expected' dict.  Any dictionary that matches this must have exactly
    the same keys, and the values must match the corresponding matchers in the
    expected dict.
    """

    matcher_factories = {
        'Extra': _SubDictOf,
        'Missing': lambda m: _SuperDictOf(m, format_value=str),
        'Differences': _MatchCommonKeys,
        }

    format_expected = lambda self, expected: _format_matcher_dict(expected)


class ContainsDict(_CombinedMatcher):
    """Match a dictionary for that contains a specified sub-dictionary.

    Specify a dictionary mapping keys (often strings) to matchers.  This is
    the 'expected' dict.  Any dictionary that matches this must have **at
    least** these keys, and the values must match the corresponding matchers
    in the expected dict.  Dictionaries that have more keys will also match.

    In other words, any matching dictionary must contain the dictionary given
    to the constructor.

    Does not check for strict sub-dictionary.  That is, equal dictionaries
    match.
    """

    matcher_factories = {
        'Missing': lambda m: _SuperDictOf(m, format_value=str),
        'Differences': _MatchCommonKeys,
        }

    format_expected = lambda self, expected: _format_matcher_dict(expected)


class ContainedByDict(_CombinedMatcher):
    """Match a dictionary for which this is a super-dictionary.

    Specify a dictionary mapping keys (often strings) to matchers.  This is
    the 'expected' dict.  Any dictionary that matches this must have **only**
    these keys, and the values must match the corresponding matchers in the
    expected dict.  Dictionaries that have fewer keys can also match.

    In other words, any matching dictionary must be contained by the
    dictionary given to the constructor.

    Does not check for strict super-dictionary.  That is, equal dictionaries
    match.
    """

    matcher_factories = {
        'Extra': _SubDictOf,
        'Differences': _MatchCommonKeys,
        }

    format_expected = lambda self, expected: _format_matcher_dict(expected)


class KeysEqual(Matcher):
    """Checks whether a dict has particular keys."""

    def __init__(self, *expected):
        """Create a `KeysEqual` Matcher.

        :param expected: The keys the dict is expected to have.  If a dict,
            then we use the keys of that dict, if a collection, we assume it
            is a collection of expected keys.
        """
        super(KeysEqual, self).__init__()
        try:
            self.expected = expected.keys()
        except AttributeError:
            self.expected = list(expected)

    def __str__(self):
        return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))

    def match(self, matchee):
        from ._basic import _BinaryMismatch, Equals
        expected = sorted(self.expected)
        matched = Equals(expected).match(sorted(matchee.keys()))
        if matched:
            return AnnotatedMismatch(
                'Keys not equal',
                _BinaryMismatch(expected, 'does not match', matchee))
        return None