# -*- Mode: Python; tab-width: 4 -*- # # Author: Sam Rushing # RCS_ID = '$Id: resolver.py,v 1.10 2002/03/21 15:48:53 htrd Exp $' # Fast, low-overhead asynchronous name resolver. uses 'pre-cooked' # DNS requests, unpacks only as much as it needs of the reply. # see rfc1035 for details import string import asyncore import socket import sys import time from counter import counter if RCS_ID.startswith('$Id: '): VERSION = string.split(RCS_ID)[2] else: VERSION = '0.0' # header # 1 1 1 1 1 1 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | ID | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # |QR| Opcode |AA|TC|RD|RA| Z | RCODE | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | QDCOUNT | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | ANCOUNT | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | NSCOUNT | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | ARCOUNT | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # question # 1 1 1 1 1 1 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | | # / QNAME / # / / # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | QTYPE | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | QCLASS | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # build a DNS address request, _quickly_ def fast_address_request (host, id=0): return ( '%c%c' % (chr((id>>8)&0xff),chr(id&0xff)) + '\001\000\000\001\000\000\000\000\000\000%s\000\000\001\000\001' % ( string.join ( map ( lambda part: '%c%s' % (chr(len(part)),part), string.split (host, '.') ), '' ) ) ) def fast_ptr_request (host, id=0): return ( '%c%c' % (chr((id>>8)&0xff),chr(id&0xff)) + '\001\000\000\001\000\000\000\000\000\000%s\000\000\014\000\001' % ( string.join ( map ( lambda part: '%c%s' % (chr(len(part)),part), string.split (host, '.') ), '' ) ) ) def unpack_name (r,pos): n = [] while 1: ll = ord(r[pos]) if (ll&0xc0): # compression pos = (ll&0x3f << 8) + (ord(r[pos+1])) elif ll == 0: break else: pos = pos + 1 n.append (r[pos:pos+ll]) pos = pos + ll return string.join (n,'.') def skip_name (r,pos): s = pos while 1: ll = ord(r[pos]) if (ll&0xc0): # compression return pos + 2 elif ll == 0: pos = pos + 1 break else: pos = pos + ll + 1 return pos def unpack_ttl (r,pos): return reduce ( lambda x,y: (x<<8)|y, map (ord, r[pos:pos+4]) ) # resource record # 1 1 1 1 1 1 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | | # / / # / NAME / # | | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | TYPE | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | CLASS | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | TTL | # | | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ # | RDLENGTH | # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| # / RDATA / # / / # +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ def unpack_address_reply (r): ancount = (ord(r[6])<<8) + (ord(r[7])) # skip question, first name starts at 12, # this is followed by QTYPE and QCLASS pos = skip_name (r, 12) + 4 if ancount: # we are looking very specifically for # an answer with TYPE=A, CLASS=IN (\000\001\000\001) for an in range(ancount): pos = skip_name (r, pos) if r[pos:pos+4] == '\000\001\000\001': return ( unpack_ttl (r,pos+4), '%d.%d.%d.%d' % tuple(map(ord,r[pos+10:pos+14])) ) # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA pos = pos + 8 rdlength = (ord(r[pos])<<8) + (ord(r[pos+1])) pos = pos + 2 + rdlength return 0, None else: return 0, None def unpack_ptr_reply (r): ancount = (ord(r[6])<<8) + (ord(r[7])) # skip question, first name starts at 12, # this is followed by QTYPE and QCLASS pos = skip_name (r, 12) + 4 if ancount: # we are looking very specifically for # an answer with TYPE=PTR, CLASS=IN (\000\014\000\001) for an in range(ancount): pos = skip_name (r, pos) if r[pos:pos+4] == '\000\014\000\001': return ( unpack_ttl (r,pos+4), unpack_name (r, pos+10) ) # skip over TYPE, CLASS, TTL, RDLENGTH, RDATA pos = pos + 8 rdlength = (ord(r[pos])<<8) + (ord(r[pos+1])) pos = pos + 2 + rdlength return 0, None else: return 0, None # This is a UDP (datagram) resolver. # # It may be useful to implement a TCP resolver. This would presumably # give us more reliable behavior when things get too busy. A TCP # client would have to manage the connection carefully, since the # server is allowed to close it at will (the RFC recommends closing # after 2 minutes of idle time). # # Note also that the TCP client will have to prepend each request # with a 2-byte length indicator (see rfc1035). # class resolver (asyncore.dispatcher): id = counter() def __init__ (self, server='127.0.0.1'): asyncore.dispatcher.__init__ (self) self.create_socket (socket.AF_INET, socket.SOCK_DGRAM) self.server = server self.request_map = {} self.last_reap_time = int(time.time()) # reap every few minutes def writable (self): return 0 def log (self, *args): pass def handle_close (self): self.log_info('closing!') self.close() def handle_error (self): # don't close the connection on error (file,fun,line), t, v, tbinfo = asyncore.compact_traceback() self.log_info( 'Problem with DNS lookup (%s:%s %s)' % (t, v, tbinfo), 'error') def get_id (self): return (self.id.as_long() % (1<<16)) def reap (self): # find DNS requests that have timed out now = int(time.time()) if now - self.last_reap_time > 180: # reap every 3 minutes self.last_reap_time = now # update before we forget for k,(host,unpack,callback,when) in self.request_map.items(): if now - when > 180: # over 3 minutes old del self.request_map[k] try: # same code as in handle_read callback (host, 0, None) # timeout val is (0,None) except: (file,fun,line), t, v, tbinfo = asyncore.compact_traceback() self.log_info('%s %s %s' % (t,v,tbinfo), 'error') def resolve (self, host, callback): self.reap() # first, get rid of old guys self.socket.sendto ( fast_address_request (host, self.get_id()), (self.server, 53) ) self.request_map [self.get_id()] = ( host, unpack_address_reply, callback, int(time.time())) self.id.increment() def resolve_ptr (self, host, callback): self.reap() # first, get rid of old guys ip = string.split (host, '.') ip.reverse() ip = string.join (ip, '.') + '.in-addr.arpa' self.socket.sendto ( fast_ptr_request (ip, self.get_id()), (self.server, 53) ) self.request_map [self.get_id()] = ( host, unpack_ptr_reply, callback, int(time.time())) self.id.increment() def handle_read (self): reply, whence = self.socket.recvfrom (512) # for security reasons we may want to double-check # that is the server we sent the request to. id = (ord(reply[0])<<8) + ord(reply[1]) if self.request_map.has_key (id): host, unpack, callback, when = self.request_map[id] del self.request_map[id] ttl, answer = unpack (reply) try: callback (host, ttl, answer) except: (file,fun,line), t, v, tbinfo = asyncore.compact_traceback() self.log_info('%s %s %s' % ( t,v,tbinfo), 'error') class rbl (resolver): def resolve_maps (self, host, callback): ip = string.split (host, '.') ip.reverse() ip = string.join (ip, '.') + '.rbl.maps.vix.com' self.socket.sendto ( fast_ptr_request (ip, self.get_id()), (self.server, 53) ) self.request_map [self.get_id()] = host, self.check_reply, callback self.id.increment() def check_reply (self, r): # we only need to check RCODE. rcode = (ord(r[3])&0xf) self.log_info('MAPS RBL; RCODE =%02x\n %s' % (rcode, repr(r))) return 0, rcode # (ttl, answer) class hooked_callback: def __init__ (self, hook, callback): self.hook, self.callback = hook, callback def __call__ (self, *args): apply (self.hook, args) apply (self.callback, args) class caching_resolver (resolver): "Cache DNS queries. Will need to honor the TTL value in the replies" def __init__ (*args): apply (resolver.__init__, args) self = args[0] self.cache = {} self.forward_requests = counter() self.reverse_requests = counter() self.cache_hits = counter() def resolve (self, host, callback): self.forward_requests.increment() if self.cache.has_key (host): when, ttl, answer = self.cache[host] # ignore TTL for now callback (host, ttl, answer) self.cache_hits.increment() else: resolver.resolve ( self, host, hooked_callback ( self.callback_hook, callback ) ) def resolve_ptr (self, host, callback): self.reverse_requests.increment() if self.cache.has_key (host): when, ttl, answer = self.cache[host] # ignore TTL for now callback (host, ttl, answer) self.cache_hits.increment() else: resolver.resolve_ptr ( self, host, hooked_callback ( self.callback_hook, callback ) ) def callback_hook (self, host, ttl, answer): self.cache[host] = time.time(), ttl, answer SERVER_IDENT = 'Caching DNS Resolver (V%s)' % VERSION def status (self): import status_handler import producers return producers.simple_producer ( '

