Merge 0.10->trunk
[prosody.git] / net / dns.lua
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5
6 -- todo: quick (default) header generation
7 -- todo: nxdomain, error handling
8 -- todo: cache results of encodeName
9
10
11 -- reference: http://tools.ietf.org/html/rfc1035
12 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
13
14
15 local socket = require "socket";
16 local timer = require "util.timer";
17 local new_ip = require "util.ip".new_ip;
18
19 local _, windows = pcall(require, "util.windows");
20 local is_windows = (_ and windows) or os.getenv("WINDIR");
21
22 local coroutine, io, math, string, table =
23       coroutine, io, math, string, table;
24
25 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
26       ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
27
28 local ztact = { -- public domain 20080404 lua@ztact.com
29         get = function(parent, ...)
30                 local len = select('#', ...);
31                 for i=1,len do
32                         parent = parent[select(i, ...)];
33                         if parent == nil then break; end
34                 end
35                 return parent;
36         end;
37         set = function(parent, ...)
38                 local len = select('#', ...);
39                 local key, value = select(len-1, ...);
40                 local cutpoint, cutkey;
41
42                 for i=1,len-2 do
43                         local key = select (i, ...)
44                         local child = parent[key]
45
46                         if value == nil then
47                                 if child == nil then
48                                         return;
49                                 elseif next(child, next(child)) then
50                                         cutpoint = nil; cutkey = nil;
51                                 elseif cutpoint == nil then
52                                         cutpoint = parent; cutkey = key;
53                                 end
54                         elseif child == nil then
55                                 child = {};
56                                 parent[key] = child;
57                         end
58                         parent = child
59                 end
60
61                 if value == nil and cutpoint then
62                         cutpoint[cutkey] = nil;
63                 else
64                         parent[key] = value;
65                         return value;
66                 end
67         end;
68 };
69 local get, set = ztact.get, ztact.set;
70
71 local default_timeout = 15;
72
73 -------------------------------------------------- module dns
74 module('dns')
75 local dns = _M;
76
77
78 -- dns type & class codes ------------------------------ dns type & class codes
79
80
81 local append = table.insert
82
83
84 local function highbyte(i)    -- - - - - - - - - - - - - - - - - - -  highbyte
85         return (i-(i%0x100))/0x100;
86 end
87
88
89 local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
90         local a = {};
91         for i,s in pairs(t) do
92                 a[i] = s;
93                 a[s] = s;
94                 a[string.lower(s)] = s;
95         end
96         return a;
97 end
98
99
100 local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
101         local code = {};
102         for i,s in pairs(t) do
103                 local word = string.char(highbyte(i), i%0x100);
104                 code[i] = word;
105                 code[s] = word;
106                 code[string.lower(s)] = word;
107         end
108         return code;
109 end
110
111
112 dns.types = {
113         'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
114         'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
115         [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
116         [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' };
117
118
119 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' };
120
121
122 dns.type      = augment (dns.types);
123 dns.class     = augment (dns.classes);
124 dns.typecode  = encode  (dns.types);
125 dns.classcode = encode  (dns.classes);
126
127
128
129 local function standardize(qname, qtype, qclass)    -- - - - - - - standardize
130         if string.byte(qname, -1) ~= 0x2E then qname = qname..'.';  end
131         qname = string.lower(qname);
132         return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'];
133 end
134
135
136 local function prune(rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
137         time = time or socket.gettime();
138         for i,rr in ipairs(rrs) do
139                 if rr.tod then
140                         -- rr.tod = rr.tod - 50    -- accelerated decripitude
141                         rr.ttl = math.floor(rr.tod - time);
142                         if rr.ttl <= 0 then
143                                 rrs[rr[rr.type:lower()]] = nil;
144                                 table.remove(rrs, i);
145                                 return prune(rrs, time, soft); -- Re-iterate
146                         end
147                 elseif soft == 'soft' then    -- What is this?  I forget!
148                         assert(rr.ttl == 0);
149                         rrs[rr[rr.type:lower()]] = nil;
150                         table.remove(rrs, i);
151                 end
152         end
153 end
154
155
156 -- metatables & co. ------------------------------------------ metatables & co.
157
158
159 local resolver = {};
160 resolver.__index = resolver;
161
162 resolver.timeout = default_timeout;
163
164 local function default_rr_tostring(rr)
165         local rr_val = rr.type and rr[rr.type:lower()];
166         if type(rr_val) ~= "string" then
167                 return "<UNKNOWN RDATA TYPE>";
168         end
169         return rr_val;
170 end
171
172 local special_tostrings = {
173         LOC = resolver.LOC_tostring;
174         MX  = function (rr)
175                 return string.format('%2i %s', rr.pref, rr.mx);
176         end;
177         SRV = function (rr)
178                 local s = rr.srv;
179                 return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
180         end;
181 };
182
183 local rr_metatable = {};   -- - - - - - - - - - - - - - - - - - -  rr_metatable
184 function rr_metatable.__tostring(rr)
185         local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr);
186         return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string);
187 end
188
189
190 local rrs_metatable = {};    -- - - - - - - - - - - - - - - - - -  rrs_metatable
191 function rrs_metatable.__tostring(rrs)
192         local t = {};
193         for i,rr in ipairs(rrs) do
194                 append(t, tostring(rr)..'\n');
195         end
196         return table.concat(t);
197 end
198
199
200 local cache_metatable = {};    -- - - - - - - - - - - - - - - -  cache_metatable
201 function cache_metatable.__tostring(cache)
202         local time = socket.gettime();
203         local t = {};
204         for class,types in pairs(cache) do
205                 for type,names in pairs(types) do
206                         for name,rrs in pairs(names) do
207                                 prune(rrs, time);
208                                 append(t, tostring(rrs));
209                         end
210                 end
211         end
212         return table.concat(t);
213 end
214
215
216 function resolver:new()    -- - - - - - - - - - - - - - - - - - - - - resolver
217         local r = { active = {}, cache = {}, unsorted = {} };
218         setmetatable(r, resolver);
219         setmetatable(r.cache, cache_metatable);
220         setmetatable(r.unsorted, { __mode = 'kv' });
221         return r;
222 end
223
224
225 -- packet layer -------------------------------------------------- packet layer
226
227
228 function dns.random(...)    -- - - - - - - - - - - - - - - - - - -  dns.random
229         math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000);
230         dns.random = math.random;
231         return dns.random(...);
232 end
233
234
235 local function encodeHeader(o)    -- - - - - - - - - - - - - - -  encodeHeader
236         o = o or {};
237         o.id = o.id or dns.random(0, 0xffff); -- 16b    (random) id
238
239         o.rd = o.rd or 1;               --  1b  1 recursion desired
240         o.tc = o.tc or 0;               --  1b  1 truncated response
241         o.aa = o.aa or 0;               --  1b  1 authoritative response
242         o.opcode = o.opcode or 0;       --  4b  0 query
243                                 --  1 inverse query
244                                 --      2 server status request
245                                 --      3-15 reserved
246         o.qr = o.qr or 0;               --  1b  0 query, 1 response
247
248         o.rcode = o.rcode or 0; --  4b  0 no error
249                                 --      1 format error
250                                 --      2 server failure
251                                 --      3 name error
252                                 --      4 not implemented
253                                 --      5 refused
254                                 --      6-15 reserved
255         o.z = o.z  or 0;                --  3b  0 resvered
256         o.ra = o.ra or 0;               --  1b  1 recursion available
257
258         o.qdcount = o.qdcount or 1;     -- 16b  number of question RRs
259         o.ancount = o.ancount or 0;     -- 16b  number of answers RRs
260         o.nscount = o.nscount or 0;     -- 16b  number of nameservers RRs
261         o.arcount = o.arcount or 0;     -- 16b  number of additional RRs
262
263         -- string.char() rounds, so prevent roundup with -0.4999
264         local header = string.char(
265                 highbyte(o.id), o.id %0x100,
266                 o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
267                 o.rcode + 16*o.z + 128*o.ra,
268                 highbyte(o.qdcount),  o.qdcount %0x100,
269                 highbyte(o.ancount),  o.ancount %0x100,
270                 highbyte(o.nscount),  o.nscount %0x100,
271                 highbyte(o.arcount),  o.arcount %0x100
272         );
273
274         return header, o.id;
275 end
276
277
278 local function encodeName(name)    -- - - - - - - - - - - - - - - - encodeName
279         local t = {};
280         for part in string.gmatch(name, '[^.]+') do
281                 append(t, string.char(string.len(part)));
282                 append(t, part);
283         end
284         append(t, string.char(0));
285         return table.concat(t);
286 end
287
288
289 local function encodeQuestion(qname, qtype, qclass)    -- - - - encodeQuestion
290         qname  = encodeName(qname);
291         qtype  = dns.typecode[qtype or 'a'];
292         qclass = dns.classcode[qclass or 'in'];
293         return qname..qtype..qclass;
294 end
295
296
297 function resolver:byte(len)    -- - - - - - - - - - - - - - - - - - - - - byte
298         len = len or 1;
299         local offset = self.offset;
300         local last = offset + len - 1;
301         if last > #self.packet then
302                 error(string.format('out of bounds: %i>%i', last, #self.packet));
303         end
304         self.offset = offset + len;
305         return string.byte(self.packet, offset, last);
306 end
307
308
309 function resolver:word()    -- - - - - - - - - - - - - - - - - - - - - -  word
310         local b1, b2 = self:byte(2);
311         return 0x100*b1 + b2;
312 end
313
314
315 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
316         local b1, b2, b3, b4 = self:byte(4);
317         --print('dword', b1, b2, b3, b4);
318         return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4;
319 end
320
321
322 function resolver:sub(len)    -- - - - - - - - - - - - - - - - - - - - - - sub
323         len = len or 1;
324         local s = string.sub(self.packet, self.offset, self.offset + len - 1);
325         self.offset = self.offset + len;
326         return s;
327 end
328
329
330 function resolver:header(force)    -- - - - - - - - - - - - - - - - - - header
331         local id = self:word();
332         --print(string.format(':header  id  %x', id));
333         if not self.active[id] and not force then return nil; end
334
335         local h = { id = id };
336
337         local b1, b2 = self:byte(2);
338
339         h.rd      = b1 %2;
340         h.tc      = b1 /2%2;
341         h.aa      = b1 /4%2;
342         h.opcode  = b1 /8%16;
343         h.qr      = b1 /128;
344
345         h.rcode   = b2 %16;
346         h.z       = b2 /16%8;
347         h.ra      = b2 /128;
348
349         h.qdcount = self:word();
350         h.ancount = self:word();
351         h.nscount = self:word();
352         h.arcount = self:word();
353
354         for k,v in pairs(h) do h[k] = v-v%1; end
355
356         return h;
357 end
358
359
360 function resolver:name()    -- - - - - - - - - - - - - - - - - - - - - -  name
361         local remember, pointers = nil, 0;
362         local len = self:byte();
363         local n = {};
364         if len == 0 then return "." end -- Root label
365         while len > 0 do
366                 if len >= 0xc0 then    -- name is "compressed"
367                         pointers = pointers + 1;
368                         if pointers >= 20 then error('dns error: 20 pointers'); end;
369                         local offset = ((len-0xc0)*0x100) + self:byte();
370                         remember = remember or self.offset;
371                         self.offset = offset + 1;    -- +1 for lua
372                 else    -- name is not compressed
373                         append(n, self:sub(len)..'.');
374                 end
375                 len = self:byte();
376         end
377         self.offset = remember or self.offset;
378         return table.concat(n);
379 end
380
381
382 function resolver:question()    -- - - - - - - - - - - - - - - - - -  question
383         local q = {};
384         q.name  = self:name();
385         q.type  = dns.type[self:word()];
386         q.class = dns.class[self:word()];
387         return q;
388 end
389
390
391 function resolver:A(rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
392         local b1, b2, b3, b4 = self:byte(4);
393         rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
394 end
395
396 function resolver:AAAA(rr)
397         local addr = {};
398         for i = 1, rr.rdlength, 2 do
399                 local b1, b2 = self:byte(2);
400                 table.insert(addr, ("%02x%02x"):format(b1, b2));
401         end
402         addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1");
403         local zeros = {};
404         for item in addr:gmatch(":[0:]+:") do
405                 table.insert(zeros, item)
406         end
407         if #zeros == 0 then
408                 rr.aaaa = addr;
409                 return
410         elseif #zeros > 1 then
411                 table.sort(zeros, function(a, b) return #a > #b end);
412         end
413         rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
414 end
415
416 function resolver:CNAME(rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
417         rr.cname = self:name();
418 end
419
420
421 function resolver:MX(rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
422         rr.pref = self:word();
423         rr.mx   = self:name();
424 end
425
426
427 function resolver:LOC_nibble_power()    -- - - - - - - - - -  LOC_nibble_power
428         local b = self:byte();
429         --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10));
430         return ((b-(b%0x10))/0x10) * (10^(b%0x10));
431 end
432
433
434 function resolver:LOC(rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
435         rr.version = self:byte();
436         if rr.version == 0 then
437                 rr.loc           = rr.loc or {};
438                 rr.loc.size      = self:LOC_nibble_power();
439                 rr.loc.horiz_pre = self:LOC_nibble_power();
440                 rr.loc.vert_pre  = self:LOC_nibble_power();
441                 rr.loc.latitude  = self:dword();
442                 rr.loc.longitude = self:dword();
443                 rr.loc.altitude  = self:dword();
444         end
445 end
446
447
448 local function LOC_tostring_degrees(f, pos, neg)    -- - - - - - - - - - - - -
449         f = f - 0x80000000;
450         if f < 0 then pos = neg; f = -f; end
451         local deg, min, msec;
452         msec = f%60000;
453         f    = (f-msec)/60000;
454         min  = f%60;
455         deg = (f-min)/60;
456         return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos);
457 end
458
459
460 function resolver.LOC_tostring(rr)    -- - - - - - - - - - - - -  LOC_tostring
461         local t = {};
462
463         --[[
464         for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do
465                 append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name]));
466         end
467         --]]
468
469         append(t, string.format(
470                 '%s    %s    %.2fm %.2fm %.2fm %.2fm',
471                 LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
472                 LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
473                 (rr.loc.altitude - 10000000) / 100,
474                 rr.loc.size / 100,
475                 rr.loc.horiz_pre / 100,
476                 rr.loc.vert_pre / 100
477         ));
478
479         return table.concat(t);
480 end
481
482
483 function resolver:NS(rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
484         rr.ns = self:name();
485 end
486
487
488 function resolver:SOA(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
489 end
490
491
492 function resolver:SRV(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
493           rr.srv = {};
494           rr.srv.priority = self:word();
495           rr.srv.weight   = self:word();
496           rr.srv.port     = self:word();
497           rr.srv.target   = self:name();
498 end
499
500 function resolver:PTR(rr)
501         rr.ptr = self:name();
502 end
503
504 function resolver:TXT(rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
505         rr.txt = self:sub (self:byte());
506 end
507
508
509 function resolver:rr()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
510         local rr = {};
511         setmetatable(rr, rr_metatable);
512         rr.name     = self:name(self);
513         rr.type     = dns.type[self:word()] or rr.type;
514         rr.class    = dns.class[self:word()] or rr.class;
515         rr.ttl      = 0x10000*self:word() + self:word();
516         rr.rdlength = self:word();
517
518         if rr.ttl <= 0 then
519                 rr.tod = self.time + 30;
520         else
521                 rr.tod = self.time + rr.ttl;
522         end
523
524         local remember = self.offset;
525         local rr_parser = self[dns.type[rr.type]];
526         if rr_parser then rr_parser(self, rr); end
527         self.offset = remember;
528         rr.rdata = self:sub(rr.rdlength);
529         return rr;
530 end
531
532
533 function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
534         local rrs = {};
535         for i = 1,count do append(rrs, self:rr()); end
536         return rrs;
537 end
538
539
540 function resolver:decode(packet, force)    -- - - - - - - - - - - - - - decode
541         self.packet, self.offset = packet, 1;
542         local header = self:header(force);
543         if not header then return nil; end
544         local response = { header = header };
545
546         response.question = {};
547         local offset = self.offset;
548         for i = 1,response.header.qdcount do
549                 append(response.question, self:question());
550         end
551         response.question.raw = string.sub(self.packet, offset, self.offset - 1);
552
553         if not force then
554                 if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then
555                         self.active[response.header.id] = nil;
556                         return nil;
557                 end
558         end
559
560         response.answer     = self:rrs(response.header.ancount);
561         response.authority  = self:rrs(response.header.nscount);
562         response.additional = self:rrs(response.header.arcount);
563
564         return response;
565 end
566
567
568 -- socket layer -------------------------------------------------- socket layer
569
570
571 resolver.delays = { 1, 3 };
572
573
574 function resolver:addnameserver(address)    -- - - - - - - - - - addnameserver
575         self.server = self.server or {};
576         append(self.server, address);
577 end
578
579
580 function resolver:setnameserver(address)    -- - - - - - - - - - setnameserver
581         self.server = {};
582         self:addnameserver(address);
583 end
584
585
586 function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
587         if is_windows then
588                 if windows and windows.get_nameservers then
589                         for _, server in ipairs(windows.get_nameservers()) do
590                                 self:addnameserver(server);
591                         end
592                 end
593                 if not self.server or #self.server == 0 then
594                         -- TODO log warning about no nameservers, adding opendns servers as fallback
595                         self:addnameserver("208.67.222.222");
596                         self:addnameserver("208.67.220.220");
597                 end
598         else -- posix
599                 local resolv_conf = io.open("/etc/resolv.conf");
600                 if resolv_conf then
601                         for line in resolv_conf:lines() do
602                                 line = line:gsub("#.*$", "")
603                                         :match('^%s*nameserver%s+([%x:%.]*)%s*$');
604                                 if line then
605                                         local ip = new_ip(line);
606                                         if ip then
607                                                 self:addnameserver(ip.addr);
608                                         end
609                                 end
610                         end
611                 end
612                 if not self.server or #self.server == 0 then
613                         -- TODO log warning about no nameservers, adding localhost as the default nameserver
614                         self:addnameserver("127.0.0.1");
615                 end
616         end
617 end
618
619
620 function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
621         self.socket = self.socket or {};
622         self.socketset = self.socketset or {};
623
624         local sock = self.socket[servernum];
625         if sock then return sock; end
626
627         local ok, err;
628         local peer = self.server[servernum];
629         if peer:find(":") then
630                 sock, err = socket.udp6();
631         else
632                 sock, err = socket.udp();
633         end
634         if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end
635         if not sock then
636                 return nil, err;
637         end
638         sock:settimeout(0);
639         -- todo: attempt to use a random port, fallback to 0
640         self.socket[servernum] = sock;
641         self.socketset[sock] = servernum;
642         -- set{sock,peer}name can fail, eg because of local routing table
643         -- if so, try the next server
644         ok, err = sock:setsockname('*', 0);
645         if not ok then return self:servfail(sock, err); end
646         ok, err = sock:setpeername(peer, 53);
647         if not ok then return self:servfail(sock, err); end
648         return sock;
649 end
650
651 function resolver:voidsocket(sock)
652         if self.socket[sock] then
653                 self.socketset[self.socket[sock]] = nil;
654                 self.socket[sock] = nil;
655         elseif self.socketset[sock] then
656                 self.socket[self.socketset[sock]] = nil;
657                 self.socketset[sock] = nil;
658         end
659         sock:close();
660 end
661
662 function resolver:socket_wrapper_set(func)  -- - - - - - - socket_wrapper_set
663         self.socket_wrapper = func;
664 end
665
666
667 function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
668         for i,sock in ipairs(self.socket) do
669                 self.socket[i] = nil;
670                 self.socketset[sock] = nil;
671                 sock:close();
672         end
673 end
674
675
676 function resolver:remember(rr, type)    -- - - - - - - - - - - - - -  remember
677         --print ('remember', type, rr.class, rr.type, rr.name)
678         local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class);
679
680         if type ~= '*' then
681                 type = qtype;
682                 local all = get(self.cache, qclass, '*', qname);
683                 --print('remember all', all);
684                 if all then append(all, rr); end
685         end
686
687         self.cache = self.cache or setmetatable({}, cache_metatable);
688         local rrs = get(self.cache, qclass, type, qname) or
689                 set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
690         if not rrs[rr[qtype:lower()]] then
691                 rrs[rr[qtype:lower()]] = true;
692                 append(rrs, rr);
693         end
694
695         if type == 'MX' then self.unsorted[rrs] = true; end
696 end
697
698
699 local function comp_mx(a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
700         return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref);
701 end
702
703
704 function resolver:peek (qname, qtype, qclass, n)    -- - - - - - - - - - - -  peek
705         qname, qtype, qclass = standardize(qname, qtype, qclass);
706         local rrs = get(self.cache, qclass, qtype, qname);
707         if not rrs then
708                 if n then if n <= 0 then return end else n = 3 end
709                 rrs = get(self.cache, qclass, "CNAME", qname);
710                 if not (rrs and rrs[1]) then return end
711                 return self:peek(rrs[1].cname, qtype, qclass, n - 1);
712         end
713         if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then
714                 set(self.cache, qclass, qtype, qname, nil);
715                 return nil;
716         end
717         if self.unsorted[rrs] then table.sort (rrs, comp_mx); self.unsorted[rrs] = nil; end
718         return rrs;
719 end
720
721
722 function resolver:purge(soft)    -- - - - - - - - - - - - - - - - - - -  purge
723         if soft == 'soft' then
724                 self.time = socket.gettime();
725                 for class,types in pairs(self.cache or {}) do
726                         for type,names in pairs(types) do
727                                 for name,rrs in pairs(names) do
728                                         prune(rrs, self.time, 'soft')
729                                 end
730                         end
731                 end
732         else self.cache = setmetatable({}, cache_metatable); end
733 end
734
735
736 function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
737         qname, qtype, qclass = standardize(qname, qtype, qclass)
738
739         local co = coroutine.running();
740         local q = get(self.wanted, qclass, qtype, qname);
741         if co and q then
742                 -- We are already waiting for a reply to an identical query.
743                 set(self.wanted, qclass, qtype, qname, co, true);
744                 return true;
745         end
746
747         if not self.server then self:adddefaultnameservers(); end
748
749         local question = encodeQuestion(qname, qtype, qclass);
750         local peek = self:peek (qname, qtype, qclass);
751         if peek then return peek; end
752
753         local header, id = encodeHeader();
754         --print ('query  id', id, qclass, qtype, qname)
755         local o = {
756                 packet = header..question,
757                 server = self.best_server,
758                 delay  = 1,
759                 retry  = socket.gettime() + self.delays[1]
760         };
761
762         -- remember the query
763         self.active[id] = self.active[id] or {};
764         self.active[id][question] = o;
765
766         -- remember which coroutine wants the answer
767         if co then
768                 set(self.wanted, qclass, qtype, qname, co, true);
769         end
770
771         local conn, err = self:getsocket(o.server)
772         if not conn then
773                 return nil, err;
774         end
775         conn:send (o.packet)
776
777         if timer and self.timeout then
778                 local num_servers = #self.server;
779                 local i = 1;
780                 timer.add_task(self.timeout, function ()
781                         if get(self.wanted, qclass, qtype, qname, co) then
782                                 if i < num_servers then
783                                         i = i + 1;
784                                         self:servfail(conn);
785                                         o.server = self.best_server;
786                                         conn, err = self:getsocket(o.server);
787                                         if conn then
788                                                 conn:send(o.packet);
789                                                 return self.timeout;
790                                         end
791                                 end
792                                 -- Tried everything, failed
793                                 self:cancel(qclass, qtype, qname);
794                         end
795                 end)
796         end
797         return true;
798 end
799
800 function resolver:servfail(sock, err)
801         -- Resend all queries for this server
802
803         local num = self.socketset[sock]
804
805         -- Socket is dead now
806         sock = self:voidsocket(sock);
807
808         -- Find all requests to the down server, and retry on the next server
809         self.time = socket.gettime();
810         for id,queries in pairs(self.active) do
811                 for question,o in pairs(queries) do
812                         if o.server == num then -- This request was to the broken server
813                                 o.server = o.server + 1 -- Use next server
814                                 if o.server > #self.server then
815                                         o.server = 1;
816                                 end
817
818                                 o.retries = (o.retries or 0) + 1;
819                                 if o.retries >= #self.server then
820                                         --print('timeout');
821                                         queries[question] = nil;
822                                 else
823                                         sock, err = self:getsocket(o.server);
824                                         if sock then sock:send(o.packet); end
825                                 end
826                         end
827                 end
828                 if next(queries) == nil then
829                         self.active[id] = nil;
830                 end
831         end
832
833         if num == self.best_server then
834                 self.best_server = self.best_server + 1;
835                 if self.best_server > #self.server then
836                         -- Exhausted all servers, try first again
837                         self.best_server = 1;
838                 end
839         end
840         return sock, err;
841 end
842
843 function resolver:settimeout(seconds)
844         self.timeout = seconds;
845 end
846
847 function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
848         --print('receive');  print(self.socket);
849         self.time = socket.gettime();
850         rset = rset or self.socket;
851
852         local response;
853         for i,sock in pairs(rset) do
854
855                 if self.socketset[sock] then
856                         local packet = sock:receive();
857                         if packet then
858                                 response = self:decode(packet);
859                                 if response and self.active[response.header.id]
860                                         and self.active[response.header.id][response.question.raw] then
861                                         --print('received response');
862                                         --self.print(response);
863
864                                         for j,rr in pairs(response.answer) do
865                                                 if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
866                                                         self:remember(rr, response.question[1].type)
867                                                 end
868                                         end
869
870                                         -- retire the query
871                                         local queries = self.active[response.header.id];
872                                         queries[response.question.raw] = nil;
873
874                                         if not next(queries) then self.active[response.header.id] = nil; end
875                                         if not next(self.active) then self:closeall(); end
876
877                                         -- was the query on the wanted list?
878                                         local q = response.question[1];
879                                         local cos = get(self.wanted, q.class, q.type, q.name);
880                                         if cos then
881                                                 for co in pairs(cos) do
882                                                         if coroutine.status(co) == "suspended" then coroutine.resume(co); end
883                                                 end
884                                                 set(self.wanted, q.class, q.type, q.name, nil);
885                                         end
886                                 end
887
888                         end
889                 end
890         end
891
892         return response;
893 end
894
895
896 function resolver:feed(sock, packet, force)
897         --print('receive'); print(self.socket);
898         self.time = socket.gettime();
899
900         local response = self:decode(packet, force);
901         if response and self.active[response.header.id]
902                 and self.active[response.header.id][response.question.raw] then
903                 --print('received response');
904                 --self.print(response);
905
906                 for j,rr in pairs(response.answer) do
907                         self:remember(rr, response.question[1].type);
908                 end
909
910                 -- retire the query
911                 local queries = self.active[response.header.id];
912                 queries[response.question.raw] = nil;
913                 if not next(queries) then self.active[response.header.id] = nil; end
914                 if not next(self.active) then self:closeall(); end
915
916                 -- was the query on the wanted list?
917                 local q = response.question[1];
918                 if q then
919                         local cos = get(self.wanted, q.class, q.type, q.name);
920                         if cos then
921                                 for co in pairs(cos) do
922                                         if coroutine.status(co) == "suspended" then coroutine.resume(co); end
923                                 end
924                                 set(self.wanted, q.class, q.type, q.name, nil);
925                         end
926                 end
927         end
928
929         return response;
930 end
931
932 function resolver:cancel(qclass, qtype, qname)
933         local cos = get(self.wanted, qclass, qtype, qname);
934         if cos then
935                 for co in pairs(cos) do
936                         if coroutine.status(co) == "suspended" then coroutine.resume(co); end
937                 end
938                 set(self.wanted, qclass, qtype, qname, nil);
939         end
940 end
941
942 function resolver:pulse()    -- - - - - - - - - - - - - - - - - - - - -  pulse
943         --print(':pulse');
944         while self:receive() do end
945         if not next(self.active) then return nil; end
946
947         self.time = socket.gettime();
948         for id,queries in pairs(self.active) do
949                 for question,o in pairs(queries) do
950                         if self.time >= o.retry then
951
952                                 o.server = o.server + 1;
953                                 if o.server > #self.server then
954                                         o.server = 1;
955                                         o.delay = o.delay + 1;
956                                 end
957
958                                 if o.delay > #self.delays then
959                                         --print('timeout');
960                                         queries[question] = nil;
961                                         if not next(queries) then self.active[id] = nil; end
962                                         if not next(self.active) then return nil; end
963                                 else
964                                         --print('retry', o.server, o.delay);
965                                         local _a = self.socket[o.server];
966                                         if _a then _a:send(o.packet); end
967                                         o.retry = self.time + self.delays[o.delay];
968                                 end
969                         end
970                 end
971         end
972
973         if next(self.active) then return true; end
974         return nil;
975 end
976
977
978 function resolver:lookup(qname, qtype, qclass)    -- - - - - - - - - -  lookup
979         self:query (qname, qtype, qclass)
980         while self:pulse() do
981                 local recvt = {}
982                 for i, s in ipairs(self.socket) do
983                         recvt[i] = s
984                 end
985                 socket.select(recvt, nil, 4)
986         end
987         --print(self.cache);
988         return self:peek(qname, qtype, qclass);
989 end
990
991 function resolver:lookupex(handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
992         return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
993 end
994
995 function resolver:tohostname(ip)
996         return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR");
997 end
998
999 --print ---------------------------------------------------------------- print
1000
1001
1002 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
1003         qr = { [0]='query', 'response' },
1004         opcode = { [0]='query', 'inverse query', 'server status request' },
1005         aa = { [0]='non-authoritative', 'authoritative' },
1006         tc = { [0]='complete', 'truncated' },
1007         rd = { [0]='recursion not desired', 'recursion desired' },
1008         ra = { [0]='recursion not available', 'recursion available' },
1009         z  = { [0]='(reserved)' },
1010         rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' },
1011
1012         type = dns.type,
1013         class = dns.class
1014 };
1015
1016
1017 local function hint(p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
1018         return (hints[s] and hints[s][p[s]]) or '';
1019 end
1020
1021
1022 function resolver.print(response)    -- - - - - - - - - - - - - resolver.print
1023         for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
1024                                                 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
1025                 print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) );
1026         end
1027
1028         for i,question in ipairs(response.question) do
1029                 print(string.format ('question[%i].name         ', i), question.name);
1030                 print(string.format ('question[%i].type         ', i), question.type);
1031                 print(string.format ('question[%i].class        ', i), question.class);
1032         end
1033
1034         local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 };
1035         local tmp;
1036         for s,s in pairs({'answer', 'authority', 'additional'}) do
1037                 for i,rr in pairs(response[s]) do
1038                         for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do
1039                                 tmp = string.format('%s[%i].%s', s, i, t);
1040                                 print(string.format('%-30s', tmp), rr[t], hint(rr, t));
1041                         end
1042                         for j,t in pairs(rr) do
1043                                 if not common[j] then
1044                                         tmp = string.format('%s[%i].%s', s, i, j);
1045                                         print(string.format('%-30s  %s', tostring(tmp), tostring(t)));
1046                                 end
1047                         end
1048                 end
1049         end
1050 end
1051
1052
1053 -- module api ------------------------------------------------------ module api
1054
1055
1056 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
1057         -- this function seems to be redundant with resolver.new ()
1058
1059         local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 };
1060         setmetatable (r, resolver);
1061         setmetatable (r.cache, cache_metatable);
1062         setmetatable (r.unsorted, { __mode = 'kv' });
1063         return r;
1064 end
1065
1066 local _resolver = dns.resolver();
1067 dns._resolver = _resolver;
1068
1069 function dns.lookup(...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
1070         return _resolver:lookup(...);
1071 end
1072
1073 function dns.tohostname(...)
1074         return _resolver:tohostname(...);
1075 end
1076
1077 function dns.purge(...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
1078         return _resolver:purge(...);
1079 end
1080
1081 function dns.peek(...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
1082         return _resolver:peek(...);
1083 end
1084
1085 function dns.query(...)    -- - - - - - - - - - - - - - - - - - - - - -  query
1086         return _resolver:query(...);
1087 end
1088
1089 function dns.feed(...)    -- - - - - - - - - - - - - - - - - - - - - - -  feed
1090         return _resolver:feed(...);
1091 end
1092
1093 function dns.cancel(...)  -- - - - - - - - - - - - - - - - - - - - - -  cancel
1094         return _resolver:cancel(...);
1095 end
1096
1097 function dns.settimeout(...)
1098         return _resolver:settimeout(...);
1099 end
1100
1101 function dns.cache()
1102         return _resolver.cache;
1103 end
1104
1105 function dns.socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
1106         return _resolver:socket_wrapper_set(...);
1107 end
1108
1109 return dns;