net.dns: Ensure all pending requests get notified of a timeout when looking up a...
[prosody.git] / util / sasl_cyrus.lua
index 7d35b5e4866ababb2bbd88fba6c1cbc64959bbdb..196845876c84141b3c8a10a402ee5a3fae424559 100644 (file)
 
 local cyrussasl = require "cyrussasl";
 local log = require "util.logger".init("sasl_cyrus");
-local array = require "util.array";
 
-local tostring = tostring;
-local pairs, ipairs = pairs, ipairs;
-local t_insert, t_concat = table.insert, table.concat;
-local s_match = string.match;
 local setmetatable = setmetatable
 
-local keys = keys;
-
-local print = print
 local pcall = pcall
 local s_match, s_gmatch = string.match, string.gmatch
 
@@ -86,22 +78,22 @@ local function init(service_name)
 end
 
 -- create a new SASL object which can be used to authenticate clients
-function new(realm, service_name, app_name)
-       local sasl_i = {};
+-- host_fqdn may be nil in which case gethostname() gives the value. 
+--      For GSSAPI, this determines the hostname in the service ticket (after
+--      reverse DNS canonicalization, only if [libdefaults] rdns = true which
+--      is the default).  
+function new(realm, service_name, app_name, host_fqdn)
 
        init(app_name or service_name);
 
-       sasl_i.realm = realm;
-       sasl_i.service_name = service_name;
-
-       local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil)
-       if st then
-               sasl_i.cyrus = ret;
-       else
+       local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil)
+       if not st then
                log("error", "Creating SASL server connection failed: %s", ret);
                return nil;
        end
 
+       local sasl_i = { realm = realm, service_name = service_name, cyrus = ret };
+
        if cyrussasl.set_canon_cb then
                local c14n_cb = function (user)
                        local node = s_match(user, "^([^@]+)");
@@ -112,37 +104,31 @@ function new(realm, service_name, app_name)
        end
 
        cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff)
-       local s = setmetatable(sasl_i, method);
-       return s;
+       local mechanisms = {};
+       local cyrus_mechs = cyrussasl.listmech(sasl_i.cyrus, nil, "", " ", "");
+       for w in s_gmatch(cyrus_mechs, "[^ ]+") do
+               mechanisms[w] = true;
+       end
+       sasl_i.mechs = mechanisms;
+       return setmetatable(sasl_i, method);
 end
 
--- get a fresh clone with the same realm, profiles and forbidden mechanisms
+-- get a fresh clone with the same realm and service name
 function method:clean_clone()
        return new(self.realm, self.service_name)
 end
 
--- set the forbidden mechanisms
-function method:forbidden( restrict )
-       log("warn", "Called method:forbidden. NOT IMPLEMENTED.")
-       return {}
-end
-
 -- get a list of possible SASL mechanims to use
 function method:mechanisms()
-       local mechanisms = {}
-       local cyrus_mechs = cyrussasl.listmech(self.cyrus, nil, "", " ", "")
-       for w in s_gmatch(cyrus_mechs, "[^ ]+") do
-               mechanisms[w] = true;
-       end
-       self.mechs = mechanisms
-       return array.collect(keys(mechanisms));
+       return self.mechs;
 end
 
 -- select a mechanism to use
 function method:select(mechanism)
-       self.mechanism = mechanism;
-       if not self.mechs then self:mechanisms(); end
-       return self.mechs[mechanism];
+       if not self.selected and self.mechs[mechanism] then
+               self.selected = mechanism;
+               return true;
+       end
 end
 
 -- feed new messages to process into the library
@@ -150,8 +136,9 @@ function method:process(message)
        local err;
        local data;
 
-       if self.mechanism then
-               err, data = cyrussasl.server_start(self.cyrus, self.mechanism, message or "")
+       if not self.first_step_done then
+               err, data = cyrussasl.server_start(self.cyrus, self.selected, message or "")
+               self.first_step_done = true;
        else
                err, data = cyrussasl.server_step(self.cyrus, message or "")
        end
@@ -159,17 +146,20 @@ function method:process(message)
        self.username = cyrussasl.get_username(self.cyrus)
 
        if (err == 0) then -- SASL_OK
-          return "success", data
+               if self.require_provisioning and not self.require_provisioning(self.username) then
+                       return "failure", "not-authorized", "User authenticated successfully, but not provisioned for XMPP";
+               end
+               return "success", data
        elseif (err == 1) then -- SASL_CONTINUE
-          return "challenge", data
+               return "challenge", data
        elseif (err == -4) then -- SASL_NOMECH
-          log("debug", "SASL mechanism not available from remote end")
-          return "failure", "invalid-mechanism", "SASL mechanism not available"
+               log("debug", "SASL mechanism not available from remote end")
+               return "failure", "invalid-mechanism", "SASL mechanism not available"
        elseif (err == -13) then -- SASL_BADAUTH
-          return "failure", "not-authorized", sasl_errstring[err];
+               return "failure", "not-authorized", sasl_errstring[err];
        else
-          log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]);
-          return "failure", "undefined-condition", sasl_errstring[err];
+               log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]);
+               return "failure", "undefined-condition", sasl_errstring[err];
        end
 end