%s

' % self.SERVER_IDENT + '
Server: %s' % self.server + '
Cache Entries: %d' % len(self.cache) + '
Outstanding Requests: %d' % len(self.request_map) + '
Forward Requests: %s' % self.forward_requests + '
Reverse Requests: %s' % self.reverse_requests + '
Cache Hits: %s' % self.cache_hits ) #test_reply = """\000\000\205\200\000\001\000\001\000\002\000\002\006squirl\011nightmare\003com\000\000\001\000\001\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\011nightmare\003com\000\000\002\000\001\000\001Q\200\000\002\300\014\3006\000\002\000\001\000\001Q\200\000\015\003ns1\003iag\003net\000\300\014\000\001\000\001\000\001Q\200\000\004\315\240\260\005\300]\000\001\000\001\000\000\350\227\000\004\314\033\322\005""" # def test_unpacker (): # print unpack_address_reply (test_reply) # # import time # class timer: # def __init__ (self): # self.start = time.time() # def end (self): # return time.time() - self.start # # # I get ~290 unpacks per second for the typical case, compared to ~48 # # using dnslib directly. also, that latter number does not include # # picking the actual data out. # # def benchmark_unpacker(): # # r = range(1000) # t = timer() # for i in r: # unpack_address_reply (test_reply) # print '%.2f unpacks per second' % (1000.0 / t.end()) if __name__ == '__main__': import sys if len(sys.argv) == 1: print 'usage: %s [-r] [-s ] host [host ...]' % sys.argv[0] sys.exit(0) elif ('-s' in sys.argv): i = sys.argv.index('-s') server = sys.argv[i+1] del sys.argv[i:i+2] else: server = '127.0.0.1' if ('-r' in sys.argv): reverse = 1 i = sys.argv.index('-r') del sys.argv[i] else: reverse = 0 if ('-m' in sys.argv): maps = 1 sys.argv.remove ('-m') else: maps = 0 if maps: r = rbl (server) else: r = caching_resolver(server) count = len(sys.argv) - 1 def print_it (host, ttl, answer): global count print '%s: %s' % (host, answer) count = count - 1 if not count: r.close() for host in sys.argv[1:]: if reverse: r.resolve_ptr (host, print_it) elif maps: r.resolve_maps (host, print_it) else: r.resolve (host, print_it) # hooked asyncore.loop() while asyncore.socket_map: asyncore.poll (30.0) print 'requests outstanding: %d' % len(r.request_map)