net.server_select: Remove some debugging code.
[prosody.git] / net / dns.lua
1 -- Prosody IM
2 -- This file is included with Prosody IM. It has modifications,
3 -- which are hereby placed in the public domain.
4
5 -- public domain 20080404 lua@ztact.com
6
7
8 -- todo: quick (default) header generation
9 -- todo: nxdomain, error handling
10 -- todo: cache results of encodeName
11
12
13 -- reference: http://tools.ietf.org/html/rfc1035
14 -- reference: http://tools.ietf.org/html/rfc1876 (LOC)
15
16
17 local socket = require "socket";
18 local ztact = require "util.ztact";
19 local _, windows = pcall(require, "util.windows");
20 local is_windows = (_ and windows) or os.getenv("WINDIR");
21
22 local coroutine, io, math, string, table =
23       coroutine, io, math, string, table;
24
25 local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
26       ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack;
27
28 local get, set = ztact.get, ztact.set;
29
30
31 -------------------------------------------------- module dns
32 module('dns')
33 local dns = _M;
34
35
36 -- dns type & class codes ------------------------------ dns type & class codes
37
38
39 local append = table.insert
40
41
42 local function highbyte(i)    -- - - - - - - - - - - - - - - - - - -  highbyte
43         return (i-(i%0x100))/0x100;
44 end
45
46
47 local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
48         local a = {};
49         for i,s in pairs(t) do
50                 a[i] = s;
51                 a[s] = s;
52                 a[string.lower(s)] = s;
53         end
54         return a;
55 end
56
57
58 local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
59         local code = {};
60         for i,s in pairs(t) do
61                 local word = string.char(highbyte(i), i%0x100);
62                 code[i] = word;
63                 code[s] = word;
64                 code[string.lower(s)] = word;
65         end
66         return code;
67 end
68
69
70 dns.types = {
71         'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
72         'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
73         [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
74         [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' };
75
76
77 dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' };
78
79
80 dns.type      = augment (dns.types);
81 dns.class     = augment (dns.classes);
82 dns.typecode  = encode  (dns.types);
83 dns.classcode = encode  (dns.classes);
84
85
86
87 local function standardize(qname, qtype, qclass)    -- - - - - - - standardize
88         if string.byte(qname, -1) ~= 0x2E then qname = qname..'.';  end
89         qname = string.lower(qname);
90         return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'];
91 end
92
93
94 local function prune(rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
95         time = time or socket.gettime();
96         for i,rr in pairs(rrs) do
97                 if rr.tod then
98                         -- rr.tod = rr.tod - 50    -- accelerated decripitude
99                         rr.ttl = math.floor(rr.tod - time);
100                         if rr.ttl <= 0 then
101                                 table.remove(rrs, i);
102                                 return prune(rrs, time, soft); -- Re-iterate
103                         end
104                 elseif soft == 'soft' then    -- What is this?  I forget!
105                         assert(rr.ttl == 0);
106                         rrs[i] = nil;
107                 end
108         end
109 end
110
111
112 -- metatables & co. ------------------------------------------ metatables & co.
113
114
115 local resolver = {};
116 resolver.__index = resolver;
117
118
119 local SRV_tostring;
120
121
122 local rr_metatable = {};   -- - - - - - - - - - - - - - - - - - -  rr_metatable
123 function rr_metatable.__tostring(rr)
124         local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name);
125         local s1 = '';
126         if rr.type == 'A' then
127                 s1 = ' '..rr.a;
128         elseif rr.type == 'MX' then
129                 s1 = string.format(' %2i %s', rr.pref, rr.mx);
130         elseif rr.type == 'CNAME' then
131                 s1 = ' '..rr.cname;
132         elseif rr.type == 'LOC' then
133                 s1 = ' '..resolver.LOC_tostring(rr);
134         elseif rr.type == 'NS' then
135                 s1 = ' '..rr.ns;
136         elseif rr.type == 'SRV' then
137                 s1 = ' '..SRV_tostring(rr);
138         elseif rr.type == 'TXT' then
139                 s1 = ' '..rr.txt;
140         else
141                 s1 = ' <UNKNOWN RDATA TYPE>';
142         end
143         return s0..s1;
144 end
145
146
147 local rrs_metatable = {};    -- - - - - - - - - - - - - - - - - -  rrs_metatable
148 function rrs_metatable.__tostring(rrs)
149         local t = {};
150         for i,rr in pairs(rrs) do
151                 append(t, tostring(rr)..'\n');
152         end
153         return table.concat(t);
154 end
155
156
157 local cache_metatable = {};    -- - - - - - - - - - - - - - - -  cache_metatable
158 function cache_metatable.__tostring(cache)
159         local time = socket.gettime();
160         local t = {};
161         for class,types in pairs(cache) do
162                 for type,names in pairs(types) do
163                         for name,rrs in pairs(names) do
164                                 prune(rrs, time);
165                                 append(t, tostring(rrs));
166                         end
167                 end
168         end
169         return table.concat(t);
170 end
171
172
173 function resolver:new()    -- - - - - - - - - - - - - - - - - - - - - resolver
174         local r = { active = {}, cache = {}, unsorted = {} };
175         setmetatable(r, resolver);
176         setmetatable(r.cache, cache_metatable);
177         setmetatable(r.unsorted, { __mode = 'kv' });
178         return r;
179 end
180
181
182 -- packet layer -------------------------------------------------- packet layer
183
184
185 function dns.random(...)    -- - - - - - - - - - - - - - - - - - -  dns.random
186         math.randomseed(math.floor(10000*socket.gettime()));
187         dns.random = math.random;
188         return dns.random(...);
189 end
190
191
192 local function encodeHeader(o)    -- - - - - - - - - - - - - - -  encodeHeader
193         o = o or {};
194         o.id = o.id or dns.random(0, 0xffff); -- 16b    (random) id
195
196         o.rd = o.rd or 1;               --  1b  1 recursion desired
197         o.tc = o.tc or 0;               --  1b  1 truncated response
198         o.aa = o.aa or 0;               --  1b  1 authoritative response
199         o.opcode = o.opcode or 0;       --  4b  0 query
200                                 --  1 inverse query
201                                 --      2 server status request
202                                 --      3-15 reserved
203         o.qr = o.qr or 0;               --  1b  0 query, 1 response
204
205         o.rcode = o.rcode or 0; --  4b  0 no error
206                                 --      1 format error
207                                 --      2 server failure
208                                 --      3 name error
209                                 --      4 not implemented
210                                 --      5 refused
211                                 --      6-15 reserved
212         o.z = o.z  or 0;                --  3b  0 resvered
213         o.ra = o.ra or 0;               --  1b  1 recursion available
214
215         o.qdcount = o.qdcount or 1;     -- 16b  number of question RRs
216         o.ancount = o.ancount or 0;     -- 16b  number of answers RRs
217         o.nscount = o.nscount or 0;     -- 16b  number of nameservers RRs
218         o.arcount = o.arcount or 0;     -- 16b  number of additional RRs
219
220         -- string.char() rounds, so prevent roundup with -0.4999
221         local header = string.char(
222                 highbyte(o.id), o.id %0x100,
223                 o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
224                 o.rcode + 16*o.z + 128*o.ra,
225                 highbyte(o.qdcount),  o.qdcount %0x100,
226                 highbyte(o.ancount),  o.ancount %0x100,
227                 highbyte(o.nscount),  o.nscount %0x100,
228                 highbyte(o.arcount),  o.arcount %0x100
229         );
230
231         return header, o.id;
232 end
233
234
235 local function encodeName(name)    -- - - - - - - - - - - - - - - - encodeName
236         local t = {};
237         for part in string.gmatch(name, '[^.]+') do
238                 append(t, string.char(string.len(part)));
239                 append(t, part);
240         end
241         append(t, string.char(0));
242         return table.concat(t);
243 end
244
245
246 local function encodeQuestion(qname, qtype, qclass)    -- - - - encodeQuestion
247         qname  = encodeName(qname);
248         qtype  = dns.typecode[qtype or 'a'];
249         qclass = dns.classcode[qclass or 'in'];
250         return qname..qtype..qclass;
251 end
252
253
254 function resolver:byte(len)    -- - - - - - - - - - - - - - - - - - - - - byte
255         len = len or 1;
256         local offset = self.offset;
257         local last = offset + len - 1;
258         if last > #self.packet then
259                 error(string.format('out of bounds: %i>%i', last, #self.packet));
260         end
261         self.offset = offset + len;
262         return string.byte(self.packet, offset, last);
263 end
264
265
266 function resolver:word()    -- - - - - - - - - - - - - - - - - - - - - -  word
267         local b1, b2 = self:byte(2);
268         return 0x100*b1 + b2;
269 end
270
271
272 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
273         local b1, b2, b3, b4 = self:byte(4);
274         --print('dword', b1, b2, b3, b4);
275         return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4;
276 end
277
278
279 function resolver:sub(len)    -- - - - - - - - - - - - - - - - - - - - - - sub
280         len = len or 1;
281         local s = string.sub(self.packet, self.offset, self.offset + len - 1);
282         self.offset = self.offset + len;
283         return s;
284 end
285
286
287 function resolver:header(force)    -- - - - - - - - - - - - - - - - - - header
288         local id = self:word();
289         --print(string.format(':header  id  %x', id));
290         if not self.active[id] and not force then return nil; end
291
292         local h = { id = id };
293
294         local b1, b2 = self:byte(2);
295
296         h.rd      = b1 %2;
297         h.tc      = b1 /2%2;
298         h.aa      = b1 /4%2;
299         h.opcode  = b1 /8%16;
300         h.qr      = b1 /128;
301
302         h.rcode   = b2 %16;
303         h.z       = b2 /16%8;
304         h.ra      = b2 /128;
305
306         h.qdcount = self:word();
307         h.ancount = self:word();
308         h.nscount = self:word();
309         h.arcount = self:word();
310
311         for k,v in pairs(h) do h[k] = v-v%1; end
312
313         return h;
314 end
315
316
317 function resolver:name()    -- - - - - - - - - - - - - - - - - - - - - -  name
318         local remember, pointers = nil, 0;
319         local len = self:byte();
320         local n = {};
321         while len > 0 do
322                 if len >= 0xc0 then    -- name is "compressed"
323                         pointers = pointers + 1;
324                         if pointers >= 20 then error('dns error: 20 pointers'); end;
325                         local offset = ((len-0xc0)*0x100) + self:byte();
326                         remember = remember or self.offset;
327                         self.offset = offset + 1;    -- +1 for lua
328                 else    -- name is not compressed
329                         append(n, self:sub(len)..'.');
330                 end
331                 len = self:byte();
332         end
333         self.offset = remember or self.offset;
334         return table.concat(n);
335 end
336
337
338 function resolver:question()    -- - - - - - - - - - - - - - - - - -  question
339         local q = {};
340         q.name  = self:name();
341         q.type  = dns.type[self:word()];
342         q.class = dns.class[self:word()];
343         return q;
344 end
345
346
347 function resolver:A(rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
348         local b1, b2, b3, b4 = self:byte(4);
349         rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
350 end
351
352
353 function resolver:CNAME(rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
354         rr.cname = self:name();
355 end
356
357
358 function resolver:MX(rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
359         rr.pref = self:word();
360         rr.mx   = self:name();
361 end
362
363
364 function resolver:LOC_nibble_power()    -- - - - - - - - - -  LOC_nibble_power
365         local b = self:byte();
366         --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10));
367         return ((b-(b%0x10))/0x10) * (10^(b%0x10));
368 end
369
370
371 function resolver:LOC(rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
372         rr.version = self:byte();
373         if rr.version == 0 then
374                 rr.loc           = rr.loc or {};
375                 rr.loc.size      = self:LOC_nibble_power();
376                 rr.loc.horiz_pre = self:LOC_nibble_power();
377                 rr.loc.vert_pre  = self:LOC_nibble_power();
378                 rr.loc.latitude  = self:dword();
379                 rr.loc.longitude = self:dword();
380                 rr.loc.altitude  = self:dword();
381         end
382 end
383
384
385 local function LOC_tostring_degrees(f, pos, neg)    -- - - - - - - - - - - - -
386         f = f - 0x80000000;
387         if f < 0 then pos = neg; f = -f; end
388         local deg, min, msec;
389         msec = f%60000;
390         f    = (f-msec)/60000;
391         min  = f%60;
392         deg = (f-min)/60;
393         return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos);
394 end
395
396
397 function resolver.LOC_tostring(rr)    -- - - - - - - - - - - - -  LOC_tostring
398         local t = {};
399
400         --[[
401         for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do
402                 append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name]));
403         end
404         --]]
405
406         append(t, string.format(
407                 '%s    %s    %.2fm %.2fm %.2fm %.2fm',
408                 LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
409                 LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
410                 (rr.loc.altitude - 10000000) / 100,
411                 rr.loc.size / 100,
412                 rr.loc.horiz_pre / 100,
413                 rr.loc.vert_pre / 100
414         ));
415
416         return table.concat(t);
417 end
418
419
420 function resolver:NS(rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
421         rr.ns = self:name();
422 end
423
424
425 function resolver:SOA(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
426 end
427
428
429 function resolver:SRV(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
430           rr.srv = {};
431           rr.srv.priority = self:word();
432           rr.srv.weight   = self:word();
433           rr.srv.port     = self:word();
434           rr.srv.target   = self:name();
435 end
436
437
438 function SRV_tostring(rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
439         local s = rr.srv;
440         return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target );
441 end
442
443
444 function resolver:TXT(rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
445         rr.txt = self:sub (rr.rdlength);
446 end
447
448
449 function resolver:rr()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
450         local rr = {};
451         setmetatable(rr, rr_metatable);
452         rr.name     = self:name(self);
453         rr.type     = dns.type[self:word()] or rr.type;
454         rr.class    = dns.class[self:word()] or rr.class;
455         rr.ttl      = 0x10000*self:word() + self:word();
456         rr.rdlength = self:word();
457
458         if rr.ttl <= 0 then
459                 rr.tod = self.time + 30;
460         else
461                 rr.tod = self.time + rr.ttl;
462         end
463
464         local remember = self.offset;
465         local rr_parser = self[dns.type[rr.type]];
466         if rr_parser then rr_parser(self, rr); end
467         self.offset = remember;
468         rr.rdata = self:sub(rr.rdlength);
469         return rr;
470 end
471
472
473 function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
474         local rrs = {};
475         for i = 1,count do append(rrs, self:rr()); end
476         return rrs;
477 end
478
479
480 function resolver:decode(packet, force)    -- - - - - - - - - - - - - - decode
481         self.packet, self.offset = packet, 1;
482         local header = self:header(force);
483         if not header then return nil; end
484         local response = { header = header };
485
486         response.question = {};
487         local offset = self.offset;
488         for i = 1,response.header.qdcount do
489                 append(response.question, self:question());
490         end
491         response.question.raw = string.sub(self.packet, offset, self.offset - 1);
492
493         if not force then
494                 if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then
495                         return nil;
496                 end
497         end
498
499         response.answer     = self:rrs(response.header.ancount);
500         response.authority  = self:rrs(response.header.nscount);
501         response.additional = self:rrs(response.header.arcount);
502
503         return response;
504 end
505
506
507 -- socket layer -------------------------------------------------- socket layer
508
509
510 resolver.delays = { 1, 3 };
511
512
513 function resolver:addnameserver(address)    -- - - - - - - - - - addnameserver
514         self.server = self.server or {};
515         append(self.server, address);
516 end
517
518
519 function resolver:setnameserver(address)    -- - - - - - - - - - setnameserver
520         self.server = {};
521         self:addnameserver(address);
522 end
523
524
525 function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
526         if is_windows then
527                 if windows then
528                         for _, server in ipairs(windows.get_nameservers()) do
529                                 self:addnameserver(server);
530                         end
531                 end
532                 if not self.server or #self.server == 0 then
533                         -- TODO log warning about no nameservers, adding opendns servers as fallback
534                         self:addnameserver("208.67.222.222");
535                         self:addnameserver("208.67.220.220");
536                 end
537         else -- posix
538                 local resolv_conf = io.open("/etc/resolv.conf");
539                 if resolv_conf then
540                         for line in resolv_conf:lines() do
541                                 line = line:gsub("#.*$", "")
542                                         :match('^%s*nameserver%s+(.*)%s*$');
543                                 if line then
544                                         line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address)
545                                                 self:addnameserver(address)
546                                         end);
547                                 end
548                         end
549                 end
550                 if not self.server or #self.server == 0 then
551                         -- TODO log warning about no nameservers, adding localhost as the default nameserver
552                         self:addnameserver("127.0.0.1");
553                 end
554         end
555 end
556
557
558 function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
559         self.socket = self.socket or {};
560         self.socketset = self.socketset or {};
561
562         local sock = self.socket[servernum];
563         if sock then return sock; end
564
565         sock = socket.udp();
566         if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end
567         sock:settimeout(0);
568         -- todo: attempt to use a random port, fallback to 0
569         sock:setsockname('*', 0);
570         sock:setpeername(self.server[servernum], 53);
571         self.socket[servernum] = sock;
572         self.socketset[sock] = servernum;
573         return sock;
574 end
575
576 function resolver:voidsocket(sock)
577         if self.socket[sock] then
578                 self.socketset[self.socket[sock]] = nil;
579                 self.socket[sock] = nil;
580         elseif self.socketset[sock] then
581                 self.socket[self.socketset[sock]] = nil;
582                 self.socketset[sock] = nil;
583         end
584 end
585
586 function resolver:socket_wrapper_set(func)  -- - - - - - - socket_wrapper_set
587         self.socket_wrapper = func;
588 end
589
590
591 function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
592         for i,sock in ipairs(self.socket) do
593                 self.socket[i] = nil;
594                 self.socketset[sock] = nil;
595                 sock:close();
596         end
597 end
598
599
600 function resolver:remember(rr, type)    -- - - - - - - - - - - - - -  remember
601         --print ('remember', type, rr.class, rr.type, rr.name)
602         local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class);
603
604         if type ~= '*' then
605                 type = qtype;
606                 local all = get(self.cache, qclass, '*', qname);
607                 --print('remember all', all);
608                 if all then append(all, rr); end
609         end
610
611         self.cache = self.cache or setmetatable({}, cache_metatable);
612         local rrs = get(self.cache, qclass, type, qname) or
613                 set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
614         append(rrs, rr);
615
616         if type == 'MX' then self.unsorted[rrs] = true; end
617 end
618
619
620 local function comp_mx(a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
621         return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref);
622 end
623
624
625 function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
626         qname, qtype, qclass = standardize(qname, qtype, qclass);
627         local rrs = get(self.cache, qclass, qtype, qname);
628         if not rrs then return nil; end
629         if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then
630                 set(self.cache, qclass, qtype, qname, nil);
631                 return nil;
632         end
633         if self.unsorted[rrs] then table.sort (rrs, comp_mx); end
634         return rrs;
635 end
636
637
638 function resolver:purge(soft)    -- - - - - - - - - - - - - - - - - - -  purge
639         if soft == 'soft' then
640                 self.time = socket.gettime();
641                 for class,types in pairs(self.cache or {}) do
642                         for type,names in pairs(types) do
643                                 for name,rrs in pairs(names) do
644                                         prune(rrs, self.time, 'soft')
645                                 end
646                         end
647                 end
648         else self.cache = {}; end
649 end
650
651
652 function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
653         qname, qtype, qclass = standardize(qname, qtype, qclass)
654
655         if not self.server then self:adddefaultnameservers(); end
656
657         local question = encodeQuestion(qname, qtype, qclass);
658         local peek = self:peek (qname, qtype, qclass);
659         if peek then return peek; end
660
661         local header, id = encodeHeader();
662         --print ('query  id', id, qclass, qtype, qname)
663         local o = {
664                 packet = header..question,
665                 server = self.best_server,
666                 delay  = 1,
667                 retry  = socket.gettime() + self.delays[1]
668         };
669
670   -- remember the query
671         self.active[id] = self.active[id] or {};
672         self.active[id][question] = o;
673
674   -- remember which coroutine wants the answer
675         local co = coroutine.running();
676         if co then
677                 set(self.wanted, qclass, qtype, qname, co, true);
678                 --set(self.yielded, co, qclass, qtype, qname, true);
679         end
680
681         self:getsocket (o.server):send (o.packet)
682 end
683
684 function resolver:servfail(sock)
685         -- Resend all queries for this server
686
687         local num = self.socketset[sock]
688
689         -- Socket is dead now
690         self:voidsocket(sock);
691
692         -- Find all requests to the down server, and retry on the next server
693         self.time = socket.gettime();
694         for id,queries in pairs(self.active) do
695                 for question,o in pairs(queries) do
696                         if o.server == num then -- This request was to the broken server
697                                 o.server = o.server + 1 -- Use next server
698                                 if o.server > #self.server then
699                                         o.server = 1;
700                                 end
701
702                                 o.retries = (o.retries or 0) + 1;
703                                 if o.retries >= #self.server then
704                                         --print('timeout');
705                                         queries[question] = nil;
706                                 else
707                                         local _a = self:getsocket(o.server);
708                                         if _a then _a:send(o.packet); end
709                                 end
710                         end
711                 end
712         end
713    
714         if num == self.best_server then
715                 self.best_server = self.best_server + 1;
716                 if self.best_server > #self.server then
717                         -- Exhausted all servers, try first again
718                         self.best_server = 1;
719                 end
720         end
721 end
722
723 function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
724         --print('receive');  print(self.socket);
725         self.time = socket.gettime();
726         rset = rset or self.socket;
727
728         local response;
729         for i,sock in pairs(rset) do
730
731                 if self.socketset[sock] then
732                         local packet = sock:receive();
733                         if packet then
734                                 response = self:decode(packet);
735                                 if response and self.active[response.header.id]
736                                         and self.active[response.header.id][response.question.raw] then
737                                         --print('received response');
738                                         --self.print(response);
739
740                                         for j,rr in pairs(response.answer) do
741                                                 if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
742                                                         self:remember(rr, response.question[1].type)
743                                                 end
744                                         end
745
746                                         -- retire the query
747                                         local queries = self.active[response.header.id];
748                                         queries[response.question.raw] = nil;
749                                         
750                                         if not next(queries) then self.active[response.header.id] = nil; end
751                                         if not next(self.active) then self:closeall(); end
752
753                                         -- was the query on the wanted list?
754                                         local q = response.question[1];
755                                         local cos = get(self.wanted, q.class, q.type, q.name);
756                                         if cos then
757                                                 for co in pairs(cos) do
758                                                         set(self.yielded, co, q.class, q.type, q.name, nil);
759                                                         if coroutine.status(co) == "suspended" then coroutine.resume(co); end
760                                                 end
761                                                 set(self.wanted, q.class, q.type, q.name, nil);
762                                         end
763                                 end
764                         end
765                 end
766         end
767
768         return response;
769 end
770
771
772 function resolver:feed(sock, packet)
773         --print('receive'); print(self.socket);
774         self.time = socket.gettime();
775
776         local response = self:decode(packet);
777         if response and self.active[response.header.id]
778                 and self.active[response.header.id][response.question.raw] then
779                 --print('received response');
780                 --self.print(response);
781
782                 for j,rr in pairs(response.answer) do
783                         self:remember(rr, response.question[1].type);
784                 end
785
786                 -- retire the query
787                 local queries = self.active[response.header.id];
788                 queries[response.question.raw] = nil;
789                 if not next(queries) then self.active[response.header.id] = nil; end
790                 if not next(self.active) then self:closeall(); end
791
792                 -- was the query on the wanted list?
793                 local q = response.question[1];
794                 if q then
795                         local cos = get(self.wanted, q.class, q.type, q.name);
796                         if cos then
797                                 for co in pairs(cos) do
798                                         set(self.yielded, co, q.class, q.type, q.name, nil);
799                                         if coroutine.status(co) == "suspended" then coroutine.resume(co); end
800                                 end
801                                 set(self.wanted, q.class, q.type, q.name, nil);
802                         end
803                 end
804         end
805
806         return response;
807 end
808
809 function resolver:cancel(data)
810         local cos = get(self.wanted, unpack(data, 1, 3));
811         if cos then
812                 cos[data[4]] = nil;
813         end
814 end
815
816 function resolver:pulse()    -- - - - - - - - - - - - - - - - - - - - -  pulse
817         --print(':pulse');
818         while self:receive() do end
819         if not next(self.active) then return nil; end
820
821         self.time = socket.gettime();
822         for id,queries in pairs(self.active) do
823                 for question,o in pairs(queries) do
824                         if self.time >= o.retry then
825
826                                 o.server = o.server + 1;
827                                 if o.server > #self.server then
828                                         o.server = 1;
829                                         o.delay = o.delay + 1;
830                                 end
831
832                                 if o.delay > #self.delays then
833                                         --print('timeout');
834                                         queries[question] = nil;
835                                         if not next(queries) then self.active[id] = nil; end
836                                         if not next(self.active) then return nil; end
837                                 else
838                                         --print('retry', o.server, o.delay);
839                                         local _a = self.socket[o.server];
840                                         if _a then _a:send(o.packet); end
841                                         o.retry = self.time + self.delays[o.delay];
842                                 end
843                         end
844                 end
845         end
846
847         if next(self.active) then return true; end
848         return nil;
849 end
850
851
852 function resolver:lookup(qname, qtype, qclass)    -- - - - - - - - - -  lookup
853         self:query (qname, qtype, qclass)
854         while self:pulse() do socket.select(self.socket, nil, 4); end
855         --print(self.cache);
856         return self:peek(qname, qtype, qclass);
857 end
858
859 function resolver:lookupex(handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
860         return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
861 end
862
863
864 --print ---------------------------------------------------------------- print
865
866
867 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
868         qr = { [0]='query', 'response' },
869         opcode = { [0]='query', 'inverse query', 'server status request' },
870         aa = { [0]='non-authoritative', 'authoritative' },
871         tc = { [0]='complete', 'truncated' },
872         rd = { [0]='recursion not desired', 'recursion desired' },
873         ra = { [0]='recursion not available', 'recursion available' },
874         z  = { [0]='(reserved)' },
875         rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' },
876
877         type = dns.type,
878         class = dns.class
879 };
880
881
882 local function hint(p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
883         return (hints[s] and hints[s][p[s]]) or '';
884 end
885
886
887 function resolver.print(response)    -- - - - - - - - - - - - - resolver.print
888         for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
889                                                 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
890                 print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) );
891         end
892
893         for i,question in ipairs(response.question) do
894                 print(string.format ('question[%i].name         ', i), question.name);
895                 print(string.format ('question[%i].type         ', i), question.type);
896                 print(string.format ('question[%i].class        ', i), question.class);
897         end
898
899         local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 };
900         local tmp;
901         for s,s in pairs({'answer', 'authority', 'additional'}) do
902                 for i,rr in pairs(response[s]) do
903                         for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do
904                                 tmp = string.format('%s[%i].%s', s, i, t);
905                                 print(string.format('%-30s', tmp), rr[t], hint(rr, t));
906                         end
907                         for j,t in pairs(rr) do
908                                 if not common[j] then
909                                         tmp = string.format('%s[%i].%s', s, i, j);
910                                         print(string.format('%-30s  %s', tostring(tmp), tostring(t)));
911                                 end
912                         end
913                 end
914         end
915 end
916
917
918 -- module api ------------------------------------------------------ module api
919
920
921 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
922         -- this function seems to be redundant with resolver.new ()
923
924         local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 };
925         setmetatable (r, resolver);
926         setmetatable (r.cache, cache_metatable);
927         setmetatable (r.unsorted, { __mode = 'kv' });
928         return r;
929 end
930
931 local _resolver = dns.resolver();
932 dns._resolver = _resolver;
933
934 function dns.lookup(...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
935         return _resolver:lookup(...);
936 end
937
938 function dns.purge(...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
939         return _resolver:purge(...);
940 end
941
942 function dns.peek(...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
943         return _resolver:peek(...);
944 end
945
946 function dns.query(...)    -- - - - - - - - - - - - - - - - - - - - - -  query
947         return _resolver:query(...);
948 end
949
950 function dns.feed(...)    -- - - - - - - - - - - - - - - - - - - - - - -  feed
951         return _resolver:feed(...);
952 end
953
954 function dns.cancel(...)  -- - - - - - - - - - - - - - - - - - - - - -  cancel
955         return _resolver:cancel(...);
956 end
957
958 function dns.socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
959         return _resolver:socket_wrapper_set(...);
960 end
961
962 return dns;