We have SRV resolving \o/
authorMatthew Wild <mwild1@gmail.com>
Tue, 18 Nov 2008 22:41:04 +0000 (22:41 +0000)
committerMatthew Wild <mwild1@gmail.com>
Tue, 18 Nov 2008 22:41:04 +0000 (22:41 +0000)
core/s2smanager.lua
net/dns.lua [new file with mode: 0644]
tests/test.lua
util/ztact.lua [new file with mode: 0644]

index 1fc2715d16e24b90ded51ff48f4cd7b05c32083a..d6ad2be157bd87f3f47c8fc8d196ed6e991d1f5f 100644 (file)
@@ -3,7 +3,7 @@ local hosts = hosts;
 local sessions = sessions;
 local socket = require "socket";
 local format = string.format;
-local t_insert = table.insert;
+local t_insert, t_sort = table.insert, table.sort;
 local get_traceback = debug.traceback;
 local tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber
     = tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber;
@@ -24,17 +24,19 @@ local md5_hash = require "util.hashes".md5;
 
 local dialback_secret = "This is very secret!!! Ha!";
 
-local srvmap = { ["gmail.com"] = "talk.google.com", ["identi.ca"] = "hampton.controlezvous.ca", ["cdr.se"] = "jabber.cdr.se" };
+local dns = require "net.dns";
 
 module "s2smanager"
 
+local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end
+
 function send_to_host(from_host, to_host, data)
        if data.name then data = tostring(data); end
        local host = hosts[from_host].s2sout[to_host];
        if host then
                -- We have a connection to this host already
                if host.type == "s2sout_unauthed" then
-                       host.log("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now...");
+                       (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now...");
                        if not host.notopen and not host.dialback_key then
                                host.log("debug", "dialback had not been initiated");
                                initiate_dialback(host);
@@ -87,11 +89,31 @@ function new_outgoing(from_host, to_host)
                local conn, handler = socket.tcp()
                
                --FIXME: Below parameters (ports/ip) are incorrect (use SRV)
-               to_host = srvmap[to_host] or to_host;
+               
+               local connect_host, connect_port = to_host, 5269;
+               
+               local answer = dns.lookup("_xmpp-server._tcp."..to_host..".", "SRV");
+               
+               if answer then
+                       log("debug", to_host.." has SRV records, handling...");
+                       local srv_hosts = {};
+                       host_session.srv_hosts = srv_hosts;
+                       for _, record in ipairs(answer) do
+                               t_insert(srv_hosts, record.srv);
+                       end
+                       t_sort(srv_hosts, compare_srv_priorities);
+                       
+                       local srv_choice = srv_hosts[1];
+                       if srv_choice then
+                               log("debug", "Best record found");
+                               connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port;
+                               log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port);
+                       end
+               end
                
                conn:settimeout(0);
-               local success, err = conn:connect(to_host, 5269);
-               if not success then
+               local success, err = conn:connect(connect_host, connect_port);
+               if not success and err ~= "timeout" then
                        log("warn", "s2s connect() failed: %s", err);
                end
                
