0.3->0.4
[prosody.git] / net / dns.lua
index 01ee133c7ef0d0e8e9921e1937bd5caef555877f..0b37c0c6008c75a2a72764f7a6a942fa959e4acc 100644 (file)
@@ -1,4 +1,6 @@
-
+-- Prosody IM v0.4
+-- This file is included with Prosody IM. It has modifications,
+-- which are hereby placed in the public domain.
 
 -- public domain 20080404 lua@ztact.com
 
 
 require 'socket'
 local ztact = require 'util.ztact'
-
+local require = require
 
 local coroutine, io, math, socket, string, table =
       coroutine, io, math, socket, string, table
 
-local ipairs, next, pairs, print, setmetatable, tostring =
-      ipairs, next, pairs, print, setmetatable, tostring
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error =
+      ipairs, next, pairs, print, setmetatable, tostring, assert, error
 
 local get, set = ztact.get, ztact.set
 
@@ -130,7 +132,7 @@ function rr_metatable.__tostring (rr)
 
 local rrs_metatable = {}    -- - - - - - - - - - - - - - - - - -  rrs_metatable
 function rrs_metatable.__tostring (rrs)
-  t = {}
+  local t = {}
   for i,rr in pairs (rrs) do  append (t, tostring (rr)..'\n')  end
   return table.concat (t)
   end
@@ -251,7 +253,7 @@ function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
 
 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
   local b1, b2, b3, b4 = self:byte (4)
-  -- print ('dword', b1, b2, b3, b4)
+  --print ('dword', b1, b2, b3, b4)
   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
   end
 
@@ -267,7 +269,7 @@ function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
 
   local id = self:word ()
-  -- print (string.format (':header  id  %x', id))
+  --print (string.format (':header  id  %x', id))
   if not self.active[id] and not force then  return nil  end
 
   local h = { id = id }
@@ -320,7 +322,7 @@ function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
   local q = {}
   q.name  = self:name ()
   q.type  = dns.type[self:word ()]
-  q.class = dns.type[self:word ()]
+  q.class = dns.class[self:word ()]
   return q
   end
 
@@ -344,7 +346,7 @@ function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
 
 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
   local b = self:byte ()
-  -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+  --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
   end
 
@@ -502,11 +504,16 @@ function resolver:setnameserver (address)    -- - - - - - - - - - setnameserver
 
 function resolver:adddefaultnameservers ()    -- - - - -  adddefaultnameservers
   local resolv_conf = io.open("/etc/resolv.conf");
-  if not resolv_conf then return nil; end
-  for line in resolv_conf:lines() do
-    address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)')
-    if address then  self:addnameserver (address)  end
-    end  end
+  if resolv_conf then
+         for line in resolv_conf:lines() do
+               local address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)')
+               if address then  self:addnameserver (address)  end
+         end
+  else -- FIXME correct for windows, using opendns nameservers for now
+       self:addnameserver ("208.67.222.222")
+       self:addnameserver ("208.67.220.220")
+  end
+end
 
 
 function resolver:getsocket (servernum)    -- - - - - - - - - - - - - getsocket
@@ -542,12 +549,12 @@ function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
 
 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
 
-  -- print ('remember', type, rr.class, rr.type, rr.name)
+  --print ('remember', type, rr.class, rr.type, rr.name)
 
   if type ~= '*' then
     type = rr.type
     local all = get (self.cache, rr.class, '*', rr.name)
-    -- print ('remember all', all)
+    --print ('remember all', all)
     if all then  append (all, rr)  end
     end
 
@@ -582,7 +589,7 @@ function resolver:purge (soft)    -- - - - - - - - - - - - - - - - - - -  purge
     for class,types in pairs (self.cache or {}) do
       for type,names in pairs (types) do
         for name,rrs in pairs (names) do
-          prune (rrs, time, 'soft')
+          prune (rrs, self.time, 'soft')
           end  end  end
   else  self.cache = {}  end
   end
@@ -592,14 +599,14 @@ function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
 
   qname, qtype, qclass = standardize (qname, qtype, qclass)
 
-  if not self.server then  self:adddefaultnameservers ()  end
+  if not self.server then self:adddefaultnameservers ()  end
 
-  local question = question or encodeQuestion (qname, qtype, qclass)
+  local question = encodeQuestion (qname, qtype, qclass)
   local peek = self:peek (qname, qtype, qclass)
   if peek then  return peek  end
 
   local header, id = encodeHeader ()
