diff options
Diffstat (limited to 'lib/dnspython/dns/query.py')
-rw-r--r-- | lib/dnspython/dns/query.py | 84 |
1 files changed, 74 insertions, 10 deletions
diff --git a/lib/dnspython/dns/query.py b/lib/dnspython/dns/query.py index c023b140af..9dc88a635c 100644 --- a/lib/dnspython/dns/query.py +++ b/lib/dnspython/dns/query.py @@ -45,7 +45,59 @@ def _compute_expiration(timeout): else: return time.time() + timeout -def _wait_for(ir, iw, ix, expiration): +def _poll_for(fd, readable, writable, error, timeout): + """ + @param fd: File descriptor (int). + @param readable: Whether to wait for readability (bool). + @param writable: Whether to wait for writability (bool). + @param expiration: Deadline timeout (expiration time, in seconds (float)). + + @return True on success, False on timeout + """ + event_mask = 0 + if readable: + event_mask |= select.POLLIN + if writable: + event_mask |= select.POLLOUT + if error: + event_mask |= select.POLLERR + + pollable = select.poll() + pollable.register(fd, event_mask) + + if timeout: + event_list = pollable.poll(long(timeout * 1000)) + else: + event_list = pollable.poll() + + return bool(event_list) + +def _select_for(fd, readable, writable, error, timeout): + """ + @param fd: File descriptor (int). + @param readable: Whether to wait for readability (bool). + @param writable: Whether to wait for writability (bool). + @param expiration: Deadline timeout (expiration time, in seconds (float)). + + @return True on success, False on timeout + """ + rset, wset, xset = [], [], [] + + if readable: + rset = [fd] + if writable: + wset = [fd] + if error: + xset = [fd] + + if timeout is None: + (rcount, wcount, xcount) = select.select(rset, wset, xset) + else: + (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) + + return bool((rcount or wcount or xcount)) + +def _wait_for(fd, readable, writable, error, expiration): done = False while not done: if expiration is None: @@ -55,22 +107,34 @@ def _wait_for(ir, iw, ix, expiration): if timeout <= 0.0: raise dns.exception.Timeout try: - if timeout is None: - (r, w, x) = select.select(ir, iw, ix) - else: - (r, w, x) = select.select(ir, iw, ix, timeout) + if not _polling_backend(fd, readable, writable, error, timeout): + raise dns.exception.Timeout except select.error, e: if e.args[0] != errno.EINTR: raise e done = True - if len(r) == 0 and len(w) == 0 and len(x) == 0: - raise dns.exception.Timeout + +def _set_polling_backend(fn): + """ + Internal API. Do not use. + """ + global _polling_backend + + _polling_backend = fn + +if hasattr(select, 'poll'): + # Prefer poll() on platforms that support it because it has no + # limits on the maximum value of a file descriptor (plus it will + # be more efficient for high values). + _polling_backend = _poll_for +else: + _polling_backend = _select_for def _wait_for_readable(s, expiration): - _wait_for([s], [], [s], expiration) + _wait_for(s, True, False, True, expiration) def _wait_for_writable(s, expiration): - _wait_for([], [s], [s], expiration) + _wait_for(s, False, True, True, expiration) def _addresses_equal(af, a1, a2): # Convert the first value of the tuple, which is a textual format @@ -310,7 +374,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if isinstance(zone, (str, unicode)): zone = dns.name.from_text(zone) - if isinstance(rdtype, str): + if isinstance(rdtype, (str, unicode)): rdtype = dns.rdatatype.from_text(rdtype) q = dns.message.make_query(zone, rdtype, rdclass) if rdtype == dns.rdatatype.IXFR: |