diff --git a/net/dns.lua b/net/dns.lua
new file mode 100644 (file)
index 0000000..a75c1bf
--- /dev/null
@@ -0,0 +1,795 @@
+
+
+-- public domain 20080404 lua@ztact.com
+
+
+-- todo: quick (default) header generation
+-- todo: nxdomain, error handling
+-- todo: cache results of encodeName
+
+
+-- reference: http://tools.ietf.org/html/rfc1035
+-- reference: http://tools.ietf.org/html/rfc1876 (LOC)
+
+
+require 'socket'
+local ztact = require 'util.ztact'
+
+
+local coroutine, io, math, socket, string, table =
+      coroutine, io, math, socket, string, table
+
+local ipairs, next, pairs, print, setmetatable, tostring =
+      ipairs, next, pairs, print, setmetatable, tostring
+
+local get, set = ztact.get, ztact.set
+
+
+-------------------------------------------------- module dns
+module ('dns')
+local dns = _M;
+
+
+-- dns type & class codes ------------------------------ dns type & class codes
+
+
+local append = table.insert
+
+
+local function highbyte (i)    -- - - - - - - - - - - - - - - - - - -  highbyte
+  return (i-(i%0x100))/0x100
+  end
+
+
+local function augment (t)    -- - - - - - - - - - - - - - - - - - - -  augment
+  local a = {}
+  for i,s in pairs (t) do  a[i] = s  a[s] = s  a[string.lower (s)] = s  end
+  return a
+  end
+
+
+local function encode (t)    -- - - - - - - - - - - - - - - - - - - - -  encode
+  local code = {}
+  for i,s in pairs (t) do
+    local word = string.char (highbyte (i), i %0x100)
+    code[i] = word
+    code[s] = word
+    code[string.lower (s)] = word
+    end
+  return code
+  end
+
+
+dns.types = {
+  'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
+  'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
+  [ 28] = 'AAAA', [ 29] = 'LOC',   [ 33] = 'SRV',
+  [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
+
+
+dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
+
+
+dns.type      = augment (dns.types)
+dns.class     = augment (dns.classes)
+dns.typecode  = encode  (dns.types)
+dns.classcode = encode  (dns.classes)
+
+
+
+local function standardize (qname, qtype, qclass)    -- - - - - - - standardize
+  if string.byte (qname, -1) ~= 0x2E then  qname = qname..'.'  end
+  qname = string.lower (qname)
+  return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
+  end
+
+
+local function prune (rrs, time, soft)    -- - - - - - - - - - - - - - -  prune
+
+  time = time or socket.gettime ()
+  for i,rr in pairs (rrs) do
+
+    if rr.tod then
+      -- rr.tod = rr.tod - 50    -- accelerated decripitude
+      rr.ttl = math.floor (rr.tod - time)
+      if rr.ttl <= 0 then  rrs[i] = nil  end
+
+    elseif soft == 'soft' then    -- What is this?  I forget!
+      assert (rr.ttl == 0)
+      rrs[i] = nil
+      end  end  end
+
+
+-- metatables & co. ------------------------------------------ metatables & co.
+
+
+local resolver = {}
+resolver.__index = resolver
+
+
+local SRV_tostring
+
+
+local rr_metatable = {}    -- - - - - - - - - - - - - - - - - - -  rr_metatable
+function rr_metatable.__tostring (rr)
+  local s0 = string.format (
+    '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
+  local s1 = ''
+  if rr.type == 'A' then  s1 = ' '..rr.a
+  elseif rr.type == 'MX' then
+    s1 = string.format (' %2i %s', rr.pref, rr.mx)
+  elseif rr.type == 'CNAME' then  s1 = ' '..rr.cname
+  elseif rr.type == 'LOC'   then  s1 = ' '..resolver.LOC_tostring (rr)
+  elseif rr.type == 'NS'    then  s1 = ' '..rr.ns
+  elseif rr.type == 'SRV'   then  s1 = ' '..SRV_tostring (rr)
+  elseif rr.type == 'TXT'   then  s1 = ' '..rr.txt
+  else  s1 = ' <UNKNOWN RDATA TYPE>'  end
+  return s0..s1
+  end
+
+
+local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
+function rrs_metatable.__tostring (rrs)
+  t = {}
+  for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
+  return table.concat (t)
+  end
+
+
+local cache_metatable = {}    -- - - - - - - - - - - - - - - -  cache_metatable
+function cache_metatable.__tostring (cache)
+  local time = socket.gettime ()
+  local t = {}
+  for class,types in pairs (cache) do
+    for type,names in pairs (types) do
+      for name,rrs in pairs (names) do
+        prune (rrs, time)
+        append (t, tostring (rrs))  end  end  end
+  return table.concat (t)
+  end
+
+
+function resolver:new ()    -- - - - - - - - - - - - - - - - - - - - - resolver
+  local r = { active = {}, cache = {}, unsorted = {} }
+  setmetatable (r, resolver)
+  setmetatable (r.cache, cache_metatable)
+  setmetatable (r.unsorted, { __mode = 'kv' })
+  return r
+  end
+
+
+-- packet layer -------------------------------------------------- packet layer
+
+
+function dns.random (...)    -- - - - - - - - - - - - - - - - - - -  dns.random
+  math.randomseed (10000*socket.gettime ())
+  dns.random = math.random
+  return dns.random (...)
+  end
+
+
+local function encodeHeader (o)    -- - - - - - - - - - - - - - -  encodeHeader
+
+  o = o or {}
+
+  o.id = o.id or               -- 16b  (random) id
+    dns.random (0, 0xffff)
+
+  o.rd = o.rd or 1             --  1b  1 recursion desired
+  o.tc = o.tc or 0             --  1b  1 truncated response
+  o.aa = o.aa or 0             --  1b  1 authoritative response
+  o.opcode = o.opcode or 0     --  4b  0 query
+                               --      1 inverse query
+                               --      2 server status request
+                               --      3-15 reserved
+  o.qr = o.qr or 0             --  1b  0 query, 1 response
+
+  o.rcode = o.rcode or 0       --  4b  0 no error
+                               --      1 format error
+                               --      2 server failure
+                               --      3 name error
+                               --      4 not implemented
+                               --      5 refused
+                               --      6-15 reserved
+  o.z  = o.z  or 0             --  3b  0 resvered
+  o.ra = o.ra or 0             --  1b  1 recursion available
+
+  o.qdcount = o.qdcount or 1   -- 16b  number of question RRs
+  o.ancount = o.ancount or 0   -- 16b  number of answers RRs
+  o.nscount = o.nscount or 0   -- 16b  number of nameservers RRs
+  o.arcount = o.arcount or 0   -- 16b  number of additional RRs
+
+  -- string.char() rounds, so prevent roundup with -0.4999
+  local header = string.char (
+    highbyte (o.id),  o.id %0x100,
+    o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
+    o.rcode + 16*o.z + 128*o.ra,
+    highbyte (o.qdcount),  o.qdcount %0x100,
+    highbyte (o.ancount),  o.ancount %0x100,
+    highbyte (o.nscount),  o.nscount %0x100,
+    highbyte (o.arcount),  o.arcount %0x100 )
+
+  return header, o.id
+  end
+
+
+local function encodeName (name)    -- - - - - - - - - - - - - - - - encodeName
+  local t = {}
+  for part in string.gmatch (name, '[^.]+') do
+    append (t, string.char (string.len (part)))
+    append (t, part)
+    end
+  append (t, string.char (0))
+  return table.concat (t)
+  end
+
+
+local function encodeQuestion (qname, qtype, qclass)    -- - - - encodeQuestion
+  qname  = encodeName (qname)
+  qtype  = dns.typecode[qtype or 'a']
+  qclass = dns.classcode[qclass or 'in']
+  return qname..qtype..qclass;
+  end
+
+
+function resolver:byte (len)    -- - - - - - - - - - - - - - - - - - - - - byte
+  len = len or 1
+  local offset = self.offset
+  local last = offset + len - 1
+  if last > #self.packet then
+    error (string.format ('out of bounds: %i>%i', last, #self.packet))  end
+  self.offset = offset + len
+  return string.byte (self.packet, offset, last)
+  end
+
+
+function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
+  local b1, b2 = self:byte (2)
+  return 0x100*b1 + b2
+  end
+
+
+function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
+  local b1, b2, b3, b4 = self:byte (4)
+  -- print ('dword', b1, b2, b3, b4)
+  return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
+  end
+
+
+function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
+  len = len or 1
+  local s = string.sub (self.packet, self.offset, self.offset + len - 1)
+  self.offset = self.offset + len
+  return s
+  end
+
+
+function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
+
+  local id = self:word ()
+  -- print (string.format (':header  id  %x', id))
+  if not self.active[id] and not force then  return nil  end
+
+  local h = { id = id }
+
+  local b1, b2 = self:byte (2)
+
+  h.rd      = b1 %2
+  h.tc      = b1 /2%2
+  h.aa      = b1 /4%2
+  h.opcode  = b1 /8%16
+  h.qr      = b1 /128
+
+  h.rcode   = b2 %16
+  h.z       = b2 /16%8
+  h.ra      = b2 /128
+
+  h.qdcount = self:word ()
+  h.ancount = self:word ()
+  h.nscount = self:word ()
+  h.arcount = self:word ()
+
+  for k,v in pairs (h) do  h[k] = v-v%1  end
+
+  return h
+  end
+
+
+function resolver:name ()    -- - - - - - - - - - - - - - - - - - - - - -  name
+  local remember, pointers = nil, 0
+  local len = self:byte ()
+  local n = {}
+  while len > 0 do
+    if len >= 0xc0 then    -- name is "compressed"
+      pointers = pointers + 1
+      if pointers >= 20 then  error ('dns error: 20 pointers')  end
+      local offset = ((len-0xc0)*0x100) + self:byte ()
+      remember = remember or self.offset
+      self.offset = offset + 1    -- +1 for lua
+    else    -- name is not compressed
+      append (n, self:sub (len)..'.')
+      end
+    len = self:byte ()
+    end
+  self.offset = remember or self.offset
+  return table.concat (n)
+  end
+
+
+function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
+  local q = {}
+  q.name  = self:name ()
+  q.type  = dns.type[self:word ()]
+  q.class = dns.type[self:word ()]
+  return q
+  end
+
+
+function resolver:A (rr)    -- - - - - - - - - - - - - - - - - - - - - - - -  A
+  local b1, b2, b3, b4 = self:byte (4)
+  rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
+  end
+
+
+function resolver:CNAME (rr)    -- - - - - - - - - - - - - - - - - - - -  CNAME
+  rr.cname = self:name ()
+  end
+
+
+function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
+  rr.pref = self:word ()
+  rr.mx   = self:name ()
+  end
+
+
+function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
+  local b = self:byte ()
+  -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+  return ((b-(b%0x10))/0x10) * (10^(b%0x10))
+  end
+
+
+function resolver:LOC (rr)    -- - - - - - - - - - - - - - - - - - - - - -  LOC
+  rr.version = self:byte ()
+  if rr.version == 0 then
+    rr.loc           = rr.loc or {}
+    rr.loc.size      = self:LOC_nibble_power ()
+    rr.loc.horiz_pre = self:LOC_nibble_power ()
+    rr.loc.vert_pre  = self:LOC_nibble_power ()
+    rr.loc.latitude  = self:dword ()
+    rr.loc.longitude = self:dword ()
+    rr.loc.altitude  = self:dword ()
+    end  end
+
+
+local function LOC_tostring_degrees (f, pos, neg)    -- - - - - - - - - - - - -
+  f = f - 0x80000000
+  if f < 0 then  pos = neg  f = -f  end
+  local deg, min, msec
+  msec = f%60000
+  f    = (f-msec)/60000
+  min  = f%60
+  deg = (f-min)/60
+  return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
+  end
+
+
+function resolver.LOC_tostring (rr)    -- - - - - - - - - - - - -  LOC_tostring
+
+  local t = {}
+
+  --[[
+  for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
+                       'latitude', 'longitude', 'altitude' } do
+    append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
+    end
+  --]]
+
+  append ( t, string.format (
+    '%s    %s    %.2fm %.2fm %.2fm %.2fm',
+    LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
+    LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
+    (rr.loc.altitude - 10000000) / 100,
+    rr.loc.size / 100,
+    rr.loc.horiz_pre / 100,
+    rr.loc.vert_pre / 100 ) )
+
+  return table.concat (t)
+  end
+
+
+function resolver:NS (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  NS
+  rr.ns = self:name ()
+  end
+
+
+function resolver:SOA (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SOA
+  end
+
+
+function resolver:SRV (rr)    -- - - - - - - - - - - - - - - - - - - - - -  SRV
+  rr.srv = {}
+  rr.srv.priority = self:word ()
+  rr.srv.weight   = self:word ()
+  rr.srv.port     = self:word ()
+  rr.srv.target   = self:name ()
+  end
+
+
+function SRV_tostring (rr)    -- - - - - - - - - - - - - - - - - - SRV_tostring
+  local s = rr.srv
+  return string.format ( '%5d %5d %5d %s',
+                         s.priority, s.weight, s.port, s.target )
+  end
+
+
+function resolver:TXT (rr)    -- - - - - - - - - - - - - - - - - - - - - -  TXT
+  rr.txt = self:sub (rr.rdlength)
+  end
+
+
+function resolver:rr ()    -- - - - - - - - - - - - - - - - - - - - - - - -  rr
+  local rr = {}
+  setmetatable (rr, rr_metatable)
+  rr.name     = self:name (self)
+  rr.type     = dns.type[self:word ()] or rr.type
+  rr.class    = dns.class[self:word ()] or rr.class
+  rr.ttl      = 0x10000*self:word () + self:word ()
+  rr.rdlength = self:word ()
+
+  if rr.ttl == 0 then  -- pass
+  else  rr.tod = self.time + rr.ttl  end
+
+  local remember = self.offset
+  local rr_parser = self[dns.type[rr.type]]
+  if rr_parser then  rr_parser (self, rr)  end
+  self.offset = remember
+  rr.rdata = self:sub (rr.rdlength)
+  return rr
+  end
+
+
+function resolver:rrs (count)    -- - - - - - - - - - - - - - - - - - - - - rrs
+  local rrs = {}
+  for i = 1,count do  append (rrs, self:rr ())  end
+  return rrs
+  end
+
+
+function resolver:decode (packet, force)    -- - - - - - - - - - - - - - decode
+
+  self.packet, self.offset = packet, 1
+  local header = self:header (force)
+  if not header then  return nil  end
+  local response = { header = header }
+
+  response.question = {}
+  local offset = self.offset
+  for i = 1,response.header.qdcount do
+    append (response.question, self:question ())  end
+  response.question.raw = string.sub (self.packet, offset, self.offset - 1)
+
+  if not force then
+    if not self.active[response.header.id] or
+       not self.active[response.header.id][response.question.raw] then
+      return nil  end  end
+
+  response.answer     = self:rrs (response.header.ancount)
+  response.authority  = self:rrs (response.header.nscount)
+  response.additional = self:rrs (response.header.arcount)
+
+  return response
+  end
+
+
+-- socket layer -------------------------------------------------- socket layer
+
+
+resolver.delays = { 1, 3, 11, 45 }
+
+
+function resolver:addnameserver (address)    -- - - - - - - - - - addnameserver
+  self.server = self.server or {}
+  append (self.server, address)
+  end
+
+
+function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
+  self.server = {}
+  self:addnameserver (address)
+  end
+
+
+function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
+  for line in io.lines ('/etc/resolv.conf') do
+    address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)')
+    if address then  self:addnameserver (address)  end
+    end  end
+
+
+function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket
+
+  self.socket = self.socket or {}
+  self.socketset = self.socketset or {}
+
+  local sock = self.socket[servernum]
+  if sock then  return sock  end
+
+  sock = socket.udp ()
+  if self.socket_wrapper then  sock = self.socket_wrapper (sock)  end
+  sock:settimeout (0)
+  -- todo: attempt to use a random port, fallback to 0
+  sock:setsockname ('*', 0)
+  sock:setpeername (self.server[servernum], 53)
+  self.socket[servernum] = sock
+  self.socketset[sock] = sock
+  return sock
+  end
+
+
+function resolver:socket_wrapper_set (func)  -- - - - - - - socket_wrapper_set
+  self.socket_wrapper = func
+  end
+
+
+function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
+  for i,sock in ipairs (self.socket) do  self.socket[i]:close ()  end
+  self.socket = {}
+  end
+
+
+function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
+
+  -- print ('remember', type, rr.class, rr.type, rr.name)
+
+  if type ~= '*' then
+    type = rr.type
+    local all = get (self.cache, rr.class, '*', rr.name)
+    -- print ('remember all', all)
+    if all then  append (all, rr)  end
+    end
+
+  self.cache = self.cache or setmetatable ({}, cache_metatable)
+  local rrs = get (self.cache, rr.class, type, rr.name) or
+    set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
+  append (rrs, rr)
+
+  if type == 'MX' then  self.unsorted[rrs] = true  end
+  end
+
+
+local function comp_mx (a, b)    -- - - - - - - - - - - - - - - - - - - comp_mx
+  return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
+  end
+
+
+function resolver:peek (qname, qtype, qclass)    -- - - - - - - - - - - -  peek
+  qname, qtype, qclass = standardize (qname, qtype, qclass)
+  local rrs = get (self.cache, qclass, qtype, qname)
+  if not rrs then  return nil  end
+  if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
+    set (self.cache, qclass, qtype, qname, nil)  return nil  end
+  if self.unsorted[rrs] then  table.sort (rrs, comp_mx)  end
+  return rrs
+  end
+
+
+function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
+  if soft == 'soft' then
+    self.time = socket.gettime ()
+    for class,types in pairs (self.cache or {}) do
+      for type,names in pairs (types) do
+        for name,rrs in pairs (names) do
+          prune (rrs, time, 'soft')
+          end  end  end
+  else  self.cache = {}  end
+  end
+
+
+function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
+
+  qname, qtype, qclass = standardize (qname, qtype, qclass)
+
+  if not self.server then  self:adddefaultnameservers ()  end
+
+  local question = question or encodeQuestion (qname, qtype, qclass)
+  local peek = self:peek (qname, qtype, qclass)
+  if peek then  return peek  end
+
+  local header, id = encodeHeader ()
+  -- print ('query  id', id, qclass, qtype, qname)
+  local o = { packet = header..question,
+              server = 1,
+              delay  = 1,
+              retry  = socket.gettime () + self.delays[1] }
+  self:getsocket (o.server):send (o.packet)
+
+  -- remember the query
+  self.active[id] = self.active[id] or {}
+  self.active[id][question] = o
+
+  -- remember which coroutine wants the answer
+  local co = coroutine.running ()
+  if co then
+    set (self.wanted, qclass, qtype, qname, co, true)
+    set (self.yielded, co, qclass, qtype, qname, true)
+    end  end
+
+
+function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
+
+  -- print 'receive'  print (self.socket)
+  self.time = socket.gettime ()
+  rset = rset or self.socket
+
+  local response
+  for i,sock in pairs (rset) do
+
+    if self.socketset[sock] then
+    local packet = sock:receive ()
+    if packet then
+
+    response = self:decode (packet)
+    if response then
+    -- print 'received response'
+    -- self.print (response)
+
+    for i,section in pairs { 'answer', 'authority', 'additional' } do
+      for j,rr in pairs (response[section]) do
+        self:remember (rr, response.question[1].type)  end  end
+
+    -- retire the query
+    local queries = self.active[response.header.id]
+    if queries[response.question.raw] then
+      queries[response.question.raw] = nil  end
+    if not next (queries) then  self.active[response.header.id] = nil  end
+    if not next (self.active) then  self:closeall ()  end
+
+    -- was the query on the wanted list?
+    local q = response.question
+    local cos = get (self.wanted, q.class, q.type, q.name)
+    if cos then
+      for co in pairs (cos) do
+        set (self.yielded, co, q.class, q.type, q.name, nil)
+       if not self.yielded[co] then  coroutine.resume (co)  end
+        end
+      set (self.wanted, q.class, q.type, q.name, nil)
+      end  end  end  end  end
+
+  return response
+  end
+
+
+function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
+
+  -- print ':pulse'
+  while self:receive () do end
+  if not next (self.active) then  return nil  end
+
+  self.time = socket.gettime ()
+  for id,queries in pairs (self.active) do
+    for question,o in pairs (queries) do
+      if self.time >= o.retry then
+
+        o.server = o.server + 1
+        if o.server > #self.server then
+          o.server = 1
+          o.delay = o.delay + 1
+          end
+
+        if o.delay > #self.delays then
+          print ('timeout')
+          queries[question] = nil
+          if not next (queries) then  self.active[id] = nil  end
+          if not next (self.active) then  return nil  end
+        else
+          -- print ('retry', o.server, o.delay)
+          self.socket[o.server]:send (o.packet)
+          o.retry = self.time + self.delays[o.delay]
+          end  end  end  end
+
+  if next (self.active) then  return true  end
+  return nil
+  end
+
+
+function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
+  self:query (qname, qtype, qclass)
+  while self:pulse () do  socket.select (self.socket, nil, 4)  end
+  -- print (self.cache)
+  return self:peek (qname, qtype, qclass)
+  end
+
+
+-- print ---------------------------------------------------------------- print
+
+
+local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
+  qr = { [0]='query', 'response' },
+  opcode = { [0]='query', 'inverse query', 'server status request' },
+  aa = { [0]='non-authoritative', 'authoritative' },
+  tc = { [0]='complete', 'truncated' },
+  rd = { [0]='recursion not desired', 'recursion desired' },
+  ra = { [0]='recursion not available', 'recursion available' },
+  z  = { [0]='(reserved)' },
+  rcode = { [0]='no error', 'format error', 'server failure', 'name error',
+            'not implemented' },
+
+  type = dns.type,
+  class = dns.class, }
+
+
+local function hint (p, s)    -- - - - - - - - - - - - - - - - - - - - - - hint
+  return (hints[s] and hints[s][p[s]]) or ''  end
+
+
+function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
+
+  for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
+                    'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
+    print ( string.format ('%-30s', 'header.'..s),
+            response.header[s], hint (response.header, s) )
+    end
+
+  for i,question in ipairs (response.question) do
+    print (string.format ('question[%i].name         ', i), question.name)
+    print (string.format ('question[%i].type         ', i), question.type)
+    print (string.format ('question[%i].class        ', i), question.class)
+    end
+
+  local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
+  local tmp
+  for s,s in pairs {'answer', 'authority', 'additional'} do
+    for i,rr in pairs (response[s]) do
+      for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
+        tmp = string.format ('%s[%i].%s', s, i, t)
+        print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
+        end
+      for j,t in pairs (rr) do
+        if not common[j] then
+          tmp = string.format ('%s[%i].%s', s, i, j)
+          print (string.format ('%-30s  %s', tmp, t))
+          end  end  end  end  end
+
+
+-- module api ------------------------------------------------------ module api
+
+
+local function resolve (func, ...)    -- - - - - - - - - - - - - - resolver_get
+  dns._resolver = dns._resolver or dns.resolver ()
+  return func (dns._resolver, ...)
+  end
+
+
+function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
+
+  -- this function seems to be redundant with resolver.new ()
+
+  r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} }
+  setmetatable (r, resolver)
+  setmetatable (r.cache, cache_metatable)
+  setmetatable (r.unsorted, { __mode = 'kv' })
+  return r
+  end
+
+
+function dns.lookup (...)    -- - - - - - - - - - - - - - - - - - - - -  lookup
+  return resolve (resolver.lookup, ...)  end
+
+
+function dns.purge (...)    -- - - - - - - - - - - - - - - - - - - - - -  purge
+  return resolve (resolver.purge, ...)  end
+
+function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
+  return resolve (resolver.peek, ...)  end
+
+
+function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
+  return resolve (resolver.query, ...)  end
+
+
+function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
+  return resolve (resolver.socket_wrapper_set, ...)  end
+
+
+return dns
index c028e8595c00186385400c283034efa95d356820..aa0275d4966a513a8da67faf98d149f507918151 100644 (file)
@@ -83,3 +83,4 @@ end
 
 dotest "util.jid"
 dotest "core.stanza_router"
+dotest "core.s2smanager"
diff --git a/util/ztact.lua b/util/ztact.lua
new file mode 100644 (file)
index 0000000..15bcffa
--- /dev/null
@@ -0,0 +1,364 @@
+
+
+-- public domain 20080410 lua@ztact.com
+
+
+pcall (require, 'lfs')      -- lfs may not be installed/necessary.
+pcall (require, 'pozix')    -- pozix may not be installed/necessary.
+
+
+local getfenv, ipairs, next, pairs, pcall, require, select, tostring, type =
+      getfenv, ipairs, next, pairs, pcall, require, select, tostring, type
+local unpack, xpcall =
+      unpack, xpcall
+
+local io, lfs, os, string, table, pozix = io, lfs, os, string, table, pozix
+
+local assert, print = assert, print
+
+local error            = error
+
+
+module ((...) or 'ztact')    ------------------------------------- module ztact
+
+
+-- dir -------------------------------------------------------------------- dir
+
+
+function dir (path)    -- - - - - - - - - - - - - - - - - - - - - - - - - - dir
+  local it = lfs.dir (path)
+  return function ()
+    repeat
+      local dir = it ()
+      if dir ~= '.' and dir ~= '..' then  return dir  end
+    until not dir
+    end  end
+
+
+function is_file (path)    -- - - - - - - - - - - - - - - - - -  is_file (path)
+  local mode = lfs.attributes (path, 'mode')
+  return mode == 'file' and path
+  end
+
+
+-- network byte ordering -------------------------------- network byte ordering
+
+
+function htons (word)    -- - - - - - - - - - - - - - - - - - - - - - - - htons
+  return (word-word%0x100)/0x100, word%0x100
+  end
+
+
+-- pcall2 -------------------------------------------------------------- pcall2
+
+
+getfenv ().pcall = pcall    -- store the original pcall as ztact.pcall
+
+
+local argc, argv, errorhandler, pcall2_f
+
+
+local function _pcall2 ()    -- - - - - - - - - - - - - - - - - - - - - _pcall2
+  local tmpv = argv
+  argv = nil
+  return pcall2_f (unpack (tmpv, 1, argc))
+  end
+
+
+function seterrorhandler (func)    -- - - - - - - - - - - - - - seterrorhandler
+  errorhandler = func
+  end
+
+
+function pcall2 (f, ...)    -- - - - - - - - - - - - - - - - - - - - - - pcall2
+
+  pcall2_f = f
+  argc = select ('#', ...)
+  argv = { ... }
+
+  if not errorhandler then
+    local debug = require ('debug')
+    errorhandler = debug.traceback
+    end
+
+  return xpcall (_pcall2, errorhandler)
+  end
+
+
+function append (t, ...)    -- - - - - - - - - - - - - - - - - - - - - - append
+  local insert = table.insert
+  for i,v in ipairs {...} do
+    insert (t, v)
+    end  end
+
+
+function print_r (d, indent)    -- - - - - - - - - - - - - - - - - - -  print_r
+  local rep = string.rep ('  ', indent or 0)
+  if type (d) == 'table' then
+    for k,v in pairs (d) do
+      if type (v) == 'table' then
+        io.write (rep, k, '\n')
+        print_r (v, (indent or 0) + 1)
+      else  io.write (rep, k, ' = ', tostring (v), '\n')  end
+      end
+  else  io.write (d, '\n')  end
+  end
+
+
+function tohex (s)    -- - - - - - - - - - - - - - - - - - - - - - - - -  tohex
+  return string.format (string.rep ('%02x ', #s), string.byte (s, 1, #s))
+  end
+
+
+function tostring_r (d, indent, tab0)    -- - - - - - - - - - - - -  tostring_r
+
+  tab1 = tab0 or {}
+  local rep = string.rep ('  ', indent or 0)
+  if type (d) == 'table' then
+    for k,v in pairs (d) do
+      if type (v) == 'table' then
+        append (tab1, rep, k, '\n')
+        tostring_r (v, (indent or 0) + 1, tab1)
+      else  append (tab1, rep, k, ' = ', tostring (v), '\n')  end
+      end
+  else  append (tab1, d, '\n')  end
+
+  if not tab0 then  return table.concat (tab1)  end
+  end
+
+
+-- queue manipulation -------------------------------------- queue manipulation
+
+
+-- Possible queue states.  1 (i.e. queue.p[1]) is head of queue.
+--
+-- 1..2
+-- 3..4  1..2
+-- 3..4  1..2  5..6
+-- 1..2        5..6
+--             1..2
+
+
+local function print_queue (queue, ...)    -- - - - - - - - - - - - print_queue
+  for i=1,10 do  io.write ((queue[i]   or '.')..' ')  end
+  io.write ('\t')
+  for i=1,6  do  io.write ((queue.p[i] or '.')..' ')  end
+  print (...)
+  end
+
+
+function dequeue (queue)    -- - - - - - - - - - - - - - - - - - - - -  dequeue
+
+  local p = queue.p
+  if not p and queue[1] then  queue.p = { 1, #queue }  p = queue.p  end
+
+  if not p[1] then  return nil  end
+
+  local element = queue[p[1]]
+  queue[p[1]] = nil
+
+  if p[1] < p[2] then  p[1] = p[1] + 1
+
+  elseif p[4] then  p[1], p[2], p[3], p[4]  =  p[3], p[4], nil, nil
+
+  elseif p[5] then  p[1], p[2], p[5], p[6]  =  p[5], p[6], nil, nil
+
+  else  p[1], p[2]  =  nil, nil  end
+
+  print_queue (queue, '  de '..element)
+  return element
+  end
+
+
+function enqueue (queue, element)    -- - - - - - - - - - - - - - - - - enqueue
+
+  local p = queue.p
+  if not p then  queue.p = {}  p = queue.p  end
+
+  if p[5] then    -- p3..p4 p1..p2 p5..p6
+    p[6] = p[6]+1
+    queue[p[6]] = element
+
+  elseif p[3] then    -- p3..p4 p1..p2
+
+    if p[4]+1 < p[1] then
+      p[4] = p[4] + 1
+      queue[p[4]] = element
+
+    else
+      p[5] = p[2]+1
+      p[6], queue[p[5]] = p[5], element
+      end
+
+  elseif p[1] then    -- p1..p2
+    if p[1] == 1 then
+      p[2] = p[2] + 1
+      queue[p[2]] = element
+
+    else
+        p[3], p[4], queue[1] = 1, 1, element
+        end
+
+  else    -- empty queue
+    p[1], p[2], queue[1] = 1, 1, element
+    end
+
+  print_queue (queue, '     '..element)
+  end
+
+
+local function test_queue ()
+  t = {}
+  enqueue (t, 1)
+  enqueue (t, 2)
+  enqueue (t, 3)
+  enqueue (t, 4)
+  enqueue (t, 5)
+  dequeue (t)
+  dequeue (t)
+  enqueue (t, 6)
+  enqueue (t, 7)
+  enqueue (t, 8)
+  enqueue (t, 9)
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  enqueue (t, 'a')
+  dequeue (t)
+  enqueue (t, 'b')
+  enqueue (t, 'c')
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  enqueue (t, 'd')
+  dequeue (t)
+  dequeue (t)
+  dequeue (t)
+  end
+
+
+-- test_queue ()
+
+
+function queue_len (queue)
+  end
+
+
+function queue_peek (queue)
+  end
+
+
+-- tree manipulation ---------------------------------------- tree manipulation
+
+
+function set (parent, ...)    --- - - - - - - - - - - - - - - - - - - - - - set
+
+  -- print ('set', ...)
+
+  local len = select ('#', ...)
+  local key, value = select (len-1, ...)
+  local cutpoint, cutkey
+
+  for i=1,len-2 do
+
+    local key = select (i, ...)
+    local child = parent[key]
+
+    if value == nil then
+      if child == nil then  return
+      elseif next (child, next (child)) then  cutpoint = nil  cutkey = nil
+      elseif cutpoint == nil then  cutpoint = parent  cutkey = key  end
+
+    elseif child == nil then  child = {}  parent[key] = child  end
+
+    parent = child
+    end
+
+  if value == nil and cutpoint then  cutpoint[cutkey] = nil
+  else  parent[key] = value  return value  end
+  end
+
+
+function get (parent, ...)    --- - - - - - - - - - - - - - - - - - - - - - get
+  local len = select ('#', ...)
+  for i=1,len do
+    parent = parent[select (i, ...)]
+    if parent == nil then  break  end
+    end
+  return parent
+  end
+
+
+-- misc ------------------------------------------------------------------ misc
+
+
+function find (path, ...)    --------------------------------------------- find
+
+  local dirs, operators = { path }, {...}
+  for operator in ivalues (operators) do
+    if not operator (path) then  break  end  end
+
+  while next (dirs) do
+    local parent = table.remove (dirs)
+    for child in assert (pozix.opendir (parent)) do
+      if  child  and  child ~= '.'  and  child ~= '..'  then
+        local path = parent..'/'..child
+       if pozix.stat (path, 'is_dir') then  table.insert (dirs, path)  end
+        for operator in ivalues (operators) do
+          if not operator (path) then  break  end  end
+        end  end  end  end
+
+
+function ivalues (t)    ----------------------------------------------- ivalues
+  local i = 0
+  return function ()  if t[i+1] then  i = i + 1  return t[i]  end  end
+  end
+
+
+function lson_encode (mixed, f, indent, indents)    --------------- lson_encode
+
+
+  local capture
+  if not f then
+    capture = {}
+    f = function (s)  append (capture, s)  end
+    end
+
+  indent = indent or 0
+  indents = indents or {}
+  indents[indent] = indents[indent] or string.rep (' ', 2*indent)
+
+  local type = type (mixed)
+
+  if type == 'number' then f (mixed)
+
+  else if type == 'string' then f (string.format ('%q', mixed))
+
+  else if type == 'table' then
+    f ('{')
+    for k,v in pairs (mixed) do
+      f ('\n')
+      f (indents[indent])
+      f ('[')  f (lson_encode (k))  f ('] = ')
+      lson_encode (v, f, indent+1, indents)
+      f (',')
+      end 
+    f (' }')
+    end  end  end
+
+  if capture then  return table.concat (capture)  end
+  end
+
+
+function timestamp (time)    ---------------------------------------- timestamp
+  return os.date ('%Y%m%d.%H%M%S', time)
+  end
+
+
+function values (t)    ------------------------------------------------- values
+  local k, v
+  return function ()  k, v = next (t, k)  return v  end
+  end