certmanager: Fix previous commit
[prosody.git] / net / dns.lua
index 23a453aa82e1c49be9e269655240f7238f43e463..2cb677f6e052f53a14cf6ea848a53eb394241fc8 100644 (file)
@@ -14,6 +14,7 @@
 
 local socket = require "socket";
 local timer = require "util.timer";
+local new_ip = require "util.ip".new_ip;
 
 local _, windows = pcall(require, "util.windows");
 local is_windows = (_ and windows) or os.getenv("WINDIR");
@@ -134,17 +135,19 @@ end
 
 local function prune(rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
        time = time or socket.gettime();
-       for i,rr in pairs(rrs) do
+       for i,rr in ipairs(rrs) do
                if rr.tod then
                        -- rr.tod = rr.tod - 50    -- accelerated decripitude
                        rr.ttl = math.floor(rr.tod - time);
                        if rr.ttl <= 0 then
+                               rrs[rr[rr.type:lower()]] = nil;
                                table.remove(rrs, i);
                                return prune(rrs, time, soft); -- Re-iterate
                        end
                elseif soft == 'soft' then    -- What is this?  I forget!
                        assert(rr.ttl == 0);
-                       rrs[i] = nil;
+                       rrs[rr[rr.type:lower()]] = nil;
+                       table.remove(rrs, i);
                end
        end
 end
@@ -158,8 +161,6 @@ resolver.__index = resolver;
 
 resolver.timeout = default_timeout;
 
-local SRV_tostring;
-
 local function default_rr_tostring(rr)
        local rr_val = rr.type and rr[rr.type:lower()];
        if type(rr_val) ~= "string" then
@@ -170,8 +171,13 @@ end
 
 local special_tostrings = {
        LOC = resolver.LOC_tostring;
-       MX = function (rr) return string.format('%2i %s', rr.pref, rr.mx); end;
-       SRV = SRV_tostring;
+       MX  = function (rr)
+               return string.format('%2i %s', rr.pref, rr.mx);
+       end;
+       SRV = function (rr)
+               local s = rr.srv;
+               return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
+       end;
 };
 
 local rr_metatable = {};   -- - - - - - - - - - - - - - - - - - -  rr_metatable
@@ -184,7 +190,7 @@ end
 local rrs_metatable = {};    -- - - - - - - - - - - - - - - - - -  rrs_metatable
 function rrs_metatable.__tostring(rrs)
        local t = {};
-       for i,rr in pairs(rrs) do
+       for i,rr in ipairs(rrs) do
                append(t, tostring(rr)..'\n');
        end
        return table.concat(t);
@@ -220,7 +226,7 @@ end
 
 
 function dns.random(...)    -- - - - - - - - - - - - - - - - - - -  dns.random
-       math.randomseed(math.floor(10000*socket.gettime()));
+       math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000);
        dns.random = math.random;
        return dns.random(...);
 end
@@ -355,6 +361,7 @@ function resolver:name()    -- - - - - - - - - - - - - - - - - - - - - -  name
        local remember, pointers = nil, 0;
        local len = self:byte();
        local n = {};
+       if len == 0 then return "." end -- Root label
        while len > 0 do
                if len >= 0xc0 then    -- name is "compressed"
                        pointers = pointers + 1;
