net.dns: Support for parsing PTR records
[prosody.git] / net / dns.lua
index 04b2cf22d8185cd1b9553755521fcb8df6ec3d04..29d9cf36d227a52df80b9e4ae75a730e094b4e2c 100644 (file)
@@ -16,6 +16,8 @@
 
 local socket = require "socket";
 local ztact = require "util.ztact";
+local timer = require "util.timer";
+
 local _, windows = pcall(require, "util.windows");
 local is_windows = (_ and windows) or os.getenv("WINDIR");
 
@@ -27,6 +29,7 @@ local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
 
 local get, set = ztact.get, ztact.set;
 
+local default_timeout = 15;
 
 -------------------------------------------------- module dns
 module('dns')
@@ -115,6 +118,7 @@ end
 local resolver = {};
 resolver.__index = resolver;
 
+resolver.timeout = default_timeout;
 
 local SRV_tostring;
 
@@ -183,7 +187,7 @@ end
 
 
 function dns.random(...)    -- - - - - - - - - - - - - - - - - - -  dns.random
-       math.randomseed(10000*socket.gettime());
+       math.randomseed(math.floor(10000*socket.gettime()));
        dns.random = math.random;
        return dns.random(...);
 end
@@ -434,6 +438,9 @@ function resolver:SRV(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
          rr.srv.target   = self:name();
 end
 
+function resolver:PTR(rr)
+       rr.ptr = self:name();
+end
 
 function SRV_tostring(rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
        local s = rr.srv;
@@ -532,14 +539,19 @@ function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
                if not self.server or #self.server == 0 then
                        -- TODO log warning about no nameservers, adding opendns servers as fallback
                        self:addnameserver("208.67.222.222");
-                       self:addnameserver("208.67.220.220") ;  
+                       self:addnameserver("208.67.220.220");
                end
        else -- posix
                local resolv_conf = io.open("/etc/resolv.conf");
                if resolv_conf then
                        for line in resolv_conf:lines() do
-                               local address = line:gsub("#.*$", ""):match('^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$');
-                               if address then self:addnameserver(address) end
+                               line = line:gsub("#.*$", "")
+                                       :match('^%s*nameserver%s+(.*)%s*$');
+                               if line then
+                                       line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address)
+                                               self:addnameserver(address)
+                                       end);
+                               end
                        end
                end
                if not self.server or #self.server == 0 then
@@ -594,17 +606,18 @@ end
 
 function resolver:remember(rr, type)    -- - - - - - - - - - - - - -  remember
        --print ('remember', type, rr.class, rr.type, rr.name)
+       local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class);
 
        if type ~= '*' then
-               type = rr.type;
-               local all = get(self.cache, rr.class, '*', rr.name);
+               type = qtype;
+               local all = get(self.cache, qclass, '*', qname);
                --print('remember all', all);
                if all then append(all, rr); end
        end
 
        self.cache = self.cache or setmetatable({}, cache_metatable);
-       local rrs = get(self.cache, rr.class, type, rr.name) or
-               set(self.cache, rr.class, type, rr.name, setmetatable({}, rrs_metatable));
+       local rrs = get(self.cache, qclass, type, qname) or
+               set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
        append(rrs, rr);
 
        if type == 'MX' then self.unsorted[rrs] = true; end
@@ -672,7 +685,28 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
                --set(self.yielded, co, qclass, qtype, qname, true);
        end
 
-       self:getsocket (o.server):send (o.packet)
+       local conn = self:getsocket(o.server)
+       conn:send (o.packet)
+       
+       if timer and self.timeout then
+               local num_servers = #self.server;
+               local i = 1;
+               timer.add_task(self.timeout, function ()
+                       if get(self.wanted, qclass, qtype, qname, co) then
+                               if i < num_servers then
+                                       i = i + 1;
+                                       self:servfail(conn);
+                                       o.server = self.best_server;
+                                       conn = self:getsocket(o.server);
+                                       conn:send(o.packet);
+                                       return self.timeout;
+                               else
+                                       -- Tried everything, failed
+                                       self:cancel(qclass, qtype, qname, co, true);
+                               end
+                       end
+               end)
+       end
 end
 
 function resolver:servfail(sock)
@@ -714,6 +748,10 @@ function resolver:servfail(sock)
        end
 end
 
+function resolver:settimeout(seconds)
+       self.timeout = seconds;
+end
+
 function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
        --print('receive');  print(self.socket);
        self.time = socket.gettime();
@@ -726,26 +764,26 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                        local packet = sock:receive();
                        if packet then
                                response = self:decode(packet);
-                               if response then
+                               if response and self.active[response.header.id]
+                                       and self.active[response.header.id][response.question.raw] then
                                        --print('received response');
                                        --self.print(response);
 
