X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=net%2Fdns.lua;h=689020a47a792e986fc95ad01546c654eca25a1e;hb=8806da0d002f1ad9a3faa9a16a3729d9c7442de1;hp=bd5c260ec2f1a638843c662544139e1b2eb6e920;hpb=cd4302a2db387ed852a06cb997e6424d61ca5419;p=prosody.git diff --git a/net/dns.lua b/net/dns.lua index bd5c260e..689020a4 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -22,8 +22,8 @@ local is_windows = (_ and windows) or os.getenv("WINDIR"); local coroutine, io, math, string, table = coroutine, io, math, string, table; -local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, select, type, unpack= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, select, type, table.unpack or unpack; local ztact = { -- public domain 20080404 lua@ztact.com get = function(parent, ...) @@ -71,8 +71,8 @@ local get, set = ztact.get, ztact.set; local default_timeout = 15; -------------------------------------------------- module dns -module('dns') -local dns = _M; +local _ENV = nil; +local dns = {}; -- dns type & class codes ------------------------------ dns type & class codes @@ -135,17 +135,19 @@ end local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune time = time or socket.gettime(); - for i,rr in pairs(rrs) do + for i,rr in ipairs(rrs) do if rr.tod then -- rr.tod = rr.tod - 50 -- accelerated decripitude rr.ttl = math.floor(rr.tod - time); if rr.ttl <= 0 then + rrs[rr[rr.type:lower()]] = nil; table.remove(rrs, i); return prune(rrs, time, soft); -- Re-iterate end elseif soft == 'soft' then -- What is this? I forget! assert(rr.ttl == 0); - rrs[i] = nil; + rrs[rr[rr.type:lower()]] = nil; + table.remove(rrs, i); end end end @@ -188,7 +190,7 @@ end local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable function rrs_metatable.__tostring(rrs) local t = {}; - for i,rr in pairs(rrs) do + for i,rr in ipairs(rrs) do append(t, tostring(rr)..'\n'); end return table.concat(t); @@ -211,20 +213,11 @@ function cache_metatable.__tostring(cache) end -function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver - local r = { active = {}, cache = {}, unsorted = {} }; - setmetatable(r, resolver); - setmetatable(r.cache, cache_metatable); - setmetatable(r.unsorted, { __mode = 'kv' }); - return r; -end - - -- packet layer -------------------------------------------------- packet layer function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); + math.randomseed(math.floor(10000*socket.gettime()) % 0x80000000); dns.random = math.random; return dns.random(...); end @@ -598,7 +591,7 @@ function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if resolv_conf then for line in resolv_conf:lines() do line = line:gsub("#.*$", "") - :match('^%s*nameserver%s+([%x:%.]*)%s*$'); + :match('^%s*nameserver%s+([%x:%.]*%%?%S*)%s*$'); if line then local ip = new_ip(line); if ip then @@ -622,12 +615,12 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket local sock = self.socket[servernum]; if sock then return sock; end - local err; + local ok, err; local peer = self.server[servernum]; if peer:find(":") then sock, err = socket.udp6(); else - sock, err = socket.udp(); + sock, err = (socket.udp4 or socket.udp)(); end if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end if not sock then @@ -635,10 +628,14 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket end sock:settimeout(0); -- todo: attempt to use a random port, fallback to 0 - sock:setsockname('*', 0); - sock:setpeername(peer, 53); self.socket[servernum] = sock; self.socketset[sock] = servernum; + -- set{sock,peer}name can fail, eg because of local routing table + -- if so, try the next server + ok, err = sock:setsockname('*', 0); + if not ok then return self:servfail(sock, err); end + ok, err = sock:setpeername(peer, 53); + if not ok then return self:servfail(sock, err); end return sock; end @@ -681,7 +678,10 @@ function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember self.cache = self.cache or setmetatable({}, cache_metatable); local rrs = get(self.cache, qclass, type, qname) or set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable)); - append(rrs, rr); + if not rrs[rr[qtype:lower()]] then + rrs[rr[qtype:lower()]] = true; + append(rrs, rr); + end if type == 'MX' then self.unsorted[rrs] = true; end end @@ -692,15 +692,20 @@ local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx end -function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek +function resolver:peek (qname, qtype, qclass, n) -- - - - - - - - - - - - peek qname, qtype, qclass = standardize(qname, qtype, qclass); local rrs = get(self.cache, qclass, qtype, qname); - if not rrs then return nil; end + if not rrs then + if n then if n <= 0 then return end else n = 3 end + rrs = get(self.cache, qclass, "CNAME", qname); + if not (rrs and rrs[1]) then return end + return self:peek(rrs[1].cname, qtype, qclass, n - 1); + end if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then set(self.cache, qclass, qtype, qname, nil); return nil; end - if self.unsorted[rrs] then table.sort (rrs, comp_mx); end + if self.unsorted[rrs] then table.sort (rrs, comp_mx); self.unsorted[rrs] = nil; end return rrs; end @@ -722,6 +727,14 @@ end function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query qname, qtype, qclass = standardize(qname, qtype, qclass) + local co = coroutine.running(); + local q = get(self.wanted, qclass, qtype, qname); + if co and q then + -- We are already waiting for a reply to an identical query. + set(self.wanted, qclass, qtype, qname, co, true); + return true; + end + if not self.server then self:adddefaultnameservers(); end local question = encodeQuestion(qname, qtype, qclass); @@ -741,19 +754,17 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query self.active[id] = self.active[id] or {}; self.active[id][question] = o; - -- remember which coroutine wants the answer - local co = coroutine.running(); - if co then - set(self.wanted, qclass, qtype, qname, co, true); - --set(self.yielded, co, qclass, qtype, qname, true); - end - local conn, err = self:getsocket(o.server) if not conn then return nil, err; end conn:send (o.packet) + -- remember which coroutine wants the answer + if co then + set(self.wanted, qclass, qtype, qname, co, true); + end + if timer and self.timeout then local num_servers = #self.server; local i = 1; @@ -770,20 +781,20 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query end end -- Tried everything, failed - self:cancel(qclass, qtype, qname, co, true); + self:cancel(qclass, qtype, qname); end end) end return true; end -function resolver:servfail(sock) +function resolver:servfail(sock, err) -- Resend all queries for this server local num = self.socketset[sock] -- Socket is dead now - self:voidsocket(sock); + sock = self:voidsocket(sock); -- Find all requests to the down server, and retry on the next server self.time = socket.gettime(); @@ -800,8 +811,8 @@ function resolver:servfail(sock) --print('timeout'); queries[question] = nil; else - local _a = self:getsocket(o.server); - if _a then _a:send(o.packet); end + sock, err = self:getsocket(o.server); + if sock then sock:send(o.packet); end end end end @@ -817,6 +828,7 @@ function resolver:servfail(sock) self.best_server = 1; end end + return sock, err; end function resolver:settimeout(seconds) @@ -849,7 +861,7 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive -- retire the query local queries = self.active[response.header.id]; queries[response.question.raw] = nil; - + if not next(queries) then self.active[response.header.id] = nil; end if not next(self.active) then self:closeall(); end @@ -858,13 +870,12 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive local cos = get(self.wanted, q.class, q.type, q.name); if cos then for co in pairs(cos) do - set(self.yielded, co, q.class, q.type, q.name, nil); if coroutine.status(co) == "suspended" then coroutine.resume(co); end end set(self.wanted, q.class, q.type, q.name, nil); end end - + end end end @@ -899,7 +910,6 @@ function resolver:feed(sock, packet, force) local cos = get(self.wanted, q.class, q.type, q.name); if cos then for co in pairs(cos) do - set(self.yielded, co, q.class, q.type, q.name, nil); if coroutine.status(co) == "suspended" then coroutine.resume(co); end end set(self.wanted, q.class, q.type, q.name, nil); @@ -910,13 +920,13 @@ function resolver:feed(sock, packet, force) return response; end -function resolver:cancel(qclass, qtype, qname, co, call_handler) +function resolver:cancel(qclass, qtype, qname) local cos = get(self.wanted, qclass, qtype, qname); if cos then - if call_handler then - coroutine.resume(co); + for co in pairs(cos) do + if coroutine.status(co) == "suspended" then coroutine.resume(co); end end - cos[co] = nil; + set(self.wanted, qclass, qtype, qname, nil); end end @@ -1035,9 +1045,7 @@ end function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver - -- this function seems to be redundant with resolver.new () - - local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 }; + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, best_server = 1 }; setmetatable (r, resolver); setmetatable (r.cache, cache_metatable); setmetatable (r.unsorted, { __mode = 'kv' });