@@ -386,6 +393,25 @@ function resolver:A(rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
        rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
 end
 
+function resolver:AAAA(rr)
+       local addr = {};
+       for i = 1, rr.rdlength, 2 do
+               local b1, b2 = self:byte(2);
+               table.insert(addr, ("%02x%02x"):format(b1, b2));
+       end
+       addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1");
+       local zeros = {};
+       for item in addr:gmatch(":[0:]+:") do
+               table.insert(zeros, item)
+       end
+       if #zeros == 0 then
+               rr.aaaa = addr;
+               return
+       elseif #zeros > 1 then
+               table.sort(zeros, function(a, b) return #a > #b end);
+       end
+       rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
+end
 
 function resolver:CNAME(rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
        rr.cname = self:name();
@@ -475,14 +501,8 @@ function resolver:PTR(rr)
        rr.ptr = self:name();
 end
 
-function SRV_tostring(rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
-       local s = rr.srv;
-       return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target );
-end
-
-
 function resolver:TXT(rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
-       rr.txt = self:sub (rr.rdlength);
+       rr.txt = self:sub (self:byte());
 end
 
 
@@ -532,6 +552,7 @@ function resolver:decode(packet, force)    -- - - - - - - - - - - - - - decode
 
        if not force then
                if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then
+                       self.active[response.header.id] = nil;
                        return nil;
                end
        end
@@ -579,11 +600,12 @@ function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
                if resolv_conf then
                        for line in resolv_conf:lines() do
                                line = line:gsub("#.*$", "")
-                                       :match('^%s*nameserver%s+(.*)%s*$');
+                                       :match('^%s*nameserver%s+([%x:%.]*)%s*$');
                                if line then
-                                       line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address)
-                                               self:addnameserver(address)
-                                       end);
+                                       local ip = new_ip(line);
+                                       if ip then
+                                               self:addnameserver(ip.addr);
+                                       end
                                end
                        end
                end
@@ -602,14 +624,27 @@ function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
        local sock = self.socket[servernum];
        if sock then return sock; end
 
-       sock = socket.udp();
-       if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end
+       local ok, err;
+       local peer = self.server[servernum];
+       if peer:find(":") then
+               sock, err = socket.udp6();
+       else
+               sock, err = socket.udp();
+       end
+       if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end
+       if not sock then
+               return nil, err;
+       end
        sock:settimeout(0);
        -- todo: attempt to use a random port, fallback to 0
-       sock:setsockname('*', 0);
-       sock:setpeername(self.server[servernum], 53);
        self.socket[servernum] = sock;
        self.socketset[sock] = servernum;
+       -- set{sock,peer}name can fail, eg because of local routing table
+       -- if so, try the next server
+       ok, err = sock:setsockname('*', 0);
+       if not ok then return self:servfail(sock, err); end
+       ok, err = sock:setpeername(peer, 53);
+       if not ok then return self:servfail(sock, err); end
        return sock;
 end
 
@@ -621,6 +656,7 @@ function resolver:voidsocket(sock)
                self.socket[self.socketset[sock]] = nil;
                self.socketset[sock] = nil;
        end
+       sock:close();
 end
 
 function resolver:socket_wrapper_set(func)  -- - - - - - - socket_wrapper_set
@@ -651,7 +687,10 @@ function resolver:remember(rr, type)    -- - - - - - - - - - - - - -  remember
        self.cache = self.cache or setmetatable({}, cache_metatable);
        local rrs = get(self.cache, qclass, type, qname) or
                set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
-       append(rrs, rr);
+       if not rrs[rr[qtype:lower()]] then
+               rrs[rr[qtype:lower()]] = true;
+               append(rrs, rr);
+       end
 
        if type == 'MX' then self.unsorted[rrs] = true; end
 end
@@ -685,13 +724,21 @@ function resolver:purge(soft)    -- - - - - - - - - - - - - - - - - - -  purge
                                end
                        end
                end
-       else self.cache = {}; end
+       else self.cache = setmetatable({}, cache_metatable); end
 end
 
 
 function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
        qname, qtype, qclass = standardize(qname, qtype, qclass)
 
+       local co = coroutine.running();
+       local q = get(self.wanted, qclass, qtype, qname);
+       if co and q then
+               -- We are already waiting for a reply to an identical query.
+               set(self.wanted, qclass, qtype, qname, co, true);
+               return true;
+       end
+
        if not self.server then self:adddefaultnameservers(); end
 
        local question = encodeQuestion(qname, qtype, qclass);
@@ -712,15 +759,16 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
        self.active[id][question] = o;
 
        -- remember which coroutine wants the answer
-       local co = coroutine.running();
        if co then
                set(self.wanted, qclass, qtype, qname, co, true);
-               --set(self.yielded, co, qclass, qtype, qname, true);
        end
 
-       local conn = self:getsocket(o.server)
+       local conn, err = self:getsocket(o.server)
+       if not conn then
+               return nil, err;
+       end
        conn:send (o.packet)
-       
+
        if timer and self.timeout then
                local num_servers = #self.server;
                local i = 1;
@@ -730,25 +778,27 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
                                        i = i + 1;
                                        self:servfail(conn);
                                        o.server = self.best_server;
-                                       conn = self:getsocket(o.server);
-                                       conn:send(o.packet);
-                                       return self.timeout;
-                               else
-                                       -- Tried everything, failed
-                                       self:cancel(qclass, qtype, qname, co, true);
+                                       conn, err = self:getsocket(o.server);
+                                       if conn then
+                                               conn:send(o.packet);
+                                               return self.timeout;
+                                       end
                                end
+                               -- Tried everything, failed
+                               self:cancel(qclass, qtype, qname);
                        end
                end)
        end
+       return true;
 end
 
-function resolver:servfail(sock)
+function resolver:servfail(sock, err)
        -- Resend all queries for this server
 
        local num = self.socketset[sock]
 
        -- Socket is dead now
-       self:voidsocket(sock);
+       sock = self:voidsocket(sock);
 
        -- Find all requests to the down server, and retry on the next server
        self.time = socket.gettime();
@@ -765,11 +815,14 @@ function resolver:servfail(sock)
                                        --print('timeout');
                                        queries[question] = nil;
                                else
-                                       local _a = self:getsocket(o.server);
-                                       if _a then _a:send(o.packet); end
+                                       sock, err = self:getsocket(o.server);
+                                       if sock then sock:send(o.packet); end
                                end
                        end
                end
+               if next(queries) == nil then
+                       self.active[id] = nil;
+               end
        end
 
        if num == self.best_server then
@@ -779,6 +832,7 @@ function resolver:servfail(sock)
                        self.best_server = 1;
                end
        end
+       return sock, err;
 end
 
 function resolver:settimeout(seconds)
@@ -811,7 +865,7 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                                        -- retire the query
                                        local queries = self.active[response.header.id];
                                        queries[response.question.raw] = nil;
-                                       
+
                                        if not next(queries) then self.active[response.header.id] = nil; end
                                        if not next(self.active) then self:closeall(); end
 
@@ -820,12 +874,12 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                                        local cos = get(self.wanted, q.class, q.type, q.name);
                                        if cos then
                                                for co in pairs(cos) do
-                                                       set(self.yielded, co, q.class, q.type, q.name, nil);
                                                        if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                                                end
                                                set(self.wanted, q.class, q.type, q.name, nil);
                                        end
                                end
+
                        end
                end
        end
@@ -860,7 +914,6 @@ function resolver:feed(sock, packet, force)
                        local cos = get(self.wanted, q.class, q.type, q.name);
                        if cos then
                                for co in pairs(cos) do
-                                       set(self.yielded, co, q.class, q.type, q.name, nil);
                                        if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                                end
                                set(self.wanted, q.class, q.type, q.name, nil);
@@ -871,13 +924,13 @@ function resolver:feed(sock, packet, force)
        return response;
 end
 
-function resolver:cancel(qclass, qtype, qname, co, call_handler)
+function resolver:cancel(qclass, qtype, qname)
        local cos = get(self.wanted, qclass, qtype, qname);
        if cos then
-               if call_handler then
-                       coroutine.resume(co);
+               for co in pairs(cos) do
+                       if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                end
-               cos[co] = nil;
+               set(self.wanted, qclass, qtype, qname, nil);
        end
 end
 
@@ -998,7 +1051,7 @@ end
 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
        -- this function seems to be redundant with resolver.new ()
 
-       local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 };
+       local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 };
        setmetatable (r, resolver);
        setmetatable (r.cache, cache_metatable);
        setmetatable (r.unsorted, { __mode = 'kv' });
@@ -1040,6 +1093,10 @@ function dns.settimeout(...)
        return _resolver:settimeout(...);
 end
 
+function dns.cache()
+       return _resolver.cache;
+end
+
 function dns.socket_wrapper_set(...)    -- - - - - - - - -  socket_wrapper_set
        return _resolver:socket_wrapper_set(...);
 end