X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=net%2Fdns.lua;h=61fb62e8e2f6855f5f2117bcb74f977b46dfe21e;hb=6d7a9cba5240a142b1a1e397e6248cf13c261cca;hp=7d1cba8e529b003b1841a4af04b0dd997aa01e39;hpb=f9628d7faf30de6b457acac8844be32585125ba8;p=prosody.git diff --git a/net/dns.lua b/net/dns.lua index 7d1cba8e..61fb62e8 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -2,8 +2,6 @@ -- This file is included with Prosody IM. It has modifications, -- which are hereby placed in the public domain. --- public domain 20080404 lua@ztact.com - -- todo: quick (default) header generation -- todo: nxdomain, error handling @@ -15,18 +13,61 @@ local socket = require "socket"; -local ztact = require "util.ztact"; +local timer = require "util.timer"; + local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); local coroutine, io, math, string, table = coroutine, io, math, string, table; -local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack = - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack; +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; local get, set = ztact.get, ztact.set; +local default_timeout = 15; -------------------------------------------------- module dns module('dns') @@ -115,32 +156,28 @@ end local resolver = {}; resolver.__index = resolver; +resolver.timeout = default_timeout; local SRV_tostring; +local function default_rr_tostring(rr) + local rr_val = rr.type and rr[rr.type:lower()]; + if type(rr_val) ~= "string" then + return ""; + end + return rr_val; +end + +local special_tostrings = { + LOC = resolver.LOC_tostring; + MX = function (rr) return string.format('%2i %s', rr.pref, rr.mx); end; + SRV = SRV_tostring; +}; local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable function rr_metatable.__tostring(rr) - local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name); - local s1 = ''; - if rr.type == 'A' then - s1 = ' '..rr.a; - elseif rr.type == 'MX' then - s1 = string.format(' %2i %s', rr.pref, rr.mx); - elseif rr.type == 'CNAME' then - s1 = ' '..rr.cname; - elseif rr.type == 'LOC' then - s1 = ' '..resolver.LOC_tostring(rr); - elseif rr.type == 'NS' then - s1 = ' '..rr.ns; - elseif rr.type == 'SRV' then - s1 = ' '..SRV_tostring(rr); - elseif rr.type == 'TXT' then - s1 = ' '..rr.txt; - else - s1 = ' '; - end - return s0..s1; + local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr); + return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string); end @@ -434,6 +471,9 @@ function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV rr.srv.target = self:name(); end +function resolver:PTR(rr) + rr.ptr = self:name(); +end function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring local s = rr.srv; @@ -524,7 +564,7 @@ end function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if is_windows then - if windows then + if windows and windows.get_nameservers then for _, server in ipairs(windows.get_nameservers()) do self:addnameserver(server); end @@ -532,7 +572,7 @@ function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if not self.server or #self.server == 0 then -- TODO log warning about no nameservers, adding opendns servers as fallback self:addnameserver("208.67.222.222"); - self:addnameserver("208.67.220.220") ; + self:addnameserver("208.67.220.220"); end else -- posix local resolv_conf = io.open("/etc/resolv.conf"); @@ -562,7 +602,11 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket local sock = self.socket[servernum]; if sock then return sock; end - sock = socket.udp(); + local err; + sock, err = socket.udp(); + if not sock then + return nil, err; + end if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end sock:settimeout(0); -- todo: attempt to use a random port, fallback to 0 @@ -667,18 +711,44 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query retry = socket.gettime() + self.delays[1] }; - -- remember the query + -- remember the query self.active[id] = self.active[id] or {}; self.active[id][question] = o; - -- remember which coroutine wants the answer + -- remember which coroutine wants the answer local co = coroutine.running(); if co then set(self.wanted, qclass, qtype, qname, co, true); --set(self.yielded, co, qclass, qtype, qname, true); end - self:getsocket (o.server):send (o.packet) + local conn, err = self:getsocket(o.server) + if not conn then + return nil, err; + end + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn, err = self:getsocket(o.server); + if conn then + conn:send(o.packet); + return self.timeout; + end + end + -- Tried everything, failed + self:cancel(qclass, qtype, qname, co, true); + end + end) + end + return true; end function resolver:servfail(sock) @@ -710,7 +780,7 @@ function resolver:servfail(sock) end end end - + if num == self.best_server then self.best_server = self.best_server + 1; if self.best_server > #self.server then @@ -720,6 +790,10 @@ function resolver:servfail(sock) end end +function resolver:settimeout(seconds) + self.timeout = seconds; +end + function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive --print('receive'); print(self.socket); self.time = socket.gettime(); @@ -769,11 +843,11 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive end -function resolver:feed(sock, packet) +function resolver:feed(sock, packet, force) --print('receive'); print(self.socket); self.time = socket.gettime(); - local response = self:decode(packet); + local response = self:decode(packet, force); if response and self.active[response.header.id] and self.active[response.header.id][response.question.raw] then --print('received response'); @@ -801,15 +875,18 @@ function resolver:feed(sock, packet) set(self.wanted, q.class, q.type, q.name, nil); end end - end + end return response; end -function resolver:cancel(data) - local cos = get(self.wanted, unpack(data, 1, 3)); +function resolver:cancel(qclass, qtype, qname, co, call_handler) + local cos = get(self.wanted, qclass, qtype, qname); if cos then - cos[data[4]] = nil; + if call_handler then + coroutine.resume(co); + end + cos[co] = nil; end end @@ -851,7 +928,13 @@ end function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup self:query (qname, qtype, qclass) - while self:pulse() do socket.select(self.socket, nil, 4); end + while self:pulse() do + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end --print(self.cache); return self:peek(qname, qtype, qclass); end @@ -860,6 +943,9 @@ function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); end +function resolver:tohostname(ip) + return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR"); +end --print ---------------------------------------------------------------- print @@ -935,6 +1021,10 @@ function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup return _resolver:lookup(...); end +function dns.tohostname(...) + return _resolver:tohostname(...); +end + function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge return _resolver:purge(...); end @@ -955,6 +1045,10 @@ function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel return _resolver:cancel(...); end +function dns.settimeout(...) + return _resolver:settimeout(...); +end + function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set return _resolver:socket_wrapper_set(...); end