Merge 0.9->0.10
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local base64 = require "util.encodings".base64;
17 local hmac_sha1 = require "util.hashes".hmac_sha1;
18 local sha1 = require "util.hashes".sha1;
19 local Hi = require "util.hashes".scram_Hi_sha1;
20 local generate_uuid = require "util.uuid".generate;
21 local saslprep = require "util.encodings".stringprep.saslprep;
22 local nodeprep = require "util.encodings".stringprep.nodeprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 local _ENV = nil;
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41
42 Supported Channel Binding Backends
43
44 'tls-unique' according to RFC 5929
45 ]]
46
47 local default_i = 4096
48
49 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
50
51 local result = {};
52 local function binaryXOR( a, b )
53         for i=1, #a do
54                 local x, y = byte(a, i), byte(b, i);
55                 local lowx, lowy = x % 16, y % 16;
56                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
57                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
58                 local r = hir * 16 + lowr;
59                 result[i] = char(r)
60         end
61         return t_concat(result);
62 end
63
64 local function validate_username(username, _nodeprep)
65         -- check for forbidden char sequences
66         for eq in username:gmatch("=(.?.?)") do
67                 if eq ~= "2C" and eq ~= "3D" then
68                         return false
69                 end
70         end
71
72         -- replace =2C with , and =3D with =
73         username = username:gsub("=2C", ",");
74         username = username:gsub("=3D", "=");
75
76         -- apply SASLprep
77         username = saslprep(username);
78
79         if username and _nodeprep ~= false then
80                 username = (_nodeprep or nodeprep)(username);
81         end
82
83         return username and #username>0 and username;
84 end
85
86 local function hashprep(hashname)
87         return hashname:lower():gsub("-", "_");
88 end
89
90 local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
91         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
92                 return false, "inappropriate argument types"
93         end
94         if iteration_count < 4096 then
95                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
96         end
97         local salted_password = Hi(password, salt, iteration_count);
98         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
99         local server_key = hmac_sha1(salted_password, "Server Key");
100         return true, stored_key, server_key
101 end
102
103 local function scram_gen(hash_name, H_f, HMAC_f)
104         local profile_name = "scram_" .. hashprep(hash_name);
105         local function scram_hash(self, message)
106                 local support_channel_binding = false;
107                 if self.profile.cb then support_channel_binding = true; end
108
109                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
110                 local state = self.state;
111                 if not state then
112                         -- we are processing client_first_message
113                         local client_first_message = message;
114
115                         -- TODO: fail if authzid is provided, since we don't support them yet
116                         local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
117                                 = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
118
119                         if not gs2_cbind_flag then
120                                 return "failure", "malformed-request";
121                         end
122
123                         if support_channel_binding and gs2_cbind_flag == "y" then
124                                 -- "y" -> client does support channel binding
125                                 --        but thinks the server does not.
126                                         return "failure", "malformed-request";
127                                 end
128
129                         if gs2_cbind_flag == "n" then
130                                 -- "n" -> client doesn't support channel binding.
131                                 support_channel_binding = false;
132                         end
133
134                         if support_channel_binding and gs2_cbind_flag == "p" then
135                                 -- check whether we support the proposed channel binding type
136                                 if not self.profile.cb[gs2_cbind_name] then
137                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
138                                 end
139                         else
140                                 -- no channel binding,
141                                 gs2_cbind_name = nil;
142                         end
143
144                         username = validate_username(username, self.profile.nodeprep);
145                         if not username then
146                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
147                                 return "failure", "malformed-request", "Invalid username.";
148                         end
149
150                         -- retreive credentials
151                         local stored_key, server_key, salt, iteration_count;
152                         if self.profile.plain then
153                                 local password, status = self.profile.plain(self, username, self.realm)
154                                 if status == nil then return "failure", "not-authorized"
155                                 elseif status == false then return "failure", "account-disabled" end
156
157                                 password = saslprep(password);
158                                 if not password then
159                                         log("debug", "Password violates SASLprep.");
160                                         return "failure", "not-authorized", "Invalid password."
161                                 end
162
163                                 salt = generate_uuid();
164                                 iteration_count = default_i;
165
166                                 local succ;
167                                 succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
168                                 if not succ then
169                                         log("error", "Generating authentication database failed. Reason: %s", stored_key);
170                                         return "failure", "temporary-auth-failure";
171                                 end
172                         elseif self.profile[profile_name] then
173                                 local status;
174                                 stored_key, server_key, iteration_count, salt, status = self.profile[profile_name](self, username, self.realm);
175                                 if status == nil then return "failure", "not-authorized"
176                                 elseif status == false then return "failure", "account-disabled" end
177                         end
178
179                         local nonce = clientnonce .. generate_uuid();
180                         local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
181                         self.state = {
182                                 gs2_header = gs2_header;
183                                 gs2_cbind_name = gs2_cbind_name;
184                                 username = username;
185                                 nonce = nonce;
186
187                                 server_key = server_key;
188                                 stored_key = stored_key;
189                                 client_first_message_bare = client_first_message_bare;
190                                 server_first_message = server_first_message;
191                         }
192                         return "challenge", server_first_message
193                 else
194                         -- we are processing client_final_message
195                         local client_final_message = message;
196
197                         local client_final_message_without_proof, channelbinding, nonce, proof
198                                 = s_match(client_final_message, "(c=([^,]*),r=([^,]*),?.-),p=(.*)$");
199
200                         if not proof or not nonce or not channelbinding then
201                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
202                         end
203
204                         local client_gs2_header = base64.decode(channelbinding)
205                         local our_client_gs2_header = state["gs2_header"]
206                         if state.gs2_cbind_name then
207                                 -- we support channelbinding, so check if the value is valid
208                                 our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
209                         end
210                         if client_gs2_header ~= our_client_gs2_header then
211                                 return "failure", "malformed-request", "Invalid channel binding value.";
212                         end
213
214                         if nonce ~= state.nonce then
215                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
216                         end
217
218                         local ServerKey = state.server_key;
219                         local StoredKey = state.stored_key;
220
221                         local AuthMessage = state.client_first_message_bare .. "," .. state.server_first_message .. "," .. client_final_message_without_proof
222                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
223                         local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
224                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
225
226                         if StoredKey == H_f(ClientKey) then
227                                 local server_final_message = "v="..base64.encode(ServerSignature);
228                                 self["username"] = state.username;
229                                 return "success", server_final_message;
230                         else
231                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
232                         end
233                 end
234         end
235         return scram_hash;
236 end
237
238 local function init(registerMechanism)
239         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
240                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
241
242                 -- register channel binding equivalent
243                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
244         end
245
246         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
247 end
248
249 return {
250         getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
251         init = init;
252 }