summaryrefslogtreecommitdiff
path: root/lib/dnspython/dns/resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/dnspython/dns/resolver.py')
-rw-r--r--lib/dnspython/dns/resolver.py412
1 files changed, 400 insertions, 12 deletions
diff --git a/lib/dnspython/dns/resolver.py b/lib/dnspython/dns/resolver.py
index 30977f3a8b..90f95e8ed0 100644
--- a/lib/dnspython/dns/resolver.py
+++ b/lib/dnspython/dns/resolver.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2003-2007, 2009, 2010 Nominum, Inc.
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
@@ -23,12 +23,15 @@ import sys
import time
import dns.exception
+import dns.ipv4
+import dns.ipv6
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
+import dns.reversename
if sys.platform == 'win32':
import _winreg
@@ -93,8 +96,11 @@ class Answer(object):
@type rrset: dns.rrset.RRset object
@ivar expiration: The time when the answer expires
@type expiration: float (seconds since the epoch)
+ @ivar canonical_name: The canonical name of the query name
+ @type canonical_name: dns.name.Name object
"""
- def __init__(self, qname, rdtype, rdclass, response):
+ def __init__(self, qname, rdtype, rdclass, response,
+ raise_on_no_answer=True):
self.qname = qname
self.rdtype = rdtype
self.rdclass = rdclass
@@ -122,11 +128,31 @@ class Answer(object):
break
continue
except KeyError:
- raise NoAnswer
- raise NoAnswer
- if rrset is None:
+ if raise_on_no_answer:
+ raise NoAnswer
+ if raise_on_no_answer:
+ raise NoAnswer
+ if rrset is None and raise_on_no_answer:
raise NoAnswer
+ self.canonical_name = qname
self.rrset = rrset
+ if rrset is None:
+ while 1:
+ # Look for a SOA RR whose owner name is a superdomain
+ # of qname.
+ try:
+ srrset = response.find_rrset(response.authority, qname,
+ rdclass, dns.rdatatype.SOA)
+ if min_ttl == -1 or srrset.ttl < min_ttl:
+ min_ttl = srrset.ttl
+ if srrset[0].minimum < min_ttl:
+ min_ttl = srrset[0].minimum
+ break
+ except KeyError:
+ try:
+ qname = qname.parent()
+ except dns.name.NoParent:
+ break
self.expiration = time.time() + min_ttl
def __getattr__(self, attr):
@@ -244,6 +270,127 @@ class Cache(object):
self.data = {}
self.next_cleaning = time.time() + self.cleaning_interval
+class LRUCacheNode(object):
+ """LRUCache node.
+ """
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+ self.prev = self
+ self.next = self
+
+ def link_before(self, node):
+ self.prev = node.prev
+ self.next = node
+ node.prev.next = self
+ node.prev = self
+
+ def link_after(self, node):
+ self.prev = node
+ self.next = node.next
+ node.next.prev = self
+ node.next = self
+
+ def unlink(self):
+ self.next.prev = self.prev
+ self.prev.next = self.next
+
+class LRUCache(object):
+ """Bounded least-recently-used DNS answer cache.
+
+ This cache is better than the simple cache (above) if you're
+ running a web crawler or other process that does a lot of
+ resolutions. The LRUCache has a maximum number of nodes, and when
+ it is full, the least-recently used node is removed to make space
+ for a new one.
+
+ @ivar data: A dictionary of cached data
+ @type data: dict
+ @ivar sentinel: sentinel node for circular doubly linked list of nodes
+ @type sentinel: LRUCacheNode object
+ @ivar max_size: The maximum number of nodes
+ @type max_size: int
+ """
+
+ def __init__(self, max_size=100000):
+ """Initialize a DNS cache.
+
+ @param max_size: The maximum number of nodes to cache; the default is
+ 100000. Must be > 1.
+ @type max_size: int
+ """
+ self.data = {}
+ self.set_max_size(max_size)
+ self.sentinel = LRUCacheNode(None, None)
+
+ def set_max_size(self, max_size):
+ if max_size < 1:
+ max_size = 1
+ self.max_size = max_size
+
+ def get(self, key):
+ """Get the answer associated with I{key}. Returns None if
+ no answer is cached for the key.
+ @param key: the key
+ @type key: (dns.name.Name, int, int) tuple whose values are the
+ query name, rdtype, and rdclass.
+ @rtype: dns.resolver.Answer object or None
+ """
+ node = self.data.get(key)
+ if node is None:
+ return None
+ # Unlink because we're either going to move the node to the front
+ # of the LRU list or we're going to free it.
+ node.unlink()
+ if node.value.expiration <= time.time():
+ del self.data[node.key]
+ return None
+ node.link_after(self.sentinel)
+ return node.value
+
+ def put(self, key, value):
+ """Associate key and value in the cache.
+ @param key: the key
+ @type key: (dns.name.Name, int, int) tuple whose values are the
+ query name, rdtype, and rdclass.
+ @param value: The answer being cached
+ @type value: dns.resolver.Answer object
+ """
+ node = self.data.get(key)
+ if not node is None:
+ node.unlink()
+ del self.data[node.key]
+ while len(self.data) >= self.max_size:
+ node = self.sentinel.prev
+ node.unlink()
+ del self.data[node.key]
+ node = LRUCacheNode(key, value)
+ node.link_after(self.sentinel)
+ self.data[key] = node
+
+ def flush(self, key=None):
+ """Flush the cache.
+
+ If I{key} is specified, only that item is flushed. Otherwise
+ the entire cache is flushed.
+
+ @param key: the key to flush
+ @type key: (dns.name.Name, int, int) tuple or None
+ """
+ if not key is None:
+ node = self.data.get(key)
+ if not node is None:
+ node.unlink()
+ del self.data[node.key]
+ else:
+ node = self.sentinel.next
+ while node != self.sentinel:
+ next = node.next
+ node.prev = None
+ node.next = None
+ node = next
+ self.data = {}
+
class Resolver(object):
"""DNS stub resolver
@@ -546,7 +693,7 @@ class Resolver(object):
return min(self.lifetime - duration, self.timeout)
def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None):
+ tcp=False, source=None, raise_on_no_answer=True):
"""Query nameservers to find the answer to the question.
The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects
@@ -564,10 +711,14 @@ class Resolver(object):
@type tcp: bool
@param source: bind to this IP address (defaults to machine default IP).
@type source: IP address in dotted quad notation
+ @param raise_on_no_answer: raise NoAnswer if there's no answer
+ (defaults is True).
+ @type raise_on_no_answer: bool
@rtype: dns.resolver.Answer instance
@raises Timeout: no answers could be found in the specified lifetime
@raises NXDOMAIN: the query name does not exist
- @raises NoAnswer: the response did not contain an answer
+ @raises NoAnswer: the response did not contain an answer and
+ raise_on_no_answer is True.
@raises NoNameservers: no non-broken nameservers are available to
answer the question."""
@@ -597,8 +748,11 @@ class Resolver(object):
for qname in qnames_to_try:
if self.cache:
answer = self.cache.get((qname, rdtype, rdclass))
- if answer:
- return answer
+ if not answer is None:
+ if answer.rrset is None and raise_on_no_answer:
+ raise NoAnswer
+ else:
+ return answer
request = dns.message.make_query(qname, rdtype, rdclass)
if not self.keyname is None:
request.use_tsig(self.keyring, self.keyname,
@@ -678,7 +832,8 @@ class Resolver(object):
break
if all_nxdomain:
raise NXDOMAIN
- answer = Answer(qname, rdtype, rdclass, response)
+ answer = Answer(qname, rdtype, rdclass, response,
+ raise_on_no_answer)
if self.cache:
self.cache.put((qname, rdtype, rdclass), answer)
return answer
@@ -731,14 +886,15 @@ def get_default_resolver():
return default_resolver
def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None):
+ tcp=False, source=None, raise_on_no_answer=True):
"""Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver
object to make the query.
@see: L{dns.resolver.Resolver.query} for more information on the
parameters."""
- return get_default_resolver().query(qname, rdtype, rdclass, tcp, source)
+ return get_default_resolver().query(qname, rdtype, rdclass, tcp, source,
+ raise_on_no_answer)
def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
"""Find the name of the zone which contains the specified name.
@@ -771,3 +927,235 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
name = name.parent()
except dns.name.NoParent:
raise NoRootSOA
+
+#
+# Support for overriding the system resolver for all python code in the
+# running process.
+#
+
+_protocols_for_socktype = {
+ socket.SOCK_DGRAM : [socket.SOL_UDP],
+ socket.SOCK_STREAM : [socket.SOL_TCP],
+ }
+
+_resolver = None
+_original_getaddrinfo = socket.getaddrinfo
+_original_getnameinfo = socket.getnameinfo
+_original_getfqdn = socket.getfqdn
+_original_gethostbyname = socket.gethostbyname
+_original_gethostbyname_ex = socket.gethostbyname_ex
+_original_gethostbyaddr = socket.gethostbyaddr
+
+def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
+ proto=0, flags=0):
+ if flags & (socket.AI_ADDRCONFIG|socket.AI_V4MAPPED) != 0:
+ raise NotImplementedError
+ if host is None and service is None:
+ raise socket.gaierror(socket.EAI_NONAME)
+ v6addrs = []
+ v4addrs = []
+ canonical_name = None
+ try:
+ # Is host None or a V6 address literal?
+ if host is None:
+ canonical_name = 'localhost'
+ if flags & socket.AI_PASSIVE != 0:
+ v6addrs.append('::')
+ v4addrs.append('0.0.0.0')
+ else:
+ v6addrs.append('::1')
+ v4addrs.append('127.0.0.1')
+ else:
+ parts = host.split('%')
+ if len(parts) == 2:
+ ahost = parts[0]
+ else:
+ ahost = host
+ addr = dns.ipv6.inet_aton(ahost)
+ v6addrs.append(host)
+ canonical_name = host
+ except:
+ try:
+ # Is it a V4 address literal?
+ addr = dns.ipv4.inet_aton(host)
+ v4addrs.append(host)
+ canonical_name = host
+ except:
+ if flags & socket.AI_NUMERICHOST == 0:
+ try:
+ qname = None
+ if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+ v6 = _resolver.query(host, dns.rdatatype.AAAA,
+ raise_on_no_answer=False)
+ # Note that setting host ensures we query the same name
+ # for A as we did for AAAA.
+ host = v6.qname
+ canonical_name = v6.canonical_name.to_text(True)
+ if v6.rrset is not None:
+ for rdata in v6.rrset:
+ v6addrs.append(rdata.address)
+ if family == socket.AF_INET or family == socket.AF_UNSPEC:
+ v4 = _resolver.query(host, dns.rdatatype.A,
+ raise_on_no_answer=False)
+ host = v4.qname
+ canonical_name = v4.canonical_name.to_text(True)
+ if v4.rrset is not None:
+ for rdata in v4.rrset:
+ v4addrs.append(rdata.address)
+ except dns.resolver.NXDOMAIN:
+ raise socket.gaierror(socket.EAI_NONAME)
+ except:
+ raise socket.gaierror(socket.EAI_SYSTEM)
+ port = None
+ try:
+ # Is it a port literal?
+ if service is None:
+ port = 0
+ else:
+ port = int(service)
+ except:
+ if flags & socket.AI_NUMERICSERV == 0:
+ try:
+ port = socket.getservbyname(service)
+ except:
+ pass
+ if port is None:
+ raise socket.gaierror(socket.EAI_NONAME)
+ tuples = []
+ if socktype == 0:
+ socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM]
+ else:
+ socktypes = [socktype]
+ if flags & socket.AI_CANONNAME != 0:
+ cname = canonical_name
+ else:
+ cname = ''
+ if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+ for addr in v6addrs:
+ for socktype in socktypes:
+ for proto in _protocols_for_socktype[socktype]:
+ tuples.append((socket.AF_INET6, socktype, proto,
+ cname, (addr, port, 0, 0)))
+ if family == socket.AF_INET or family == socket.AF_UNSPEC:
+ for addr in v4addrs:
+ for socktype in socktypes:
+ for proto in _protocols_for_socktype[socktype]:
+ tuples.append((socket.AF_INET, socktype, proto,
+ cname, (addr, port)))
+ if len(tuples) == 0:
+ raise socket.gaierror(socket.EAI_NONAME)
+ return tuples
+
+def _getnameinfo(sockaddr, flags=0):
+ host = sockaddr[0]
+ port = sockaddr[1]
+ if len(sockaddr) == 4:
+ scope = sockaddr[3]
+ family = socket.AF_INET6
+ else:
+ scope = None
+ family = socket.AF_INET
+ tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM,
+ socket.SOL_TCP, 0)
+ if len(tuples) > 1:
+ raise socket.error('sockaddr resolved to multiple addresses')
+ addr = tuples[0][4][0]
+ if flags & socket.NI_DGRAM:
+ pname = 'udp'
+ else:
+ pname = 'tcp'
+ qname = dns.reversename.from_address(addr)
+ if flags & socket.NI_NUMERICHOST == 0:
+ try:
+ answer = _resolver.query(qname, 'PTR')
+ hostname = answer.rrset[0].target.to_text(True)
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ if flags & socket.NI_NAMEREQD:
+ raise socket.gaierror(socket.EAI_NONAME)
+ hostname = addr
+ if scope is not None:
+ hostname += '%' + str(scope)
+ else:
+ hostname = addr
+ if scope is not None:
+ hostname += '%' + str(scope)
+ if flags & socket.NI_NUMERICSERV:
+ service = str(port)
+ else:
+ service = socket.getservbyport(port, pname)
+ return (hostname, service)
+
+def _getfqdn(name=None):
+ if name is None:
+ name = socket.gethostname()
+ return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0]
+
+def _gethostbyname(name):
+ return _gethostbyname_ex(name)[2][0]
+
+def _gethostbyname_ex(name):
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM,
+ socket.SOL_TCP, socket.AI_CANONNAME)
+ canonical = tuples[0][3]
+ for item in tuples:
+ addresses.append(item[4][0])
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+def _gethostbyaddr(ip):
+ try:
+ addr = dns.ipv6.inet_aton(ip)
+ sockaddr = (ip, 80, 0, 0)
+ family = socket.AF_INET6
+ except:
+ sockaddr = (ip, 80)
+ family = socket.AF_INET
+ (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD)
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP,
+ socket.AI_CANONNAME)
+ canonical = tuples[0][3]
+ for item in tuples:
+ addresses.append(item[4][0])
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+def override_system_resolver(resolver=None):
+ """Override the system resolver routines in the socket module with
+ versions which use dnspython's resolver.
+
+ This can be useful in testing situations where you want to control
+ the resolution behavior of python code without having to change
+ the system's resolver settings (e.g. /etc/resolv.conf).
+
+ The resolver to use may be specified; if it's not, the default
+ resolver will be used.
+
+ @param resolver: the resolver to use
+ @type resolver: dns.resolver.Resolver object or None
+ """
+ if resolver is None:
+ resolver = get_default_resolver()
+ global _resolver
+ _resolver = resolver
+ socket.getaddrinfo = _getaddrinfo
+ socket.getnameinfo = _getnameinfo
+ socket.getfqdn = _getfqdn
+ socket.gethostbyname = _gethostbyname
+ socket.gethostbyname_ex = _gethostbyname_ex
+ socket.gethostbyaddr = _gethostbyaddr
+
+def restore_system_resolver():
+ """Undo the effects of override_system_resolver().
+ """
+ global _resolver
+ _resolver = None
+ socket.getaddrinfo = _original_getaddrinfo
+ socket.getnameinfo = _original_getnameinfo
+ socket.getfqdn = _original_getfqdn
+ socket.gethostbyname = _original_gethostbyname
+ socket.gethostbyname_ex = _original_gethostbyname_ex
+ socket.gethostbyaddr = _original_gethostbyaddr