Merge with backout
[prosody.git] / util / sasl / digest-md5.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local tostring = tostring;
15 local type = type;
16
17 local s_gmatch = string.gmatch;
18 local s_match = string.match;
19 local t_concat = table.concat;
20 local t_insert = table.insert;
21 local to_byte, to_char = string.byte, string.char;
22
23 local md5 = require "util.hashes".md5;
24 local log = require "util.logger".init("sasl");
25 local generate_uuid = require "util.uuid".generate;
26
27 module "digest-md5"
28
29 --=========================
30 --SASL DIGEST-MD5 according to RFC 2831
31
32 local function digest(self, message)
33         --TODO complete support for authzid
34
35         local function serialize(message)
36                 local data = ""
37
38                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
39
40                 -- testing all possible values
41                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
42                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
43                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
44                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
45                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
46                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
47                 data = data:gsub(",$", "")
48                 return data
49         end
50
51         local function utf8tolatin1ifpossible(passwd)
52                 local i = 1;
53                 while i <= #passwd do
54                         local passwd_i = to_byte(passwd:sub(i, i));
55                         if passwd_i > 0x7F then
56                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
57                                         return passwd;
58                                 end
59                                 i = i + 1;
60                                 passwd_i = to_byte(passwd:sub(i, i));
61                                 if passwd_i < 0x80 or passwd_i > 0xBF then
62                                         return passwd;
63                                 end
64                         end
65                         i = i + 1;
66                 end
67
68                 local p = {};
69                 local j = 0;
70                 i = 1;
71                 while (i <= #passwd) do
72                         local passwd_i = to_byte(passwd:sub(i, i));
73                         if passwd_i > 0x7F then
74                                 i = i + 1;
75                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
76                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
77                         else
78                                 t_insert(p, to_char(passwd_i));
79                         end
80                         i = i + 1;
81                 end
82                 return t_concat(p);
83         end
84         local function latin1toutf8(str)
85                 local p = {};
86                 for ch in s_gmatch(str, ".") do
87                         ch = to_byte(ch);
88                         if (ch < 0x80) then
89                                 t_insert(p, to_char(ch));
90                         elseif (ch < 0xC0) then
91                                 t_insert(p, to_char(0xC2, ch));
92                         else
93                                 t_insert(p, to_char(0xC3, ch - 64));
94                         end
95                 end
96                 return t_concat(p);
97         end
98         local function parse(data)
99                 local message = {}
100                 -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
101                 for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
102                         message[k] = v;
103                 end
104                 return message;
105         end
106
107         if not self.nonce then
108                 self.nonce = generate_uuid();
109                 self.step = 0;
110                 self.nonce_count = {};
111         end
112
113         self.step = self.step + 1;
114         if (self.step == 1) then
115                 local challenge = serialize({   nonce = self.nonce,
116                                                                                 qop = "auth",
117                                                                                 charset = "utf-8",
118                                                                                 algorithm = "md5-sess",
119                                                                                 realm = self.realm});
120                 return "challenge", challenge;
121         elseif (self.step == 2) then
122                 local response = parse(message);
123                 -- check for replay attack
124                 if response["nc"] then
125                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
126                 end
127
128                 -- check for username, it's REQUIRED by RFC 2831
129                 if not response["username"] then
130                         return "failure", "malformed-request";
131                 end
132                 self["username"] = response["username"];
133
134                 -- check for nonce, ...
135                 if not response["nonce"] then
136                         return "failure", "malformed-request";
137                 else
138                         -- check if it's the right nonce
139                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
140                 end
141
142                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
143                 if not response["qop"] then response["qop"] = "auth" end
144
145                 if response["realm"] == nil or response["realm"] == "" then
146                         response["realm"] = "";
147                 elseif response["realm"] ~= self.realm then
148                         return "failure", "not-authorized", "Incorrect realm value";
149                 end
150
151                 local decoder;
152                 if response["charset"] == nil then
153                         decoder = utf8tolatin1ifpossible;
154                 elseif response["charset"] ~= "utf-8" then
155                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
156                 end
157
158                 local domain = "";
159                 local protocol = "";
160                 if response["digest-uri"] then
161                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
162                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
163                 else
164                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
165                 end
166
167                 --TODO maybe realm support
168                 self.username = response["username"];
169                 local Y, state;
170                 if self.profile.plain then
171                         local password, state = self.profile.plain(response["username"], self.realm)
172                         if state == nil then return "failure", "not-authorized"
173                         elseif state == false then return "failure", "account-disabled" end
174                         Y = md5(response["username"]..":"..response["realm"]..":"..password);
175                 elseif self.profile["digest-md5"] then
176                         Y, state = self.profile["digest-md5"](response["username"], self.realm, response["realm"], response["charset"])
177                         if state == nil then return "failure", "not-authorized"
178                         elseif state == false then return "failure", "account-disabled" end
179                 elseif self.profile["digest-md5-test"] then
180                         -- TODO
181                 end
182                 --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
183                 --if Y == nil then return "failure", "not-authorized"
184                 --elseif Y == false then return "failure", "account-disabled" end
185                 local A1 = "";
186                 if response.authzid then
187                         if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
188                                 -- COMPAT
189                                 log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
190                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
191                         else
192                                 return "failure", "invalid-authzid";
193                         end
194                 else
195                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
196                 end
197                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
198
199                 local HA1 = md5(A1, true);
200                 local HA2 = md5(A2, true);
201
202                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
203                 local response_value = md5(KD, true);
204
205                 if response_value == response["response"] then
206                         -- calculate rspauth
207                         A2 = ":"..protocol.."/"..domain;
208
209                         HA1 = md5(A1, true);
210                         HA2 = md5(A2, true);
211
212                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
213                         local rspauth = md5(KD, true);
214                         self.authenticated = true;
215                         --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
216                         return "challenge", serialize({rspauth = rspauth});
217                 else
218                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
219                 end
220         elseif self.step == 3 then
221                 if self.authenticated ~= nil then return "success"
222                 else return "failure", "malformed-request" end
223         end
224 end
225
226 function init(registerMechanism)
227         registerMechanism("DIGEST-MD5", {"plain"}, digest);
228 end
229
230 return _M;