util/sasl: Removed unnecessary references to util.encodings.idna
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 -- 
4 --    All rights reserved.
5 --    
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --    
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --    
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local s_match = string.match;
23 local gmatch = string.gmatch
24 local string = string
25 local math = require "math"
26 local type = type
27 local error = error
28 local print = print
29
30 module "sasl"
31
32 local function new_plain(realm, password_handler)
33         local object = { mechanism = "PLAIN", realm = realm, password_handler = password_handler}
34         function object.feed(self, message)
35         
36                 if message == "" or message == nil then return "failure", "malformed-request" end
37                 local response = message
38                 local authorization = s_match(response, "([^&%z]+)")
39                 local authentication = s_match(response, "%z([^&%z]+)%z")
40                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
41                 
42                 if authentication == nil or password == nil then return "failure", "malformed-request" end
43                 
44                 local password_encoding, correct_password = self.password_handler(authentication, self.realm, "PLAIN")
45                 
46                 if correct_password == nil then return "failure", "not-authorized"
47                 elseif correct_password == false then return "failure", "account-disabled" end
48                 
49                 local claimed_password = ""
50                 if password_encoding == nil then claimed_password = password
51                 else claimed_password = password_encoding(password) end
52                 
53                 self.username = authentication
54                 if claimed_password == correct_password then
55                         return "success"
56                 else
57                         return "failure", "not-authorized"
58                 end
59         end
60         return object
61 end
62
63 local function new_digest_md5(realm, password_handler)
64         --TODO maybe support for authzid
65
66         local function serialize(message)
67                 local data = ""
68                 
69                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
70                 
71                 -- testing all possible values
72                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
73                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
74                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
75                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
76                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
77                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
78                 data = data:gsub(",$", "")
79                 return data
80         end
81         
82         local function utf8tolatin1ifpossible(passwd)
83                 local i = 1;
84                 while i <= #passwd do
85                         local passwd_i = to_byte(passwd:sub(i, i));
86                         if passwd_i > 0x7F then
87                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
88                                         return passwd;
89                                 end
90                                 i = i + 1;
91                                 passwd_i = to_byte(passwd:sub(i, i));
92                                 if passwd_i < 0x80 or passwd_i > 0xBF then
93                                         return passwd;
94                                 end
95                         end
96                         i = i + 1;
97                 end
98
99                 local p = {};
100                 local j = 0;
101                 i = 1;
102                 while (i <= #passwd) do
103                         local passwd_i = to_byte(passwd:sub(i, i));
104                         if passwd_i > 0x7F then
105                                 i = i + 1;
106                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
107                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
108                         else
109                                 t_insert(p, to_char(passwd_i));
110                         end
111                         i = i + 1;
112                 end
113                 return t_concat(p);
114         end
115         local function latin1toutf8(str)
116                 local p = {};
117                 for ch in gmatch(str, ".") do
118                         ch = to_byte(ch);
119                         if (ch < 0x80) then
120                                 t_insert(p, to_char(ch));
121                         elseif (ch < 0xC0) then
122                                 t_insert(p, to_char(0xC2, ch));
123                         else
124                                 t_insert(p, to_char(0xC3, ch - 64));
125                         end
126                 end
127                 return t_concat(p);
128         end
129         local function parse(data)
130                 message = {}
131                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
132                         message[k] = v
133                 end
134                 return message
135         end
136
137         local object = { mechanism = "DIGEST-MD5", realm = realm, password_handler = password_handler}
138         
139         --TODO: something better than math.random would be nice, maybe OpenSSL's random number generator
140         object.nonce = generate_uuid()
141         object.step = 0
142         object.nonce_count = {}
143                                                                                                 
144         function object.feed(self, message)
145                 self.step = self.step + 1
146                 if (self.step == 1) then
147                         local challenge = serialize({   nonce = object.nonce, 
148                                                                                         qop = "auth",
149                                                                                         charset = "utf-8",
150                                                                                         algorithm = "md5-sess",
151                                                                                         realm = self.realm});
152                         return "challenge", challenge
153                 elseif (self.step == 2) then
154                         local response = parse(message)
155                         -- check for replay attack
156                         if response["nc"] then
157                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
158                         end
159                         
160                         -- check for username, it's REQUIRED by RFC 2831
161                         if not response["username"] then
162                                 return "failure", "malformed-request"
163                         end
164                         self["username"] = response["username"] 
165                         
166                         -- check for nonce, ...
167                         if not response["nonce"] then
168                                 return "failure", "malformed-request"
169                         else
170                                 -- check if it's the right nonce
171                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
172                         end
173                         
174                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
175                         if not response["qop"] then response["qop"] = "auth" end
176                         
177                         if response["realm"] == nil or response["realm"] == "" then
178                                 response["realm"] = self.realm;
179                         elseif response["realm"] ~= self.realm then
180                                 return "failure", "not-authorized", "Incorrect realm value";
181                         end
182                         
183                         local decoder;
184                         if response["charset"] == nil then
185                                 decoder = utf8tolatin1ifpossible;
186                         elseif response["charset"] ~= "utf-8" then
187                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8."
188                         end
189                         
190                         local domain = ""
191                         local protocol = ""
192                         if response["digest-uri"] then
193                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$")
194                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
195                         else
196                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
197                         end
198                         
199                         --TODO maybe realm support
200                         self.username = response["username"]
201                         local password_encoding, Y = self.password_handler(response["username"], response["realm"], "DIGEST-MD5", decoder)
202                         if Y == nil then return "failure", "not-authorized"
203                         elseif Y == false then return "failure", "account-disabled" end
204                         
205                         local A1 = Y..":"..response["nonce"]..":"..response["cnonce"]--:authzid
206                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
207                         
208                         local HA1 = md5(A1, true)
209                         local HA2 = md5(A2, true)
210                         
211                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
212                         local response_value = md5(KD, true)
213                         
214                         if response_value == response["response"] then
215                                 -- calculate rspauth
216                                 A2 = ":"..protocol.."/"..domain;
217                                 
218                                 HA1 = md5(A1, true)
219                                 HA2 = md5(A2, true)
220         
221                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
222                                 local rspauth = md5(KD, true)
223                                 self.authenticated = true
224                                 return "challenge", serialize({rspauth = rspauth})
225                         else
226                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
227                         end                                                     
228                 elseif self.step == 3 then
229                         if self.authenticated ~= nil then return "success"
230                         else return "failure", "malformed-request" end
231                 end
232         end
233         return object
234 end
235
236 local function new_anonymous(realm, password_handler)
237         local object = { mechanism = "ANONYMOUS", realm = realm, password_handler = password_handler}
238                 function object.feed(self, message)
239                         return "success"
240                 end
241         --TODO: From XEP-0175 "It is RECOMMENDED for the node identifier to be a UUID as specified in RFC 4122 [5]." So util.uuid() should (or have an option to) behave as specified in RFC 4122.
242         object["username"] = generate_uuid()
243         return object
244 end
245
246
247 function new(mechanism, realm, password_handler)
248         local object
249         if mechanism == "PLAIN" then object = new_plain(realm, password_handler)
250         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, password_handler)
251         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, password_handler)
252         else
253                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
254                 return nil
255         end
256         return object
257 end
258
259 return _M;