Merge Tobias SCRAM-PLUS work
authorKim Alvefur <zash@zash.se>
Sat, 21 Sep 2013 22:44:20 +0000 (00:44 +0200)
committerKim Alvefur <zash@zash.se>
Sat, 21 Sep 2013 22:44:20 +0000 (00:44 +0200)
1  2 
plugins/mod_saslauth.lua
plugins/storage/ejabberdstore.lib.lua
util/sasl.lua
util/sasl/scram.lua

index 1bf6fb96f1bd2332333ab4afa1d3367f1ccf04d3,422bc18760195bc1b70537040ae8b0a56a3e1677..f24eacf87b0925e2cd80c2394eee1df3a00dafe5
@@@ -241,14 -245,24 +241,24 @@@ module:hook("stream-features", function
                if secure_auth_only and not origin.secure then
                        return;
                end
 -              origin.sasl_handler = usermanager_get_sasl_handler(module.host);
 +              origin.sasl_handler = usermanager_get_sasl_handler(module.host, origin);
+               if origin.secure then
+                       -- check wether LuaSec has the nifty binding to the function needed for tls-unique
+                       -- FIXME: would be nice to have this check only once and not for every socket
+                       if origin.conn:socket().getpeerfinished then
+                               origin.sasl_handler:add_cb_handler("tls-unique", function(self)
+                                       return self.userdata:getpeerfinished();
+                               end);
+                               origin.sasl_handler["userdata"] = origin.conn:socket();
+                       end
+               end
 -              features:tag("mechanisms", mechanisms_attr);
 +              local mechanisms = st.stanza("mechanisms", mechanisms_attr);
                for mechanism in pairs(origin.sasl_handler:mechanisms()) do
                        if mechanism ~= "PLAIN" or origin.secure or allow_unencrypted_plain_auth then
 -                              features:tag("mechanism"):text(mechanism):up();
 +                              mechanisms:tag("mechanism"):text(mechanism):up();
                        end
                end
 -              features:up();
 +              if mechanisms[1] then features:add_child(mechanisms); end
        else
                features:tag("bind", bind_attr):tag("required"):up():up();
                features:tag("session", xmpp_session_attr):tag("optional"):up():up();
diff --cc plugins/storage/ejabberdstore.lib.lua
index 7e8592a8fbc2a1f99ae8308f647eb8c8e8ca8b4d,7e8592a8fbc2a1f99ae8308f647eb8c8e8ca8b4d..0000000000000000000000000000000000000000
deleted file mode 100644,100644
+++ /dev/null
@@@ -1,190 -1,190 +1,0 @@@
--\r
--local handlers = {};\r
--\r
--handlers.accounts = {\r
--      get = function(self, user)\r
--              local select = self:query("select password from users where username=?", user);\r
--              local row = select and select:fetch();\r
--              if row then return { password = row[1] }; end\r
--      end;\r
--      set = function(self, user, data)\r
--              if data and data.password then\r
--                      return self:modify("update users set password=? where username=?", data.password, user)\r
--                              or self:modify("insert into users (username, password) values (?, ?)", user, data.password);\r
--              else\r
--                      return self:modify("delete from users where username=?", user);\r
--              end\r
--      end;\r
--};\r
--handlers.vcard = {\r
--      get = function(self, user)\r
--              local select = self:query("select vcard from vcard where username=?", user);\r
--              local row = select and select:fetch();\r
--              if row then return parse_xml(row[1]); end\r
--      end;\r
--      set = function(self, user, data)\r
--              if data then\r
--                      data = unparse_xml(data);\r
--                      return self:modify("update vcard set vcard=? where username=?", data, user)\r
--                              or self:modify("insert into vcard (username, vcard) values (?, ?)", user, data);\r
--              else\r
--                      return self:modify("delete from vcard where username=?", user);\r
--              end\r
--      end;\r
--};\r
--handlers.private = {\r
--      get = function(self, user)\r
--              local select = self:query("select namespace,data from private_storage where username=?", user);\r
--              if select then\r
--                      local data = {};\r
--                      for row in select:rows() do\r
--                              data[row[1]] = parse_xml(row[2]);\r
--                      end\r
--                      return data;\r
--              end\r
--      end;\r
--      set = function(self, user, data)\r
--              if data then\r
--                      self:modify("delete from private_storage where username=?", user);\r
--                      for namespace,text in pairs(data) do\r
--                              self:modify("insert into private_storage (username, namespace, data) values (?, ?, ?)", user, namespace, unparse_xml(text));\r
--                      end\r
--                      return true;\r
--              else\r
--                      return self:modify("delete from private_storage where username=?", user);\r
--              end\r
--      end;\r
--      -- TODO map_set, map_get\r
--};\r
--local subscription_map = { N = "none", B = "both", F = "from", T = "to" };\r
--local subscription_map_reverse = { none = "N", both = "B", from = "F", to = "T" };\r
--handlers.roster = {\r
--      get = function(self, user)\r
--              local select = self:query("select jid,nick,subscription,ask,server,subscribe,type from rosterusers where username=?", user);\r
--              if select then\r
--                      local roster = { pending = {} };\r
--                      for row in select:rows() do\r
--                              local jid,nick,subscription,ask,server,subscribe,typ = unpack(row);\r
--                              local item = { groups = {} };\r
--                              if nick == "" then nick = nil; end\r
--                              item.nick = nick;\r
--                              item.subscription = subscription_map[subscription];\r
--                              if ask == "N" then ask = nil;\r
--                              elseif ask == "O" then ask = "subscribe"\r
--                              elseif ask == "I" then roster.pending[jid] = true; ask = nil;\r
--                              elseif ask == "B" then roster.pending[jid] = true; ask = "subscribe";\r
--                              else module:log("debug", "bad roster_item.ask: %s", ask); ask = nil; end\r
--                              item.ask = ask;\r
--                              roster[jid] = item;\r
--                      end\r
--                      \r
--                      select = self:query("select jid,grp from rostergroups where username=?", user);\r
--                      if select then\r
--                              for row in select:rows() do\r
--                                      local jid,grp = unpack(rows);\r
--                                      if roster[jid] then roster[jid].groups[grp] = true; end\r
--                              end\r
--                      end\r
--                      select = self:query("select version from roster_version where username=?", user);\r
--                      local row = select and select:fetch();\r
--                      if row then\r
--                              roster[false] = { version = row[1]; };\r
--                      end\r
--                      return roster;\r
--              end\r
--      end;\r
--      set = function(self, user, data)\r
--              if data and next(data) ~= nil then\r
--                      self:modify("delete from rosterusers where username=?", user);\r
--                      self:modify("delete from rostergroups where username=?", user);\r
--                      self:modify("delete from roster_version where username=?", user);\r
--                      local done = {};\r
--                      local pending = data.pending or {};\r
--                      for jid,item in pairs(data) do\r
--                              if jid and jid ~= "pending" then\r
--                                      local subscription = subscription_map_reverse[item.subscription];\r
--                                      local ask;\r
--                                      if pending[jid] then\r
--                                              if item.ask then ask = "B"; else ask = "I"; end\r
--                                      else\r
--                                              if item.ask then ask = "O"; else ask = "N"; end\r
--                                      end\r
--                                      local r = self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?, '', '', '')", user, jid, item.nick or "", subscription, ask);\r
--                                      if not r then module:log("debug", "--- :( %s", tostring(r)); end\r
--                                      done[jid] = true;\r
--                                      for group in pairs(item.groups) do\r
--                                              self:modify("insert into rostergroups (username,jid,grp) values (?, ?, ?)", user, jid, group);\r
--                                      end\r
--                              end\r
--                      end\r
--                      for jid in pairs(pending) do\r
--                              if not done[jid] then\r
--                                      self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?. ''. ''. '')", user, jid, "", "N", "I");\r
--                              end\r
--                      end\r
--                      local version = data[false] and data[false].version;\r
--                      if version then\r
--                              self:modify("insert into roster_version (username,version) values (?, ?)", user, version);\r
--                      end\r
--                      return true;\r
--              else\r
--                      self:modify("delete from rosterusers where username=?", user);\r
--                      self:modify("delete from rostergroups where username=?", user);\r
--                      self:modify("delete from roster_version where username=?", user);\r
--              end\r
--      end;\r
--};\r
--\r
-------------------------------\r
--local driver = {};\r
--driver.__index = driver;\r
--\r
--function driver:prepare(sql)\r
--      module:log("debug", "query: %s", sql);\r
--      local err;\r
--      if not self.sqlcache then self.sqlcache = {}; end\r
--      local r = self.sqlcache[sql];\r
--      if r then return r; end\r
--      r, err = self.database:prepare(sql);\r
--      if not r then error("Unable to prepare SQL statement: "..err); end\r
--      self.sqlcache[sql] = r;\r
--      return r;\r
--end\r
--\r
--function driver:query(sql, ...)\r
--      local stmt = self:prepare(sql);\r
--      if stmt:execute(...) then return stmt; end\r
--end\r
--function driver:modify(sql, ...)\r
--      local stmt = self:query(sql, ...);\r
--      if stmt and stmt:affected() > 0 then return stmt; end\r
--end\r
--\r
--function driver:open(host, datastore, typ)\r
--      local cache_key = host.." "..datastore;\r
--      if self.ds_cache[cache_key] then return self.ds_cache[cache_key]; end\r
--      local instance = setmetatable({}, self);\r
--      instance.host = host;\r
--      instance.datastore = datastore;\r
--      local handler = handlers[datastore];\r
--      if not handler then return nil; end\r
--      for key,val in pairs(handler) do\r
--              instance[key] = val;\r
--      end\r
--      if instance.init then instance:init(); end\r
--      self.ds_cache[cache_key] = instance;\r
--      return instance;\r
--end\r
--\r
-------------------------------\r
--local _M = {};\r
--\r
--function _M.new(dbtype, dbname, ...)\r
--      local instance = setmetatable({}, driver);\r
--      instance.__index = instance;\r
--      instance.database = get_database(dbtype, dbname, ...);\r
--      instance.ds_cache = {};\r
--      return instance;\r
--end\r
--\r
--return _M;\r
diff --cc util/sasl.lua
Simple merge
index 31c078a0513a3fbb4880d8f460a4c757640aea52,071de505142886ba2b4864c9314518f102c5aef9..cf938dba0888b5e995d1418f726860ba974f7fed
  local s_match = string.match;
  local type = type
  local string = string
