8448b7ca340b3c76904b720c62fc91645711c8b4
[prosody.git] / net / dns.lua
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5 -- public domain 20080404 lua@ztact.com
6
7
8 -- todo: quick (default) header generation
9 -- todo: nxdomain, error handling
10 -- todo: cache results of encodeName
11
12
13 -- reference: http://tools.ietf.org/html/rfc1035
14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
15
16
17 require 'socket'
18 local ztact = require 'util.ztact'
19 local require = require
20 local os = os;
21
22 local coroutine, io, math, socket, string, table =
23       coroutine, io, math, socket, string, table
24
25 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
26       ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
27
28 local get, set = ztact.get, ztact.set
29
30
31 -------------------------------------------------- module dns
32 module ('dns')
33 local dns = _M;
34
35
36 -- dns type & class codes ------------------------------ dns type & class codes
37
38
39 local append = table.insert
40
41
42 local function highbyte (i)    -- - - - - - - - - - - - - - - - - - -  highbyte
43   return (i-(i%0x100))/0x100
44   end
45
46
47 local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
48   local a = {}
49   for i,s in pairs (t) do  a[i] = s  a[s] = s  a[string.lower (s)] = s  end
50   return a
51   end
52
53
54 local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
55   local code = {}
56   for i,s in pairs (t) do
57     local word = string.char (highbyte (i), i %0x100)
58     code[i] = word
59     code[s] = word
60     code[string.lower (s)] = word
61     end
62   return code
63   end
64
65
66 dns.types = {
67   'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
68   'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
69   [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
70   [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
71
72
73 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
74
75
76 dns.type      = augment (dns.types)
77 dns.class     = augment (dns.classes)
78 dns.typecode  = encode  (dns.types)
79 dns.classcode = encode  (dns.classes)
80
81
82
83 local function standardize (qname, qtype, qclass)    -- - - - - - - standardize
84   if string.byte (qname, -1) ~= 0x2E then  qname = qname..'.'  end
85   qname = string.lower (qname)
86   return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
87   end
88
89
90 local function prune (rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
91
92   time = time or socket.gettime ()
93   for i,rr in pairs (rrs) do
94
95     if rr.tod then
96       -- rr.tod = rr.tod - 50    -- accelerated decripitude
97       rr.ttl = math.floor (rr.tod - time)
98       if rr.ttl <= 0 then  rrs[i] = nil  end
99
100     elseif soft == 'soft' then    -- What is this?  I forget!
101       assert (rr.ttl == 0)
102       rrs[i] = nil
103       end  end  end
104
105
106 -- metatables & co. ------------------------------------------ metatables & co.
107
108
109 local resolver = {}
110 resolver.__index = resolver
111
112
113 local SRV_tostring
114
115
116 local rr_metatable = {}    -- - - - - - - - - - - - - - - - - - -  rr_metatable
117 function rr_metatable.__tostring (rr)
118   local s0 = string.format (
119     '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
120   local s1 = ''
121   if rr.type == 'A' then  s1 = ' '..rr.a
122   elseif rr.type == 'MX' then
123     s1 = string.format (' %2i %s', rr.pref, rr.mx)
124   elseif rr.type == 'CNAME' then  s1 = ' '..rr.cname
125   elseif rr.type == 'LOC'   then  s1 = ' '..resolver.LOC_tostring (rr)
126   elseif rr.type == 'NS'    then  s1 = ' '..rr.ns
127   elseif rr.type == 'SRV'   then  s1 = ' '..SRV_tostring (rr)
128   elseif rr.type == 'TXT'   then  s1 = ' '..rr.txt
129   else  s1 = ' <UNKNOWN RDATA TYPE>'  end
130   return s0..s1
131   end
132
133
134 local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
135 function rrs_metatable.__tostring (rrs)
136   local t = {}
137   for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
138   return table.concat (t)
139   end
140
141
142 local cache_metatable = {}    -- - - - - - - - - - - - - - - -  cache_metatable
143 function cache_metatable.__tostring (cache)
144   local time = socket.gettime ()
145   local t = {}
146   for class,types in pairs (cache) do
147     for type,names in pairs (types) do
148       for name,rrs in pairs (names) do
149         prune (rrs, time)
150         append (t, tostring (rrs))  end  end  end
151   return table.concat (t)
152   end
153
154
155 function resolver:new ()    -- - - - - - - - - - - - - - - - - - - - - resolver
156   local r = { active = {}, cache = {}, unsorted = {} }
157   setmetatable (r, resolver)
158   setmetatable (r.cache, cache_metatable)
159   setmetatable (r.unsorted, { __mode = 'kv' })
160   return r
161   end
162
163
164 -- packet layer -------------------------------------------------- packet layer
165
166
167 function dns.random (...)    -- - - - - - - - - - - - - - - - - - -  dns.random
168   math.randomseed (10000*socket.gettime ())
169   dns.random = math.random
170   return dns.random (...)
171   end
172
173
174 local function encodeHeader (o)    -- - - - - - - - - - - - - - -  encodeHeader
175
176   o = o or {}
177
178   o.id = o.id or                -- 16b  (random) id
179     dns.random (0, 0xffff)
180
181   o.rd = o.rd or 1              --  1b  1 recursion desired
182   o.tc = o.tc or 0              --  1b  1 truncated response
183   o.aa = o.aa or 0              --  1b  1 authoritative response
184   o.opcode = o.opcode or 0      --  4b  0 query
185                                 --      1 inverse query
186                                 --      2 server status request
187                                 --      3-15 reserved
188   o.qr = o.qr or 0              --  1b  0 query, 1 response
189
190   o.rcode = o.rcode or 0        --  4b  0 no error
191                                 --      1 format error
192                                 --      2 server failure
193                                 --      3 name error
194                                 --      4 not implemented
195                                 --      5 refused
196                                 --      6-15 reserved
197   o.z  = o.z  or 0              --  3b  0 resvered
198   o.ra = o.ra or 0              --  1b  1 recursion available
199
200   o.qdcount = o.qdcount or 1    -- 16b  number of question RRs
201   o.ancount = o.ancount or 0    -- 16b  number of answers RRs
202   o.nscount = o.nscount or 0    -- 16b  number of nameservers RRs
203   o.arcount = o.arcount or 0    -- 16b  number of additional RRs
204
205   -- string.char() rounds, so prevent roundup with -0.4999
206   local header = string.char (
207     highbyte (o.id),  o.id %0x100,
208     o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
209     o.rcode + 16*o.z + 128*o.ra,
210     highbyte (o.qdcount),  o.qdcount %0x100,
211     highbyte (o.ancount),  o.ancount %0x100,
212     highbyte (o.nscount),  o.nscount %0x100,
213     highbyte (o.arcount),  o.arcount %0x100 )
214
215   return header, o.id
216   end
217
218
219 local function encodeName (name)    -- - - - - - - - - - - - - - - - encodeName
220   local t = {}
221   for part in string.gmatch (name, '[^.]+') do
222     append (t, string.char (string.len (part)))
223     append (t, part)
224     end
225   append (t, string.char (0))
226   return table.concat (t)
227   end
228
229
230 local function encodeQuestion (qname, qtype, qclass)    -- - - - encodeQuestion
231   qname  = encodeName (qname)
232   qtype  = dns.typecode[qtype or 'a']
233   qclass = dns.classcode[qclass or 'in']
234   return qname..qtype..qclass;
235   end
236
237
238 function resolver:byte (len)    -- - - - - - - - - - - - - - - - - - - - - byte
239   len = len or 1
240   local offset = self.offset
241   local last = offset + len - 1
242   if last > #self.packet then
243     error (string.format ('out of bounds: %i>%i', last, #self.packet))  end
244   self.offset = offset + len
245   return string.byte (self.packet, offset, last)
246   end
247
248
249 function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
250   local b1, b2 = self:byte (2)
251   return 0x100*b1 + b2
252   end
253
254
255 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
256   local b1, b2, b3, b4 = self:byte (4)
257   --print ('dword', b1, b2, b3, b4)
258   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
259   end
260
261
262 function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
263   len = len or 1
264   local s = string.sub (self.packet, self.offset, self.offset + len - 1)
265   self.offset = self.offset + len
266   return s
267   end
268
269
270 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
271
272   local id = self:word ()
273   --print (string.format (':header  id  %x', id))
274   if not self.active[id] and not force then  return nil  end
275
276   local h = { id = id }
277
278   local b1, b2 = self:byte (2)
279
280   h.rd      = b1 %2
281   h.tc      = b1 /2%2
282   h.aa      = b1 /4%2
283   h.opcode  = b1 /8%16
284   h.qr      = b1 /128
285
286   h.rcode   = b2 %16
287   h.z       = b2 /16%8
288   h.ra      = b2 /128
289
290   h.qdcount = self:word ()
291   h.ancount = self:word ()
292   h.nscount = self:word ()
293   h.arcount = self:word ()
294
295   for k,v in pairs (h) do  h[k] = v-v%1  end
296
297   return h
298   end
299
300
301 function resolver:name ()    -- - - - - - - - - - - - - - - - - - - - - -  name
302   local remember, pointers = nil, 0
303   local len = self:byte ()
304   local n = {}
305   while len > 0 do
306     if len >= 0xc0 then    -- name is "compressed"
307       pointers = pointers + 1
308       if pointers >= 20 then  error ('dns error: 20 pointers')  end
309       local offset = ((len-0xc0)*0x100) + self:byte ()
310       remember = remember or self.offset
311       self.offset = offset + 1    -- +1 for lua
312     else    -- name is not compressed
313       append (n, self:sub (len)..'.')
314       end
315     len = self:byte ()
316     end
317   self.offset = remember or self.offset
318   return table.concat (n)
319   end
320
321
322 function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
323   local q = {}
324   q.name  = self:name ()
325   q.type  = dns.type[self:word ()]
326   q.class = dns.class[self:word ()]
327   return q
328   end
329
330
331 function resolver:A (rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
332   local b1, b2, b3, b4 = self:byte (4)
333   rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
334   end
335
336
337 function resolver:CNAME (rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
338   rr.cname = self:name ()
339   end
340
341
342 function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
343   rr.pref = self:word ()
344   rr.mx   = self:name ()
345   end
346
347
348 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
349   local b = self:byte ()
350   --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
351   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
352   end
353
354
355 function resolver:LOC (rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
356   rr.version = self:byte ()
357   if rr.version == 0 then
358     rr.loc           = rr.loc or {}
359     rr.loc.size      = self:LOC_nibble_power ()
360     rr.loc.horiz_pre = self:LOC_nibble_power ()
361     rr.loc.vert_pre  = self:LOC_nibble_power ()
362     rr.loc.latitude  = self:dword ()
363     rr.loc.longitude = self:dword ()
364     rr.loc.altitude  = self:dword ()
365     end  end
366
367
368 local function LOC_tostring_degrees (f, pos, neg)    -- - - - - - - - - - - - -
369   f = f - 0x80000000
370   if f < 0 then  pos = neg  f = -f  end
371   local deg, min, msec
372   msec = f%60000
373   f    = (f-msec)/60000
374   min  = f%60
375   deg = (f-min)/60
376   return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
377   end
378
379
380 function resolver.LOC_tostring (rr)    -- - - - - - - - - - - - -  LOC_tostring
381
382   local t = {}
383
384   --[[
385   for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
386                         'latitude', 'longitude', 'altitude' } do
387     append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
388     end
389   --]]
390
391   append ( t, string.format (
392     '%s    %s    %.2fm %.2fm %.2fm %.2fm',
393     LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
394     LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
395     (rr.loc.altitude - 10000000) / 100,
396     rr.loc.size / 100,
397     rr.loc.horiz_pre / 100,
398     rr.loc.vert_pre / 100 ) )
399
400   return table.concat (t)
401   end
402
403
404 function resolver:NS (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
405   rr.ns = self:name ()
406   end
407
408
409 function resolver:SOA (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
410   end
411
412
413 function resolver:SRV (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
414   rr.srv = {}
415   rr.srv.priority = self:word ()
416   rr.srv.weight   = self:word ()
417   rr.srv.port     = self:word ()
418   rr.srv.target   = self:name ()
419   end
420
421
422 function SRV_tostring (rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
423   local s = rr.srv
424   return string.format ( '%5d %5d %5d %s',
425                          s.priority, s.weight, s.port, s.target )
426   end
427
428
429 function resolver:TXT (rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
430   rr.txt = self:sub (rr.rdlength)
431   end
432
433
434 function resolver:rr ()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
435   local rr = {}
436   setmetatable (rr, rr_metatable)
437   rr.name     = self:name (self)
438   rr.type     = dns.type[self:word ()] or rr.type
439   rr.class    = dns.class[self:word ()] or rr.class
440   rr.ttl      = 0x10000*self:word () + self:word ()
441   rr.rdlength = self:word ()
442
443   if rr.ttl == 0 then  -- pass
444   else  rr.tod = self.time + rr.ttl  end
445
446   local remember = self.offset
447   local rr_parser = self[dns.type[rr.type]]
448   if rr_parser then  rr_parser (self, rr)  end
449   self.offset = remember
450   rr.rdata = self:sub (rr.rdlength)
451   return rr
452   end
453
454
455 function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
456   local rrs = {}
457   for i = 1,count do  append (rrs, self:rr ())  end
458   return rrs
459   end
460
461
462 function resolver:decode (packet, force)    -- - - - - - - - - - - - - - decode
463
464   self.packet, self.offset = packet, 1
465   local header = self:header (force)
466   if not header then  return nil  end
467   local response = { header = header }
468
469   response.question = {}
470   local offset = self.offset
471   for i = 1,response.header.qdcount do
472     append (response.question, self:question ())  end
473   response.question.raw = string.sub (self.packet, offset, self.offset - 1)
474
475   if not force then
476     if not self.active[response.header.id] or
477        not self.active[response.header.id][response.question.raw] then
478       return nil  end  end
479
480   response.answer     = self:rrs (response.header.ancount)
481   response.authority  = self:rrs (response.header.nscount)
482   response.additional = self:rrs (response.header.arcount)
483
484   return response
485   end
486
487
488 -- socket layer -------------------------------------------------- socket layer
489
490
491 resolver.delays = { 1, 3  }
492
493
494 function resolver:addnameserver (address)    -- - - - - - - - - - addnameserver
495   self.server = self.server or {}
496   append (self.server, address)
497   end
498
499
500 function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
501   self.server = {}
502   self:addnameserver (address)
503   end
504
505
506 function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
507   local resolv_conf = io.open("/etc/resolv.conf");
508   if resolv_conf then
509           for line in resolv_conf:lines() do
510                 local address = string.match (line, '^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$')
511                 if address then self:addnameserver (address)  end
512           end
513   elseif os.getenv("WINDIR") then
514         self:addnameserver ("208.67.222.222")
515         self:addnameserver ("208.67.220.220")   
516   end
517   if not self.server or #self.server == 0 then
518         self:addnameserver("127.0.0.1");
519   end
520 end
521
522
523 function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket
524
525   self.socket = self.socket or {}
526   self.socketset = self.socketset or {}
527
528   local sock = self.socket[servernum]
529   if sock then  return sock  end
530
531   sock = socket.udp ()
532   if self.socket_wrapper then  sock = self.socket_wrapper (sock, self)  end
533   sock:settimeout (0)
534   -- todo: attempt to use a random port, fallback to 0
535   sock:setsockname ('*', 0)
536   sock:setpeername (self.server[servernum], 53)
537   self.socket[servernum] = sock
538   self.socketset[sock] = servernum
539   return sock
540   end
541
542 function resolver:voidsocket (sock)
543   if self.socket[sock] then
544     self.socketset[self.socket[sock]] = nil
545     self.socket[sock] = nil
546   elseif self.socketset[sock] then
547     self.socket[self.socketset[sock]] = nil
548     self.socketset[sock] = nil
549   end
550 end
551
552 function resolver:socket_wrapper_set (func)  -- - - - - - - socket_wrapper_set
553   self.socket_wrapper = func
554   end
555
556
557 function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
558   for i,sock in ipairs (self.socket) do  self.socket[i]:close ()  end
559   self.socket = {}
560   end
561
562
563 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
564
565   --print ('remember', type, rr.class, rr.type, rr.name)
566
567   if type ~= '*' then
568     type = rr.type
569     local all = get (self.cache, rr.class, '*', rr.name)
570     --print ('remember all', all)
571     if all then  append (all, rr)  end
572     end
573
574   self.cache = self.cache or setmetatable ({}, cache_metatable)
575   local rrs = get (self.cache, rr.class, type, rr.name) or
576     set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
577   append (rrs, rr)
578
579   if type == 'MX' then  self.unsorted[rrs] = true  end
580   end
581
582
583 local function comp_mx (a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
584   return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
585   end
586
587
588 function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
589   qname, qtype, qclass = standardize (qname, qtype, qclass)
590   local rrs = get (self.cache, qclass, qtype, qname)
591   if not rrs then  return nil  end
592   if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
593     set (self.cache, qclass, qtype, qname, nil)  return nil  end
594   if self.unsorted[rrs] then  table.sort (rrs, comp_mx)  end
595   return rrs
596   end
597
598
599 function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
600   if soft == 'soft' then
601     self.time = socket.gettime ()
602     for class,types in pairs (self.cache or {}) do
603       for type,names in pairs (types) do
604         for name,rrs in pairs (names) do
605           prune (rrs, self.time, 'soft')
606           end  end  end
607   else  self.cache = {}  end
608   end
609
610
611 function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
612
613   qname, qtype, qclass = standardize (qname, qtype, qclass)
614
615   if not self.server then self:adddefaultnameservers ()  end
616
617   local question = encodeQuestion (qname, qtype, qclass)
618   local peek = self:peek (qname, qtype, qclass)
619   if peek then  return peek  end
620
621   local header, id = encodeHeader ()
622   --print ('query  id', id, qclass, qtype, qname)
623   local o = { packet = header..question,
624               server = self.best_server,
625               delay  = 1,
626               retry  = socket.gettime () + self.delays[1] }
627
628   -- remember the query
629   self.active[id] = self.active[id] or {}
630   self.active[id][question] = o
631
632   -- remember which coroutine wants the answer
633   local co = coroutine.running ()
634   if co then
635     set (self.wanted, qclass, qtype, qname, co, true)
636     --set (self.yielded, co, qclass, qtype, qname, true)
637   end
638   
639   self:getsocket (o.server):send (o.packet)
640
641 end
642
643 function resolver:servfail(sock)
644   -- Resend all queries for this server
645   
646   local num = self.socketset[sock]
647   
648   -- Socket is dead now
649   self:voidsocket(sock);
650   
651   -- Find all requests to the down server, and retry on the next server
652   self.time = socket.gettime ()
653   for id,queries in pairs (self.active) do
654     for question,o in pairs (queries) do
655       if o.server == num then -- This request was to the broken server
656         o.server = o.server + 1 -- Use next server
657         if o.server > #self.server then
658           o.server = 1
659         end
660
661         o.retries = (o.retries or 0) + 1;
662         if o.retries >= #self.server then
663           --print ('timeout')
664           queries[question] = nil
665         else
666           local _a = self:getsocket(o.server);
667           if _a then _a:send (o.packet) end
668         end
669       end
670     end
671   end
672    
673    if num == self.best_server then
674         self.best_server = self.best_server + 1
675         if self.best_server > #self.server then
676                 -- Exhausted all servers, try first again
677                 self.best_server = 1
678         end
679    end
680 end
681
682 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
683
684   --print 'receive'  print (self.socket)
685   self.time = socket.gettime ()
686   rset = rset or self.socket
687
688   local response
689   for i,sock in pairs (rset) do
690
691     if self.socketset[sock] then
692     local packet = sock:receive ()
693     if packet then
694
695     response = self:decode (packet)
696     if response then
697     --print 'received response'
698     --self.print (response)
699
700     for i,section in pairs { 'answer', 'authority', 'additional' } do
701       for j,rr in pairs (response[section]) do
702         self:remember (rr, response.question[1].type)  end  end
703
704     -- retire the query
705     local queries = self.active[response.header.id]
706     if queries[response.question.raw] then
707       queries[response.question.raw] = nil  end
708     if not next (queries) then  self.active[response.header.id] = nil  end
709     if not next (self.active) then  self:closeall ()  end
710
711     -- was the query on the wanted list?
712     local q = response.question
713     local cos = get (self.wanted, q.class, q.type, q.name)
714     if cos then
715       for co in pairs (cos) do
716         set (self.yielded, co, q.class, q.type, q.name, nil)
717         if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
718         end
719       set (self.wanted, q.class, q.type, q.name, nil)
720       end  end  end  end  end
721
722   return response
723   end
724
725
726 function resolver:feed(sock, packet)
727   --print 'receive'  print (self.socket)
728   self.time = socket.gettime ()
729
730   local response = self:decode (packet)
731   if response then
732     --print 'received response'
733     --self.print (response)
734
735     for i,section in pairs { 'answer', 'authority', 'additional' } do
736       for j,rr in pairs (response[section]) do
737         self:remember (rr, response.question[1].type)
738       end
739     end
740
741     -- retire the query
742     local queries = self.active[response.header.id]
743     if queries[response.question.raw] then
744       queries[response.question.raw] = nil
745     end
746     if not next (queries) then  self.active[response.header.id] = nil  end
747     if not next (self.active) then  self:closeall ()  end
748
749     -- was the query on the wanted list?
750     local q = response.question[1]
751     if q then
752       local cos = get (self.wanted, q.class, q.type, q.name)
753       if cos then
754         for co in pairs (cos) do
755           set (self.yielded, co, q.class, q.type, q.name, nil)
756           if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
757         end
758         set (self.wanted, q.class, q.type, q.name, nil)
759       end
760     end
761   end 
762
763   return response
764 end
765
766 function resolver:cancel(data)
767         local cos = get (self.wanted, unpack(data, 1, 3))
768         if cos then
769                 cos[data[4]] = nil;
770         end
771 end
772
773 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
774
775   --print ':pulse'
776   while self:receive() do end
777   if not next (self.active) then  return nil  end
778
779   self.time = socket.gettime ()
780   for id,queries in pairs (self.active) do
781     for question,o in pairs (queries) do
782       if self.time >= o.retry then
783
784         o.server = o.server + 1
785         if o.server > #self.server then
786           o.server = 1
787           o.delay = o.delay + 1
788           end
789
790         if o.delay > #self.delays then
791           --print ('timeout')
792           queries[question] = nil
793           if not next (queries) then  self.active[id] = nil  end
794           if not next (self.active) then  return nil  end
795         else
796           --print ('retry', o.server, o.delay)
797           local _a = self.socket[o.server];
798           if _a then _a:send (o.packet) end
799           o.retry = self.time + self.delays[o.delay]
800           end  end  end  end
801
802   if next (self.active) then  return true  end
803   return nil
804   end
805
806
807 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
808   self:query (qname, qtype, qclass)
809   while self:pulse () do  socket.select (self.socket, nil, 4)  end
810   --print (self.cache)
811   return self:peek (qname, qtype, qclass)
812   end
813
814 function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
815   return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
816   end
817
818
819 --print ---------------------------------------------------------------- print
820
821
822 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
823   qr = { [0]='query', 'response' },
824   opcode = { [0]='query', 'inverse query', 'server status request' },
825   aa = { [0]='non-authoritative', 'authoritative' },
826   tc = { [0]='complete', 'truncated' },
827   rd = { [0]='recursion not desired', 'recursion desired' },
828   ra = { [0]='recursion not available', 'recursion available' },
829   z  = { [0]='(reserved)' },
830   rcode = { [0]='no error', 'format error', 'server failure', 'name error',
831             'not implemented' },
832
833   type = dns.type,
834   class = dns.class, }
835
836
837 local function hint (p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
838   return (hints[s] and hints[s][p[s]]) or ''  end
839
840
841 function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
842
843   for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
844                      'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
845     print ( string.format ('%-30s', 'header.'..s),
846             response.header[s], hint (response.header, s) )
847     end
848
849   for i,question in ipairs (response.question) do
850     print (string.format ('question[%i].name         ', i), question.name)
851     print (string.format ('question[%i].type         ', i), question.type)
852     print (string.format ('question[%i].class        ', i), question.class)
853     end
854
855   local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
856   local tmp
857   for s,s in pairs {'answer', 'authority', 'additional'} do
858     for i,rr in pairs (response[s]) do
859       for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
860         tmp = string.format ('%s[%i].%s', s, i, t)
861         print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
862         end
863       for j,t in pairs (rr) do
864         if not common[j] then
865           tmp = string.format ('%s[%i].%s', s, i, j)
866           print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
867           end  end  end  end  end
868
869
870 -- module api ------------------------------------------------------ module api
871
872
873 local function resolve (func, ...)    -- - - - - - - - - - - - - - resolver_get
874   dns._resolver = dns._resolver or dns.resolver ()
875   return func (dns._resolver, ...)
876   end
877
878
879 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
880
881   -- this function seems to be redundant with resolver.new ()
882
883   local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, 
884               best_server = 1 }
885   setmetatable (r, resolver)
886   setmetatable (r.cache, cache_metatable)
887   setmetatable (r.unsorted, { __mode = 'kv' })
888   return r
889   end
890
891
892 function dns.lookup (...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
893   return resolve (resolver.lookup, ...)  end
894
895
896 function dns.purge (...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
897   return resolve (resolver.purge, ...)  end
898
899 function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
900   return resolve (resolver.peek, ...)  end
901
902
903 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
904   return resolve (resolver.query, ...)  end
905
906 function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
907   return resolve (resolver.feed, ...)  end
908
909 function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
910   return resolve(resolver.cancel, ...) end
911
912 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
913   return resolve (resolver.socket_wrapper_set, ...)  end
914
915
916 return dns