Merge 0.9->0.10
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local base64 = require "util.encodings".base64;
17 local hmac_sha1 = require "util.hashes".hmac_sha1;
18 local sha1 = require "util.hashes".sha1;
19 local Hi = require "util.hashes".scram_Hi_sha1;
20 local generate_uuid = require "util.uuid".generate;
21 local saslprep = require "util.encodings".stringprep.saslprep;
22 local nodeprep = require "util.encodings".stringprep.nodeprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 module "sasl.scram"
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41
42 Supported Channel Binding Backends
43
44 'tls-unique' according to RFC 5929
45 ]]
46
47 local default_i = 4096
48
49 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
50
51 local result = {};
52 local function binaryXOR( a, b )
53         for i=1, #a do
54                 local x, y = byte(a, i), byte(b, i);
55                 local lowx, lowy = x % 16, y % 16;
56                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
57                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
58                 local r = hir * 16 + lowr;
59                 result[i] = char(r)
60         end
61         return t_concat(result);
62 end
63
64 local function validate_username(username, _nodeprep)
65         -- check for forbidden char sequences
66         for eq in username:gmatch("=(.?.?)") do
67                 if eq ~= "2C" and eq ~= "3D" then
68                         return false
69                 end
70         end
71
72         -- replace =2C with , and =3D with =
73         username = username:gsub("=2C", ",");
74         username = username:gsub("=3D", "=");
75
76         -- apply SASLprep
77         username = saslprep(username);
78
79         if username and _nodeprep ~= false then
80                 username = (_nodeprep or nodeprep)(username);
81         end
82
83         return username and #username>0 and username;
84 end
85
86 local function hashprep(hashname)
87         return hashname:lower():gsub("-", "_");
88 end
89
90 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
91         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
92                 return false, "inappropriate argument types"
93         end
94         if iteration_count < 4096 then
95                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
96         end
97         local salted_password = Hi(password, salt, iteration_count);
98         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
99         local server_key = hmac_sha1(salted_password, "Server Key");
100         return true, stored_key, server_key
101 end
102
103 local function scram_gen(hash_name, H_f, HMAC_f)
104         local function scram_hash(self, message)
105                 if not self.state then self["state"] = {} end
106                 local support_channel_binding = false;
107                 if self.profile.cb then support_channel_binding = true; end
108
109                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
110                 if not self.state.name then
111                         -- we are processing client_first_message
112                         local client_first_message = message;
113
114                         -- TODO: fail if authzid is provided, since we don't support them yet
115                         self.state["client_first_message"] = client_first_message;
116                         self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
117                                 = client_first_message:match("^([ynp])=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
118
119                         local gs2_cbind_flag = self.state.gs2_cbind_flag;
120
121                         if not gs2_cbind_flag then
122                                 return "failure", "malformed-request";
123                         end
124
125                         if support_channel_binding and gs2_cbind_flag == "y" then
126                                 -- "y" -> client does support channel binding
127                                 --        but thinks the server does not.
128                                         return "failure", "malformed-request";
129                                 end
130
131                         if gs2_cbind_flag == "n" then
132                                 -- "n" -> client doesn't support channel binding.
133                                 support_channel_binding = false;
134                         end
135
136                         if support_channel_binding and gs2_cbind_flag == "p" then
137                                 -- check whether we support the proposed channel binding type
138                                 if not self.profile.cb[self.state.gs2_cbind_name] then
139                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
140                                 end
141                         else
142                                 -- no channel binding,
143                                 self.state.gs2_cbind_name = nil;
144                         end
145
146                         if not self.state.name or not self.state.clientnonce then
147                                 return "failure", "malformed-request", "Channel binding isn't support at this time.";
148                         end
149
150                         self.state.name = validate_username(self.state.name, self.profile.nodeprep);
151                         if not self.state.name then
152                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
153                                 return "failure", "malformed-request", "Invalid username.";
154                         end
155
156                         self.state["servernonce"] = generate_uuid();
157
158                         -- retreive credentials
159                         if self.profile.plain then
160                                 local password, state = self.profile.plain(self, self.state.name, self.realm)
161                                 if state == nil then return "failure", "not-authorized"
162                                 elseif state == false then return "failure", "account-disabled" end
163
164                                 password = saslprep(password);
165                                 if not password then
166                                         log("debug", "Password violates SASLprep.");
167                                         return "failure", "not-authorized", "Invalid password."
168                                 end
169
170                                 self.state.salt = generate_uuid();
171                                 self.state.iteration_count = default_i;
172
173                                 local succ = false;
174                                 succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
175                                 if not succ then
176                                         log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
177                                         return "failure", "temporary-auth-failure";
178                                 end
179                         elseif self.profile["scram_"..hashprep(hash_name)] then
180                                 local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
181                                 if state == nil then return "failure", "not-authorized"
182                                 elseif state == false then return "failure", "account-disabled" end
183
184                                 self.state.stored_key = stored_key;
185                                 self.state.server_key = server_key;
186                                 self.state.iteration_count = iteration_count;
187                                 self.state.salt = salt
188                         end
189
190                         local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
191                         self.state["server_first_message"] = server_first_message;
192                         return "challenge", server_first_message
193                 else
194                         -- we are processing client_final_message
195                         local client_final_message = message;
196
197                         self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
198
199                         if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
200                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
201                         end
202
203                         if self.state.gs2_cbind_name then
204                                 -- we support channelbinding, so check if the value is valid
205                                 local client_gs2_header = base64.decode(self.state.channelbinding)
206                                 local our_client_gs2_header = "p="..self.state.gs2_cbind_name..","..self.state["authzid"]..","..self.profile.cb[self.state.gs2_cbind_name](self);
207
208                                 if client_gs2_header ~= our_client_gs2_header then
209                                         return "failure", "malformed-request", "Invalid channel binding value.";
210                                 end
211                         end
212
213                         if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
214                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
215                         end
216
217                         local ServerKey = self.state.server_key;
218                         local StoredKey = self.state.stored_key;
219
220                         local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
221                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
222                         local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
223                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
224
225                         if StoredKey == H_f(ClientKey) then
226                                 local server_final_message = "v="..base64.encode(ServerSignature);
227                                 self["username"] = self.state.name;
228                                 return "success", server_final_message;
229                         else
230                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
231                         end
232                 end
233         end
234         return scram_hash;
235 end
236
237 function init(registerMechanism)
238         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
239                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
240
241                 -- register channel binding equivalent
242                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
243         end
244
245         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
246 end
247
248 return _M;