0.3->0.4
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 -- 
4 --    All rights reserved.
5 --    
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --    
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --    
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local tostring = tostring;
18 local st = require "util.stanza";
19 local generate_uuid = require "util.uuid".generate;
20 local t_insert, t_concat = table.insert, table.concat;
21 local to_byte, to_char = string.byte, string.char;
22 local s_match = string.match;
23 local gmatch = string.gmatch
24 local string = string
25 local math = require "math"
26 local type = type
27 local error = error
28 local print = print
29 local idna_ascii = require "util.encodings".idna.to_ascii
30 local idna_unicode = require "util.encodings".idna.to_unicode
31
32 module "sasl"
33
34 local function new_plain(realm, password_handler)
35         local object = { mechanism = "PLAIN", realm = realm, password_handler = password_handler}
36         function object.feed(self, message)
37         
38                 if message == "" or message == nil then return "failure", "malformed-request" end
39                 local response = message
40                 local authorization = s_match(response, "([^&%z]+)")
41                 local authentication = s_match(response, "%z([^&%z]+)%z")
42                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
43                 
44                 if authentication == nil or password == nil then return "failure", "malformed-request" end
45                 
46                 local password_encoding, correct_password = self.password_handler(authentication, self.realm, "PLAIN")
47                 
48                 if correct_password == nil then return "failure", "not-authorized"
49                 elseif correct_password == false then return "failure", "account-disabled" end
50                 
51                 local claimed_password = ""
52                 if password_encoding == nil then claimed_password = password
53                 else claimed_password = password_encoding(password) end
54                 
55                 self.username = authentication
56                 if claimed_password == correct_password then
57                         return "success"
58                 else
59                         return "failure", "not-authorized"
60                 end
61         end
62         return object
63 end
64
65 local function new_digest_md5(realm, password_handler)
66         --TODO maybe support for authzid
67
68         local function serialize(message)
69                 local data = ""
70                 
71                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
72                 
73                 -- testing all possible values
74                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
75                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
76                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
77                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
78                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
79                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
80                 data = data:gsub(",$", "")
81                 return data
82         end
83         
84         local function utf8tolatin1ifpossible(passwd)
85                 local i = 1;
86                 while i <= #passwd do
87                         local passwd_i = to_byte(passwd:sub(i, i));
88                         if passwd_i > 0x7F then
89                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
90                                         return passwd;
91                                 end
92                                 i = i + 1;
93                                 passwd_i = to_byte(passwd:sub(i, i));
94                                 if passwd_i < 0x80 or passwd_i > 0xBF then
95                                         return passwd;
96                                 end
97                         end
98                         i = i + 1;
99                 end
100
101                 local p = {};
102                 local j = 0;
103                 i = 1;
104                 while (i <= #passwd) do
105                         local passwd_i = to_byte(passwd:sub(i, i));
106                         if passwd_i > 0x7F then
107                                 i = i + 1;
108                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
109                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
110                         else
111                                 t_insert(p, to_char(passwd_i));
112                         end
113                         i = i + 1;
114                 end
115                 return t_concat(p);
116         end
117         local function latin1toutf8(str)
118                 local p = {};
119                 for ch in gmatch(str, ".") do
120                         ch = to_byte(ch);
121                         if (ch < 0x80) then
122                                 t_insert(p, to_char(ch));
123                         elseif (ch < 0xC0) then
124                                 t_insert(p, to_char(0xC2, ch));
125                         else
126                                 t_insert(p, to_char(0xC3, ch - 64));
127                         end
128                 end
129                 return t_concat(p);
130         end
131         local function parse(data)
132                 message = {}
133                 for k, v in gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
134                         message[k] = v
135                 end
136                 return message
137         end
138
139         local object = { mechanism = "DIGEST-MD5", realm = realm, password_handler = password_handler}
140         
141         --TODO: something better than math.random would be nice, maybe OpenSSL's random number generator
142         object.nonce = generate_uuid()
143         object.step = 0
144         object.nonce_count = {}
145                                                                                                 
146         function object.feed(self, message)
147                 self.step = self.step + 1
148                 if (self.step == 1) then
149                         local challenge = serialize({   nonce = object.nonce, 
150                                                                                         qop = "auth",
151                                                                                         charset = "utf-8",
152                                                                                         algorithm = "md5-sess",
153                                                                                         realm = self.realm});
154                         return "challenge", challenge
155                 elseif (self.step == 2) then
156                         local response = parse(message)
157                         -- check for replay attack
158                         if response["nc"] then
159                                 if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
160                         end
161                         
162                         -- check for username, it's REQUIRED by RFC 2831
163                         if not response["username"] then
164                                 return "failure", "malformed-request"
165                         end
166                         self["username"] = response["username"] 
167                         
168                         -- check for nonce, ...
169                         if not response["nonce"] then
170                                 return "failure", "malformed-request"
171                         else
172                                 -- check if it's the right nonce
173                                 if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
174                         end
175                         
176                         if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
177                         if not response["qop"] then response["qop"] = "auth" end
178                         
179                         if response["realm"] == nil or response["realm"] == "" then
180                                 response["realm"] = self.realm;
181                         elseif response["realm"] ~= self.realm then
182                                 return "failure", "not-authorized", "Incorrect realm value";
183                         end
184                         
185                         local decoder;
186                         if response["charset"] == nil then
187                                 decoder = utf8tolatin1ifpossible;
188                         elseif response["charset"] ~= "utf-8" then
189                                 return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8."
190                         end
191                         
192                         local domain = ""
193                         local protocol = ""
194                         if response["digest-uri"] then
195                                 protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$")
196                                 if protocol == nil or domain == nil then return "failure", "malformed-request" end
197                         else
198                                 return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
199                         end
200                         
201                         --TODO maybe realm support
202                         self.username = response["username"]
203                         local password_encoding, Y = self.password_handler(response["username"], response["realm"], "DIGEST-MD5", decoder)
204                         if Y == nil then return "failure", "not-authorized"
205                         elseif Y == false then return "failure", "account-disabled" end
206                         
207                         local A1 = Y..":"..response["nonce"]..":"..response["cnonce"]--:authzid
208                         local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
209                         
210                         local HA1 = md5(A1, true)
211                         local HA2 = md5(A2, true)
212                         
213                         local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
214                         local response_value = md5(KD, true)
215                         
216                         if response_value == response["response"] then
217                                 -- calculate rspauth
218                                 A2 = ":"..protocol.."/"..domain;
219                                 
220                                 HA1 = md5(A1, true)
221                                 HA2 = md5(A2, true)
222         
223                                 KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
224                                 local rspauth = md5(KD, true)
225                                 self.authenticated = true
226                                 return "challenge", serialize({rspauth = rspauth})
227                         else
228                                 return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
229                         end                                                     
230                 elseif self.step == 3 then
231                         if self.authenticated ~= nil then return "success"
232                         else return "failure", "malformed-request" end
233                 end
234         end
235         return object
236 end
237
238 local function new_anonymous(realm, password_handler)
239         local object = { mechanism = "ANONYMOUS", realm = realm, password_handler = password_handler}
240                 function object.feed(self, message)
241                         return "success"
242                 end
243         --TODO: From XEP-0175 "It is RECOMMENDED for the node identifier to be a UUID as specified in RFC 4122 [5]." So util.uuid() should (or have an option to) behave as specified in RFC 4122.
244         object["username"] = generate_uuid()
245         return object
246 end
247
248
249 function new(mechanism, realm, password_handler)
250         local object
251         if mechanism == "PLAIN" then object = new_plain(realm, password_handler)
252         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(realm, password_handler)
253         elseif mechanism == "ANONYMOUS" then object = new_anonymous(realm, password_handler)
254         else
255                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
256                 return nil
257         end
258         return object
259 end
260
261 return _M;