d6462031a85de7be38c374c8381c4e148989a72b
[prosody.git] / net / dns.lua
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5 -- public domain 20080404 lua@ztact.com
6
7
8 -- todo: quick (default) header generation
9 -- todo: nxdomain, error handling
10 -- todo: cache results of encodeName
11
12
13 -- reference: http://tools.ietf.org/html/rfc1035
14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
15
16
17 require 'socket'
18 local ztact = require 'util.ztact'
19 local require = require
20 local os = os;
21
22 local coroutine, io, math, socket, string, table =
23       coroutine, io, math, socket, string, table
24
25 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
26       ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
27
28 local get, set = ztact.get, ztact.set
29
30
31 -------------------------------------------------- module dns
32 module ('dns')
33 local dns = _M;
34
35
36 -- dns type & class codes ------------------------------ dns type & class codes
37
38
39 local append = table.insert
40
41
42 local function highbyte (i)    -- - - - - - - - - - - - - - - - - - -  highbyte
43   return (i-(i%0x100))/0x100
44   end
45
46
47 local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
48   local a = {}
49   for i,s in pairs (t) do  a[i] = s  a[s] = s  a[string.lower (s)] = s  end
50   return a
51   end
52
53
54 local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
55   local code = {}
56   for i,s in pairs (t) do
57     local word = string.char (highbyte (i), i %0x100)
58     code[i] = word
59     code[s] = word
60     code[string.lower (s)] = word
61     end
62   return code
63   end
64
65
66 dns.types = {
67   'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
68   'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
69   [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
70   [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
71
72
73 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
74
75
76 dns.type      = augment (dns.types)
77 dns.class     = augment (dns.classes)
78 dns.typecode  = encode  (dns.types)
79 dns.classcode = encode  (dns.classes)
80
81
82
83 local function standardize (qname, qtype, qclass)    -- - - - - - - standardize
84   if string.byte (qname, -1) ~= 0x2E then  qname = qname..'.'  end
85   qname = string.lower (qname)
86   return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
87   end
88
89
90 local function prune (rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
91
92   time = time or socket.gettime ()
93   for i,rr in pairs (rrs) do
94
95     if rr.tod then
96       -- rr.tod = rr.tod - 50    -- accelerated decripitude
97       rr.ttl = math.floor (rr.tod - time)
98       if rr.ttl <= 0 then
99         table.remove(rrs, i);
100         return prune(rrs, time, soft); -- Re-iterate
101       end
102
103     elseif soft == 'soft' then    -- What is this?  I forget!
104       assert (rr.ttl == 0)
105       rrs[i] = nil
106       end  end  end
107
108
109 -- metatables & co. ------------------------------------------ metatables & co.
110
111
112 local resolver = {}
113 resolver.__index = resolver
114
115
116 local SRV_tostring
117
118
119 local rr_metatable = {}    -- - - - - - - - - - - - - - - - - - -  rr_metatable
120 function rr_metatable.__tostring (rr)
121   local s0 = string.format (
122     '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
123   local s1 = ''
124   if rr.type == 'A' then  s1 = ' '..rr.a
125   elseif rr.type == 'MX' then
126     s1 = string.format (' %2i %s', rr.pref, rr.mx)
127   elseif rr.type == 'CNAME' then  s1 = ' '..rr.cname
128   elseif rr.type == 'LOC'   then  s1 = ' '..resolver.LOC_tostring (rr)
129   elseif rr.type == 'NS'    then  s1 = ' '..rr.ns
130   elseif rr.type == 'SRV'   then  s1 = ' '..SRV_tostring (rr)
131   elseif rr.type == 'TXT'   then  s1 = ' '..rr.txt
132   else  s1 = ' <UNKNOWN RDATA TYPE>'  end
133   return s0..s1
134   end
135
136
137 local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
138 function rrs_metatable.__tostring (rrs)
139   local t = {}
140   for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
141   return table.concat (t)
142   end
143
144
145 local cache_metatable = {}    -- - - - - - - - - - - - - - - -  cache_metatable
146 function cache_metatable.__tostring (cache)
147   local time = socket.gettime ()
148   local t = {}
149   for class,types in pairs (cache) do
150     for type,names in pairs (types) do
151       for name,rrs in pairs (names) do
152         prune (rrs, time)
153         append (t, tostring (rrs))  end  end  end
154   return table.concat (t)
155   end
156
157
158 function resolver:new ()    -- - - - - - - - - - - - - - - - - - - - - resolver
159   local r = { active = {}, cache = {}, unsorted = {} }
160   setmetatable (r, resolver)
161   setmetatable (r.cache, cache_metatable)
162   setmetatable (r.unsorted, { __mode = 'kv' })
163   return r
164   end
165
166
167 -- packet layer -------------------------------------------------- packet layer
168
169
170 function dns.random (...)    -- - - - - - - - - - - - - - - - - - -  dns.random
171   math.randomseed (10000*socket.gettime ())
172   dns.random = math.random
173   return dns.random (...)
174   end
175
176
177 local function encodeHeader (o)    -- - - - - - - - - - - - - - -  encodeHeader
178
179   o = o or {}
180
181   o.id = o.id or                -- 16b  (random) id
182     dns.random (0, 0xffff)
183
184   o.rd = o.rd or 1              --  1b  1 recursion desired
185   o.tc = o.tc or 0              --  1b  1 truncated response
186   o.aa = o.aa or 0              --  1b  1 authoritative response
187   o.opcode = o.opcode or 0      --  4b  0 query
188                                 --      1 inverse query
189                                 --      2 server status request
190                                 --      3-15 reserved
191   o.qr = o.qr or 0              --  1b  0 query, 1 response
192
193   o.rcode = o.rcode or 0        --  4b  0 no error
194                                 --      1 format error
195                                 --      2 server failure
196                                 --      3 name error
197                                 --      4 not implemented
198                                 --      5 refused
199                                 --      6-15 reserved
200   o.z  = o.z  or 0              --  3b  0 resvered
201   o.ra = o.ra or 0              --  1b  1 recursion available
202
203   o.qdcount = o.qdcount or 1    -- 16b  number of question RRs
204   o.ancount = o.ancount or 0    -- 16b  number of answers RRs
205   o.nscount = o.nscount or 0    -- 16b  number of nameservers RRs
206   o.arcount = o.arcount or 0    -- 16b  number of additional RRs
207
208   -- string.char() rounds, so prevent roundup with -0.4999
209   local header = string.char (
210     highbyte (o.id),  o.id %0x100,
211     o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
212     o.rcode + 16*o.z + 128*o.ra,
213     highbyte (o.qdcount),  o.qdcount %0x100,
214     highbyte (o.ancount),  o.ancount %0x100,
215     highbyte (o.nscount),  o.nscount %0x100,
216     highbyte (o.arcount),  o.arcount %0x100 )
217
218   return header, o.id
219   end
220
221
222 local function encodeName (name)    -- - - - - - - - - - - - - - - - encodeName
223   local t = {}
224   for part in string.gmatch (name, '[^.]+') do
225     append (t, string.char (string.len (part)))
226     append (t, part)
227     end
228   append (t, string.char (0))
229   return table.concat (t)
230   end
231
232
233 local function encodeQuestion (qname, qtype, qclass)    -- - - - encodeQuestion
234   qname  = encodeName (qname)
235   qtype  = dns.typecode[qtype or 'a']
236   qclass = dns.classcode[qclass or 'in']
237   return qname..qtype..qclass;
238   end
239
240
241 function resolver:byte (len)    -- - - - - - - - - - - - - - - - - - - - - byte
242   len = len or 1
243   local offset = self.offset
244   local last = offset + len - 1
245   if last > #self.packet then
246     error (string.format ('out of bounds: %i>%i', last, #self.packet))  end
247   self.offset = offset + len
248   return string.byte (self.packet, offset, last)
249   end
250
251
252 function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
253   local b1, b2 = self:byte (2)
254   return 0x100*b1 + b2
255   end
256
257
258 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
259   local b1, b2, b3, b4 = self:byte (4)
260   --print ('dword', b1, b2, b3, b4)
261   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
262   end
263
264
265 function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
266   len = len or 1
267   local s = string.sub (self.packet, self.offset, self.offset + len - 1)
268   self.offset = self.offset + len
269   return s
270   end
271
272
273 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
274
275   local id = self:word ()
276   --print (string.format (':header  id  %x', id))
277   if not self.active[id] and not force then  return nil  end
278
279   local h = { id = id }
280
281   local b1, b2 = self:byte (2)
282
283   h.rd      = b1 %2
284   h.tc      = b1 /2%2
285   h.aa      = b1 /4%2
286   h.opcode  = b1 /8%16
287   h.qr      = b1 /128
288
289   h.rcode   = b2 %16
290   h.z       = b2 /16%8
291   h.ra      = b2 /128
292
293   h.qdcount = self:word ()
294   h.ancount = self:word ()
295   h.nscount = self:word ()
296   h.arcount = self:word ()
297
298   for k,v in pairs (h) do  h[k] = v-v%1  end
299
300   return h
301   end
302
303
304 function resolver:name ()    -- - - - - - - - - - - - - - - - - - - - - -  name
305   local remember, pointers = nil, 0
306   local len = self:byte ()
307   local n = {}
308   while len > 0 do
309     if len >= 0xc0 then    -- name is "compressed"
310       pointers = pointers + 1
311       if pointers >= 20 then  error ('dns error: 20 pointers')  end
312       local offset = ((len-0xc0)*0x100) + self:byte ()
313       remember = remember or self.offset
314       self.offset = offset + 1    -- +1 for lua
315     else    -- name is not compressed
316       append (n, self:sub (len)..'.')
317       end
318     len = self:byte ()
319     end
320   self.offset = remember or self.offset
321   return table.concat (n)
322   end
323
324
325 function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
326   local q = {}
327   q.name  = self:name ()
328   q.type  = dns.type[self:word ()]
329   q.class = dns.class[self:word ()]
330   return q
331   end
332
333
334 function resolver:A (rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
335   local b1, b2, b3, b4 = self:byte (4)
336   rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
337   end
338
339
340 function resolver:CNAME (rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
341   rr.cname = self:name ()
342   end
343
344
345 function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
346   rr.pref = self:word ()
347   rr.mx   = self:name ()
348   end
349
350
351 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
352   local b = self:byte ()
353   --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
354   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
355   end
356
357
358 function resolver:LOC (rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
359   rr.version = self:byte ()
360   if rr.version == 0 then
361     rr.loc           = rr.loc or {}
362     rr.loc.size      = self:LOC_nibble_power ()
363     rr.loc.horiz_pre = self:LOC_nibble_power ()
364     rr.loc.vert_pre  = self:LOC_nibble_power ()
365     rr.loc.latitude  = self:dword ()
366     rr.loc.longitude = self:dword ()
367     rr.loc.altitude  = self:dword ()
368     end  end
369
370
371 local function LOC_tostring_degrees (f, pos, neg)    -- - - - - - - - - - - - -
372   f = f - 0x80000000
373   if f < 0 then  pos = neg  f = -f  end
374   local deg, min, msec
375   msec = f%60000
376   f    = (f-msec)/60000
377   min  = f%60
378   deg = (f-min)/60
379   return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
380   end
381
382
383 function resolver.LOC_tostring (rr)    -- - - - - - - - - - - - -  LOC_tostring
384
385   local t = {}
386
387   --[[
388   for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
389                         'latitude', 'longitude', 'altitude' } do
390     append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
391     end
392   --]]
393
394   append ( t, string.format (
395     '%s    %s    %.2fm %.2fm %.2fm %.2fm',
396     LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
397     LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
398     (rr.loc.altitude - 10000000) / 100,
399     rr.loc.size / 100,
400     rr.loc.horiz_pre / 100,
401     rr.loc.vert_pre / 100 ) )
402
403   return table.concat (t)
404   end
405
406
407 function resolver:NS (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
408   rr.ns = self:name ()
409   end
410
411
412 function resolver:SOA (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
413   end
414
415
416 function resolver:SRV (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
417   rr.srv = {}
418   rr.srv.priority = self:word ()
419   rr.srv.weight   = self:word ()
420   rr.srv.port     = self:word ()
421   rr.srv.target   = self:name ()
422   end
423
424
425 function SRV_tostring (rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
426   local s = rr.srv
427   return string.format ( '%5d %5d %5d %s',
428                          s.priority, s.weight, s.port, s.target )
429   end
430
431
432 function resolver:TXT (rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
433   rr.txt = self:sub (rr.rdlength)
434   end
435
436
437 function resolver:rr ()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
438   local rr = {}
439   setmetatable (rr, rr_metatable)
440   rr.name     = self:name (self)
441   rr.type     = dns.type[self:word ()] or rr.type
442   rr.class    = dns.class[self:word ()] or rr.class
443   rr.ttl      = 0x10000*self:word () + self:word ()
444   rr.rdlength = self:word ()
445
446   if rr.ttl == 0 then  -- pass
447   else  rr.tod = self.time + rr.ttl  end
448
449   local remember = self.offset
450   local rr_parser = self[dns.type[rr.type]]
451   if rr_parser then  rr_parser (self, rr)  end
452   self.offset = remember
453   rr.rdata = self:sub (rr.rdlength)
454   return rr
455   end
456
457
458 function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
459   local rrs = {}
460   for i = 1,count do  append (rrs, self:rr ())  end
461   return rrs
462   end
463
464
465 function resolver:decode (packet, force)    -- - - - - - - - - - - - - - decode
466
467   self.packet, self.offset = packet, 1
468   local header = self:header (force)
469   if not header then  return nil  end
470   local response = { header = header }
471
472   response.question = {}
473   local offset = self.offset
474   for i = 1,response.header.qdcount do
475     append (response.question, self:question ())  end
476   response.question.raw = string.sub (self.packet, offset, self.offset - 1)
477
478   if not force then
479     if not self.active[response.header.id] or
480        not self.active[response.header.id][response.question.raw] then
481       return nil  end  end
482
483   response.answer     = self:rrs (response.header.ancount)
484   response.authority  = self:rrs (response.header.nscount)
485   response.additional = self:rrs (response.header.arcount)
486
487   return response
488   end
489
490
491 -- socket layer -------------------------------------------------- socket layer
492
493
494 resolver.delays = { 1, 3  }
495
496
497 function resolver:addnameserver (address)    -- - - - - - - - - - addnameserver
498   self.server = self.server or {}
499   append (self.server, address)
500   end
501
502
503 function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
504   self.server = {}
505   self:addnameserver (address)
506   end
507
508
509 function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
510   local resolv_conf = io.open("/etc/resolv.conf");
511   if resolv_conf then
512           for line in resolv_conf:lines() do
513                 local address = string.match (line, '^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$')
514                 if address then self:addnameserver (address)  end
515           end
516   elseif os.getenv("WINDIR") then
517         self:addnameserver ("208.67.222.222")
518         self:addnameserver ("208.67.220.220")   
519   end
520   if not self.server or #self.server == 0 then
521         self:addnameserver("127.0.0.1");
522   end
523 end
524
525
526 function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket
527
528   self.socket = self.socket or {}
529   self.socketset = self.socketset or {}
530
531   local sock = self.socket[servernum]
532   if sock then  return sock  end
533
534   sock = socket.udp ()
535   if self.socket_wrapper then  sock = self.socket_wrapper (sock, self)  end
536   sock:settimeout (0)
537   -- todo: attempt to use a random port, fallback to 0
538   sock:setsockname ('*', 0)
539   sock:setpeername (self.server[servernum], 53)
540   self.socket[servernum] = sock
541   self.socketset[sock] = servernum
542   return sock
543   end
544
545 function resolver:voidsocket (sock)
546   if self.socket[sock] then
547     self.socketset[self.socket[sock]] = nil
548     self.socket[sock] = nil
549   elseif self.socketset[sock] then
550     self.socket[self.socketset[sock]] = nil
551     self.socketset[sock] = nil
552   end
553 end
554
555 function resolver:socket_wrapper_set (func)  -- - - - - - - socket_wrapper_set
556   self.socket_wrapper = func
557   end
558
559
560 function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
561   for i,sock in ipairs (self.socket) do  self.socket[i]:close ()  end
562   self.socket = {}
563   end
564
565
566 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
567
568   --print ('remember', type, rr.class, rr.type, rr.name)
569
570   if type ~= '*' then
571     type = rr.type
572     local all = get (self.cache, rr.class, '*', rr.name)
573     --print ('remember all', all)
574     if all then  append (all, rr)  end
575     end
576
577   self.cache = self.cache or setmetatable ({}, cache_metatable)
578   local rrs = get (self.cache, rr.class, type, rr.name) or
579     set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
580   append (rrs, rr)
581
582   if type == 'MX' then  self.unsorted[rrs] = true  end
583   end
584
585
586 local function comp_mx (a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
587   return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
588   end
589
590
591 function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
592   qname, qtype, qclass = standardize (qname, qtype, qclass)
593   local rrs = get (self.cache, qclass, qtype, qname)
594   if not rrs then  return nil  end
595   if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
596     set (self.cache, qclass, qtype, qname, nil)  return nil  end
597   if self.unsorted[rrs] then  table.sort (rrs, comp_mx)  end
598   return rrs
599   end
600
601
602 function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
603   if soft == 'soft' then
604     self.time = socket.gettime ()
605     for class,types in pairs (self.cache or {}) do
606       for type,names in pairs (types) do
607         for name,rrs in pairs (names) do
608           prune (rrs, self.time, 'soft')
609           end  end  end
610   else  self.cache = {}  end
611   end
612
613
614 function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
615
616   qname, qtype, qclass = standardize (qname, qtype, qclass)
617
618   if not self.server then self:adddefaultnameservers ()  end
619
620   local question = encodeQuestion (qname, qtype, qclass)
621   local peek = self:peek (qname, qtype, qclass)
622   if peek then  return peek  end
623
624   local header, id = encodeHeader ()
625   --print ('query  id', id, qclass, qtype, qname)
626   local o = { packet = header..question,
627               server = self.best_server,
628               delay  = 1,
629               retry  = socket.gettime () + self.delays[1] }
630
631   -- remember the query
632   self.active[id] = self.active[id] or {}
633   self.active[id][question] = o
634
635   -- remember which coroutine wants the answer
636   local co = coroutine.running ()
637   if co then
638     set (self.wanted, qclass, qtype, qname, co, true)
639     --set (self.yielded, co, qclass, qtype, qname, true)
640   end
641   
642   self:getsocket (o.server):send (o.packet)
643
644 end
645
646 function resolver:servfail(sock)
647   -- Resend all queries for this server
648   
649   local num = self.socketset[sock]
650   
651   -- Socket is dead now
652   self:voidsocket(sock);
653   
654   -- Find all requests to the down server, and retry on the next server
655   self.time = socket.gettime ()
656   for id,queries in pairs (self.active) do
657     for question,o in pairs (queries) do
658       if o.server == num then -- This request was to the broken server
659         o.server = o.server + 1 -- Use next server
660         if o.server > #self.server then
661           o.server = 1
662         end
663
664         o.retries = (o.retries or 0) + 1;
665         if o.retries >= #self.server then
666           --print ('timeout')
667           queries[question] = nil
668         else
669           local _a = self:getsocket(o.server);
670           if _a then _a:send (o.packet) end
671         end
672       end
673     end
674   end
675    
676    if num == self.best_server then
677         self.best_server = self.best_server + 1
678         if self.best_server > #self.server then
679                 -- Exhausted all servers, try first again
680                 self.best_server = 1
681         end
682    end
683 end
684
685 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
686
687   --print 'receive'  print (self.socket)
688   self.time = socket.gettime ()
689   rset = rset or self.socket
690
691   local response
692   for i,sock in pairs (rset) do
693
694     if self.socketset[sock] then
695     local packet = sock:receive ()
696     if packet then
697
698     response = self:decode (packet)
699     if response then
700     --print 'received response'
701     --self.print (response)
702
703     for i,section in pairs { 'answer', 'authority', 'additional' } do
704       for j,rr in pairs (response[section]) do
705         self:remember (rr, response.question[1].type)  end  end
706
707     -- retire the query
708     local queries = self.active[response.header.id]
709     if queries[response.question.raw] then
710       queries[response.question.raw] = nil  end
711     if not next (queries) then  self.active[response.header.id] = nil  end
712     if not next (self.active) then  self:closeall ()  end
713
714     -- was the query on the wanted list?
715     local q = response.question
716     local cos = get (self.wanted, q.class, q.type, q.name)
717     if cos then
718       for co in pairs (cos) do
719         set (self.yielded, co, q.class, q.type, q.name, nil)
720         if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
721         end
722       set (self.wanted, q.class, q.type, q.name, nil)
723       end  end  end  end  end
724
725   return response
726   end
727
728
729 function resolver:feed(sock, packet)
730   --print 'receive'  print (self.socket)
731   self.time = socket.gettime ()
732
733   local response = self:decode (packet)
734   if response then
735     --print 'received response'
736     --self.print (response)
737
738     for i,section in pairs { 'answer', 'authority', 'additional' } do
739       for j,rr in pairs (response[section]) do
740         self:remember (rr, response.question[1].type)
741       end
742     end
743
744     -- retire the query
745     local queries = self.active[response.header.id]
746     if queries[response.question.raw] then
747       queries[response.question.raw] = nil
748     end
749     if not next (queries) then  self.active[response.header.id] = nil  end
750     if not next (self.active) then  self:closeall ()  end
751
752     -- was the query on the wanted list?
753     local q = response.question[1]
754     if q then
755       local cos = get (self.wanted, q.class, q.type, q.name)
756       if cos then
757         for co in pairs (cos) do
758           set (self.yielded, co, q.class, q.type, q.name, nil)
759           if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
760         end
761         set (self.wanted, q.class, q.type, q.name, nil)
762       end
763     end
764   end 
765
766   return response
767 end
768
769 function resolver:cancel(data)
770         local cos = get (self.wanted, unpack(data, 1, 3))
771         if cos then
772                 cos[data[4]] = nil;
773         end
774 end
775
776 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
777
778   --print ':pulse'
779   while self:receive() do end
780   if not next (self.active) then  return nil  end
781
782   self.time = socket.gettime ()
783   for id,queries in pairs (self.active) do
784     for question,o in pairs (queries) do
785       if self.time >= o.retry then
786
787         o.server = o.server + 1
788         if o.server > #self.server then
789           o.server = 1
790           o.delay = o.delay + 1
791           end
792
793         if o.delay > #self.delays then
794           --print ('timeout')
795           queries[question] = nil
796           if not next (queries) then  self.active[id] = nil  end
797           if not next (self.active) then  return nil  end
798         else
799           --print ('retry', o.server, o.delay)
800           local _a = self.socket[o.server];
801           if _a then _a:send (o.packet) end
802           o.retry = self.time + self.delays[o.delay]
803           end  end  end  end
804
805   if next (self.active) then  return true  end
806   return nil
807   end
808
809
810 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
811   self:query (qname, qtype, qclass)
812   while self:pulse () do  socket.select (self.socket, nil, 4)  end
813   --print (self.cache)
814   return self:peek (qname, qtype, qclass)
815   end
816
817 function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
818   return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
819   end
820
821
822 --print ---------------------------------------------------------------- print
823
824
825 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
826   qr = { [0]='query', 'response' },
827   opcode = { [0]='query', 'inverse query', 'server status request' },
828   aa = { [0]='non-authoritative', 'authoritative' },
829   tc = { [0]='complete', 'truncated' },
830   rd = { [0]='recursion not desired', 'recursion desired' },
831   ra = { [0]='recursion not available', 'recursion available' },
832   z  = { [0]='(reserved)' },
833   rcode = { [0]='no error', 'format error', 'server failure', 'name error',
834             'not implemented' },
835
836   type = dns.type,
837   class = dns.class, }
838
839
840 local function hint (p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
841   return (hints[s] and hints[s][p[s]]) or ''  end
842
843
844 function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
845
846   for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
847                      'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
848     print ( string.format ('%-30s', 'header.'..s),
849             response.header[s], hint (response.header, s) )
850     end
851
852   for i,question in ipairs (response.question) do
853     print (string.format ('question[%i].name         ', i), question.name)
854     print (string.format ('question[%i].type         ', i), question.type)
855     print (string.format ('question[%i].class        ', i), question.class)
856     end
857
858   local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
859   local tmp
860   for s,s in pairs {'answer', 'authority', 'additional'} do
861     for i,rr in pairs (response[s]) do
862       for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
863         tmp = string.format ('%s[%i].%s', s, i, t)
864         print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
865         end
866       for j,t in pairs (rr) do
867         if not common[j] then
868           tmp = string.format ('%s[%i].%s', s, i, j)
869           print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
870           end  end  end  end  end
871
872
873 -- module api ------------------------------------------------------ module api
874
875
876 local function resolve (func, ...)    -- - - - - - - - - - - - - - resolver_get
877   dns._resolver = dns._resolver or dns.resolver ()
878   return func (dns._resolver, ...)
879   end
880
881
882 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
883
884   -- this function seems to be redundant with resolver.new ()
885
886   local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, 
887               best_server = 1 }
888   setmetatable (r, resolver)
889   setmetatable (r.cache, cache_metatable)
890   setmetatable (r.unsorted, { __mode = 'kv' })
891   return r
892   end
893
894
895 function dns.lookup (...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
896   return resolve (resolver.lookup, ...)  end
897
898
899 function dns.purge (...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
900   return resolve (resolver.purge, ...)  end
901
902 function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
903   return resolve (resolver.peek, ...)  end
904
905
906 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
907   return resolve (resolver.query, ...)  end
908
909 function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
910   return resolve (resolver.feed, ...)  end
911
912 function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
913   return resolve(resolver.cancel, ...) end
914
915 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
916   return resolve (resolver.socket_wrapper_set, ...)  end
917
918
919 return dns