5abc8afc9f1ee83c12dd7c5f2199e557314f9aca
[prosody.git] / net / dns.lua
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5 -- public domain 20080404 lua@ztact.com
6
7
8 -- todo: quick (default) header generation
9 -- todo: nxdomain, error handling
10 -- todo: cache results of encodeName
11
12
13 -- reference: http://tools.ietf.org/html/rfc1035
14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
15
16
17 require 'socket'
18 local ztact = require 'util.ztact'
19 local require = require
20 local _, windows = pcall(require, "util.windows");
21 local is_windows = (_ and windows) or os.getenv("WINDIR");
22
23 local coroutine, io, math, socket, string, table =
24       coroutine, io, math, socket, string, table
25
26 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
27       ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
28
29 local get, set = ztact.get, ztact.set
30
31
32 -------------------------------------------------- module dns
33 module ('dns')
34 local dns = _M;
35
36
37 -- dns type & class codes ------------------------------ dns type & class codes
38
39
40 local append = table.insert
41
42
43 local function highbyte (i)    -- - - - - - - - - - - - - - - - - - -  highbyte
44   return (i-(i%0x100))/0x100
45   end
46
47
48 local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
49   local a = {}
50   for i,s in pairs (t) do  a[i] = s  a[s] = s  a[string.lower (s)] = s  end
51   return a
52   end
53
54
55 local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
56   local code = {}
57   for i,s in pairs (t) do
58     local word = string.char (highbyte (i), i %0x100)
59     code[i] = word
60     code[s] = word
61     code[string.lower (s)] = word
62     end
63   return code
64   end
65
66
67 dns.types = {
68   'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
69   'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
70   [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
71   [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
72
73
74 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
75
76
77 dns.type      = augment (dns.types)
78 dns.class     = augment (dns.classes)
79 dns.typecode  = encode  (dns.types)
80 dns.classcode = encode  (dns.classes)
81
82
83
84 local function standardize (qname, qtype, qclass)    -- - - - - - - standardize
85   if string.byte (qname, -1) ~= 0x2E then  qname = qname..'.'  end
86   qname = string.lower (qname)
87   return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
88   end
89
90
91 local function prune (rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
92
93   time = time or socket.gettime ()
94   for i,rr in pairs (rrs) do
95
96     if rr.tod then
97       -- rr.tod = rr.tod - 50    -- accelerated decripitude
98       rr.ttl = math.floor (rr.tod - time)
99       if rr.ttl <= 0 then
100         table.remove(rrs, i);
101         return prune(rrs, time, soft); -- Re-iterate
102       end
103
104     elseif soft == 'soft' then    -- What is this?  I forget!
105       assert (rr.ttl == 0)
106       rrs[i] = nil
107       end  end  end
108
109
110 -- metatables & co. ------------------------------------------ metatables & co.
111
112
113 local resolver = {}
114 resolver.__index = resolver
115
116
117 local SRV_tostring
118
119
120 local rr_metatable = {}    -- - - - - - - - - - - - - - - - - - -  rr_metatable
121 function rr_metatable.__tostring (rr)
122   local s0 = string.format (
123     '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
124   local s1 = ''
125   if rr.type == 'A' then  s1 = ' '..rr.a
126   elseif rr.type == 'MX' then
127     s1 = string.format (' %2i %s', rr.pref, rr.mx)
128   elseif rr.type == 'CNAME' then  s1 = ' '..rr.cname
129   elseif rr.type == 'LOC'   then  s1 = ' '..resolver.LOC_tostring (rr)
130   elseif rr.type == 'NS'    then  s1 = ' '..rr.ns
131   elseif rr.type == 'SRV'   then  s1 = ' '..SRV_tostring (rr)
132   elseif rr.type == 'TXT'   then  s1 = ' '..rr.txt
133   else  s1 = ' <UNKNOWN RDATA TYPE>'  end
134   return s0..s1
135   end
136
137
138 local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
139 function rrs_metatable.__tostring (rrs)
140   local t = {}
141   for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
142   return table.concat (t)
143   end
144
145
146 local cache_metatable = {}    -- - - - - - - - - - - - - - - -  cache_metatable
147 function cache_metatable.__tostring (cache)
148   local time = socket.gettime ()
149   local t = {}
150   for class,types in pairs (cache) do
151     for type,names in pairs (types) do
152       for name,rrs in pairs (names) do
153         prune (rrs, time)
154         append (t, tostring (rrs))  end  end  end
155   return table.concat (t)
156   end
157
158
159 function resolver:new ()    -- - - - - - - - - - - - - - - - - - - - - resolver
160   local r = { active = {}, cache = {}, unsorted = {} }
161   setmetatable (r, resolver)
162   setmetatable (r.cache, cache_metatable)
163   setmetatable (r.unsorted, { __mode = 'kv' })
164   return r
165   end
166
167
168 -- packet layer -------------------------------------------------- packet layer
169
170
171 function dns.random (...)    -- - - - - - - - - - - - - - - - - - -  dns.random
172   math.randomseed (10000*socket.gettime ())
173   dns.random = math.random
174   return dns.random (...)
175   end
176
177
178 local function encodeHeader (o)    -- - - - - - - - - - - - - - -  encodeHeader
179
180   o = o or {}
181
182   o.id = o.id or                -- 16b  (random) id
183     dns.random (0, 0xffff)
184
185   o.rd = o.rd or 1              --  1b  1 recursion desired
186   o.tc = o.tc or 0              --  1b  1 truncated response
187   o.aa = o.aa or 0              --  1b  1 authoritative response
188   o.opcode = o.opcode or 0      --  4b  0 query
189                                 --      1 inverse query
190                                 --      2 server status request
191                                 --      3-15 reserved
192   o.qr = o.qr or 0              --  1b  0 query, 1 response
193
194   o.rcode = o.rcode or 0        --  4b  0 no error
195                                 --      1 format error
196                                 --      2 server failure
197                                 --      3 name error
198                                 --      4 not implemented
199                                 --      5 refused
200                                 --      6-15 reserved
201   o.z  = o.z  or 0              --  3b  0 resvered
202   o.ra = o.ra or 0              --  1b  1 recursion available
203
204   o.qdcount = o.qdcount or 1    -- 16b  number of question RRs
205   o.ancount = o.ancount or 0    -- 16b  number of answers RRs
206   o.nscount = o.nscount or 0    -- 16b  number of nameservers RRs
207   o.arcount = o.arcount or 0    -- 16b  number of additional RRs
208
209   -- string.char() rounds, so prevent roundup with -0.4999
210   local header = string.char (
211     highbyte (o.id),  o.id %0x100,
212     o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
213     o.rcode + 16*o.z + 128*o.ra,
214     highbyte (o.qdcount),  o.qdcount %0x100,
215     highbyte (o.ancount),  o.ancount %0x100,
216     highbyte (o.nscount),  o.nscount %0x100,
217     highbyte (o.arcount),  o.arcount %0x100 )
218
219   return header, o.id
220   end
221
222
223 local function encodeName (name)    -- - - - - - - - - - - - - - - - encodeName
224   local t = {}
225   for part in string.gmatch (name, '[^.]+') do
226     append (t, string.char (string.len (part)))
227     append (t, part)
228     end
229   append (t, string.char (0))
230   return table.concat (t)
231   end
232
233
234 local function encodeQuestion (qname, qtype, qclass)    -- - - - encodeQuestion
235   qname  = encodeName (qname)
236   qtype  = dns.typecode[qtype or 'a']
237   qclass = dns.classcode[qclass or 'in']
238   return qname..qtype..qclass;
239   end
240
241
242 function resolver:byte (len)    -- - - - - - - - - - - - - - - - - - - - - byte
243   len = len or 1
244   local offset = self.offset
245   local last = offset + len - 1
246   if last > #self.packet then
247     error (string.format ('out of bounds: %i>%i', last, #self.packet))  end
248   self.offset = offset + len
249   return string.byte (self.packet, offset, last)
250   end
251
252
253 function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
254   local b1, b2 = self:byte (2)
255   return 0x100*b1 + b2
256   end
257
258
259 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
260   local b1, b2, b3, b4 = self:byte (4)
261   --print ('dword', b1, b2, b3, b4)
262   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
263   end
264
265
266 function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
267   len = len or 1
268   local s = string.sub (self.packet, self.offset, self.offset + len - 1)
269   self.offset = self.offset + len
270   return s
271   end
272
273
274 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
275
276   local id = self:word ()
277   --print (string.format (':header  id  %x', id))
278   if not self.active[id] and not force then  return nil  end
279
280   local h = { id = id }
281
282   local b1, b2 = self:byte (2)
283
284   h.rd      = b1 %2
285   h.tc      = b1 /2%2
286   h.aa      = b1 /4%2
287   h.opcode  = b1 /8%16
288   h.qr      = b1 /128
289
290   h.rcode   = b2 %16
291   h.z       = b2 /16%8
292   h.ra      = b2 /128
293
294   h.qdcount = self:word ()
295   h.ancount = self:word ()
296   h.nscount = self:word ()
297   h.arcount = self:word ()
298
299   for k,v in pairs (h) do  h[k] = v-v%1  end
300
301   return h
302   end
303
304
305 function resolver:name ()    -- - - - - - - - - - - - - - - - - - - - - -  name
306   local remember, pointers = nil, 0
307   local len = self:byte ()
308   local n = {}
309   while len > 0 do
310     if len >= 0xc0 then    -- name is "compressed"
311       pointers = pointers + 1
312       if pointers >= 20 then  error ('dns error: 20 pointers')  end
313       local offset = ((len-0xc0)*0x100) + self:byte ()
314       remember = remember or self.offset
315       self.offset = offset + 1    -- +1 for lua
316     else    -- name is not compressed
317       append (n, self:sub (len)..'.')
318       end
319     len = self:byte ()
320     end
321   self.offset = remember or self.offset
322   return table.concat (n)
323   end
324
325
326 function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
327   local q = {}
328   q.name  = self:name ()
329   q.type  = dns.type[self:word ()]
330   q.class = dns.class[self:word ()]
331   return q
332   end
333
334
335 function resolver:A (rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
336   local b1, b2, b3, b4 = self:byte (4)
337   rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
338   end
339
340
341 function resolver:CNAME (rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
342   rr.cname = self:name ()
343   end
344
345
346 function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
347   rr.pref = self:word ()
348   rr.mx   = self:name ()
349   end
350
351
352 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
353   local b = self:byte ()
354   --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
355   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
356   end
357
358
359 function resolver:LOC (rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
360   rr.version = self:byte ()
361   if rr.version == 0 then
362     rr.loc           = rr.loc or {}
363     rr.loc.size      = self:LOC_nibble_power ()
364     rr.loc.horiz_pre = self:LOC_nibble_power ()
365     rr.loc.vert_pre  = self:LOC_nibble_power ()
366     rr.loc.latitude  = self:dword ()
367     rr.loc.longitude = self:dword ()
368     rr.loc.altitude  = self:dword ()
369     end  end
370
371
372 local function LOC_tostring_degrees (f, pos, neg)    -- - - - - - - - - - - - -
373   f = f - 0x80000000
374   if f < 0 then  pos = neg  f = -f  end
375   local deg, min, msec
376   msec = f%60000
377   f    = (f-msec)/60000
378   min  = f%60
379   deg = (f-min)/60
380   return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
381   end
382
383
384 function resolver.LOC_tostring (rr)    -- - - - - - - - - - - - -  LOC_tostring
385
386   local t = {}
387
388   --[[
389   for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
390                         'latitude', 'longitude', 'altitude' } do
391     append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
392     end
393   --]]
394
395   append ( t, string.format (
396     '%s    %s    %.2fm %.2fm %.2fm %.2fm',
397     LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
398     LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
399     (rr.loc.altitude - 10000000) / 100,
400     rr.loc.size / 100,
401     rr.loc.horiz_pre / 100,
402     rr.loc.vert_pre / 100 ) )
403
404   return table.concat (t)
405   end
406
407
408 function resolver:NS (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
409   rr.ns = self:name ()
410   end
411
412
413 function resolver:SOA (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
414   end
415
416
417 function resolver:SRV (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
418   rr.srv = {}
419   rr.srv.priority = self:word ()
420   rr.srv.weight   = self:word ()
421   rr.srv.port     = self:word ()
422   rr.srv.target   = self:name ()
423   end
424
425
426 function SRV_tostring (rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
427   local s = rr.srv
428   return string.format ( '%5d %5d %5d %s',
429                          s.priority, s.weight, s.port, s.target )
430   end
431
432
433 function resolver:TXT (rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
434   rr.txt = self:sub (rr.rdlength)
435   end
436
437
438 function resolver:rr ()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
439   local rr = {}
440   setmetatable (rr, rr_metatable)
441   rr.name     = self:name (self)
442   rr.type     = dns.type[self:word ()] or rr.type
443   rr.class    = dns.class[self:word ()] or rr.class
444   rr.ttl      = 0x10000*self:word () + self:word ()
445   rr.rdlength = self:word ()
446
447   if rr.ttl <= 0 then rr.tod = self.time + 30;
448   else  rr.tod = self.time + rr.ttl  end
449
450   local remember = self.offset
451   local rr_parser = self[dns.type[rr.type]]
452   if rr_parser then  rr_parser (self, rr)  end
453   self.offset = remember
454   rr.rdata = self:sub (rr.rdlength)
455   return rr
456   end
457
458
459 function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
460   local rrs = {}
461   for i = 1,count do  append (rrs, self:rr ())  end
462   return rrs
463   end
464
465
466 function resolver:decode (packet, force)    -- - - - - - - - - - - - - - decode
467
468   self.packet, self.offset = packet, 1
469   local header = self:header (force)
470   if not header then  return nil  end
471   local response = { header = header }
472
473   response.question = {}
474   local offset = self.offset
475   for i = 1,response.header.qdcount do
476     append (response.question, self:question ())  end
477   response.question.raw = string.sub (self.packet, offset, self.offset - 1)
478
479   if not force then
480     if not self.active[response.header.id] or
481        not self.active[response.header.id][response.question.raw] then
482       return nil  end  end
483
484   response.answer     = self:rrs (response.header.ancount)
485   response.authority  = self:rrs (response.header.nscount)
486   response.additional = self:rrs (response.header.arcount)
487
488   return response
489   end
490
491
492 -- socket layer -------------------------------------------------- socket layer
493
494
495 resolver.delays = { 1, 3  }
496
497
498 function resolver:addnameserver (address)    -- - - - - - - - - - addnameserver
499   self.server = self.server or {}
500   append (self.server, address)
501   end
502
503
504 function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
505   self.server = {}
506   self:addnameserver (address)
507   end
508
509
510 function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
511   if is_windows then
512     if windows then
513       for _, server in ipairs(windows.get_nameservers()) do
514         self:addnameserver(server)
515       end
516     end
517     if not self.server or #self.server == 0 then
518       -- TODO log warning about no nameservers, adding opendns servers as fallback
519       self:addnameserver("208.67.222.222")
520       self:addnameserver("208.67.220.220")      
521     end
522   else -- posix
523     local resolv_conf = io.open("/etc/resolv.conf");
524     if resolv_conf then
525       for line in resolv_conf:lines() do
526         local address = line:gsub("#.*$", ""):match('^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$')
527         if address then self:addnameserver (address)  end
528       end
529     end
530     if not self.server or #self.server == 0 then
531       -- TODO log warning about no nameservers, adding localhost as the default nameserver
532       self:addnameserver("127.0.0.1");
533     end
534   end
535 end
536
537
538 function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket
539
540   self.socket = self.socket or {}
541   self.socketset = self.socketset or {}
542
543   local sock = self.socket[servernum]
544   if sock then  return sock  end
545
546   sock = socket.udp ()
547   if self.socket_wrapper then  sock = self.socket_wrapper (sock, self)  end
548   sock:settimeout (0)
549   -- todo: attempt to use a random port, fallback to 0
550   sock:setsockname ('*', 0)
551   sock:setpeername (self.server[servernum], 53)
552   self.socket[servernum] = sock
553   self.socketset[sock] = servernum
554   return sock
555   end
556
557 function resolver:voidsocket (sock)
558   if self.socket[sock] then
559     self.socketset[self.socket[sock]] = nil
560     self.socket[sock] = nil
561   elseif self.socketset[sock] then
562     self.socket[self.socketset[sock]] = nil
563     self.socketset[sock] = nil
564   end
565 end
566
567 function resolver:socket_wrapper_set (func)  -- - - - - - - socket_wrapper_set
568   self.socket_wrapper = func
569   end
570
571
572 function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
573   for i,sock in ipairs (self.socket) do
574     self.socket[i] = nil;
575     self.socketset[sock] = nil;
576     sock:close();
577     end
578   end
579
580
581 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
582
583   --print ('remember', type, rr.class, rr.type, rr.name)
584
585   if type ~= '*' then
586     type = rr.type
587     local all = get (self.cache, rr.class, '*', rr.name)
588     --print ('remember all', all)
589     if all then  append (all, rr)  end
590     end
591
592   self.cache = self.cache or setmetatable ({}, cache_metatable)
593   local rrs = get (self.cache, rr.class, type, rr.name) or
594     set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
595   append (rrs, rr)
596
597   if type == 'MX' then  self.unsorted[rrs] = true  end
598   end
599
600
601 local function comp_mx (a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
602   return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
603   end
604
605
606 function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
607   qname, qtype, qclass = standardize (qname, qtype, qclass)
608   local rrs = get (self.cache, qclass, qtype, qname)
609   if not rrs then  return nil  end
610   if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
611     set (self.cache, qclass, qtype, qname, nil)  return nil  end
612   if self.unsorted[rrs] then  table.sort (rrs, comp_mx)  end
613   return rrs
614   end
615
616
617 function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
618   if soft == 'soft' then
619     self.time = socket.gettime ()
620     for class,types in pairs (self.cache or {}) do
621       for type,names in pairs (types) do
622         for name,rrs in pairs (names) do
623           prune (rrs, self.time, 'soft')
624           end  end  end
625   else  self.cache = {}  end
626   end
627
628
629 function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
630
631   qname, qtype, qclass = standardize (qname, qtype, qclass)
632
633   if not self.server then self:adddefaultnameservers ()  end
634
635   local question = encodeQuestion (qname, qtype, qclass)
636   local peek = self:peek (qname, qtype, qclass)
637   if peek then  return peek  end
638
639   local header, id = encodeHeader ()
640   --print ('query  id', id, qclass, qtype, qname)
641   local o = { packet = header..question,
642               server = self.best_server,
643               delay  = 1,
644               retry  = socket.gettime () + self.delays[1] }
645
646   -- remember the query
647   self.active[id] = self.active[id] or {}
648   self.active[id][question] = o
649
650   -- remember which coroutine wants the answer
651   local co = coroutine.running ()
652   if co then
653     set (self.wanted, qclass, qtype, qname, co, true)
654     --set (self.yielded, co, qclass, qtype, qname, true)
655   end
656   
657   self:getsocket (o.server):send (o.packet)
658
659 end
660
661 function resolver:servfail(sock)
662   -- Resend all queries for this server
663   
664   local num = self.socketset[sock]
665   
666   -- Socket is dead now
667   self:voidsocket(sock);
668   
669   -- Find all requests to the down server, and retry on the next server
670   self.time = socket.gettime ()
671   for id,queries in pairs (self.active) do
672     for question,o in pairs (queries) do
673       if o.server == num then -- This request was to the broken server
674         o.server = o.server + 1 -- Use next server
675         if o.server > #self.server then
676           o.server = 1
677         end
678
679         o.retries = (o.retries or 0) + 1;
680         if o.retries >= #self.server then
681           --print ('timeout')
682           queries[question] = nil
683         else
684           local _a = self:getsocket(o.server);
685           if _a then _a:send (o.packet) end
686         end
687       end
688     end
689   end
690    
691    if num == self.best_server then
692         self.best_server = self.best_server + 1
693         if self.best_server > #self.server then
694                 -- Exhausted all servers, try first again
695                 self.best_server = 1
696         end
697    end
698 end
699
700 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
701
702   --print 'receive'  print (self.socket)
703   self.time = socket.gettime ()
704   rset = rset or self.socket
705
706   local response
707   for i,sock in pairs (rset) do
708
709     if self.socketset[sock] then
710     local packet = sock:receive ()
711     if packet then
712
713     response = self:decode (packet)
714     if response then
715     --print 'received response'
716     --self.print (response)
717
718     for i,section in pairs { 'answer', 'authority', 'additional' } do
719       for j,rr in pairs (response[section]) do
720         self:remember (rr, response.question[1].type)  end  end
721
722     -- retire the query
723     local queries = self.active[response.header.id]
724     if queries[response.question.raw] then
725       queries[response.question.raw] = nil  end
726     if not next (queries) then  self.active[response.header.id] = nil  end
727     if not next (self.active) then  self:closeall ()  end
728
729     -- was the query on the wanted list?
730     local q = response.question
731     local cos = get (self.wanted, q.class, q.type, q.name)
732     if cos then
733       for co in pairs (cos) do
734         set (self.yielded, co, q.class, q.type, q.name, nil)
735         if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
736         end
737       set (self.wanted, q.class, q.type, q.name, nil)
738       end  end  end  end  end
739
740   return response
741   end
742
743
744 function resolver:feed(sock, packet)
745   --print 'receive'  print (self.socket)
746   self.time = socket.gettime ()
747
748   local response = self:decode (packet)
749   if response then
750     --print 'received response'
751     --self.print (response)
752
753     for i,section in pairs { 'answer', 'authority', 'additional' } do
754       for j,rr in pairs (response[section]) do
755         self:remember (rr, response.question[1].type)
756       end
757     end
758
759     -- retire the query
760     local queries = self.active[response.header.id]
761     if queries[response.question.raw] then
762       queries[response.question.raw] = nil
763     end
764     if not next (queries) then  self.active[response.header.id] = nil  end
765     if not next (self.active) then  self:closeall ()  end
766
767     -- was the query on the wanted list?
768     local q = response.question[1]
769     if q then
770       local cos = get (self.wanted, q.class, q.type, q.name)
771       if cos then
772         for co in pairs (cos) do
773           set (self.yielded, co, q.class, q.type, q.name, nil)
774           if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
775         end
776         set (self.wanted, q.class, q.type, q.name, nil)
777       end
778     end
779   end 
780
781   return response
782 end
783
784 function resolver:cancel(data)
785         local cos = get (self.wanted, unpack(data, 1, 3))
786         if cos then
787                 cos[data[4]] = nil;
788         end
789 end
790
791 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
792
793   --print ':pulse'
794   while self:receive() do end
795   if not next (self.active) then  return nil  end
796
797   self.time = socket.gettime ()
798   for id,queries in pairs (self.active) do
799     for question,o in pairs (queries) do
800       if self.time >= o.retry then
801
802         o.server = o.server + 1
803         if o.server > #self.server then
804           o.server = 1
805           o.delay = o.delay + 1
806           end
807
808         if o.delay > #self.delays then
809           --print ('timeout')
810           queries[question] = nil
811           if not next (queries) then  self.active[id] = nil  end
812           if not next (self.active) then  return nil  end
813         else
814           --print ('retry', o.server, o.delay)
815           local _a = self.socket[o.server];
816           if _a then _a:send (o.packet) end
817           o.retry = self.time + self.delays[o.delay]
818           end  end  end  end
819
820   if next (self.active) then  return true  end
821   return nil
822   end
823
824
825 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
826   self:query (qname, qtype, qclass)
827   while self:pulse () do  socket.select (self.socket, nil, 4)  end
828   --print (self.cache)
829   return self:peek (qname, qtype, qclass)
830   end
831
832 function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
833   return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
834   end
835
836
837 --print ---------------------------------------------------------------- print
838
839
840 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
841   qr = { [0]='query', 'response' },
842   opcode = { [0]='query', 'inverse query', 'server status request' },
843   aa = { [0]='non-authoritative', 'authoritative' },
844   tc = { [0]='complete', 'truncated' },
845   rd = { [0]='recursion not desired', 'recursion desired' },
846   ra = { [0]='recursion not available', 'recursion available' },
847   z  = { [0]='(reserved)' },
848   rcode = { [0]='no error', 'format error', 'server failure', 'name error',
849             'not implemented' },
850
851   type = dns.type,
852   class = dns.class, }
853
854
855 local function hint (p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
856   return (hints[s] and hints[s][p[s]]) or ''  end
857
858
859 function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
860
861   for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
862                      'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
863     print ( string.format ('%-30s', 'header.'..s),
864             response.header[s], hint (response.header, s) )
865     end
866
867   for i,question in ipairs (response.question) do
868     print (string.format ('question[%i].name         ', i), question.name)
869     print (string.format ('question[%i].type         ', i), question.type)
870     print (string.format ('question[%i].class        ', i), question.class)
871     end
872
873   local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
874   local tmp
875   for s,s in pairs {'answer', 'authority', 'additional'} do
876     for i,rr in pairs (response[s]) do
877       for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
878         tmp = string.format ('%s[%i].%s', s, i, t)
879         print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
880         end
881       for j,t in pairs (rr) do
882         if not common[j] then
883           tmp = string.format ('%s[%i].%s', s, i, j)
884           print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
885           end  end  end  end  end
886
887
888 -- module api ------------------------------------------------------ module api
889
890
891 local function resolve (func, ...)    -- - - - - - - - - - - - - - resolver_get
892   return func (dns._resolver, ...)
893   end
894
895
896 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
897
898   -- this function seems to be redundant with resolver.new ()
899
900   local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, 
901               best_server = 1 }
902   setmetatable (r, resolver)
903   setmetatable (r.cache, cache_metatable)
904   setmetatable (r.unsorted, { __mode = 'kv' })
905   return r
906   end
907
908
909 function dns.lookup (...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
910   return resolve (resolver.lookup, ...)  end
911
912
913 function dns.purge (...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
914   return resolve (resolver.purge, ...)  end
915
916 function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
917   return resolve (resolver.peek, ...)  end
918
919
920 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
921   return resolve (resolver.query, ...)  end
922
923 function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
924   return resolve (resolver.feed, ...)  end
925
926 function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
927   return resolve(resolver.cancel, ...) end
928
929 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
930   return resolve (resolver.socket_wrapper_set, ...)  end
931
932 dns._resolver = dns.resolver ()
933
934 return dns