74854619b1ad4d072d15030e26185910da48d38b
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local string = string
17 local tostring = tostring;
18 local base64 = require "util.encodings".base64;
19 local hmac_sha1 = require "util.hmac".sha1;
20 local sha1 = require "util.hashes".sha1;
21 local generate_uuid = require "util.uuid".generate;
22 local saslprep = require "util.encodings".stringprep.saslprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 module "scram"
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41
42 Supported Channel Binding Backends
43
44 'tls-unique' according to RFC 5929
45 ]]
46
47 local default_i = 4096
48
49 local function bp( b )
50         local result = ""
51         for i=1, b:len() do
52                 result = result.."\\"..b:byte(i)
53         end
54         return result
55 end
56
57 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
58
59 local result = {};
60 local function binaryXOR( a, b )
61         for i=1, #a do
62                 local x, y = byte(a, i), byte(b, i);
63                 local lowx, lowy = x % 16, y % 16;
64                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
65                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
66                 local r = hir * 16 + lowr;
67                 result[i] = char(r)
68         end
69         return t_concat(result);
70 end
71
72 -- hash algorithm independent Hi(PBKDF2) implementation
73 function Hi(hmac, str, salt, i)
74         local Ust = hmac(str, salt.."\0\0\0\1");
75         local res = Ust;
76         for n=1,i-1 do
77                 local Und = hmac(str, Ust)
78                 res = binaryXOR(res, Und)
79                 Ust = Und
80         end
81         return res
82 end
83
84 local function validate_username(username)
85         -- check for forbidden char sequences
86         for eq in username:gmatch("=(.?.?)") do
87                 if eq ~= "2D" and eq ~= "3D" then
88                         return false
89                 end
90         end
91         
92         -- replace =2D with , and =3D with =
93         username = username:gsub("=2D", ",");
94         username = username:gsub("=3D", "=");
95         
96         -- apply SASLprep
97         username = saslprep(username);
98         return username;
99 end
100
101 local function hashprep(hashname)
102         return hashname:lower():gsub("-", "_");
103 end
104
105 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
106         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
107                 return false, "inappropriate argument types"
108         end
109         if iteration_count < 4096 then
110                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
111         end
112         local salted_password = Hi(hmac_sha1, password, salt, iteration_count);
113         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
114         local server_key = hmac_sha1(salted_password, "Server Key");
115         return true, stored_key, server_key
116 end
117
118 local function scram_gen(hash_name, H_f, HMAC_f)
119         local function scram_hash(self, message)
120                 if not self.state then self["state"] = {} end
121                 local support_channel_binding = false;
122                 if self.profile.cb then support_channel_binding = true; end
123                 
124                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
125                 if not self.state.name then
126                         -- we are processing client_first_message
127                         local client_first_message = message;
128                         log("debug", client_first_message);
129                         -- TODO: fail if authzid is provided, since we don't support them yet
130                         self.state["client_first_message"] = client_first_message;
131                         self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
132                                 = client_first_message:match("^(%a)=?([%a%-]*),(.*),n=(.*),r=([^,]*).*");
133
134                         -- we don't do any channel binding yet
135                         log("debug", "Decoded: cbind_flag: %s, cbind_name: %s, authzid: %s, name: %s, clientnonce: %s", tostring(self.state.gs2_cbind_flag),
136                                                                                                                                                                                                 tostring(self.state.gs2_cbind_name),
137                                                                                                                                                                                                 tostring(self.state.authzid), 
138                                                                                                                                                                                                 tostring(self.state.name), 
139                                                                                                                                                                                                 tostring(self.state.clientnonce));
140                         if support_channel_binding then
141                                 if string.sub(self.state.gs2_cbind_flag, 0, 1) == "y" then
142                                         return "failure", "malformed-request";
143                                 end
144                                 
145                                 -- check whether we support the proposed channel binding type
146                                 if not self.profile.cb[self.state.gs2_cbind_name] then
147                                         return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
148                                 end
149                         else
150                                 if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
151                                         return "failure", "malformed-request";
152                                 end
153                         end
154
155                         if not self.state.name or not self.state.clientnonce then
156                                 return "failure", "malformed-request", "Channel binding isn't support at this time.";
157                         end
158                 
159                         self.state.name = validate_username(self.state.name);
160                         if not self.state.name then
161                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
162                                 return "failure", "malformed-request", "Invalid username.";
163                         end
164                 
165                         self.state["servernonce"] = generate_uuid();
166                         
167                         -- retreive credentials
168                         if self.profile.plain then
169                                 local password, state = self.profile.plain(self, self.state.name, self.realm)
170                                 if state == nil then return "failure", "not-authorized"
171                                 elseif state == false then return "failure", "account-disabled" end
172                                 
173                                 password = saslprep(password);
174                                 if not password then
175                                         log("debug", "Password violates SASLprep.");
176                                         return "failure", "not-authorized", "Invalid password."
177                                 end
178
179                                 self.state.salt = generate_uuid();
180                                 self.state.iteration_count = default_i;
181
182                                 local succ = false;
183                                 succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
184                                 if not succ then
185                                         log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
186                                         return "failure", "temporary-auth-failure";
187                                 end
188                         elseif self.profile["scram_"..hashprep(hash_name)] then
189                                 local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
190                                 if state == nil then return "failure", "not-authorized"
191                                 elseif state == false then return "failure", "account-disabled" end
192                                 
193                                 self.state.stored_key = stored_key;
194                                 self.state.server_key = server_key;
195                                 self.state.iteration_count = iteration_count;
196                                 self.state.salt = salt
197                         end
198                 
199                         local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
200                         self.state["server_first_message"] = server_first_message;
201                         return "challenge", server_first_message
202                 else
203                         -- we are processing client_final_message
204                         local client_final_message = message;
205                         log("debug", "client_final_message: %s", client_final_message);
206                         self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
207
208                         if self.state.gs2_cbind_name then
209                                 local client_gs2_header = base64.decode(self.state.channelbinding)
210                                 local our_client_gs2_header = "p="..self.state.gs2_cbind_name..","..self.state["authzid"]..","..self.profile.cb[self.state.gs2_cbind_name](self);
211
212                                 if client_gs2_header ~= our_client_gs2_header then
213                                         return "failure", "malformed-request", "Invalid channel binding value.";
214                                 end
215                         else
216                                 if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
217                                         return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
218                                 end
219                         end
220
221                         if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
222                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
223                         end
224                         
225                         local ServerKey = self.state.server_key;
226                         local StoredKey = self.state.stored_key;
227                         
228                         local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
229                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
230                         local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
231                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
232
233                         if StoredKey == H_f(ClientKey) then
234                                 local server_final_message = "v="..base64.encode(ServerSignature);
235                                 self["username"] = self.state.name;
236                                 return "success", server_final_message;
237                         else
238                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
239                         end
240                 end
241         end
242         return scram_hash;
243 end
244
245 function init(registerMechanism)
246         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
247                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
248                 
249                 -- register channel binding equivalent
250                 registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
251         end
252
253         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
254 end
255
256 return _M;