Do SASLprep for SASL PLAIN mechanism to be more conform with RFC 4616.
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 -- 
4 --    All rights reserved.
5 --    
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --    
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --    
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local to_unicode = require "util.encodings".idna.to_unicode;
23 local u_e_saslprep = require "utii.encodings".stringprep.saslprep;
24 local s_match = string.match;
25 local gmatch = string.gmatch
26 local string = string
27 local math = require "math"
28 local type = type
29 local error = error
30 local print = print
31
32 module "sasl"
33
34 local function new_plain(realm, password_handler)
35         local object = { mechanism = "PLAIN", realm = realm, password_handler = password_handler}
36         function object.feed(self, message)
37         
38                 if message == "" or message == nil then return "failure", "malformed-request" end
39                 local response = message
40                 local authorization = s_match(response, "([^&%z]+)")
41                 local authentication = s_match(response, "%z([^&%z]+)%z")
42                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
43                 authorization, authentication, password = u_e_saslprep(authorization), u_e_saslprep(authentication), u_e_saslprep(password);
44                 
45                 if authentication == nil or password == nil then return "failure", "malformed-request" end
46                 
47                 local password_encoding, correct_password = self.password_handler(authentication, self.realm, self.realm, "PLAIN")
48                 
49                 if correct_password == nil then return "failure", "not-authorized"
50                 elseif correct_password == false then return "failure", "account-disabled" end
51                 
52                 local claimed_password = ""
53                 if password_encoding == nil then claimed_password = password
54                 else claimed_password = password_encoding(password) end
55                 caimed_password = u_e_saslprep(claimed_password);
56                 
57                 self.username = authentication
58                 if claimed_password == correct_password then
59                         return "success"
60                 else
61                         return "failure", "not-authorized"
62                 end
63         end
64         return object
65 end
66
67
68 -- implementing RFC 2831
69 local function new_digest_md5(realm, password_handler)
70         --TODO complete support for authzid
71
72         local function serialize(message)
73                 local data = ""
74                 
75                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
76                 
77                 -- testing all possible values
78                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
79                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
80                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
81                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
82                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
83                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
84                 data = data:gsub(",$", "")
85                 return data
86         end
87         
88         local function utf8tolatin1ifpossible(passwd)
89                 local i = 1;
90                 while i <= #passwd do
91                         local passwd_i = to_byte(passwd:sub(i, i));
92                         if passwd_i > 0x7F then
93                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
94                                         return passwd;
95                                 end
96                                 i = i + 1;
97                                 passwd_i = to_byte(passwd:sub(i, i));
98                                 if passwd_i < 0x80 or passwd_i > 0xBF then
99                                         return passwd;
100                                 end
101                         end
102                         i = i + 1;
103                 end
104
105                 local p = {};
106                 local j = 0;
107                 i = 1;
108                 while (i <= #passwd) do
109                         local passwd_i = to_byte(passwd:sub(i, i));
110                         if passwd_i > 0x7F then
111                                 i = i + 1;
112                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
113                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
114                         else
115                                 t_insert(p, to_char(passwd_i));
116                         end
117                         i = i + 1;
118                 end
119                 return t_concat(p);
120         end
121         local function latin1toutf8(str)
122                 local p = {};
123                 for ch in gmatch(str, ".") do
124                         ch = to_byte(ch);
125                         if (ch < 0x80) then
126                                 t_insert(p, to_char(ch));
127                         elseif (ch < 0xC0) then
128                                 t_insert(p, to_char(0xC2, ch));
129                         else
130                                 t_insert(p, to_char(0xC3, ch - 64));
131                         end
132                 end
133                 return t_concat(p);
134         end
135         local function parse(data)
136                 message = {}
137                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
138                         message[k] = v;
139                 end
140                 return message;
141         end
142
143         local object = { mechanism = "DIGEST-MD5", realm = realm, password_handler = password_handler};
144         
145         object.nonce = generate_uuid();
146         object.step = 0;
147         object.nonce_count = {};
148                                                                                                 
149         function object.feed(self, message)
150                 self.step = self.step + 1;
151                 if (self.step == 1) then
152                         local challenge = serialize({   nonce = object.nonce, 
153                                                                                         qop = "auth",
154                                                                                         charset = "utf-8",
155                                                                                         algorithm = "md5-sess",
156                                                                                         realm = self.realm});
157                         return "challenge", challenge;
158                 elseif (self.step == 2) then
159                         local response = parse(message);
160                         -- check for replay attack
161                         if response["nc"] then
162                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
163                         end
164                         
165                         -- check for username, it's REQUIRED by RFC 2831
166                         if not response["username"] then
167                                 return "failure", "malformed-request";
168                         end
169                         self["username"] = response["username"];
170                         
171                         -- check for nonce, ...
172                         if not response["nonce"] then
173                                 return "failure", "malformed-request";
174                         else
175                                 -- check if it's the right nonce
176                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
177                         end
178                         
179                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
180                         if not response["qop"] then response["qop"] = "auth" end
181                         
182                         if response["realm"] == nil or response["realm"] == "" then
183                                 response["realm"] = "";
184                         elseif response["realm"] ~= self.realm then
185                                 return "failure", "not-authorized", "Incorrect realm value";
186                         end
187                         
188                         local decoder;
189                         if response["charset"] == nil then
190                                 decoder = utf8tolatin1ifpossible;
191                         elseif response["charset"] ~= "utf-8" then
192                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
193                         end
194                         
195                         local domain = "";
196                         local protocol = "";
197                         if response["digest-uri"] then
198                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
199                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
200                         else
201                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
202                         end
203                         
204                         --TODO maybe realm support
205                         self.username = response["username"];
206                         local password_encoding, Y = self.password_handler(response["username"], to_unicode(domain), response["realm"], "DIGEST-MD5", decoder);
207                         if Y == nil then return "failure", "not-authorized"
208                         elseif Y == false then return "failure", "account-disabled" end
209                         local A1 = "";
210                         if response.authzid then
211                                 if response.authzid == self.username.."@"..self.realm then
212                                         -- COMPAT
213                                         log("warn", "Client is violating XMPP RFC. See section 6.1 of RFC 3920.");
214                                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
215                                 else
216                                         A1 = "?";
217                                 end
218                         else
219                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
220                         end
221                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
222                         
223                         local HA1 = md5(A1, true);
224                         local HA2 = md5(A2, true);
225                         
226                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
227                         local response_value = md5(KD, true);
228                         
229                         if response_value == response["response"] then
230                                 -- calculate rspauth
231                                 A2 = ":"..protocol.."/"..domain;
232                                 
233                                 HA1 = md5(A1, true);
234                                 HA2 = md5(A2, true);
235                                 
236                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
237                                 local rspauth = md5(KD, true);
238                                 self.authenticated = true;
239                                 return "challenge", serialize({rspauth = rspauth});
240                         else
241                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
242                         end                                                     
243                 elseif self.step == 3 then
244                         if self.authenticated ~= nil then return "success"
245                         else return "failure", "malformed-request" end
246                 end
247         end
248         return object;
249 end
250
251 local function new_anonymous(realm, password_handler)
252         local object = { mechanism = "ANONYMOUS", realm = realm, password_handler = password_handler}
253                 function object.feed(self, message)
254                         return "success"
255                 end
256         object["username"] = generate_uuid()
257         return object
258 end
259
260
261 function new(mechanism, realm, password_handler)
262         local object
263         if mechanism == "PLAIN" then object = new_plain(realm, password_handler)
264         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, password_handler)
265         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, password_handler)
266         else
267                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
268                 return nil
269         end
270         return object
271 end
272
273 return _M;