Add check for forbidden char sequences in validate_username().
[prosody.git] / util / sasl / scram.lua
index b7507f3e14bd276ca6ce902149f6abdc5cf6d25a..c3bc9600d465dddd6084e99d972c2898f726e8a4 100644 (file)
@@ -17,6 +17,7 @@ local base64 = require "util.encodings".base64;
 local xor = require "bit".bxor
 local hmac_sha1 = require "util.hmac".sha1;
 local sha1 = require "util.hashes".sha1;
+local generate_uuid = require "util.uuid".generate;
 
 module "plain"
 
@@ -59,7 +60,8 @@ end
 
 local function validate_username(username)
        -- check for forbidden char sequences
-       
+       for eq in s:gmatch("=(.?.?)") do
+               if eq ~= "2D" and eq ~= "3D" then return false end end return true;
        -- replace =2D with , and =3D with =
        
        -- apply SASLprep
@@ -71,15 +73,52 @@ local function scram_sha_1(self, message)
        
        if not self.state.name then
                -- we are processing client_first_message
-               self.state["name"] = string.match(client_first_message, "n=(.+),r=")
-               self.state["clientnonce"] = string.match(client_first_message, "r=([^,]+)")
+               local client_first_message = message;
+               self.state["name"] = client_first_message:match("n=(.+),r=")
+               self.state["clientnonce"] = client_first_message:match("r=([^,]+)")
                
                self.state.name = validate_username(self.state.name);
-               if not self.state.name then
+               if not self.state.name or not self.state.clientnonce then
                        return "failure", "malformed-request";
                end
+               self.state["servernonce"] = generate_uuid();
+               self.state["salt"] = generate_uuid();
+               
+               local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..default_i;
+               return "challenge", server_first_message
        else
                -- we are processing client_final_message
+               local client_final_message = message;
+               
+               self.state["proof"] = client_final_message:match("p=(.+)");
+               self.state["nonce"] = client_final_message:match("r=(.+),p=");
+               self.state["channelbinding"] = client_final_message:match("c=(.+),r=");
+               if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+                       return "failure", "malformed-request";
+               end
+               
+               local password;
+               if self.profile.plain then
+                       password, state = self.profile.plain(self.state.name, self.realm)
+                       if state == nil then return "failure", "not-authorized"
+                       elseif state == false then return "failure", "account-disabled" end
+               end
+               
+               local SaltedPassword = Hi(hmac_sha1, password, self.state.salt, default_i)
+               local ClientKey = hmac_sha1(SaltedPassword, "Client Key")
+               local ServerKey = hmac_sha1(SaltedPassword, "Server Key")
+               local StoredKey = sha1(ClientKey)
+               local AuthMessage = "n=" .. s_match(client_first_message,"n=(.+)") .. "," .. server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+               local ClientSignature = hmac_sha1(StoredKey, AuthMessage)
+               local ClientProof     = binaryXOR(ClientKey, ClientSignature)
+               local ServerSignature = hmac_sha1(ServerKey, AuthMessage)
+               
+               if base64.encode(ClientProof) == self.state.proof then
+                       local server_final_message = "v="..base64.encode(ServerSignature);
+                       return "success", server_final_message;
+               else
+                       return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
+               end
        end
 end