util.sasl: Convert spaces to tabs
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 -- 
4 --    All rights reserved.
5 --    
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --    
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --    
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local to_unicode = require "util.encodings".idna.to_unicode;
23 local s_match = string.match;
24 local gmatch = string.gmatch
25 local string = string
26 local math = require "math"
27 local type = type
28 local error = error
29 local print = print
30
31 module "sasl"
32
33 local function new_plain(realm, password_handler)
34         local object = { mechanism = "PLAIN", realm = realm, password_handler = password_handler}
35         function object.feed(self, message)
36         
37                 if message == "" or message == nil then return "failure", "malformed-request" end
38                 local response = message
39                 local authorization = s_match(response, "([^&%z]+)")
40                 local authentication = s_match(response, "%z([^&%z]+)%z")
41                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
42                 
43                 if authentication == nil or password == nil then return "failure", "malformed-request" end
44                 
45                 local password_encoding, correct_password = self.password_handler(authentication, self.realm, self.realm, "PLAIN")
46                 
47                 if correct_password == nil then return "failure", "not-authorized"
48                 elseif correct_password == false then return "failure", "account-disabled" end
49                 
50                 local claimed_password = ""
51                 if password_encoding == nil then claimed_password = password
52                 else claimed_password = password_encoding(password) end
53                 
54                 self.username = authentication
55                 if claimed_password == correct_password then
56                         return "success"
57                 else
58                         return "failure", "not-authorized"
59                 end
60         end
61         return object
62 end
63
64
65 -- implementing RFC 2831
66 local function new_digest_md5(realm, password_handler)
67         --TODO complete support for authzid
68
69         local function serialize(message)
70                 local data = ""
71                 
72                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
73                 
74                 -- testing all possible values
75                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
76                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
77                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
78                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
79                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
80                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
81                 data = data:gsub(",$", "")
82                 return data
83         end
84         
85         local function utf8tolatin1ifpossible(passwd)
86                 local i = 1;
87                 while i <= #passwd do
88                         local passwd_i = to_byte(passwd:sub(i, i));
89                         if passwd_i > 0x7F then
90                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
91                                         return passwd;
92                                 end
93                                 i = i + 1;
94                                 passwd_i = to_byte(passwd:sub(i, i));
95                                 if passwd_i < 0x80 or passwd_i > 0xBF then
96                                         return passwd;
97                                 end
98                         end
99                         i = i + 1;
100                 end
101
102                 local p = {};
103                 local j = 0;
104                 i = 1;
105                 while (i <= #passwd) do
106                         local passwd_i = to_byte(passwd:sub(i, i));
107                         if passwd_i > 0x7F then
108                                 i = i + 1;
109                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
110                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
111                         else
112                                 t_insert(p, to_char(passwd_i));
113                         end
114                         i = i + 1;
115                 end
116                 return t_concat(p);
117         end
118         local function latin1toutf8(str)
119                 local p = {};
120                 for ch in gmatch(str, ".") do
121                         ch = to_byte(ch);
122                         if (ch < 0x80) then
123                                 t_insert(p, to_char(ch));
124                         elseif (ch < 0xC0) then
125                                 t_insert(p, to_char(0xC2, ch));
126                         else
127                                 t_insert(p, to_char(0xC3, ch - 64));
128                         end
129                 end
130                 return t_concat(p);
131         end
132         local function parse(data)
133                 message = {}
134                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
135                         message[k] = v;
136                 end
137                 return message;
138         end
139
140         local object = { mechanism = "DIGEST-MD5", realm = realm, password_handler = password_handler};
141         
142         object.nonce = generate_uuid();
143         object.step = 0;
144         object.nonce_count = {};
145                                                                                                 
146         function object.feed(self, message)
147                 self.step = self.step + 1;
148                 if (self.step == 1) then
149                         local challenge = serialize({   nonce = object.nonce, 
150                                                                                         qop = "auth",
151                                                                                         charset = "utf-8",
152                                                                                         algorithm = "md5-sess",
153                                                                                         realm = self.realm});
154                         return "challenge", challenge;
155                 elseif (self.step == 2) then
156                         local response = parse(message);
157                         -- check for replay attack
158                         if response["nc"] then
159                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
160                         end
161                         
162                         -- check for username, it's REQUIRED by RFC 2831
163                         if not response["username"] then
164                                 return "failure", "malformed-request";
165                         end
166                         self["username"] = response["username"];
167                         
168                         -- check for nonce, ...
169                         if not response["nonce"] then
170                                 return "failure", "malformed-request";
171                         else
172                                 -- check if it's the right nonce
173                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
174                         end
175                         
176                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
177                         if not response["qop"] then response["qop"] = "auth" end
178                         
179                         if response["realm"] == nil or response["realm"] == "" then
180                                 response["realm"] = "";
181                         elseif response["realm"] ~= self.realm then
182                                 return "failure", "not-authorized", "Incorrect realm value";
183                         end
184                         
185                         local decoder;
186                         if response["charset"] == nil then
187                                 decoder = utf8tolatin1ifpossible;
188                         elseif response["charset"] ~= "utf-8" then
189                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
190                         end
191                         
192                         local domain = "";
193                         local protocol = "";
194                         if response["digest-uri"] then
195                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
196                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
197                         else
198                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
199                         end
200                         
201                         --TODO maybe realm support
202                         self.username = response["username"];
203                         local password_encoding, Y = self.password_handler(response["username"], to_unicode(domain), response["realm"], "DIGEST-MD5", decoder);
204                         if Y == nil then return "failure", "not-authorized"
205                         elseif Y == false then return "failure", "account-disabled" end
206                         local A1 = "";
207                         if response.authzid then
208                                 if response.authzid == self.username.."@"..self.realm then
209                                         -- COMPAT
210                                         log("warn", "Client is violating XMPP RFC. See section 6.1 of RFC 3920.");
211                                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
212                                 else
213                                         A1 = "?";
214                                 end
215                         else
216                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
217                         end
218                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
219                         
220                         local HA1 = md5(A1, true);
221                         local HA2 = md5(A2, true);
222                         
223                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
224                         local response_value = md5(KD, true);
225                         
226                         if response_value == response["response"] then
227                                 -- calculate rspauth
228                                 A2 = ":"..protocol.."/"..domain;
229                                 
230                                 HA1 = md5(A1, true);
231                                 HA2 = md5(A2, true);
232                                 
233                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
234                                 local rspauth = md5(KD, true);
235                                 self.authenticated = true;
236                                 return "challenge", serialize({rspauth = rspauth});
237                         else
238                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
239                         end                                                     
240                 elseif self.step == 3 then
241                         if self.authenticated ~= nil then return "success"
242                         else return "failure", "malformed-request" end
243                 end
244         end
245         return object;
246 end
247
248 local function new_anonymous(realm, password_handler)
249         local object = { mechanism = "ANONYMOUS", realm = realm, password_handler = password_handler}
250                 function object.feed(self, message)
251                         return "success"
252                 end
253         object["username"] = generate_uuid()
254         return object
255 end
256
257
258 function new(mechanism, realm, password_handler)
259         local object
260         if mechanism == "PLAIN" then object = new_plain(realm, password_handler)
261         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, password_handler)
262         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, password_handler)
263         else
264                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
265                 return nil
266         end
267         return object
268 end
269
270 return _M;