net.dns: remove unused one-letter loop variables [luacheck]
[prosody.git] / net / dns.lua
index bd5c260ec2f1a638843c662544139e1b2eb6e920..4a35fc1be79640660c374d1239910ec8bb253f08 100644 (file)
@@ -22,8 +22,8 @@ local is_windows = (_ and windows) or os.getenv("WINDIR");
 local coroutine, io, math, string, table =
       coroutine, io, math, string, table;
 
-local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
-      ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error, select, type, unpack=
+      ipairs, next, pairs, print, setmetatable, tostring, assert, error, select, type, table.unpack or unpack;
 
 local ztact = { -- public domain 20080404 lua@ztact.com
        get = function(parent, ...)
@@ -71,8 +71,8 @@ local get, set = ztact.get, ztact.set;
 local default_timeout = 15;
 
 -------------------------------------------------- module dns
-module('dns')
-local dns = _M;
+local _ENV = nil;
+local dns = {};
 
 
 -- dns type & class codes ------------------------------ dns type & class codes
@@ -135,17 +135,19 @@ end
 
 local function prune(rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
        time = time or socket.gettime();
-       for i,rr in pairs(rrs) do
+       for i,rr in ipairs(rrs) do
                if rr.tod then
                        -- rr.tod = rr.tod - 50    -- accelerated decripitude
                        rr.ttl = math.floor(rr.tod - time);
                        if rr.ttl <= 0 then
+                               rrs[rr[rr.type:lower()]] = nil;
                                table.remove(rrs, i);
                                return prune(rrs, time, soft); -- Re-iterate
                        end
                elseif soft == 'soft' then    -- What is this?  I forget!
                        assert(rr.ttl == 0);
-                       rrs[i] = nil;
+                       rrs[rr[rr.type:lower()]] = nil;
+                       table.remove(rrs, i);
                end
        end
 end
@@ -188,7 +190,7 @@ end
 local rrs_metatable = {};    -- - - - - - - - - - - - - - - - - -  rrs_metatable
 function rrs_metatable.__tostring(rrs)
        local t = {};
-       for i,rr in pairs(rrs) do
+       for _, rr in ipairs(rrs) do
                append(t, tostring(rr)..'\n');
        end
        return table.concat(t);
@@ -211,20 +213,11 @@ function cache_metatable.__tostring(cache)
 end
 
 
-function resolver:new()    -- - - - - - - - - - - - - - - - - - - - - resolver
-       local r = { active = {}, cache = {}, unsorted = {} };
-       setmetatable(r, resolver);
-       setmetatable(r.cache, cache_metatable);
-       setmetatable(r.unsorted, { __mode = 'kv' });
-       return r;
-end
-
-
 -- packet layer -------------------------------------------------- packet layer
 
 
 function dns.random(...)    -- - - - - - - - - - - - - - - - - - -  dns.random
-       math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000);
+       math.randomseed(math.floor(10000*socket.gettime()) % 0x80000000);
        dns.random = math.random;
        return dns.random(...);
 end
@@ -598,7 +591,7 @@ function resolver:adddefaultnameservers()    -- - - - -  adddefaultnameservers
                if resolv_conf then
                        for line in resolv_conf:lines() do
                                line = line:gsub("#.*$", "")
-                                       :match('^%s*nameserver%s+([%x:%.]*)%s*$');
+                                       :match('^%s*nameserver%s+([%x:%.]*%%?%S*)%s*$');
                                if line then
                                        local ip = new_ip(line);
                                        if ip then
@@ -622,12 +615,12 @@ function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
        local sock = self.socket[servernum];
        if sock then return sock; end
 
-       local err;
+       local ok, err;
        local peer = self.server[servernum];
        if peer:find(":") then
                sock, err = socket.udp6();
        else
-               sock, err = socket.udp();
+               sock, err = (socket.udp4 or socket.udp)();
        end
        if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end
        if not sock then
@@ -635,10 +628,14 @@ function resolver:getsocket(servernum)    -- - - - - - - - - - - - - getsocket
        end
        sock:settimeout(0);
        -- todo: attempt to use a random port, fallback to 0
-       sock:setsockname('*', 0);
-       sock:setpeername(peer, 53);
        self.socket[servernum] = sock;
        self.socketset[sock] = servernum;
+       -- set{sock,peer}name can fail, eg because of local routing table
+       -- if so, try the next server
+       ok, err = sock:setsockname('*', 0);
+       if not ok then return self:servfail(sock, err); end
+       ok, err = sock:setpeername(peer, 53);
+       if not ok then return self:servfail(sock, err); end
        return sock;
 end
 
@@ -681,7 +678,10 @@ function resolver:remember(rr, type)    -- - - - - - - - - - - - - -  remember
        self.cache = self.cache or setmetatable({}, cache_metatable);
        local rrs = get(self.cache, qclass, type, qname) or
                set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
-       append(rrs, rr);
+       if not rrs[rr[qtype:lower()]] then
+               rrs[rr[qtype:lower()]] = true;
+               append(rrs, rr);
+       end
 
        if type == 'MX' then self.unsorted[rrs] = true; end
 end
@@ -692,15 +692,20 @@ local function comp_mx(a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
 end
 
 
-function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
+function resolver:peek (qname, qtype, qclass, n)    -- - - - - - - - - - - -  peek
        qname, qtype, qclass = standardize(qname, qtype, qclass);
        local rrs = get(self.cache, qclass, qtype, qname);
-       if not rrs then return nil; end
+       if not rrs then
+               if n then if n <= 0 then return end else n = 3 end
+               rrs = get(self.cache, qclass, "CNAME", qname);
+               if not (rrs and rrs[1]) then return end
+               return self:peek(rrs[1].cname, qtype, qclass, n - 1);
+       end
        if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then
                set(self.cache, qclass, qtype, qname, nil);
                return nil;
        end
-       if self.unsorted[rrs] then table.sort (rrs, comp_mx); end
+       if self.unsorted[rrs] then table.sort (rrs, comp_mx); self.unsorted[rrs] = nil; end
        return rrs;
 end
 
@@ -722,6 +727,14 @@ end
 function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
        qname, qtype, qclass = standardize(qname, qtype, qclass)
 
+       local co = coroutine.running();
+       local q = get(self.wanted, qclass, qtype, qname);
+       if co and q then
+               -- We are already waiting for a reply to an identical query.
+               set(self.wanted, qclass, qtype, qname, co, true);
+               return true;
+       end
+
        if not self.server then self:adddefaultnameservers(); end
 
        local question = encodeQuestion(qname, qtype, qclass);
@@ -741,19 +754,17 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
        self.active[id] = self.active[id] or {};
        self.active[id][question] = o;
 
-       -- remember which coroutine wants the answer
-       local co = coroutine.running();
-       if co then
-               set(self.wanted, qclass, qtype, qname, co, true);
-               --set(self.yielded, co, qclass, qtype, qname, true);
-       end
-
        local conn, err = self:getsocket(o.server)
        if not conn then
                return nil, err;
        end
        conn:send (o.packet)
 
+       -- remember which coroutine wants the answer
+       if co then
+               set(self.wanted, qclass, qtype, qname, co, true);
+       end
+       
        if timer and self.timeout then
                local num_servers = #self.server;
                local i = 1;
@@ -770,20 +781,20 @@ function resolver:query(qname, qtype, qclass)    -- - - - - - - - - - -- query
                                        end
                                end
                                -- Tried everything, failed
-                               self:cancel(qclass, qtype, qname, co, true);
+                               self:cancel(qclass, qtype, qname);
                        end
                end)
        end
        return true;
 end
 
-function resolver:servfail(sock)
+function resolver:servfail(sock, err)
        -- Resend all queries for this server
 
        local num = self.socketset[sock]
 
        -- Socket is dead now
-       self:voidsocket(sock);
+       sock = self:voidsocket(sock);
 
        -- Find all requests to the down server, and retry on the next server
        self.time = socket.gettime();
@@ -800,8 +811,8 @@ function resolver:servfail(sock)
                                        --print('timeout');
                                        queries[question] = nil;
                                else
-                                       local _a = self:getsocket(o.server);
-                                       if _a then _a:send(o.packet); end
+                                       sock, err = self:getsocket(o.server);
+                                       if sock then sock:send(o.packet); end
                                end
                        end
                end
@@ -817,6 +828,7 @@ function resolver:servfail(sock)
                        self.best_server = 1;
                end
        end
+       return sock, err;
 end
 
 function resolver:settimeout(seconds)
@@ -829,7 +841,7 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
        rset = rset or self.socket;
 
        local response;
-       for i,sock in pairs(rset) do
+       for _, sock in pairs(rset) do
 
                if self.socketset[sock] then
                        local packet = sock:receive();
@@ -840,7 +852,7 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                                        --print('received response');
                                        --self.print(response);
 
-                                       for j,rr in pairs(response.answer) do
+                                       for _, rr in pairs(response.answer) do
                                                if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
                                                        self:remember(rr, response.question[1].type)
                                                end
@@ -849,7 +861,7 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                                        -- retire the query
                                        local queries = self.active[response.header.id];
                                        queries[response.question.raw] = nil;
-
+                                       
                                        if not next(queries) then self.active[response.header.id] = nil; end
                                        if not next(self.active) then self:closeall(); end
 
@@ -858,13 +870,12 @@ function resolver:receive(rset)    -- - - - - - - - - - - - - - - - -  receive
                                        local cos = get(self.wanted, q.class, q.type, q.name);
                                        if cos then
                                                for co in pairs(cos) do
-                                                       set(self.yielded, co, q.class, q.type, q.name, nil);
                                                        if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                                                end
                                                set(self.wanted, q.class, q.type, q.name, nil);
                                        end
                                end
-
+                               
                        end
                end
        end
@@ -883,7 +894,7 @@ function resolver:feed(sock, packet, force)
                --print('received response');
                --self.print(response);
 
-               for j,rr in pairs(response.answer) do
+               for _, rr in pairs(response.answer) do
                        self:remember(rr, response.question[1].type);
                end
 
@@ -899,7 +910,6 @@ function resolver:feed(sock, packet, force)
                        local cos = get(self.wanted, q.class, q.type, q.name);
                        if cos then
                                for co in pairs(cos) do
-                                       set(self.yielded, co, q.class, q.type, q.name, nil);
                                        if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                                end
                                set(self.wanted, q.class, q.type, q.name, nil);
@@ -910,13 +920,13 @@ function resolver:feed(sock, packet, force)
        return response;
 end
 
-function resolver:cancel(qclass, qtype, qname, co, call_handler)
+function resolver:cancel(qclass, qtype, qname)
        local cos = get(self.wanted, qclass, qtype, qname);
        if cos then
-               if call_handler then
-                       coroutine.resume(co);
+               for co in pairs(cos) do
+                       if coroutine.status(co) == "suspended" then coroutine.resume(co); end
                end
-               cos[co] = nil;
+               set(self.wanted, qclass, qtype, qname, nil);
        end
 end
 
@@ -1035,9 +1045,7 @@ end
 
 
 function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
-       -- this function seems to be redundant with resolver.new ()
-
-       local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 };
+       local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 };
        setmetatable (r, resolver);
        setmetatable (r.cache, cache_metatable);
        setmetatable (r.unsorted, { __mode = 'kv' });