X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=net%2Fdns.lua;h=c0875b5a589240e13c4ba554627b0d51d597815b;hb=eecf63b9e5aff94dc3d3e88e3b5dfa853d92fc1a;hp=c0de97fd426db6f3573332cda332072028bedc7c;hpb=543fb7e01d3706b68fa37693747007260d0a1323;p=prosody.git diff --git a/net/dns.lua b/net/dns.lua index c0de97fd..c0875b5a 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -2,8 +2,6 @@ -- This file is included with Prosody IM. It has modifications, -- which are hereby placed in the public domain. --- public domain 20080404 lua@ztact.com - -- todo: quick (default) header generation -- todo: nxdomain, error handling @@ -15,18 +13,61 @@ local socket = require "socket"; -local ztact = require "util.ztact"; +local timer = require "util.timer"; + local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); local coroutine, io, math, string, table = coroutine, io, math, string, table; -local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack = - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack; +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select = + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select; +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; local get, set = ztact.get, ztact.set; +local default_timeout = 15; -------------------------------------------------- module dns module('dns') @@ -115,6 +156,7 @@ end local resolver = {}; resolver.__index = resolver; +resolver.timeout = default_timeout; local SRV_tostring; @@ -434,6 +476,9 @@ function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV rr.srv.target = self:name(); end +function resolver:PTR(rr) + rr.ptr = self:name(); +end function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring local s = rr.srv; @@ -524,7 +569,7 @@ end function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if is_windows then - if windows then + if windows and windows.get_nameservers then for _, server in ipairs(windows.get_nameservers()) do self:addnameserver(server); end @@ -667,18 +712,39 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query retry = socket.gettime() + self.delays[1] }; - -- remember the query + -- remember the query self.active[id] = self.active[id] or {}; self.active[id][question] = o; - -- remember which coroutine wants the answer + -- remember which coroutine wants the answer local co = coroutine.running(); if co then set(self.wanted, qclass, qtype, qname, co, true); --set(self.yielded, co, qclass, qtype, qname, true); end - self:getsocket (o.server):send (o.packet) + local conn = self:getsocket(o.server) + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn = self:getsocket(o.server); + conn:send(o.packet); + return self.timeout; + else + -- Tried everything, failed + self:cancel(qclass, qtype, qname, co, true); + end + end + end) + end end function resolver:servfail(sock) @@ -710,7 +776,7 @@ function resolver:servfail(sock) end end end - + if num == self.best_server then self.best_server = self.best_server + 1; if self.best_server > #self.server then @@ -720,6 +786,10 @@ function resolver:servfail(sock) end end +function resolver:settimeout(seconds) + self.timeout = seconds; +end + function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive --print('receive'); print(self.socket); self.time = socket.gettime(); @@ -769,11 +839,11 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive end -function resolver:feed(sock, packet) +function resolver:feed(sock, packet, force) --print('receive'); print(self.socket); self.time = socket.gettime(); - local response = self:decode(packet); + local response = self:decode(packet, force); if response and self.active[response.header.id] and self.active[response.header.id][response.question.raw] then --print('received response'); @@ -806,10 +876,13 @@ function resolver:feed(sock, packet) return response; end -function resolver:cancel(data) - local cos = get(self.wanted, unpack(data, 1, 3)); +function resolver:cancel(qclass, qtype, qname, co, call_handler) + local cos = get(self.wanted, qclass, qtype, qname); if cos then - cos[data[4]] = nil; + if call_handler then + coroutine.resume(co); + end + cos[co] = nil; end end @@ -852,12 +925,12 @@ end function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup self:query (qname, qtype, qclass) while self:pulse() do - local recvt = {} - for i, s in ipairs(self.socket) do - recvt[i] = s - end - socket.select(recvt, nil, 4) - end + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end --print(self.cache); return self:peek(qname, qtype, qclass); end @@ -961,6 +1034,10 @@ function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel return _resolver:cancel(...); end +function dns.settimeout(...) + return _resolver:settimeout(...); +end + function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set return _resolver:socket_wrapper_set(...); end