Fail if mechanism has already been selected.
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local pairs, ipairs = pairs, ipairs;
21 local t_insert, t_concat = table.insert, table.concat;
22 local to_byte, to_char = string.byte, string.char;
23 local to_unicode = require "util.encodings".idna.to_unicode;
24 local s_match = string.match;
25 local gmatch = string.gmatch
26 local string = string
27 local math = require "math"
28 local type = type
29 local error = error
30 local print = print
31 local setmetatable = setmetatable;
32 local assert = assert;
33
34 require "util.iterators"
35 local keys = keys
36
37 local array = require "util.array"
38 module "sasl"
39
40 --[[
41 Authentication Backend Prototypes:
42
43 plain:
44         function(username, realm)
45                 return password, state;
46         end
47
48 plain-test:
49         function(username, realm, password)
50                 return true or false, state;
51         end
52
53 digest-md5:
54         function(username, realm, encoding)
55                 return digesthash, state;
56         end
57
58 digest-md5-test:
59         function(username, realm, encoding, digesthash)
60                 return true or false, state;
61         end
62 ]]
63
64 local method = {};
65 method.__index = method;
66 local mechanisms = {};
67 local backend_mechanism = {};
68
69 -- register a new SASL mechanims
70 local function registerMechanism(name, backends, f)
71         assert(type(name) == "string", "Parameter name MUST be a string.");
72         assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
73         assert(type(f) == "function", "Parameter f MUST be a function.");
74         mechanisms[name] = f
75         for _, backend_name in ipairs(backends) do
76                 if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
77                 t_insert(backend_mechanism[backend_name], name);
78         end
79 end
80
81 -- create a new SASL object which can be used to authenticate clients
82 function new(realm, profile)
83         sasl_i = {profile = profile};
84         sasl_i.realm = realm;
85         return setmetatable(sasl_i, method);
86 end
87
88 -- get a list of possible SASL mechanims to use
89 function method:mechanisms()
90         local mechanisms = {}
91         for backend, f in pairs(self.profile) do
92                 print(backend)
93                 if backend_mechanism[backend] then
94                         for _, mechanism in ipairs(backend_mechanism[backend]) do
95                                 mechanisms[mechanism] = true;
96                         end
97                 end
98         end
99         self["possible_mechanisms"] = mechanisms;
100         return array.collect(keys(mechanisms));
101 end
102
103 -- select a mechanism to use
104 function method:select(mechanism)
105         if self.mech_i then
106                 return false;
107         end
108         
109         self.mech_i = mechanisms[mechanism]
110         if self.mech_i == nil then 
111                 return false;
112         end
113         return true;
114 end
115
116 -- feed new messages to process into the library
117 function method:process(message)
118         if message == "" or message == nil then return "failure", "malformed-request" end
119         return self.mech_i(self, message);
120 end
121
122 --=========================
123 --SASL PLAIN according to RFC 4616
124 local function sasl_mechanism_plain(self, message)
125         local response = message
126         local authorization = s_match(response, "([^%z]+)")
127         local authentication = s_match(response, "%z([^%z]+)%z")
128         local password = s_match(response, "%z[^%z]+%z([^%z]+)")
129
130         if authentication == nil or password == nil then
131                 return "failure", "malformed-request";
132         end
133
134         local correct, state = false, false;
135         if self.profile.plain then
136                 local correct_password;
137                 correct_password, state = self.profile.plain(authentication, self.realm);
138                 if correct_password == password then correct = true; else correct = false; end
139         elseif self.profile.plain_test then
140                 correct, state = self.profile.plain_test(authentication, self.realm, password);
141         end
142
143         self.username = authentication
144         if not state then
145                 return "failure", "account-disabled";
146         end
147
148         if correct then
149                 return "success";
150         else
151                 return "failure", "not-authorized";
152         end
153 end
154 registerMechanism("PLAIN", {"plain", "plain_test"}, sasl_mechanism_plain);
155
156 --=========================
157 --SASL DIGEST-MD5 according to RFC 2831
158 local function sasl_mechanism_digest_md5(self, message)
159         --TODO complete support for authzid
160
161         local function serialize(message)
162                 local data = ""
163
164                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
165
166                 -- testing all possible values
167                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
168                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
169                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
170                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
171                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
172                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
173                 data = data:gsub(",$", "")
174                 return data
175         end
176
177         local function utf8tolatin1ifpossible(passwd)
178                 local i = 1;
179                 while i <= #passwd do
180                         local passwd_i = to_byte(passwd:sub(i, i));
181                         if passwd_i > 0x7F then
182                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
183                                         return passwd;
184                                 end
185                                 i = i + 1;
186                                 passwd_i = to_byte(passwd:sub(i, i));
187                                 if passwd_i < 0x80 or passwd_i > 0xBF then
188                                         return passwd;
189                                 end
190                         end
191                         i = i + 1;
192                 end
193
194                 local p = {};
195                 local j = 0;
196                 i = 1;
197                 while (i <= #passwd) do
198                         local passwd_i = to_byte(passwd:sub(i, i));
199                         if passwd_i > 0x7F then
200                                 i = i + 1;
201                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
202                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
203                         else
204                                 t_insert(p, to_char(passwd_i));
205                         end
206                         i = i + 1;
207                 end
208                 return t_concat(p);
209         end
210         local function latin1toutf8(str)
211                 local p = {};
212                 for ch in gmatch(str, ".") do
213                         ch = to_byte(ch);
214                         if (ch < 0x80) then
215                                 t_insert(p, to_char(ch));
216                         elseif (ch < 0xC0) then
217                                 t_insert(p, to_char(0xC2, ch));
218                         else
219                                 t_insert(p, to_char(0xC3, ch - 64));
220                         end
221                 end
222                 return t_concat(p);
223         end
224         local function parse(data)
225                 local message = {}
226                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
227                         message[k] = v;
228                 end
229                 return message;
230         end
231
232         if not self.nonce then
233                 self.nonce = generate_uuid();
234                 self.step = 0;
235                 self.nonce_count = {};
236         end
237
238         self.step = self.step + 1;
239         if (self.step == 1) then
240                 local challenge = serialize({   nonce = object.nonce,
241                                                                                 qop = "auth",
242                                                                                 charset = "utf-8",
243                                                                                 algorithm = "md5-sess",
244                                                                                 realm = self.realm});
245                 return "challenge", challenge;
246         elseif (self.step == 2) then
247                 local response = parse(message);
248                 -- check for replay attack
249                 if response["nc"] then
250                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
251                 end
252
253                 -- check for username, it's REQUIRED by RFC 2831
254                 if not response["username"] then
255                         return "failure", "malformed-request";
256                 end
257                 self["username"] = response["username"];
258
259                 -- check for nonce, ...
260                 if not response["nonce"] then
261                         return "failure", "malformed-request";
262                 else
263                         -- check if it's the right nonce
264                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
265                 end
266
267                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
268                 if not response["qop"] then response["qop"] = "auth" end
269
270                 if response["realm"] == nil or response["realm"] == "" then
271                         response["realm"] = "";
272                 elseif response["realm"] ~= self.realm then
273                         return "failure", "not-authorized", "Incorrect realm value";
274                 end
275
276                 local decoder;
277                 if response["charset"] == nil then
278                         decoder = utf8tolatin1ifpossible;
279                 elseif response["charset"] ~= "utf-8" then
280                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
281                 end
282
283                 local domain = "";
284                 local protocol = "";
285                 if response["digest-uri"] then
286                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
287                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
288                 else
289                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
290                 end
291
292                 --TODO maybe realm support
293                 self.username = response["username"];
294                 local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
295                 if Y == nil then return "failure", "not-authorized"
296                 elseif Y == false then return "failure", "account-disabled" end
297                 local A1 = "";
298                 if response.authzid then
299                         if response.authzid == self.username.."@"..self.realm then
300                                 -- COMPAT
301                                 log("warn", "Client is violating XMPP RFC. See section 6.1 of RFC 3920.");
302                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
303                         else
304                                 A1 = "?";
305                         end
306                 else
307                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
308                 end
309                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
310
311                 local HA1 = md5(A1, true);
312                 local HA2 = md5(A2, true);
313
314                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
315                 local response_value = md5(KD, true);
316
317                 if response_value == response["response"] then
318                         -- calculate rspauth
319                         A2 = ":"..protocol.."/"..domain;
320
321                         HA1 = md5(A1, true);
322                         HA2 = md5(A2, true);
323
324                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
325                         local rspauth = md5(KD, true);
326                         self.authenticated = true;
327                         return "challenge", serialize({rspauth = rspauth});
328                 else
329                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
330                 end
331         elseif self.step == 3 then
332                 if self.authenticated ~= nil then return "success"
333                 else return "failure", "malformed-request" end
334         end
335 end
336
337 return _M;