Merge 0.7->0.8
[prosody.git] / net / dns.lua
index c0de97fd426db6f3573332cda332072028bedc7c..c905f56c5d612565d416a2fcd85321cc64f58b7a 100644 (file)
@@ -2,8 +2,6 @@
 -- This file is included with Prosody IM. It has modifications,
 -- which are hereby placed in the public domain.
 
--- public domain 20080404 lua@ztact.com
-
 
 -- todo: quick (default) header generation
 -- todo: nxdomain, error handling
 
 
 local socket = require "socket";
-local ztact = require "util.ztact";
+local timer = require "util.timer";
+
 local _, windows = pcall(require, "util.windows");
 local is_windows = (_ and windows) or os.getenv("WINDIR");
 
 local coroutine, io, math, string, table =
       coroutine, io, math, string, table;
 
-local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
-      ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack;
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
+      ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
 
+local ztact = { -- public domain 20080404 lua@ztact.com
+       get = function(parent, ...)
+               local len = select('#', ...);
+               for i=1,len do
+                       parent = parent[select(i, ...)];
+                       if parent == nil then break; end
+               end
+               return parent;
+       end;
+       set = function(parent, ...)
+               local len = select('#', ...);
+               local key, value = select(len-1, ...);
+               local cutpoint, cutkey;
+
+               for i=1,len-2 do
+                       local key = select (i, ...)
+                       local child = parent[key]
+
+                       if value == nil then
+                               if child == nil then
+                                       return;
+                               elseif next(child, next(child)) then
+                                       cutpoint = nil; cutkey = nil;
+                               elseif cutpoint == nil then
+                                       cutpoint = parent; cutkey = key;
+                               end
+                       elseif child == nil then
+                               child = {};
+                               parent[key] = child;
+                       end
+                       parent = child
+               end
+
+               if value == nil and cutpoint then
+                       cutpoint[cutkey] = nil;
+               else
+                       parent[key] = value;
+                       return value;
+               end
+       end;
+};
 local get, set = ztact.get, ztact.set;
 
+local default_timeout = 15;
 
 -------------------------------------------------- module dns
 module('dns')
@@ -115,32 +156,31 @@ end
 local resolver = {};
 resolver.__index = resolver;
 
+resolver.timeout = default_timeout;
 
-local SRV_tostring;
-
+local function default_rr_tostring(rr)
+       local rr_val = rr.type and rr[rr.type:lower()];
+       if type(rr_val) ~= "string" then
+               return "<UNKNOWN RDATA TYPE>";
+       end
+       return rr_val;
+end
+
+local special_tostrings = {
+       LOC = resolver.LOC_tostring;
+       MX  = function (rr)
+               return string.format('%2i %s', rr.pref, rr.mx);
+       end;
+       SRV = function (rr)
+               local s = rr.srv;
+               return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
+       end;
+};
 
 local rr_metatable = {};   -- - - - - - - - - - - - - - - - - - -  rr_metatable
 function rr_metatable.__tostring(rr)
-       local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name);
-       local s1 = '';
-       if rr.type == 'A' then
-               s1 = ' '..rr.a;
-       elseif rr.type == 'MX' then
-               s1 = string.format(' %2i %s', rr.pref, rr.mx);
-       elseif rr.type == 'CNAME' then
-               s1 = ' '..rr.cname;
-       elseif rr.type == 'LOC' then
-               s1 = ' '..resolver.LOC_tostring(rr);
-       elseif rr.type == 'NS' then
-               s1 = ' '..rr.ns;
-       elseif rr.type == 'SRV' then
-               s1 = ' '..SRV_tostring(rr);
-       elseif rr.type == 'TXT' then
-               s1 = ' '..rr.txt;
-       else
-               s1 = ' <UNKNOWN RDATA TYPE>';
-       end
-       return s0..s1;
+       local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr);
+       return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string);
 end
 
 
