diff options
Diffstat (limited to 'lib/dnspython/dns/resolver.py')
-rw-r--r-- | lib/dnspython/dns/resolver.py | 412 |
1 files changed, 400 insertions, 12 deletions
diff --git a/lib/dnspython/dns/resolver.py b/lib/dnspython/dns/resolver.py index 30977f3a8b..90f95e8ed0 100644 --- a/lib/dnspython/dns/resolver.py +++ b/lib/dnspython/dns/resolver.py @@ -1,4 +1,4 @@ -# Copyright (C) 2003-2007, 2009, 2010 Nominum, Inc. +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, @@ -23,12 +23,15 @@ import sys import time import dns.exception +import dns.ipv4 +import dns.ipv6 import dns.message import dns.name import dns.query import dns.rcode import dns.rdataclass import dns.rdatatype +import dns.reversename if sys.platform == 'win32': import _winreg @@ -93,8 +96,11 @@ class Answer(object): @type rrset: dns.rrset.RRset object @ivar expiration: The time when the answer expires @type expiration: float (seconds since the epoch) + @ivar canonical_name: The canonical name of the query name + @type canonical_name: dns.name.Name object """ - def __init__(self, qname, rdtype, rdclass, response): + def __init__(self, qname, rdtype, rdclass, response, + raise_on_no_answer=True): self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -122,11 +128,31 @@ class Answer(object): break continue except KeyError: - raise NoAnswer - raise NoAnswer - if rrset is None: + if raise_on_no_answer: + raise NoAnswer + if raise_on_no_answer: + raise NoAnswer + if rrset is None and raise_on_no_answer: raise NoAnswer + self.canonical_name = qname self.rrset = rrset + if rrset is None: + while 1: + # Look for a SOA RR whose owner name is a superdomain + # of qname. + try: + srrset = response.find_rrset(response.authority, qname, + rdclass, dns.rdatatype.SOA) + if min_ttl == -1 or srrset.ttl < min_ttl: + min_ttl = srrset.ttl + if srrset[0].minimum < min_ttl: + min_ttl = srrset[0].minimum + break + except KeyError: + try: + qname = qname.parent() + except dns.name.NoParent: + break self.expiration = time.time() + min_ttl def __getattr__(self, attr): @@ -244,6 +270,127 @@ class Cache(object): self.data = {} self.next_cleaning = time.time() + self.cleaning_interval +class LRUCacheNode(object): + """LRUCache node. + """ + def __init__(self, key, value): + self.key = key + self.value = value + self.prev = self + self.next = self + + def link_before(self, node): + self.prev = node.prev + self.next = node + node.prev.next = self + node.prev = self + + def link_after(self, node): + self.prev = node + self.next = node.next + node.next.prev = self + node.next = self + + def unlink(self): + self.next.prev = self.prev + self.prev.next = self.next + +class LRUCache(object): + """Bounded least-recently-used DNS answer cache. + + This cache is better than the simple cache (above) if you're + running a web crawler or other process that does a lot of + resolutions. The LRUCache has a maximum number of nodes, and when + it is full, the least-recently used node is removed to make space + for a new one. + + @ivar data: A dictionary of cached data + @type data: dict + @ivar sentinel: sentinel node for circular doubly linked list of nodes + @type sentinel: LRUCacheNode object + @ivar max_size: The maximum number of nodes + @type max_size: int + """ + + def __init__(self, max_size=100000): + """Initialize a DNS cache. + + @param max_size: The maximum number of nodes to cache; the default is + 100000. Must be > 1. + @type max_size: int + """ + self.data = {} + self.set_max_size(max_size) + self.sentinel = LRUCacheNode(None, None) + + def set_max_size(self, max_size): + if max_size < 1: + max_size = 1 + self.max_size = max_size + + def get(self, key): + """Get the answer associated with I{key}. Returns None if + no answer is cached for the key. + @param key: the key + @type key: (dns.name.Name, int, int) tuple whose values are the + query name, rdtype, and rdclass. + @rtype: dns.resolver.Answer object or None + """ + node = self.data.get(key) + if node is None: + return None + # Unlink because we're either going to move the node to the front + # of the LRU list or we're going to free it. + node.unlink() + if node.value.expiration <= time.time(): + del self.data[node.key] + return None + node.link_after(self.sentinel) + return node.value + + def put(self, key, value): + """Associate key and value in the cache. + @param key: the key + @type key: (dns.name.Name, int, int) tuple whose values are the + query name, rdtype, and rdclass. + @param value: The answer being cached + @type value: dns.resolver.Answer object + """ + node = self.data.get(key) + if not node is None: + node.unlink() + del self.data[node.key] + while len(self.data) >= self.max_size: + node = self.sentinel.prev + node.unlink() + del self.data[node.key] + node = LRUCacheNode(key, value) + node.link_after(self.sentinel) + self.data[key] = node + + def flush(self, key=None): + """Flush the cache. + + If I{key} is specified, only that item is flushed. Otherwise + the entire cache is flushed. + + @param key: the key to flush + @type key: (dns.name.Name, int, int) tuple or None + """ + if not key is None: + node = self.data.get(key) + if not node is None: + node.unlink() + del self.data[node.key] + else: + node = self.sentinel.next + while node != self.sentinel: + next = node.next + node.prev = None + node.next = None + node = next + self.data = {} + class Resolver(object): """DNS stub resolver @@ -546,7 +693,7 @@ class Resolver(object): return min(self.lifetime - duration, self.timeout) def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None): + tcp=False, source=None, raise_on_no_answer=True): """Query nameservers to find the answer to the question. The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects @@ -564,10 +711,14 @@ class Resolver(object): @type tcp: bool @param source: bind to this IP address (defaults to machine default IP). @type source: IP address in dotted quad notation + @param raise_on_no_answer: raise NoAnswer if there's no answer + (defaults is True). + @type raise_on_no_answer: bool @rtype: dns.resolver.Answer instance @raises Timeout: no answers could be found in the specified lifetime @raises NXDOMAIN: the query name does not exist - @raises NoAnswer: the response did not contain an answer + @raises NoAnswer: the response did not contain an answer and + raise_on_no_answer is True. @raises NoNameservers: no non-broken nameservers are available to answer the question.""" @@ -597,8 +748,11 @@ class Resolver(object): for qname in qnames_to_try: if self.cache: answer = self.cache.get((qname, rdtype, rdclass)) - if answer: - return answer + if not answer is None: + if answer.rrset is None and raise_on_no_answer: + raise NoAnswer + else: + return answer request = dns.message.make_query(qname, rdtype, rdclass) if not self.keyname is None: request.use_tsig(self.keyring, self.keyname, @@ -678,7 +832,8 @@ class Resolver(object): break if all_nxdomain: raise NXDOMAIN - answer = Answer(qname, rdtype, rdclass, response) + answer = Answer(qname, rdtype, rdclass, response, + raise_on_no_answer) if self.cache: self.cache.put((qname, rdtype, rdclass), answer) return answer @@ -731,14 +886,15 @@ def get_default_resolver(): return default_resolver def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None): + tcp=False, source=None, raise_on_no_answer=True): """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver object to make the query. @see: L{dns.resolver.Resolver.query} for more information on the parameters.""" - return get_default_resolver().query(qname, rdtype, rdclass, tcp, source) + return get_default_resolver().query(qname, rdtype, rdclass, tcp, source, + raise_on_no_answer) def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): """Find the name of the zone which contains the specified name. @@ -771,3 +927,235 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): name = name.parent() except dns.name.NoParent: raise NoRootSOA + +# +# Support for overriding the system resolver for all python code in the +# running process. +# + +_protocols_for_socktype = { + socket.SOCK_DGRAM : [socket.SOL_UDP], + socket.SOCK_STREAM : [socket.SOL_TCP], + } + +_resolver = None +_original_getaddrinfo = socket.getaddrinfo +_original_getnameinfo = socket.getnameinfo +_original_getfqdn = socket.getfqdn +_original_gethostbyname = socket.gethostbyname +_original_gethostbyname_ex = socket.gethostbyname_ex +_original_gethostbyaddr = socket.gethostbyaddr + +def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0, + proto=0, flags=0): + if flags & (socket.AI_ADDRCONFIG|socket.AI_V4MAPPED) != 0: + raise NotImplementedError + if host is None and service is None: + raise socket.gaierror(socket.EAI_NONAME) + v6addrs = [] + v4addrs = [] + canonical_name = None + try: + # Is host None or a V6 address literal? + if host is None: + canonical_name = 'localhost' + if flags & socket.AI_PASSIVE != 0: + v6addrs.append('::') + v4addrs.append('0.0.0.0') + else: + v6addrs.append('::1') + v4addrs.append('127.0.0.1') + else: + parts = host.split('%') + if len(parts) == 2: + ahost = parts[0] + else: + ahost = host + addr = dns.ipv6.inet_aton(ahost) + v6addrs.append(host) + canonical_name = host + except: + try: + # Is it a V4 address literal? + addr = dns.ipv4.inet_aton(host) + v4addrs.append(host) + canonical_name = host + except: + if flags & socket.AI_NUMERICHOST == 0: + try: + qname = None + if family == socket.AF_INET6 or family == socket.AF_UNSPEC: + v6 = _resolver.query(host, dns.rdatatype.AAAA, + raise_on_no_answer=False) + # Note that setting host ensures we query the same name + # for A as we did for AAAA. + host = v6.qname + canonical_name = v6.canonical_name.to_text(True) + if v6.rrset is not None: + for rdata in v6.rrset: + v6addrs.append(rdata.address) + if family == socket.AF_INET or family == socket.AF_UNSPEC: + v4 = _resolver.query(host, dns.rdatatype.A, + raise_on_no_answer=False) + host = v4.qname + canonical_name = v4.canonical_name.to_text(True) + if v4.rrset is not None: + for rdata in v4.rrset: + v4addrs.append(rdata.address) + except dns.resolver.NXDOMAIN: + raise socket.gaierror(socket.EAI_NONAME) + except: + raise socket.gaierror(socket.EAI_SYSTEM) + port = None + try: + # Is it a port literal? + if service is None: + port = 0 + else: + port = int(service) + except: + if flags & socket.AI_NUMERICSERV == 0: + try: + port = socket.getservbyname(service) + except: + pass + if port is None: + raise socket.gaierror(socket.EAI_NONAME) + tuples = [] + if socktype == 0: + socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM] + else: + socktypes = [socktype] + if flags & socket.AI_CANONNAME != 0: + cname = canonical_name + else: + cname = '' + if family == socket.AF_INET6 or family == socket.AF_UNSPEC: + for addr in v6addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + tuples.append((socket.AF_INET6, socktype, proto, + cname, (addr, port, 0, 0))) + if family == socket.AF_INET or family == socket.AF_UNSPEC: + for addr in v4addrs: + for socktype in socktypes: + for proto in _protocols_for_socktype[socktype]: + tuples.append((socket.AF_INET, socktype, proto, + cname, (addr, port))) + if len(tuples) == 0: + raise socket.gaierror(socket.EAI_NONAME) + return tuples + +def _getnameinfo(sockaddr, flags=0): + host = sockaddr[0] + port = sockaddr[1] + if len(sockaddr) == 4: + scope = sockaddr[3] + family = socket.AF_INET6 + else: + scope = None + family = socket.AF_INET + tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM, + socket.SOL_TCP, 0) + if len(tuples) > 1: + raise socket.error('sockaddr resolved to multiple addresses') + addr = tuples[0][4][0] + if flags & socket.NI_DGRAM: + pname = 'udp' + else: + pname = 'tcp' + qname = dns.reversename.from_address(addr) + if flags & socket.NI_NUMERICHOST == 0: + try: + answer = _resolver.query(qname, 'PTR') + hostname = answer.rrset[0].target.to_text(True) + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + if flags & socket.NI_NAMEREQD: + raise socket.gaierror(socket.EAI_NONAME) + hostname = addr + if scope is not None: + hostname += '%' + str(scope) + else: + hostname = addr + if scope is not None: + hostname += '%' + str(scope) + if flags & socket.NI_NUMERICSERV: + service = str(port) + else: + service = socket.getservbyport(port, pname) + return (hostname, service) + +def _getfqdn(name=None): + if name is None: + name = socket.gethostname() + return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0] + +def _gethostbyname(name): + return _gethostbyname_ex(name)[2][0] + +def _gethostbyname_ex(name): + aliases = [] + addresses = [] + tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, socket.AI_CANONNAME) + canonical = tuples[0][3] + for item in tuples: + addresses.append(item[4][0]) + # XXX we just ignore aliases + return (canonical, aliases, addresses) + +def _gethostbyaddr(ip): + try: + addr = dns.ipv6.inet_aton(ip) + sockaddr = (ip, 80, 0, 0) + family = socket.AF_INET6 + except: + sockaddr = (ip, 80) + family = socket.AF_INET + (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD) + aliases = [] + addresses = [] + tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP, + socket.AI_CANONNAME) + canonical = tuples[0][3] + for item in tuples: + addresses.append(item[4][0]) + # XXX we just ignore aliases + return (canonical, aliases, addresses) + +def override_system_resolver(resolver=None): + """Override the system resolver routines in the socket module with + versions which use dnspython's resolver. + + This can be useful in testing situations where you want to control + the resolution behavior of python code without having to change + the system's resolver settings (e.g. /etc/resolv.conf). + + The resolver to use may be specified; if it's not, the default + resolver will be used. + + @param resolver: the resolver to use + @type resolver: dns.resolver.Resolver object or None + """ + if resolver is None: + resolver = get_default_resolver() + global _resolver + _resolver = resolver + socket.getaddrinfo = _getaddrinfo + socket.getnameinfo = _getnameinfo + socket.getfqdn = _getfqdn + socket.gethostbyname = _gethostbyname + socket.gethostbyname_ex = _gethostbyname_ex + socket.gethostbyaddr = _gethostbyaddr + +def restore_system_resolver(): + """Undo the effects of override_system_resolver(). + """ + global _resolver + _resolver = None + socket.getaddrinfo = _original_getaddrinfo + socket.getnameinfo = _original_getnameinfo + socket.getfqdn = _original_getfqdn + socket.gethostbyname = _original_gethostbyname + socket.gethostbyname_ex = _original_gethostbyname_ex + socket.gethostbyaddr = _original_gethostbyaddr |