util.sasl.scram: Create the state table as late as possible, keep state in locals...
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local base64 = require "util.encodings".base64;
17 local hmac_sha1 = require "util.hashes".hmac_sha1;
18 local sha1 = require "util.hashes".sha1;
19 local Hi = require "util.hashes".scram_Hi_sha1;
20 local generate_uuid = require "util.uuid".generate;
21 local saslprep = require "util.encodings".stringprep.saslprep;
22 local nodeprep = require "util.encodings".stringprep.nodeprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 module "sasl.scram"
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41
42 Supported Channel Binding Backends
43
44 'tls-unique' according to RFC 5929
45 ]]
46
47 local default_i = 4096
48
49 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
50
51 local result = {};
52 local function binaryXOR( a, b )
53         for i=1, #a do
54                 local x, y = byte(a, i), byte(b, i);
55                 local lowx, lowy = x % 16, y % 16;
56                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
57                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
58                 local r = hir * 16 + lowr;
59                 result[i] = char(r)
60         end
61         return t_concat(result);
62 end
63
64 local function validate_username(username, _nodeprep)
65         -- check for forbidden char sequences
66         for eq in username:gmatch("=(.?.?)") do
67                 if eq ~= "2C" and eq ~= "3D" then
68                         return false
69                 end
70         end
71
72         -- replace =2C with , and =3D with =
73         username = username:gsub("=2C", ",");
74         username = username:gsub("=3D", "=");
75
76         -- apply SASLprep
77         username = saslprep(username);
78
79         if username and _nodeprep ~= false then
80                 username = (_nodeprep or nodeprep)(username);
81         end
82
83         return username and #username>0 and username;
84 end
85
86 local function hashprep(hashname)
87         return hashname:lower():gsub("-", "_");
88 end
89
90 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
91         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
92                 return false, "inappropriate argument types"
93         end
94         if iteration_count < 4096 then
95                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
96         end
97         local salted_password = Hi(password, salt, iteration_count);
98         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
99         local server_key = hmac_sha1(salted_password, "Server Key");
100         return true, stored_key, server_key
101 end
102
103 local function scram_gen(hash_name, H_f, HMAC_f)
104         local function scram_hash(self, message)
105                 local support_channel_binding = false;
106                 if self.profile.cb then support_channel_binding = true; end
107
108                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
109                 local state = self.state;
110                 if not state then
111                         -- we are processing client_first_message
112                         local client_first_message = message;
113
114                         -- TODO: fail if authzid is provided, since we don't support them yet
115                         local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, name, clientnonce
116                                 = client_first_message:match("^(([ynp])=?([%a%-]*),(.*),)n=(.*),r=([^,]*).*");
117
118                         if not gs2_cbind_flag then
119                                 return "failure", "malformed-request";
120                         end
121
122                         if support_channel_binding and gs2_cbind_flag == "y" then
123                                 -- "y" -> client does support channel binding
124                                 --        but thinks the server does not.
125                                         return "failure", "malformed-request";
126                                 end
127
128                         if gs2_cbind_flag == "n" then
129                                 -- "n" -> client doesn't support channel binding.
130                                 support_channel_binding = false;
131                         end
132
133                         if support_channel_binding and gs2_cbind_flag == "p" then
134                                 -- check whether we support the proposed channel binding type
135                                 if not self.profile.cb[gs2_cbind_name] then
136                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
137                                 end
138                         else
139                                 -- no channel binding,
140                                 gs2_cbind_name = nil;
141                         end
142
143                         name = validate_username(name, self.profile.nodeprep);
144                         if not name then
145                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
146                                 return "failure", "malformed-request", "Invalid username.";
147                         end
148
149                         -- retreive credentials
150                         local stored_key, server_key, salt, iteration_count;
151                         if self.profile.plain then
152                                 local password, state = self.profile.plain(self, name, self.realm)
153                                 if state == nil then return "failure", "not-authorized"
154                                 elseif state == false then return "failure", "account-disabled" end
155
156                                 password = saslprep(password);
157                                 if not password then
158                                         log("debug", "Password violates SASLprep.");
159                                         return "failure", "not-authorized", "Invalid password."
160                                 end
161
162                                 salt = generate_uuid();
163                                 iteration_count = default_i;
164
165                                 local succ = false;
166                                 succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
167                                 if not succ then
168                                         log("error", "Generating authentication database failed. Reason: %s", stored_key);
169                                         return "failure", "temporary-auth-failure";
170                                 end
171                         elseif self.profile["scram_"..hashprep(hash_name)] then
172                                 local state;
173                                 stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, name, self.realm);
174                                 if state == nil then return "failure", "not-authorized"
175                                 elseif state == false then return "failure", "account-disabled" end
176                         end
177
178                         local nonce = clientnonce .. generate_uuid();
179                         local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
180                         self.state = {
181                                 gs2_header = gs2_header;
182                                 gs2_cbind_name = gs2_cbind_name;
183                                 name = name;
184                                 nonce = nonce;
185
186                                 server_key = server_key;
187                                 stored_key = stored_key;
188                                 client_first_message = client_first_message;
189                                 server_first_message = server_first_message;
190                         }
191                         return "challenge", server_first_message
192                 else
193                         -- we are processing client_final_message
194                         local client_final_message = message;
195
196                         local channelbinding, nonce, proof = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
197
198                         if not proof or not nonce or not channelbinding then
199                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
200                         end
201
202                         local client_gs2_header = base64.decode(channelbinding)
203                         local our_client_gs2_header = state["gs2_header"]
204                         if state.gs2_cbind_name then
205                                 -- we support channelbinding, so check if the value is valid
206                                 our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
207                         end
208                         if client_gs2_header ~= our_client_gs2_header then
209                                 return "failure", "malformed-request", "Invalid channel binding value.";
210                         end
211
212                         if nonce ~= state.nonce then
213                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
214                         end
215
216                         local ServerKey = state.server_key;
217                         local StoredKey = state.stored_key;
218
219                         local AuthMessage = "n=" .. s_match(state.client_first_message,"n=(.+)") .. "," .. state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
220                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
221                         local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
222                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
223
224                         if StoredKey == H_f(ClientKey) then
225                                 local server_final_message = "v="..base64.encode(ServerSignature);
226                                 self["username"] = state.name;
227                                 return "success", server_final_message;
228                         else
229                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
230                         end
231                 end
232         end
233         return scram_hash;
234 end
235
236 function init(registerMechanism)
237         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
238                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
239
240                 -- register channel binding equivalent
241                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
242         end
243
244         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
245 end
246
247 return _M;