Merge 0.7->trunk
[prosody.git] / util / sasl_cyrus.lua
index 8023d12135a5496429f83a96f7fbc15bf18d97dc..b5b0e08d0a5adf1662075bf024fc93e85ced0c6a 100644 (file)
@@ -19,38 +19,57 @@ local tostring = tostring;
 local pairs, ipairs = pairs, ipairs;
 local t_insert, t_concat = table.insert, table.concat;
 local s_match = string.match;
+local setmetatable = setmetatable
 
 local keys = keys;
 
 local print = print
+local pcall = pcall
+local s_match, s_gmatch = string.match, string.gmatch
 
 module "sasl_cyrus"
 
 local method = {};
 method.__index = method;
-local mechanisms = {};
-local backend_mechanism = {};
-
--- register a new SASL mechanims
-local function registerMechanism(name, backends, f)
-       assert(type(name) == "string", "Parameter name MUST be a string.");
-       assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
-       assert(type(f) == "function", "Parameter f MUST be a function.");
-       mechanisms[name] = f
-       for _, backend_name in ipairs(backends) do
-               if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
-               t_insert(backend_mechanism[backend_name], name);
+local initialized = false;
+
+local function init(service_name)
+       if not initialized then
+               local st, errmsg = pcall(cyrussasl.server_init, service_name);
+               if st then
+                       initialized = true;
+               else
+                       log("error", "Failed to initialize Cyrus SASL: %s", errmsg);
+               end
        end
 end
 
 -- create a new SASL object which can be used to authenticate clients
 function new(realm, service_name)
        local sasl_i = {};
+
+       init(service_name);
+
        sasl_i.realm = realm;
        sasl_i.service_name = service_name;
-       sasl_i.cyrus = cyrussasl.server_new(service_name, realm, realm, nil, nil)
-       if sasl_i.cyrus ~= 0, 
-                  "got NULL return value from server_new")
+
+       local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil)
+       if st then
+               sasl_i.cyrus = ret;
+       else
+               log("error", "Creating SASL server connection failed: %s", ret);
+               return nil;
+       end
+
+       if cyrussasl.set_canon_cb then
+               local c14n_cb = function (user)
+                       local node = s_match(user, "^([^@]+)");
+                       log("debug", "Canonicalizing username %s to %s", user, node)
+                       return node
+               end
+               cyrussasl.set_canon_cb(sasl_i.cyrus, c14n_cb);
+       end
+
        cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff)
        local s = setmetatable(sasl_i, method);
        return s;
@@ -63,35 +82,37 @@ end
 
 -- set the forbidden mechanisms
 function method:forbidden( restrict )
-       log("debug", "Called method:forbidden. NOT IMPLEMENTED.")
+       log("warn", "Called method:forbidden. NOT IMPLEMENTED.")
        return {}
 end
 
 -- get a list of possible SASL mechanims to use
 function method:mechanisms()
        local mechanisms = {}
-       local cyrus_mechs = cyrussasl.listmech(self.cyrus)
-       for w in s_gmatch(cyrus_mechs, "%a+") do
+       local cyrus_mechs = cyrussasl.listmech(self.cyrus, nil, "", " ", "")
+       for w in s_gmatch(cyrus_mechs, "[^ ]+") do
                mechanisms[w] = true;
        end
-       self.mechanisms = mechanisms
+       self.mechs = mechanisms
        return array.collect(keys(mechanisms));
 end
 
 -- select a mechanism to use
 function method:select(mechanism)
        self.mechanism = mechanism;
-       return not self.mechanisms[mechanisms];
+       if not self.mechs then self:mechanisms(); end
+       return self.mechs[mechanism];
 end
 
 -- feed new messages to process into the library
 function method:process(message)
        local err;
        local data;
+
        if self.mechanism then
-               err, data = cyrussasl.server_start(self.cyrus, self.mechanism, message)
+               err, data = cyrussasl.server_start(self.cyrus, self.mechanism, message or "")
        else
-               err, data = cyrussasl.server_step(self.cyrus, message)
+               err, data = cyrussasl.server_step(self.cyrus, message or "")
        end
 
        self.username = cyrussasl.get_username(self.cyrus)
@@ -102,16 +123,12 @@ function method:process(message)
           return "challenge", data
        elseif (err == -4) then -- SASL_NOMECH
           log("debug", "SASL mechanism not available from remote end")
-          return "failure", 
-            "undefined-condition",
-            "SASL mechanism not available"
+          return "failure", "invalid-mechanism", "SASL mechanism not available"
        elseif (err == -13) then -- SASL_BADAUTH
-          return "failure", "not-authorized"
+          return "failure", "not-authorized", cyrussasl.get_message( self.cyrus )
        else
           log("debug", "Got SASL error condition %d", err)
-          return "failure", 
-            "undefined-condition",
-            cyrussasl.get_message( self.cyrus )
+          return "failure", "undefined-condition", cyrussasl.get_message( self.cyrus )
        end
 end