-  -- print ('query  id', id, qclass, qtype, qname)
+  --print ('query  id', id, qclass, qtype, qname)
   local o = { packet = header..question,
               server = 1,
               delay  = 1,
@@ -614,13 +621,15 @@ function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
   local co = coroutine.running ()
   if co then
     set (self.wanted, qclass, qtype, qname, co, true)
-    set (self.yielded, co, qclass, qtype, qname, true)
-    end  end
+    --set (self.yielded, co, qclass, qtype, qname, true)
+  end
+end
+
 
 
 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
-  -- print 'receive'  print (self.socket)
+  --print 'receive'  print (self.socket)
   self.time = socket.gettime ()
   rset = rset or self.socket
 
@@ -633,8 +642,8 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
     response = self:decode (packet)
     if response then
-    -- print 'received response'
-    -- self.print (response)
+    --print 'received response'
+    --self.print (response)
 
     for i,section in pairs { 'answer', 'authority', 'additional' } do
       for j,rr in pairs (response[section]) do
@@ -653,7 +662,7 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
     if cos then
       for co in pairs (cos) do
         set (self.yielded, co, q.class, q.type, q.name, nil)
-       if not self.yielded[co] then  coroutine.resume (co)  end
+       if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
         end
       set (self.wanted, q.class, q.type, q.name, nil)
       end  end  end  end  end
@@ -662,10 +671,51 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
   end
 
 
+function resolver:feed(sock, packet)
+  --print 'receive'  print (self.socket)
+  self.time = socket.gettime ()
+
+  local response = self:decode (packet)
+  if response then
+    --print 'received response'
+    --self.print (response)
+
+    for i,section in pairs { 'answer', 'authority', 'additional' } do
+      for j,rr in pairs (response[section]) do
+        self:remember (rr, response.question[1].type)
+      end
+    end
+
+    -- retire the query
+    local queries = self.active[response.header.id]
+    if queries[response.question.raw] then
+      queries[response.question.raw] = nil
+    end
+    if not next (queries) then  self.active[response.header.id] = nil  end
+    if not next (self.active) then  self:closeall ()  end
+
+    -- was the query on the wanted list?
+    local q = response.question[1]
+    if q then
+      local cos = get (self.wanted, q.class, q.type, q.name)
+      if cos then
+        for co in pairs (cos) do
+          set (self.yielded, co, q.class, q.type, q.name, nil)
+          if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
+        end
+        set (self.wanted, q.class, q.type, q.name, nil)
+      end
+    end
+  end 
+
+  return response
+end
+
+
 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 
-  -- print ':pulse'
-  while self:receive () do end
+  --print ':pulse'
+  while self:receive() do end
   if not next (self.active) then  return nil  end
 
   self.time = socket.gettime ()
@@ -680,13 +730,14 @@ function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
           end
 
         if o.delay > #self.delays then
-          print ('timeout')
+          --print ('timeout')
           queries[question] = nil
           if not next (queries) then  self.active[id] = nil  end
           if not next (self.active) then  return nil  end
         else
-          -- print ('retry', o.server, o.delay)
-          self.socket[o.server]:send (o.packet)
+          --print ('retry', o.server, o.delay)
+          local _a = self.socket[o.server];
+          if _a then _a:send (o.packet) end
           o.retry = self.time + self.delays[o.delay]
           end  end  end  end
 
@@ -698,12 +749,16 @@ function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
   self:query (qname, qtype, qclass)
   while self:pulse () do  socket.select (self.socket, nil, 4)  end
-  -- print (self.cache)
+  --print (self.cache)
   return self:peek (qname, qtype, qclass)
   end
 
+function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
+  return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
+  end
+
 
--- print ---------------------------------------------------------------- print
+--print ---------------------------------------------------------------- print
 
 
 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
@@ -750,7 +805,7 @@ function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
       for j,t in pairs (rr) do
         if not common[j] then
           tmp = string.format ('%s[%i].%s', s, i, j)
-          print (string.format ('%-30s  %s', tmp, t))
+          print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
           end  end  end  end  end
 
 
@@ -767,7 +822,7 @@ function dns.resolver ()    -- - - - - - - - - - - - - - - - - - - - - resolver
 
   -- this function seems to be redundant with resolver.new ()
 
-  r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} }
+  local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} }
   setmetatable (r, resolver)
   setmetatable (r.cache, cache_metatable)
   setmetatable (r.unsorted, { __mode = 'kv' })
@@ -789,6 +844,9 @@ function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
   return resolve (resolver.query, ...)  end
 
+function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
+  return resolve (resolver.feed, ...)  end
+
 
 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
   return resolve (resolver.socket_wrapper_set, ...)  end