util.hmac, util.hashes: Implement HMAC functions in C, and move to util.hashes
[prosody.git] / util / sasl / scram.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --        All rights reserved.
5 --
6 --        Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --                * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --                * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --                * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local s_match = string.match;
15 local type = type
16 local string = string
17 local base64 = require "util.encodings".base64;
18 local hmac_sha1 = require "util.hashes".hmac_sha1;
19 local sha1 = require "util.hashes".sha1;
20 local generate_uuid = require "util.uuid".generate;
21 local saslprep = require "util.encodings".stringprep.saslprep;
22 local nodeprep = require "util.encodings".stringprep.nodeprep;
23 local log = require "util.logger".init("sasl");
24 local t_concat = table.concat;
25 local char = string.char;
26 local byte = string.byte;
27
28 module "sasl.scram"
29
30 --=========================
31 --SASL SCRAM-SHA-1 according to RFC 5802
32
33 --[[
34 Supported Authentication Backends
35
36 scram_{MECH}:
37         -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
38         function(username, realm)
39                 return stored_key, server_key, iteration_count, salt, state;
40         end
41 ]]
42
43 local default_i = 4096
44
45 local function bp( b )
46         local result = ""
47         for i=1, b:len() do
48                 result = result.."\\"..b:byte(i)
49         end
50         return result
51 end
52
53 local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
54
55 local result = {};
56 local function binaryXOR( a, b )
57         for i=1, #a do
58                 local x, y = byte(a, i), byte(b, i);
59                 local lowx, lowy = x % 16, y % 16;
60                 local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
61                 local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
62                 local r = hir * 16 + lowr;
63                 result[i] = char(r)
64         end
65         return t_concat(result);
66 end
67
68 -- hash algorithm independent Hi(PBKDF2) implementation
69 function Hi(hmac, str, salt, i)
70         local Ust = hmac(str, salt.."\0\0\0\1");
71         local res = Ust;
72         for n=1,i-1 do
73                 local Und = hmac(str, Ust)
74                 res = binaryXOR(res, Und)
75                 Ust = Und
76         end
77         return res
78 end
79
80 local function validate_username(username, _nodeprep)
81         -- check for forbidden char sequences
82         for eq in username:gmatch("=(.?.?)") do
83                 if eq ~= "2C" and eq ~= "3D" then
84                         return false
85                 end
86         end
87         
88         -- replace =2C with , and =3D with =
89         username = username:gsub("=2C", ",");
90         username = username:gsub("=3D", "=");
91         
92         -- apply SASLprep
93         username = saslprep(username);
94
95         if username and _nodeprep ~= false then
96                 username = (_nodeprep or nodeprep)(username);
97         end
98
99         return username and #username>0 and username;
100 end
101
102 local function hashprep(hashname)
103         return hashname:lower():gsub("-", "_");
104 end
105
106 function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
107         if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
108                 return false, "inappropriate argument types"
109         end
110         if iteration_count < 4096 then
111                 log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
112         end
113         local salted_password = Hi(hmac_sha1, password, salt, iteration_count);
114         local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
115         local server_key = hmac_sha1(salted_password, "Server Key");
116         return true, stored_key, server_key
117 end
118
119 local function scram_gen(hash_name, H_f, HMAC_f)
120         local function scram_hash(self, message)
121                 if not self.state then self["state"] = {} end
122         
123                 if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
124                 if not self.state.name then
125                         -- we are processing client_first_message
126                         local client_first_message = message;
127                         
128                         -- TODO: fail if authzid is provided, since we don't support them yet
129                         self.state["client_first_message"] = client_first_message;
130                         self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
131                                 = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
132
133                         -- we don't do any channel binding yet
134                         if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
135                                 return "failure", "malformed-request";
136                         end
137
138                         if not self.state.name or not self.state.clientnonce then
139                                 return "failure", "malformed-request", "Channel binding isn't support at this time.";
140                         end
141                 
142                         self.state.name = validate_username(self.state.name, self.profile.nodeprep);
143                         if not self.state.name then
144                                 log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
145                                 return "failure", "malformed-request", "Invalid username.";
146                         end
147                 
148                         self.state["servernonce"] = generate_uuid();
149                         
150                         -- retreive credentials
151                         if self.profile.plain then
152                                 local password, state = self.profile.plain(self, self.state.name, self.realm)
153                                 if state == nil then return "failure", "not-authorized"
154                                 elseif state == false then return "failure", "account-disabled" end
155                                 
156                                 password = saslprep(password);
157                                 if not password then
158                                         log("debug", "Password violates SASLprep.");
159                                         return "failure", "not-authorized", "Invalid password."
160                                 end
161
162                                 self.state.salt = generate_uuid();
163                                 self.state.iteration_count = default_i;
164
165                                 local succ = false;
166                                 succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
167                                 if not succ then
168                                         log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
169                                         return "failure", "temporary-auth-failure";
170                                 end
171                         elseif self.profile["scram_"..hashprep(hash_name)] then
172                                 local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
173                                 if state == nil then return "failure", "not-authorized"
174                                 elseif state == false then return "failure", "account-disabled" end
175                                 
176                                 self.state.stored_key = stored_key;
177                                 self.state.server_key = server_key;
178                                 self.state.iteration_count = iteration_count;
179                                 self.state.salt = salt
180                         end
181                 
182                         local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
183                         self.state["server_first_message"] = server_first_message;
184                         return "challenge", server_first_message
185                 else
186                         -- we are processing client_final_message
187                         local client_final_message = message;
188                         
189                         self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
190         
191                         if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
192                                 return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
193                         end
194
195                         if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
196                                 return "failure", "malformed-request", "Wrong nonce in client-final-message.";
197                         end
198                         
199                         local ServerKey = self.state.server_key;
200                         local StoredKey = self.state.stored_key;
201                         
202                         local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
203                         local ClientSignature = HMAC_f(StoredKey, AuthMessage)
204                         local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
205                         local ServerSignature = HMAC_f(ServerKey, AuthMessage)
206
207                         if StoredKey == H_f(ClientKey) then
208                                 local server_final_message = "v="..base64.encode(ServerSignature);
209                                 self["username"] = self.state.name;
210                                 return "success", server_final_message;
211                         else
212                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
213                         end
214                 end
215         end
216         return scram_hash;
217 end
218
219 function init(registerMechanism)
220         local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
221                 registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
222         end
223
224         registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
225 end
226
227 return _M;