Merge with waqas
[prosody.git] / net / dns.lua
index 7364161e367e18f50848dd2477dc4276f485495e..48c082185645d560273183f6f53d46da4c431abe 100644 (file)
@@ -1,4 +1,6 @@
-
+-- Prosody IM
+-- This file is included with Prosody IM. It has modifications,
+-- which are hereby placed in the public domain.
 
 -- public domain 20080404 lua@ztact.com
 
 
 -- public domain 20080404 lua@ztact.com
 
 
 require 'socket'
 local ztact = require 'util.ztact'
 
 require 'socket'
 local ztact = require 'util.ztact'
-
+local require = require
 
 local coroutine, io, math, socket, string, table =
       coroutine, io, math, socket, string, table
 
 
 local coroutine, io, math, socket, string, table =
       coroutine, io, math, socket, string, table
 
-local ipairs, next, pairs, print, setmetatable, tostring, assert, error =
-      ipairs, next, pairs, print, setmetatable, tostring, assert, error
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
+      ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
 
 local get, set = ztact.get, ztact.set
 
 
 local get, set = ztact.get, ztact.set
 
@@ -251,7 +253,7 @@ function resolver:word ()    -- - - - - - - - - - - - - - - - - - - - - -  word
 
 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
   local b1, b2, b3, b4 = self:byte (4)
 
 function resolver:dword ()    -- - - - - - - - - - - - - - - - - - - - -  dword
   local b1, b2, b3, b4 = self:byte (4)
-  -- print ('dword', b1, b2, b3, b4)
+  --print ('dword', b1, b2, b3, b4)
   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
   end
 
   return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
   end
 
@@ -267,7 +269,7 @@ function resolver:sub (len)    -- - - - - - - - - - - - - - - - - - - - - - sub
 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
 
   local id = self:word ()
 function resolver:header (force)    -- - - - - - - - - - - - - - - - - - header
 
   local id = self:word ()
-  -- print (string.format (':header  id  %x', id))
+  --print (string.format (':header  id  %x', id))
   if not self.active[id] and not force then  return nil  end
 
   local h = { id = id }
   if not self.active[id] and not force then  return nil  end
 
   local h = { id = id }
@@ -320,7 +322,7 @@ function resolver:question ()    -- - - - - - - - - - - - - - - - - -  question
   local q = {}
   q.name  = self:name ()
   q.type  = dns.type[self:word ()]
   local q = {}
   q.name  = self:name ()
   q.type  = dns.type[self:word ()]
-  q.class = dns.type[self:word ()]
+  q.class = dns.class[self:word ()]
   return q
   end
 
   return q
   end
 
@@ -344,7 +346,7 @@ function resolver:MX (rr)    -- - - - - - - - - - - - - - - - - - - - - - -  MX
 
 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
   local b = self:byte ()
 
 function resolver:LOC_nibble_power ()    -- - - - - - - - - -  LOC_nibble_power
   local b = self:byte ()
-  -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+  --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
   end
 
   return ((b-(b%0x10))/0x10) * (10^(b%0x10))
   end
 
@@ -547,12 +549,12 @@ function resolver:closeall ()    -- - - - - - - - - - - - - - - - - -  closeall
 
 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
 
 
 function resolver:remember (rr, type)    -- - - - - - - - - - - - - -  remember
 
-  -- print ('remember', type, rr.class, rr.type, rr.name)
+  --print ('remember', type, rr.class, rr.type, rr.name)
 
   if type ~= '*' then
     type = rr.type
     local all = get (self.cache, rr.class, '*', rr.name)
 
   if type ~= '*' then
     type = rr.type
     local all = get (self.cache, rr.class, '*', rr.name)
-    -- print ('remember all', all)
+    --print ('remember all', all)
     if all then  append (all, rr)  end
     end
 
     if all then  append (all, rr)  end
     end
 
