util.sasl: Add COMPAT comment
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local to_unicode = require "util.encodings".idna.to_unicode;
23 local s_match = string.match;
24 local gmatch = string.gmatch
25 local string = string
26 local math = require "math"
27 local type = type
28 local error = error
29 local print = print
30
31 module "sasl"
32
33 -- Credentials handler:
34 --   Arguments: ("PLAIN", user, host, password)
35 --   Returns: true (success) | false (fail) | nil (user unknown)
36 local function new_plain(realm, credentials_handler)
37         local object = { mechanism = "PLAIN", realm = realm, credentials_handler = credentials_handler}
38         function object.feed(self, message)
39                 if message == "" or message == nil then return "failure", "malformed-request" end
40                 local response = message
41                 local authorization = s_match(response, "([^%z]+)")
42                 local authentication = s_match(response, "%z([^%z]+)%z")
43                 local password = s_match(response, "%z[^%z]+%z([^%z]+)")
44
45     if authentication == nil or password == nil then return "failure", "malformed-request" end
46     self.username = authentication
47     local auth_success = self.credentials_handler("PLAIN", self.username, self.realm, password)
48
49     if auth_success then
50       return "success"
51     elseif auth_success == nil then
52       return "failure", "account-disabled"
53     else
54       return "failure", "not-authorized"
55     end
56   end
57   return object
58 end
59
60 -- credentials_handler:
61 --   Arguments: (mechanism, node, domain, realm, decoder)
62 --   Returns: Password encoding, (plaintext) password
63 -- implementing RFC 2831
64 local function new_digest_md5(realm, credentials_handler)
65         --TODO complete support for authzid
66
67         local function serialize(message)
68                 local data = ""
69
70                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
71
72                 -- testing all possible values
73                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
74                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
75                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
76                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
77                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
78                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
79                 data = data:gsub(",$", "")
80                 return data
81         end
82
83         local function utf8tolatin1ifpossible(passwd)
84                 local i = 1;
85                 while i <= #passwd do
86                         local passwd_i = to_byte(passwd:sub(i, i));
87                         if passwd_i > 0x7F then
88                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
89                                         return passwd;
90                                 end
91                                 i = i + 1;
92                                 passwd_i = to_byte(passwd:sub(i, i));
93                                 if passwd_i < 0x80 or passwd_i > 0xBF then
94                                         return passwd;
95                                 end
96                         end
97                         i = i + 1;
98                 end
99
100                 local p = {};
101                 local j = 0;
102                 i = 1;
103                 while (i <= #passwd) do
104                         local passwd_i = to_byte(passwd:sub(i, i));
105                         if passwd_i > 0x7F then
106                                 i = i + 1;
107                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
108                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
109                         else
110                                 t_insert(p, to_char(passwd_i));
111                         end
112                         i = i + 1;
113                 end
114                 return t_concat(p);
115         end
116         local function latin1toutf8(str)
117                 local p = {};
118                 for ch in gmatch(str, ".") do
119                         ch = to_byte(ch);
120                         if (ch < 0x80) then
121                                 t_insert(p, to_char(ch));
122                         elseif (ch < 0xC0) then
123                                 t_insert(p, to_char(0xC2, ch));
124                         else
125                                 t_insert(p, to_char(0xC3, ch - 64));
126                         end
127                 end
128                 return t_concat(p);
129         end
130         local function parse(data)
131                 local message = {}
132                 -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
133                 for k, v in gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
134                         message[k] = v;
135                 end
136                 return message;
137         end
138
139         local object = { mechanism = "DIGEST-MD5", realm = realm, credentials_handler = credentials_handler};
140
141         object.nonce = generate_uuid();
142         object.step = 0;
143         object.nonce_count = {};
144
145         function object.feed(self, message)
146                 self.step = self.step + 1;
147                 if (self.step == 1) then
148                         local challenge = serialize({   nonce = object.nonce,
149                                                                                         qop = "auth",
150                                                                                         charset = "utf-8",
151                                                                                         algorithm = "md5-sess",
152                                                                                         realm = self.realm});
153                         return "challenge", challenge;
154                 elseif (self.step == 2) then
155                         local response = parse(message);
156                         -- check for replay attack
157                         if response["nc"] then
158                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
159                         end
160
161                         -- check for username, it's REQUIRED by RFC 2831
162                         if not response["username"] then
163                                 return "failure", "malformed-request";
164                         end
165                         self["username"] = response["username"];
166
167                         -- check for nonce, ...
168                         if not response["nonce"] then
169                                 return "failure", "malformed-request";
170                         else
171                                 -- check if it's the right nonce
172                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
173                         end
174
175                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
176                         if not response["qop"] then response["qop"] = "auth" end
177
178                         if response["realm"] == nil or response["realm"] == "" then
179                                 response["realm"] = "";
180                         elseif response["realm"] ~= self.realm then
181                                 return "failure", "not-authorized", "Incorrect realm value";
182                         end
183
184                         local decoder;
185                         if response["charset"] == nil then
186                                 decoder = utf8tolatin1ifpossible;
187                         elseif response["charset"] ~= "utf-8" then
188                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
189                         end
190
191                         local domain = "";
192                         local protocol = "";
193                         if response["digest-uri"] then
194                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
195                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
196                         else
197                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
198                         end
199
200                         --TODO maybe realm support
201                         self.username = response["username"];
202                         local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
203                         if Y == nil then return "failure", "not-authorized"
204                         elseif Y == false then return "failure", "account-disabled" end
205                         local A1 = "";
206                         if response.authzid then
207                                 if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
208                                         -- COMPAT
209                                         log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
210                                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
211                                 else
212                                         return "failure", "invalid-authzid";
213                                 end
214                         else
215                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
216                         end
217                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
218
219                         local HA1 = md5(A1, true);
220                         local HA2 = md5(A2, true);
221
222                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
223                         local response_value = md5(KD, true);
224
225                         if response_value == response["response"] then
226                                 -- calculate rspauth
227                                 A2 = ":"..protocol.."/"..domain;
228
229                                 HA1 = md5(A1, true);
230                                 HA2 = md5(A2, true);
231
232                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
233                                 local rspauth = md5(KD, true);
234                                 self.authenticated = true;
235                                 return "challenge", serialize({rspauth = rspauth});
236                         else
237                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
238                         end
239                 elseif self.step == 3 then
240                         if self.authenticated ~= nil then return "success"
241                         else return "failure", "malformed-request" end
242                 end
243         end
244         return object;
245 end
246
247 -- Credentials handler: Can be nil. If specified, should take the mechanism as
248 -- the only argument, and return true for OK, or false for not-OK (TODO)
249 local function new_anonymous(realm, credentials_handler)
250         local object = { mechanism = "ANONYMOUS", realm = realm, credentials_handler = credentials_handler}
251                 function object.feed(self, message)
252                         return "success"
253                 end
254         object["username"] = generate_uuid()
255         return object
256 end
257
258
259 function new(mechanism, realm, credentials_handler)
260         local object
261         if mechanism == "PLAIN" then object = new_plain(realm, credentials_handler)
262         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, credentials_handler)
263         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, credentials_handler)
264         else
265                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
266                 return nil
267         end
268         return object
269 end
270
271 return _M;