Allow ampersands in passwords for SASL PLAIN mechanism.
[prosody.git] / util / sasl.lua
index 772e2dd5e9dbfa27b09b60b7b5c76530594e20e3..e7d90704342ffd0f2fad51947803675a0836ae64 100644 (file)
@@ -81,6 +81,7 @@ end
 -- create a new SASL object which can be used to authenticate clients
 function new(realm, profile)
        sasl_i = {profile = profile};
+       sasl_i.realm = realm;
        return setmetatable(sasl_i, method);
 end
 
@@ -92,7 +93,7 @@ function method:mechanisms()
                if backend_mechanism[backend] then
                        for _, mechanism in ipairs(backend_mechanism[backend]) do
                                mechanisms[mechanism] = true;
-                               end
+                       end
                end
        end
        self["possible_mechanisms"] = mechanisms;
@@ -101,38 +102,50 @@ end
 
 -- select a mechanism to use
 function method:select(mechanism)
-
+       self.mech_i = mechanisms[mechanism]
+       if self.mech_i == nil then 
+               return false;
+       end
+       return true;
 end
 
 -- feed new messages to process into the library
 function method:process(message)
-
+       if message == "" or message == nil then return "failure", "malformed-request" end
+       return self.mech_i(self, message);
 end
 
 --=========================
 --SASL PLAIN
-local function sasl_mechanism_plain(realm, credentials_handler)
-       local object = { mechanism = "PLAIN", realm = realm, credentials_handler = credentials_handler}
-       function object.feed(self, message)
-               if message == "" or message == nil then return "failure", "malformed-request" end
-               local response = message
-               local authorization = s_match(response, "([^&%z]+)")
-               local authentication = s_match(response, "%z([^&%z]+)%z")
-               local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
-
-               if authentication == nil or password == nil then return "failure", "malformed-request" end
-               self.username = authentication
-               local auth_success = self.credentials_handler("PLAIN", self.username, self.realm, password)
-
-               if auth_success then
-                       return "success"
-               elseif auth_success == nil then
-                       return "failure", "account-disabled"
-               else
-                       return "failure", "not-authorized"
-               end
+local function sasl_mechanism_plain(self, message)
+       local response = message
+       local authorization = s_match(response, "([^%z]+)")
+       local authentication = s_match(response, "%z([^%z]+)%z")
+       local password = s_match(response, "%z[^%z]+%z([^%z]+)")
+
+       if authentication == nil or password == nil then
+               return "failure", "malformed-request";
+       end
+
+       local correct, state = false, false;
+       if self.profile.plain then
+               local correct_password;
+               correct_password, state = self.profile.plain(authentication, self.realm);
+               if correct_password == password then correct = true; else correct = false; end
+       elseif self.profile.plain_test then
+               correct, state = self.profile.plain_test(authentication, self.realm, password);
+       end
+
+       self.username = authentication
+       if not state then
+               return "failure", "account-disabled";
+       end
+
+       if correct then
+               return "success";
+       else
+               return "failure", "not-authorized";
        end
-       return object
 end
 registerMechanism("PLAIN", {"plain", "plain_test"}, sasl_mechanism_plain);