@@ -597,14 +599,14 @@ function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
 
   qname, qtype, qclass = standardize (qname, qtype, qclass)
 
 
   qname, qtype, qclass = standardize (qname, qtype, qclass)
 
-  if not self.server then  self:adddefaultnameservers ()  end
+  if not self.server then self:adddefaultnameservers ()  end
 
   local question = encodeQuestion (qname, qtype, qclass)
   local peek = self:peek (qname, qtype, qclass)
   if peek then  return peek  end
 
   local header, id = encodeHeader ()
 
   local question = encodeQuestion (qname, qtype, qclass)
   local peek = self:peek (qname, qtype, qclass)
   if peek then  return peek  end
 
   local header, id = encodeHeader ()
-  -- print ('query  id', id, qclass, qtype, qname)
+  --print ('query  id', id, qclass, qtype, qname)
   local o = { packet = header..question,
               server = 1,
               delay  = 1,
   local o = { packet = header..question,
               server = 1,
               delay  = 1,
@@ -619,13 +621,15 @@ function resolver:query (qname, qtype, qclass)    -- - - - - - - - - - -- query
   local co = coroutine.running ()
   if co then
     set (self.wanted, qclass, qtype, qname, co, true)
   local co = coroutine.running ()
   if co then
     set (self.wanted, qclass, qtype, qname, co, true)
-    set (self.yielded, co, qclass, qtype, qname, true)
-    end  end
+    --set (self.yielded, co, qclass, qtype, qname, true)
+  end
+end
+
 
 
 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
 
 
 function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
-  -- print 'receive'  print (self.socket)
+  --print 'receive'  print (self.socket)
   self.time = socket.gettime ()
   rset = rset or self.socket
 
   self.time = socket.gettime ()
   rset = rset or self.socket
 
@@ -638,8 +642,8 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
 
     response = self:decode (packet)
     if response then
 
     response = self:decode (packet)
     if response then
-    -- print 'received response'
-    -- self.print (response)
+    --print 'received response'
+    --self.print (response)
 
     for i,section in pairs { 'answer', 'authority', 'additional' } do
       for j,rr in pairs (response[section]) do
 
     for i,section in pairs { 'answer', 'authority', 'additional' } do
       for j,rr in pairs (response[section]) do
@@ -658,7 +662,7 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
     if cos then
       for co in pairs (cos) do
         set (self.yielded, co, q.class, q.type, q.name, nil)
     if cos then
       for co in pairs (cos) do
         set (self.yielded, co, q.class, q.type, q.name, nil)
-       if not self.yielded[co] then  coroutine.resume (co)  end
+       if coroutine.status(co) == "suspended" then  coroutine.resume (co)  end
         end
       set (self.wanted, q.class, q.type, q.name, nil)
       end  end  end  end  end
         end
       set (self.wanted, q.class, q.type, q.name, nil)
       end  end  end  end  end
@@ -667,10 +671,57 @@ function resolver:receive (rset)    -- - - - - - - - - - - - - - - - -  receive
   end
 
 
   end
 
 
+function resolver:feed(sock, packet)
+  --print 'receive'  print (self.socket)
+  self.time = socket.gettime ()
+
+  local response = self:decode (packet)
+  if response then
+    --print 'received response'
+    --self.print (response)
+
+    for i,section in pairs { 'answer', 'authority', 'additional' } do
+      for j,rr in pairs (response[section]) do
+        self:remember (rr, response.question[1].type)
+      end
+    end
+
+    -- retire the query
+    local queries = self.active[response.header.id]
+    if queries[response.question.raw] then
+      queries[response.question.raw] = nil
+    end
+    if not next (queries) then  self.active[response.header.id] = nil  end
+    if not next (self.active) then  self:closeall ()  end
+
+    -- was the query on the wanted list?
+    local q = response.question[1]
+    if q then
+      local cos = get (self.wanted, q.class, q.type, q.name)
+      if cos then
+        for co in pairs (cos) do
+          set (self.yielded, co, q.class, q.type, q.name, nil)
+          if coroutine.status(co) == "suspended" then coroutine.resume (co)  end
+        end
+        set (self.wanted, q.class, q.type, q.name, nil)
+      end
+    end
+  end 
+
+  return response
+end
+
+function resolver:cancel(data)
+       local cos = get (self.wanted, unpack(data, 1, 3))
+       if cos then
+               cos[data[4]] = nil;
+       end
+end
+
 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 
 function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 
-  -- print ':pulse'
-  while self:receive () do end
+  --print ':pulse'
+  while self:receive() do end
   if not next (self.active) then  return nil  end
 
   self.time = socket.gettime ()
   if not next (self.active) then  return nil  end
 
   self.time = socket.gettime ()
@@ -685,13 +736,14 @@ function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
           end
 
         if o.delay > #self.delays then
           end
 
         if o.delay > #self.delays then
-          print ('timeout')
+          --print ('timeout')
           queries[question] = nil
           if not next (queries) then  self.active[id] = nil  end
           if not next (self.active) then  return nil  end
         else
           queries[question] = nil
           if not next (queries) then  self.active[id] = nil  end
           if not next (self.active) then  return nil  end
         else
-          -- print ('retry', o.server, o.delay)
-          self.socket[o.server]:send (o.packet)
+          --print ('retry', o.server, o.delay)
+          local _a = self.socket[o.server];
+          if _a then _a:send (o.packet) end
           o.retry = self.time + self.delays[o.delay]
           end  end  end  end
 
           o.retry = self.time + self.delays[o.delay]
           end  end  end  end
 
@@ -703,12 +755,16 @@ function resolver:pulse ()    -- - - - - - - - - - - - - - - - - - - - -  pulse
 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
   self:query (qname, qtype, qclass)
   while self:pulse () do  socket.select (self.socket, nil, 4)  end
 function resolver:lookup (qname, qtype, qclass)    -- - - - - - - - - -  lookup
   self:query (qname, qtype, qclass)
   while self:pulse () do  socket.select (self.socket, nil, 4)  end
-  -- print (self.cache)
+  --print (self.cache)
   return self:peek (qname, qtype, qclass)
   end
 
   return self:peek (qname, qtype, qclass)
   end
 
+function resolver:lookupex (handler, qname, qtype, qclass)    -- - - - - - - - - -  lookup
+  return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
+  end
+
 
 
--- print ---------------------------------------------------------------- print
+--print ---------------------------------------------------------------- print
 
 
 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
 
 
 local hints = {    -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
@@ -755,7 +811,7 @@ function resolver.print (response)    -- - - - - - - - - - - - - resolver.print
       for j,t in pairs (rr) do
         if not common[j] then
           tmp = string.format ('%s[%i].%s', s, i, j)
       for j,t in pairs (rr) do
         if not common[j] then
           tmp = string.format ('%s[%i].%s', s, i, j)
-          print (string.format ('%-30s  %s', tmp, t))
+          print (string.format ('%-30s  %s', tostring(tmp), tostring(t)))
           end  end  end  end  end
 
 
           end  end  end  end  end
 
 
@@ -794,6 +850,11 @@ function dns.peek (...)    -- - - - - - - - - - - - - - - - - - - - - - -  peek
 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
   return resolve (resolver.query, ...)  end
 
 function dns.query (...)    -- - - - - - - - - - - - - - - - - - - - - -  query
   return resolve (resolver.query, ...)  end
 
+function dns.feed (...)    -- - - - - - - - - - - - - - - - - - - - - -  feed
+  return resolve (resolver.feed, ...)  end
+
+function dns.cancel(...)   -- - - - - - - - - - - - - - - - - - - - - -  cancel
+  return resolve(resolver.cancel, ...) end
 
 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
   return resolve (resolver.socket_wrapper_set, ...)  end
 
 function dns:socket_wrapper_set (...)    -- - - - - - - - -  socket_wrapper_set
   return resolve (resolver.socket_wrapper_set, ...)  end