Only advertise mechanisms needing channel binding if a channel binding backend is...
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local string = string
17 local tostring = tostring;
18 local base64 = require "util.encodings".base64;
19 local hmac_sha1 = require "util.hmac".sha1;
20 local sha1 = require "util.hashes".sha1;
21 local generate_uuid = require "util.uuid".generate;
22 local saslprep = require "util.encodings".stringprep.saslprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 module "scram"
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41
42 Supported Channel Binding Backends
43
44 'tls-unique' according to RFC 5929
45 ]]
46
47 local default_i = 4096
48
49 local function bp( b )
50         local result = ""
51         for i=1, b:len() do
52                 result = result.."\\"..b:byte(i)
53         end
54         return result
55 end
56
57 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
58
59 local result = {};
60 local function binaryXOR( a, b )
61         for i=1, #a do
62                 local x, y = byte(a, i), byte(b, i);
63                 local lowx, lowy = x % 16, y % 16;
64                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
65                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
66                 local r = hir * 16 + lowr;
67                 result[i] = char(r)
68         end
69         return t_concat(result);
70 end
71
72 -- hash algorithm independent Hi(PBKDF2) implementation
73 function Hi(hmac, str, salt, i)
74         local Ust = hmac(str, salt.."\0\0\0\1");
75         local res = Ust;
76         for n=1,i-1 do
77                 local Und = hmac(str, Ust)
78                 res = binaryXOR(res, Und)
79                 Ust = Und
80         end
81         return res
82 end
83
84 local function validate_username(username)
85         -- check for forbidden char sequences
86         for eq in username:gmatch("=(.?.?)") do
87                 if eq ~= "2D" and eq ~= "3D" then
88                         return false
89                 end
90         end
91         
92         -- replace =2D with , and =3D with =
93         username = username:gsub("=2D", ",");
94         username = username:gsub("=3D", "=");
95         
96         -- apply SASLprep
97         username = saslprep(username);
98         return username;
99 end
100
101 local function hashprep(hashname)
102         return hashname:lower():gsub("-", "_");
103 end
104
105 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
106         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
107                 return false, "inappropriate argument types"
108         end
109         if iteration_count < 4096 then
110                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
111         end
112         local salted_password = Hi(hmac_sha1, password, salt, iteration_count);
113         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
114         local server_key = hmac_sha1(salted_password, "Server Key");
115         return true, stored_key, server_key
116 end
117
118 local function scram_gen(hash_name, H_f, HMAC_f)
119         local function scram_hash(self, message)
120                 if not self.state then self["state"] = {} end
121                 local support_channel_binding = false;
122                 if self.profile.cb then support_channel_binding = true; end
123                 
124                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
125                 if not self.state.name then
126                         -- we are processing client_first_message
127                         local client_first_message = message;
128                         log("debug", client_first_message);
129                         -- TODO: fail if authzid is provided, since we don't support them yet
130                         self.state["client_first_message"] = client_first_message;
131                         self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
132                                 = client_first_message:match("^(%a)=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
133
134                         -- check for invalid gs2_flag_type start
135                         local gs2_flag_type = string.sub(self.state.gs2_cbind_flag, 0, 1)
136                         if gs2_flag_type ~=  "y" and gs2_flag_type ~=  "n" and gs2_flag_type ~=  "p" then
137                                 return "failure", "malformed-request", "The GS2 header has to start with 'y', 'n', or 'p'."
138                         end
139
140                         if support_channel_binding then
141                                 if string.sub(self.state.gs2_cbind_flag, 0, 1) == "y" then
142                                         return "failure", "malformed-request";
143                                 end
144                                 
145                                 -- check whether we support the proposed channel binding type
146                                 if not self.profile.cb[self.state.gs2_cbind_name] then
147                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
148                                 end
149                         else
150                                 -- we don't support channelbinding, 
151                                 if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
152                                         return "failure", "malformed-request";
153                                 end
154                         end
155
156                         if not self.state.name or not self.state.clientnonce then
157                                 return "failure", "malformed-request", "Channel binding isn't support at this time.";
158                         end
159                 
160                         self.state.name = validate_username(self.state.name);
161                         if not self.state.name then
162                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
163                                 return "failure", "malformed-request", "Invalid username.";
164                         end
165                 
166                         self.state["servernonce"] = generate_uuid();
167                         
168                         -- retreive credentials
169                         if self.profile.plain then
170                                 local password, state = self.profile.plain(self, self.state.name, self.realm)
171                                 if state == nil then return "failure", "not-authorized"
172                                 elseif state == false then return "failure", "account-disabled" end
173                                 
174                                 password = saslprep(password);
175                                 if not password then
176                                         log("debug", "Password violates SASLprep.");
177                                         return "failure", "not-authorized", "Invalid password."
178                                 end
179
180                                 self.state.salt = generate_uuid();
181                                 self.state.iteration_count = default_i;
182
183                                 local succ = false;
184                                 succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
185                                 if not succ then
186                                         log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
187                                         return "failure", "temporary-auth-failure";
188                                 end
189                         elseif self.profile["scram_"..hashprep(hash_name)] then
190                                 local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
191                                 if state == nil then return "failure", "not-authorized"
192                                 elseif state == false then return "failure", "account-disabled" end
193                                 
194                                 self.state.stored_key = stored_key;
195                                 self.state.server_key = server_key;
196                                 self.state.iteration_count = iteration_count;
197                                 self.state.salt = salt
198                         end
199                 
200                         local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
201                         self.state["server_first_message"] = server_first_message;
202                         return "challenge", server_first_message
203                 else
204                         -- we are processing client_final_message
205                         local client_final_message = message;
206                         log("debug", "client_final_message: %s", client_final_message);
207                         self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
208
209                         if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
210                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
211                         end
212
213                         if self.state.gs2_cbind_name then
214                                 -- we support channelbinding, so check if the value is valid
215                                 local client_gs2_header = base64.decode(self.state.channelbinding)
216                                 local our_client_gs2_header = "p="..self.state.gs2_cbind_name..","..self.state["authzid"]..","..self.profile.cb[self.state.gs2_cbind_name](self);
217
218                                 if client_gs2_header ~= our_client_gs2_header then
219                                         return "failure", "malformed-request", "Invalid channel binding value.";
220                                 end
221                         end
222
223                         if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
224                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
225                         end
226                         
227                         local ServerKey = self.state.server_key;
228                         local StoredKey = self.state.stored_key;
229                         
230                         local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
231                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
232                         local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
233                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
234
235                         if StoredKey == H_f(ClientKey) then
236                                 local server_final_message = "v="..base64.encode(ServerSignature);
237                                 self["username"] = self.state.name;
238                                 return "success", server_final_message;
239                         else
240                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
241                         end
242                 end
243         end
244         return scram_hash;
245 end
246
247 function init(registerMechanism)
248         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
249                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
250                 
251                 -- register channel binding equivalent
252                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
253         end
254
255         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
256 end
257
258 return _M;