Store stage in SASL object.
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local pairs, ipairs = pairs, ipairs;
21 local t_insert, t_concat = table.insert, table.concat;
22 local to_byte, to_char = string.byte, string.char;
23 local to_unicode = require "util.encodings".idna.to_unicode;
24 local s_match = string.match;
25 local gmatch = string.gmatch
26 local string = string
27 local math = require "math"
28 local type = type
29 local error = error
30 local print = print
31 local setmetatable = setmetatable;
32 local assert = assert;
33
34 require "util.iterators"
35 local keys = keys
36
37 local array = require "util.array"
38 module "sasl"
39
40 --[[
41 Authentication Backend Prototypes:
42
43 plain:
44         function(username, realm)
45                 return password, state;
46         end
47
48 plain-test:
49         function(username, realm, password)
50                 return true or false, state;
51         end
52
53 digest-md5:
54         function(username, realm, encoding)
55                 return digesthash, state;
56         end
57
58 digest-md5-test:
59         function(username, realm, encoding, digesthash)
60                 return true or false, state;
61         end
62 ]]
63
64 local method = {};
65 method.__index = method;
66 local mechanisms = {};
67 local backend_mechanism = {};
68
69 -- register a new SASL mechanims
70 local function registerMechanism(name, backends, f)
71         assert(type(name) == "string", "Parameter name MUST be a string.");
72         assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
73         assert(type(f) == "function", "Parameter f MUST be a function.");
74         mechanisms[name] = f
75         for _, backend_name in ipairs(backends) do
76                 if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
77                 t_insert(backend_mechanism[backend_name], name);
78         end
79 end
80
81 -- create a new SASL object which can be used to authenticate clients
82 function new(realm, profile)
83         sasl_i = {profile = profile};
84         sasl_i.realm = realm;
85         return setmetatable(sasl_i, method);
86 end
87
88 -- get a list of possible SASL mechanims to use
89 function method:mechanisms()
90         local mechanisms = {}
91         for backend, f in pairs(self.profile) do
92                 print(backend)
93                 if backend_mechanism[backend] then
94                         for _, mechanism in ipairs(backend_mechanism[backend]) do
95                                 mechanisms[mechanism] = true;
96                         end
97                 end
98         end
99         self["possible_mechanisms"] = mechanisms;
100         return array.collect(keys(mechanisms));
101 end
102
103 -- select a mechanism to use
104 function method:select(mechanism)
105         self.mech_i = mechanisms[mechanism]
106         if self.mech_i == nil then 
107                 return false;
108         end
109         return true;
110 end
111
112 -- feed new messages to process into the library
113 function method:process(message)
114         if message == "" or message == nil then return "failure", "malformed-request" end
115         return self.mech_i(self, message);
116 end
117
118 --=========================
119 --SASL PLAIN according to RFC 4616
120 local function sasl_mechanism_plain(self, message)
121         local response = message
122         local authorization = s_match(response, "([^%z]+)")
123         local authentication = s_match(response, "%z([^%z]+)%z")
124         local password = s_match(response, "%z[^%z]+%z([^%z]+)")
125
126         if authentication == nil or password == nil then
127                 return "failure", "malformed-request";
128         end
129
130         local correct, state = false, false;
131         if self.profile.plain then
132                 local correct_password;
133                 correct_password, state = self.profile.plain(authentication, self.realm);
134                 if correct_password == password then correct = true; else correct = false; end
135         elseif self.profile.plain_test then
136                 correct, state = self.profile.plain_test(authentication, self.realm, password);
137         end
138
139         self.username = authentication
140         if not state then
141                 return "failure", "account-disabled";
142         end
143
144         if correct then
145                 return "success";
146         else
147                 return "failure", "not-authorized";
148         end
149 end
150 registerMechanism("PLAIN", {"plain", "plain_test"}, sasl_mechanism_plain);
151
152 --=========================
153 --SASL DIGEST-MD5 according to RFC 2831
154 local function sasl_mechanism_digest_md5(self, message)
155         --TODO complete support for authzid
156
157         local function serialize(message)
158                 local data = ""
159
160                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
161
162                 -- testing all possible values
163                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
164                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
165                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
166                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
167                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
168                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
169                 data = data:gsub(",$", "")
170                 return data
171         end
172
173         local function utf8tolatin1ifpossible(passwd)
174                 local i = 1;
175                 while i <= #passwd do
176                         local passwd_i = to_byte(passwd:sub(i, i));
177                         if passwd_i > 0x7F then
178                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
179                                         return passwd;
180                                 end
181                                 i = i + 1;
182                                 passwd_i = to_byte(passwd:sub(i, i));
183                                 if passwd_i < 0x80 or passwd_i > 0xBF then
184                                         return passwd;
185                                 end
186                         end
187                         i = i + 1;
188                 end
189
190                 local p = {};
191                 local j = 0;
192                 i = 1;
193                 while (i <= #passwd) do
194                         local passwd_i = to_byte(passwd:sub(i, i));
195                         if passwd_i > 0x7F then
196                                 i = i + 1;
197                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
198                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
199                         else
200                                 t_insert(p, to_char(passwd_i));
201                         end
202                         i = i + 1;
203                 end
204                 return t_concat(p);
205         end
206         local function latin1toutf8(str)
207                 local p = {};
208                 for ch in gmatch(str, ".") do
209                         ch = to_byte(ch);
210                         if (ch < 0x80) then
211                                 t_insert(p, to_char(ch));
212                         elseif (ch < 0xC0) then
213                                 t_insert(p, to_char(0xC2, ch));
214                         else
215                                 t_insert(p, to_char(0xC3, ch - 64));
216                         end
217                 end
218                 return t_concat(p);
219         end
220         local function parse(data)
221                 local message = {}
222                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
223                         message[k] = v;
224                 end
225                 return message;
226         end
227
228         if not self.nonce then
229                 self.nonce = generate_uuid();
230                 self.step = 0;
231                 self.nonce_count = {};
232         end
233
234         self.step = self.step + 1;
235         if (self.step == 1) then
236                 local challenge = serialize({   nonce = object.nonce,
237                                                                                 qop = "auth",
238                                                                                 charset = "utf-8",
239                                                                                 algorithm = "md5-sess",
240                                                                                 realm = self.realm});
241                 return "challenge", challenge;
242         elseif (self.step == 2) then
243                 local response = parse(message);
244                 -- check for replay attack
245                 if response["nc"] then
246                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
247                 end
248
249                 -- check for username, it's REQUIRED by RFC 2831
250                 if not response["username"] then
251                         return "failure", "malformed-request";
252                 end
253                 self["username"] = response["username"];
254
255                 -- check for nonce, ...
256                 if not response["nonce"] then
257                         return "failure", "malformed-request";
258                 else
259                         -- check if it's the right nonce
260                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
261                 end
262
263                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
264                 if not response["qop"] then response["qop"] = "auth" end
265
266                 if response["realm"] == nil or response["realm"] == "" then
267                         response["realm"] = "";
268                 elseif response["realm"] ~= self.realm then
269                         return "failure", "not-authorized", "Incorrect realm value";
270                 end
271
272                 local decoder;
273                 if response["charset"] == nil then
274                         decoder = utf8tolatin1ifpossible;
275                 elseif response["charset"] ~= "utf-8" then
276                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
277                 end
278
279                 local domain = "";
280                 local protocol = "";
281                 if response["digest-uri"] then
282                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
283                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
284                 else
285                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
286                 end
287
288                 --TODO maybe realm support
289                 self.username = response["username"];
290                 local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
291                 if Y == nil then return "failure", "not-authorized"
292                 elseif Y == false then return "failure", "account-disabled" end
293                 local A1 = "";
294                 if response.authzid then
295                         if response.authzid == self.username.."@"..self.realm then
296                                 -- COMPAT
297                                 log("warn", "Client is violating XMPP RFC. See section 6.1 of RFC 3920.");
298                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
299                         else
300                                 A1 = "?";
301                         end
302                 else
303                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
304                 end
305                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
306
307                 local HA1 = md5(A1, true);
308                 local HA2 = md5(A2, true);
309
310                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
311                 local response_value = md5(KD, true);
312
313                 if response_value == response["response"] then
314                         -- calculate rspauth
315                         A2 = ":"..protocol.."/"..domain;
316
317                         HA1 = md5(A1, true);
318                         HA2 = md5(A2, true);
319
320                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
321                         local rspauth = md5(KD, true);
322                         self.authenticated = true;
323                         return "challenge", serialize({rspauth = rspauth});
324                 else
325                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
326                 end
327         elseif self.step == 3 then
328                 if self.authenticated ~= nil then return "success"
329                 else return "failure", "malformed-request" end
330         end
331 end
332
333 return _M;