util.sasl_cyrus: Protect the call to cyrussasl.server_new properly.
[prosody.git] / util / sasl_cyrus.lua
index 610257ca124f8252b6955f0e815b699764ec166f..f20aff51a8965d340efdacc1fa6f9fe3bea245ed 100644 (file)
@@ -35,8 +35,11 @@ local initialized = false;
 
 local function init(service_name)
        if not initialized then
-               if pcall(cyrussasl.server_init, service_name) then
+               local st, errmsg = pcall(cyrussasl.server_init, service_name);
+               if st then
                        initialized = true;
+               else
+                       log("error", "Failed to initialize CyrusSASL: %s", errmsg);
                end
        end
 end
@@ -49,11 +52,24 @@ function new(realm, service_name)
 
        sasl_i.realm = realm;
        sasl_i.service_name = service_name;
-       sasl_i.cyrus = cyrussasl.server_new(service_name, nil, nil, nil, nil)
-       if sasl_i.cyrus == 0 then
-               log("error", "got NULL return value from server_new")
+
+       local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil)
+       if st then
+               sasl_i.cyrus = ret;
+       else
+               log("error", "server_new failed: %s", ret);
                return nil;
        end
+
+       if cyrussasl.set_canon_cb then
+               local c14n_cb = function (user)
+                       local node = s_match(user, "^([^@]+)");
+                       log("debug", "Canonicalizing username %s to %s", user, node)
+                       return node
+               end
+               cyrussasl.set_canon_cb(sasl_i.cyrus, c14n_cb);
+       end
+
        cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff)
        local s = setmetatable(sasl_i, method);
        return s;
@@ -84,6 +100,7 @@ end
 -- select a mechanism to use
 function method:select(mechanism)
        self.mechanism = mechanism;
+       if not self.mechs then self:mechanisms(); end
        return self.mechs[mechanism];
 end
 
@@ -110,7 +127,7 @@ function method:process(message)
             "undefined-condition",
             "SASL mechanism not available"
        elseif (err == -13) then -- SASL_BADAUTH
-          return "failure", "not-authorized"
+          return "failure", "not-authorized", cyrussasl.get_message( self.cyrus )
        else
           log("debug", "Got SASL error condition %d", err)
           return "failure",