+ local tostring = tostring;
  local base64 = require "util.encodings".base64;
 -local hmac_sha1 = require "util.hmac".sha1;
 +local hmac_sha1 = require "util.hashes".hmac_sha1;
  local sha1 = require "util.hashes".sha1;
 +local Hi = require "util.hashes".scram_Hi_sha1;
  local generate_uuid = require "util.uuid".generate;
  local saslprep = require "util.encodings".stringprep.saslprep;
 +local nodeprep = require "util.encodings".stringprep.nodeprep;
  local log = require "util.logger".init("sasl");
  local t_concat = table.concat;
  local char = string.char;
@@@ -108,20 -118,39 +113,39 @@@ en
  local function scram_gen(hash_name, H_f, HMAC_f)
        local function scram_hash(self, message)
                if not self.state then self["state"] = {} end
 -              
+               local support_channel_binding = false;
+               if self.profile.cb then support_channel_binding = true; end
 +
                if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
                if not self.state.name then
                        -- we are processing client_first_message
                        local client_first_message = message;
 -                      log("debug", client_first_message);
 +
                        -- TODO: fail if authzid is provided, since we don't support them yet
                        self.state["client_first_message"] = client_first_message;
-                       self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
-                               = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+                       self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
+                               = client_first_message:match("^(%a)=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
+                       -- check for invalid gs2_flag_type start
+                       local gs2_flag_type = string.sub(self.state.gs2_cbind_flag, 0, 1)
+                       if gs2_flag_type ~=  "y" and gs2_flag_type ~=  "n" and gs2_flag_type ~=  "p" then
+                               return "failure", "malformed-request", "The GS2 header has to start with 'y', 'n', or 'p'."
+                       end
  
-                       -- we don't do any channel binding yet
-                       if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
-                               return "failure", "malformed-request";
+                       if support_channel_binding then
+                               if string.sub(self.state.gs2_cbind_flag, 0, 1) == "y" then
+                                       return "failure", "malformed-request";
+                               end
+                               
+                               -- check whether we support the proposed channel binding type
+                               if not self.profile.cb[self.state.gs2_cbind_name] then
+                                       return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
+                               end
+                       else
+                               -- we don't support channelbinding, 
+                               if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+                                       return "failure", "malformed-request";
+                               end
                        end
  
                        if not self.state.name or not self.state.clientnonce then