Merge 0.9->0.10
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local tostring = tostring;
17 local base64 = require "util.encodings".base64;
18 local hmac_sha1 = require "util.hashes".hmac_sha1;
19 local sha1 = require "util.hashes".sha1;
20 local Hi = require "util.hashes".scram_Hi_sha1;
21 local generate_uuid = require "util.uuid".generate;
22 local saslprep = require "util.encodings".stringprep.saslprep;
23 local nodeprep = require "util.encodings".stringprep.nodeprep;
24 local log = require "util.logger".init("sasl");
25 local t_concat = table.concat;
26 local char = string.char;
27 local byte = string.byte;
28
29 module "sasl.scram"
30
31 --=========================
32 --SASL SCRAM-SHA-1 according to RFC 5802
33
34 --[[
35 Supported Authentication Backends
36
37 scram_{MECH}:
38         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
39         function(username, realm)
40                 return stored_key, server_key, iteration_count, salt, state;
41         end
42
43 Supported Channel Binding Backends
44
45 'tls-unique' according to RFC 5929
46 ]]
47
48 local default_i = 4096
49
50 local function bp( b )
51         local result = ""
52         for i=1, b:len() do
53                 result = result.."\\"..b:byte(i)
54         end
55         return result
56 end
57
58 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
59
60 local result = {};
61 local function binaryXOR( a, b )
62         for i=1, #a do
63                 local x, y = byte(a, i), byte(b, i);
64                 local lowx, lowy = x % 16, y % 16;
65                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
66                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
67                 local r = hir * 16 + lowr;
68                 result[i] = char(r)
69         end
70         return t_concat(result);
71 end
72
73 local function validate_username(username, _nodeprep)
74         -- check for forbidden char sequences
75         for eq in username:gmatch("=(.?.?)") do
76                 if eq ~= "2C" and eq ~= "3D" then
77                         return false
78                 end
79         end
80
81         -- replace =2C with , and =3D with =
82         username = username:gsub("=2C", ",");
83         username = username:gsub("=3D", "=");
84
85         -- apply SASLprep
86         username = saslprep(username);
87
88         if username and _nodeprep ~= false then
89                 username = (_nodeprep or nodeprep)(username);
90         end
91
92         return username and #username>0 and username;
93 end
94
95 local function hashprep(hashname)
96         return hashname:lower():gsub("-", "_");
97 end
98
99 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
100         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
101                 return false, "inappropriate argument types"
102         end
103         if iteration_count < 4096 then
104                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
105         end
106         local salted_password = Hi(password, salt, iteration_count);
107         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
108         local server_key = hmac_sha1(salted_password, "Server Key");
109         return true, stored_key, server_key
110 end
111
112 local function scram_gen(hash_name, H_f, HMAC_f)
113         local function scram_hash(self, message)
114                 if not self.state then self["state"] = {} end
115                 local support_channel_binding = false;
116                 if self.profile.cb then support_channel_binding = true; end
117
118                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
119                 if not self.state.name then
120                         -- we are processing client_first_message
121                         local client_first_message = message;
122
123                         -- TODO: fail if authzid is provided, since we don't support them yet
124                         self.state["client_first_message"] = client_first_message;
125                         self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
126                                 = client_first_message:match("^([ynp])=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
127
128                         local gs2_cbind_flag = self.state.gs2_cbind_flag;
129
130                         if not gs2_cbind_flag then
131                                 return "failure", "malformed-request";
132                         end
133
134                         if support_channel_binding and gs2_cbind_flag == "y" then
135                                 -- "y" -> client does support channel binding
136                                 --        but thinks the server does not.
137                                         return "failure", "malformed-request";
138                                 end
139
140                         if gs2_cbind_flag == "n" then
141                                 -- "n" -> client doesn't support channel binding.
142                                 support_channel_binding = false;
143                         end
144
145                         if support_channel_binding and gs2_cbind_flag == "p" then
146                                 -- check whether we support the proposed channel binding type
147                                 if not self.profile.cb[self.state.gs2_cbind_name] then
148                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
149                                 end
150                         else
151                                 -- no channel binding,
152                                 self.state.gs2_cbind_name = nil;
153                         end
154
155                         if not self.state.name or not self.state.clientnonce then
156                                 return "failure", "malformed-request", "Channel binding isn't support at this time.";
157                         end
158
159                         self.state.name = validate_username(self.state.name, self.profile.nodeprep);
160                         if not self.state.name then
161                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
162                                 return "failure", "malformed-request", "Invalid username.";
163                         end
164
165                         self.state["servernonce"] = generate_uuid();
166
167                         -- retreive credentials
168                         if self.profile.plain then
169                                 local password, state = self.profile.plain(self, self.state.name, self.realm)
170                                 if state == nil then return "failure", "not-authorized"
171                                 elseif state == false then return "failure", "account-disabled" end
172
173                                 password = saslprep(password);
174                                 if not password then
175                                         log("debug", "Password violates SASLprep.");
176                                         return "failure", "not-authorized", "Invalid password."
177                                 end
178
179                                 self.state.salt = generate_uuid();
180                                 self.state.iteration_count = default_i;
181
182                                 local succ = false;
183                                 succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
184                                 if not succ then
185                                         log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
186                                         return "failure", "temporary-auth-failure";
187                                 end
188                         elseif self.profile["scram_"..hashprep(hash_name)] then
189                                 local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
190                                 if state == nil then return "failure", "not-authorized"
191                                 elseif state == false then return "failure", "account-disabled" end
192
193                                 self.state.stored_key = stored_key;
194                                 self.state.server_key = server_key;
195                                 self.state.iteration_count = iteration_count;
196                                 self.state.salt = salt
197                         end
198
199                         local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
200                         self.state["server_first_message"] = server_first_message;
201                         return "challenge", server_first_message
202                 else
203                         -- we are processing client_final_message
204                         local client_final_message = message;
205
206                         self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
207
208                         if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
209                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
210                         end
211
212                         if self.state.gs2_cbind_name then
213                                 -- we support channelbinding, so check if the value is valid
214                                 local client_gs2_header = base64.decode(self.state.channelbinding)
215                                 local our_client_gs2_header = "p="..self.state.gs2_cbind_name..","..self.state["authzid"]..","..self.profile.cb[self.state.gs2_cbind_name](self);
216
217                                 if client_gs2_header ~= our_client_gs2_header then
218                                         return "failure", "malformed-request", "Invalid channel binding value.";
219                                 end
220                         end
221
222                         if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
223                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
224                         end
225
226                         local ServerKey = self.state.server_key;
227                         local StoredKey = self.state.stored_key;
228
229                         local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
230                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
231                         local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
232                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
233
234                         if StoredKey == H_f(ClientKey) then
235                                 local server_final_message = "v="..base64.encode(ServerSignature);
236                                 self["username"] = self.state.name;
237                                 return "success", server_final_message;
238                         else
239                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
240                         end
241                 end
242         end
243         return scram_hash;
244 end
245
246 function init(registerMechanism)
247         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
248                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
249
250                 -- register channel binding equivalent
251                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
252         end
253
254         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
255 end
256
257 return _M;