-                                       for i,section in pairs({ 'answer', 'authority', 'additional' }) do
-                                               for j,rr in pairs(response[section]) do
+                                       for j,rr in pairs(response.answer) do
+                                               if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
                                                        self:remember(rr, response.question[1].type)
                                                end
                                        end
 
                                        -- retire the query
                                        local queries = self.active[response.header.id];
-                                       if queries[response.question.raw] then
-                                               queries[response.question.raw] = nil;
-                                       end
+                                       queries[response.question.raw] = nil;
+                                       
                                        if not next(queries) then self.active[response.header.id] = nil; end
                                        if not next(self.active) then self:closeall(); end
 
                                        -- was the query on the wanted list?
-                                       local q = response.question;
+                                       local q = response.question[1];
                                        local cos = get(self.wanted, q.class, q.type, q.name);
                                        if cos then
                                                for co in pairs(cos) do
@@ -763,26 +801,23 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
 end
 
 
-function resolver:feed(sock, packet)
+function resolver:feed(sock, packet, force)
        --print('receive'); print(self.socket);
        self.time = socket.gettime();
 
-       local response = self:decode(packet);
-       if response then
+       local response = self:decode(packet, force);
+       if response and self.active[response.header.id]
+               and self.active[response.header.id][response.question.raw] then
                --print('received response');
                --self.print(response);
 
-               for i,section in pairs({ 'answer', 'authority', 'additional' }) do
-                       for j,rr in pairs(response[section]) do
-                               self:remember(rr, response.question[1].type);
-                       end
+               for j,rr in pairs(response.answer) do
+                       self:remember(rr, response.question[1].type);
                end
 
                -- retire the query
                local queries = self.active[response.header.id];
-               if queries[response.question.raw] then
-                       queries[response.question.raw] = nil;
-               end
+               queries[response.question.raw] = nil;
                if not next(queries) then self.active[response.header.id] = nil; end
                if not next(self.active) then self:closeall(); end
 
@@ -798,15 +833,18 @@ function resolver:feed(sock, packet)
                                set(self.wanted, q.class, q.type, q.name, nil);
                        end
                end
-       end 
+       end
 
        return response;
 end
 
-function resolver:cancel(data)
-       local cos = get(self.wanted, unpack(data, 1, 3));
+function resolver:cancel(qclass, qtype, qname, co, call_handler)
+       local cos = get(self.wanted, qclass, qtype, qname);
        if cos then
-               cos[data[4]] = nil;
+               if call_handler then
+                       coroutine.resume(co);
+               end
+               cos[co] = nil;
        end
 end
 
@@ -848,7 +886,13 @@ end
 
 function resolver:lookup(qname, qtype, qclass)    -- - - - - - - - - -  lookup
        self:query (qname, qtype, qclass)
-       while self:pulse() do socket.select(self.socket, nil, 4); end
+       while self:pulse() do
+           local recvt = {}
+           for i, s in ipairs(self.socket) do
+              recvt[i] = s
+           end
+           socket.select(recvt, nil, 4)
+        end
        --print(self.cache);
        return self:peek(qname, qtype, qclass);
 end
@@ -915,11 +959,6 @@ end
 -- module api ------------------------------------------------------ module api
 
 
-local function resolve(func, ...)    -- - - - - - - - - - - - - - resolver_get
-       return func(dns._resolver, ...);
-end
-
-
 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
        -- this function seems to be redundant with resolver.new ()
 
@@ -930,37 +969,39 @@ function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
        return r;
 end
 
+local _resolver = dns.resolver();
+dns._resolver = _resolver;
 
 function dns.lookup(...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
-       return resolve(resolver.lookup, ...);
+       return _resolver:lookup(...);
 end
 
-
 function dns.purge(...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
-       return resolve(resolver.purge, ...);
+       return _resolver:purge(...);
 end
 
 function dns.peek(...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
-       return resolve(resolver.peek, ...);
+       return _resolver:peek(...);
 end
 
-
 function dns.query(...)    -- - - - - - - - - - - - - - - - - - - - - -  query
-       return resolve(resolver.query, ...);
+       return _resolver:query(...);
 end
 
-function dns.feed(...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
-       return resolve(resolver.feed, ...);
+function dns.feed(...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
+       return _resolver:feed(...);
 end
 
-function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
-       return resolve(resolver.cancel, ...);
+function dns.cancel(...)  -- - - - - - - - - - - - - - - - - - - - - -  cancel
+       return _resolver:cancel(...);
 end
 
-function dns:socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
-       return resolve(resolver.socket_wrapper_set, ...);
+function dns.settimeout(...)
+       return _resolver:settimeout(...);
 end
 
-dns._resolver = dns.resolver();
+function dns.socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
+       return _resolver:socket_wrapper_set(...);
+end
 
 return dns;