@@ -434,13 +474,10 @@ function resolver:SRV(rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
          rr.srv.target   = self:name();
 end
 
-
-function SRV_tostring(rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
-       local s = rr.srv;
-       return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target );
+function resolver:PTR(rr)
+       rr.ptr = self:name();
 end
 
-
 function resolver:TXT(rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
        rr.txt = self:sub (rr.rdlength);
 end
@@ -524,7 +561,7 @@ end
 
 function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
        if is_windows then
-               if windows then
+               if windows and windows.get_nameservers then
                        for _, server in ipairs(windows.get_nameservers()) do
                                self:addnameserver(server);
                        end
@@ -562,7 +599,11 @@ function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
        local sock = self.socket[servernum];
        if sock then return sock; end
 
-       sock = socket.udp();
+       local err;
+       sock, err = socket.udp();
+       if not sock then
+               return nil, err;
+       end
        if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end
        sock:settimeout(0);
        -- todo: attempt to use a random port, fallback to 0
@@ -667,18 +708,44 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
                retry  = socket.gettime() + self.delays[1]
        };
 
-  -- remember the query
+       -- remember the query
        self.active[id] = self.active[id] or {};
        self.active[id][question] = o;
 
-  -- remember which coroutine wants the answer
+       -- remember which coroutine wants the answer
        local co = coroutine.running();
        if co then
                set(self.wanted, qclass, qtype, qname, co, true);
                --set(self.yielded, co, qclass, qtype, qname, true);
        end
 
-       self:getsocket (o.server):send (o.packet)
+       local conn, err = self:getsocket(o.server)
+       if not conn then
+               return nil, err;
+       end
+       conn:send (o.packet)
+       
+       if timer and self.timeout then
+               local num_servers = #self.server;
+               local i = 1;
+               timer.add_task(self.timeout, function ()
+                       if get(self.wanted, qclass, qtype, qname, co) then
+                               if i < num_servers then
+                                       i = i + 1;
+                                       self:servfail(conn);
+                                       o.server = self.best_server;
+                                       conn, err = self:getsocket(o.server);
+                                       if conn then
+                                               conn:send(o.packet);
+                                               return self.timeout;
+                                       end
+                               end
+                               -- Tried everything, failed
+                               self:cancel(qclass, qtype, qname, co, true);
+                       end
+               end)
+       end
+       return true;
 end
 
 function resolver:servfail(sock)
@@ -710,7 +777,7 @@ function resolver:servfail(sock)
                        end
                end
        end
-   
+
        if num == self.best_server then
                self.best_server = self.best_server + 1;
                if self.best_server > #self.server then
@@ -720,6 +787,10 @@ function resolver:servfail(sock)
        end
 end
 
+function resolver:settimeout(seconds)
+       self.timeout = seconds;
+end
+
 function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
        --print('receive');  print(self.socket);
        self.time = socket.gettime();
@@ -769,11 +840,11 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
 end
 
 
-function resolver:feed(sock, packet)
+function resolver:feed(sock, packet, force)
        --print('receive'); print(self.socket);
        self.time = socket.gettime();
 
-       local response = self:decode(packet);
+       local response = self:decode(packet, force);
        if response and self.active[response.header.id]
                and self.active[response.header.id][response.question.raw] then
                --print('received response');
@@ -806,10 +877,13 @@ function resolver:feed(sock, packet)
        return response;
 end
 
-function resolver:cancel(data)
-       local cos = get(self.wanted, unpack(data, 1, 3));
+function resolver:cancel(qclass, qtype, qname, co, call_handler)
+       local cos = get(self.wanted, qclass, qtype, qname);
        if cos then
-               cos[data[4]] = nil;
+               if call_handler then
+                       coroutine.resume(co);
+               end
+               cos[co] = nil;
        end
 end
 
@@ -852,12 +926,12 @@ end
 function resolver:lookup(qname, qtype, qclass)    -- - - - - - - - - -  lookup
        self:query (qname, qtype, qclass)
        while self:pulse() do
-           local recvt = {}
-           for i, s in ipairs(self.socket) do
-              recvt[i] = s
-           end
-           socket.select(recvt, nil, 4)
-        end
+               local recvt = {}
+               for i, s in ipairs(self.socket) do
+                       recvt[i] = s
+               end
+               socket.select(recvt, nil, 4)
+       end
        --print(self.cache);
        return self:peek(qname, qtype, qclass);
 end
@@ -866,6 +940,9 @@ function resolver:lookupex(handler, qname, qtype, qclass)    -- - - - - - - - -
        return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
 end
 
+function resolver:tohostname(ip)
+       return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR");
+end
 
 --print ---------------------------------------------------------------- print
 
@@ -941,6 +1018,10 @@ function dns.lookup(...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
        return _resolver:lookup(...);
 end
 
+function dns.tohostname(...)
+       return _resolver:tohostname(...);
+end
+
 function dns.purge(...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
        return _resolver:purge(...);
 end
@@ -961,6 +1042,10 @@ function dns.cancel(...)  -- - - - - - - - - - - - - - - - - - - - - -  cancel
        return _resolver:cancel(...);
 end
 
+function dns.settimeout(...)
+       return _resolver:settimeout(...);
+end
+
 function dns.socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
        return _resolver:socket_wrapper_set(...);
 end