util.sasl: Improved a log message.
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local to_unicode = require "util.encodings".idna.to_unicode;
23 local s_match = string.match;
24 local gmatch = string.gmatch
25 local string = string
26 local math = require "math"
27 local type = type
28 local error = error
29 local print = print
30
31 module "sasl"
32
33 -- Credentials handler:
34 --   Arguments: ("PLAIN", user, host, password)
35 --   Returns: true (success) | false (fail) | nil (user unknown)
36 local function new_plain(realm, credentials_handler)
37         local object = { mechanism = "PLAIN", realm = realm, credentials_handler = credentials_handler}
38         function object.feed(self, message)
39                 if message == "" or message == nil then return "failure", "malformed-request" end
40                 local response = message
41                 local authorization = s_match(response, "([^%z]+)")
42                 local authentication = s_match(response, "%z([^%z]+)%z")
43                 local password = s_match(response, "%z[^%z]+%z([^%z]+)")
44
45     if authentication == nil or password == nil then return "failure", "malformed-request" end
46     self.username = authentication
47     local auth_success = self.credentials_handler("PLAIN", self.username, self.realm, password)
48
49     if auth_success then
50       return "success"
51     elseif auth_success == nil then
52       return "failure", "account-disabled"
53     else
54       return "failure", "not-authorized"
55     end
56   end
57   return object
58 end
59
60 -- credentials_handler:
61 --   Arguments: (mechanism, node, domain, realm, decoder)
62 --   Returns: Password encoding, (plaintext) password
63 -- implementing RFC 2831
64 local function new_digest_md5(realm, credentials_handler)
65         --TODO complete support for authzid
66
67         local function serialize(message)
68                 local data = ""
69
70                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
71
72                 -- testing all possible values
73                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
74                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
75                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
76                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
77                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
78                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
79                 data = data:gsub(",$", "")
80                 return data
81         end
82
83         local function utf8tolatin1ifpossible(passwd)
84                 local i = 1;
85                 while i <= #passwd do
86                         local passwd_i = to_byte(passwd:sub(i, i));
87                         if passwd_i > 0x7F then
88                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
89                                         return passwd;
90                                 end
91                                 i = i + 1;
92                                 passwd_i = to_byte(passwd:sub(i, i));
93                                 if passwd_i < 0x80 or passwd_i > 0xBF then
94                                         return passwd;
95                                 end
96                         end
97                         i = i + 1;
98                 end
99
100                 local p = {};
101                 local j = 0;
102                 i = 1;
103                 while (i <= #passwd) do
104                         local passwd_i = to_byte(passwd:sub(i, i));
105                         if passwd_i > 0x7F then
106                                 i = i + 1;
107                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
108                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
109                         else
110                                 t_insert(p, to_char(passwd_i));
111                         end
112                         i = i + 1;
113                 end
114                 return t_concat(p);
115         end
116         local function latin1toutf8(str)
117                 local p = {};
118                 for ch in gmatch(str, ".") do
119                         ch = to_byte(ch);
120                         if (ch < 0x80) then
121                                 t_insert(p, to_char(ch));
122                         elseif (ch < 0xC0) then
123                                 t_insert(p, to_char(0xC2, ch));
124                         else
125                                 t_insert(p, to_char(0xC3, ch - 64));
126                         end
127                 end
128                 return t_concat(p);
129         end
130         local function parse(data)
131                 local message = {}
132                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
133                         message[k] = v;
134                 end
135                 return message;
136         end
137
138         local object = { mechanism = "DIGEST-MD5", realm = realm, credentials_handler = credentials_handler};
139
140         object.nonce = generate_uuid();
141         object.step = 0;
142         object.nonce_count = {};
143
144         function object.feed(self, message)
145                 self.step = self.step + 1;
146                 if (self.step == 1) then
147                         local challenge = serialize({   nonce = object.nonce,
148                                                                                         qop = "auth",
149                                                                                         charset = "utf-8",
150                                                                                         algorithm = "md5-sess",
151                                                                                         realm = self.realm});
152                         return "challenge", challenge;
153                 elseif (self.step == 2) then
154                         local response = parse(message);
155                         -- check for replay attack
156                         if response["nc"] then
157                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
158                         end
159
160                         -- check for username, it's REQUIRED by RFC 2831
161                         if not response["username"] then
162                                 return "failure", "malformed-request";
163                         end
164                         self["username"] = response["username"];
165
166                         -- check for nonce, ...
167                         if not response["nonce"] then
168                                 return "failure", "malformed-request";
169                         else
170                                 -- check if it's the right nonce
171                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
172                         end
173
174                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
175                         if not response["qop"] then response["qop"] = "auth" end
176
177                         if response["realm"] == nil or response["realm"] == "" then
178                                 response["realm"] = "";
179                         elseif response["realm"] ~= self.realm then
180                                 return "failure", "not-authorized", "Incorrect realm value";
181                         end
182
183                         local decoder;
184                         if response["charset"] == nil then
185                                 decoder = utf8tolatin1ifpossible;
186                         elseif response["charset"] ~= "utf-8" then
187                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
188                         end
189
190                         local domain = "";
191                         local protocol = "";
192                         if response["digest-uri"] then
193                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
194                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
195                         else
196                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
197                         end
198
199                         --TODO maybe realm support
200                         self.username = response["username"];
201                         local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
202                         if Y == nil then return "failure", "not-authorized"
203                         elseif Y == false then return "failure", "account-disabled" end
204                         local A1 = "";
205                         if response.authzid then
206                                 if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
207                                         -- COMPAT
208                                         log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
209                                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
210                                 else
211                                         return "failure", "invalid-authzid";
212                                 end
213                         else
214                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
215                         end
216                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
217
218                         local HA1 = md5(A1, true);
219                         local HA2 = md5(A2, true);
220
221                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
222                         local response_value = md5(KD, true);
223
224                         if response_value == response["response"] then
225                                 -- calculate rspauth
226                                 A2 = ":"..protocol.."/"..domain;
227
228                                 HA1 = md5(A1, true);
229                                 HA2 = md5(A2, true);
230
231                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
232                                 local rspauth = md5(KD, true);
233                                 self.authenticated = true;
234                                 return "challenge", serialize({rspauth = rspauth});
235                         else
236                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
237                         end
238                 elseif self.step == 3 then
239                         if self.authenticated ~= nil then return "success"
240                         else return "failure", "malformed-request" end
241                 end
242         end
243         return object;
244 end
245
246 -- Credentials handler: Can be nil. If specified, should take the mechanism as
247 -- the only argument, and return true for OK, or false for not-OK (TODO)
248 local function new_anonymous(realm, credentials_handler)
249         local object = { mechanism = "ANONYMOUS", realm = realm, credentials_handler = credentials_handler}
250                 function object.feed(self, message)
251                         return "success"
252                 end
253         object["username"] = generate_uuid()
254         return object
255 end
256
257
258 function new(mechanism, realm, credentials_handler)
259         local object
260         if mechanism == "PLAIN" then object = new_plain(realm, credentials_handler)
261         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, credentials_handler)
262         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, credentials_handler)
263         else
264                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
265                 return nil
266         end
267         return object
268 end
269
270 return _M;