Merge Tobias SCRAM-PLUS work
[prosody.git] / util / sasl / scram.lua
index 31c078a0513a3fbb4880d8f460a4c757640aea52..cf938dba0888b5e995d1418f726860ba974f7fed 100644 (file)
@@ -14,6 +14,7 @@
 local s_match = string.match;
 local type = type
 local string = string
+local tostring = tostring;
 local base64 = require "util.encodings".base64;
 local hmac_sha1 = require "util.hashes".hmac_sha1;
 local sha1 = require "util.hashes".sha1;
@@ -39,6 +40,10 @@ scram_{MECH}:
        function(username, realm)
                return stored_key, server_key, iteration_count, salt, state;
        end
+
+Supported Channel Binding Backends
+
+'tls-unique' according to RFC 5929
 ]]
 
 local default_i = 4096
@@ -108,6 +113,8 @@ end
 local function scram_gen(hash_name, H_f, HMAC_f)
        local function scram_hash(self, message)
                if not self.state then self["state"] = {} end
+               local support_channel_binding = false;
+               if self.profile.cb then support_channel_binding = true; end
 
                if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
                if not self.state.name then
@@ -116,12 +123,29 @@ local function scram_gen(hash_name, H_f, HMAC_f)
 
                        -- TODO: fail if authzid is provided, since we don't support them yet
                        self.state["client_first_message"] = client_first_message;
-                       self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
-                               = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+                       self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
+                               = client_first_message:match("^(%a)=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
+
+                       -- check for invalid gs2_flag_type start
+                       local gs2_flag_type = string.sub(self.state.gs2_cbind_flag, 0, 1)
+                       if gs2_flag_type ~=  "y" and gs2_flag_type ~=  "n" and gs2_flag_type ~=  "p" then
+                               return "failure", "malformed-request", "The GS2 header has to start with 'y', 'n', or 'p'."
+                       end
 
-                       -- we don't do any channel binding yet
-                       if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
-                               return "failure", "malformed-request";
+                       if support_channel_binding then
+                               if string.sub(self.state.gs2_cbind_flag, 0, 1) == "y" then
+                                       return "failure", "malformed-request";
+                               end
+                               
+                               -- check whether we support the proposed channel binding type
+                               if not self.profile.cb[self.state.gs2_cbind_name] then
+                                       return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
+                               end
+                       else
+                               -- we don't support channelbinding, 
+                               if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+                                       return "failure", "malformed-request";
+                               end
                        end
 
                        if not self.state.name or not self.state.clientnonce then
@@ -181,6 +205,16 @@ local function scram_gen(hash_name, H_f, HMAC_f)
                                return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
                        end
 
+                       if self.state.gs2_cbind_name then
+                               -- we support channelbinding, so check if the value is valid
+                               local client_gs2_header = base64.decode(self.state.channelbinding)
+                               local our_client_gs2_header = "p="..self.state.gs2_cbind_name..","..self.state["authzid"]..","..self.profile.cb[self.state.gs2_cbind_name](self);
+
+                               if client_gs2_header ~= our_client_gs2_header then
+                                       return "failure", "malformed-request", "Invalid channel binding value.";
+                               end
+                       end
+
                        if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
                                return "failure", "malformed-request", "Wrong nonce in client-final-message.";
                        end
@@ -208,6 +242,9 @@ end
 function init(registerMechanism)
        local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
                registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+               
+               -- register channel binding equivalent
+               registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
        end
 
        registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);