2837148ec4723270b92d11b30ac551299e230184
[prosody.git] / util / sasl / digest-md5.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2010 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local tostring = tostring;
15 local type = type;
16
17 local s_gmatch = string.gmatch;
18 local s_match = string.match;
19 local t_concat = table.concat;
20 local t_insert = table.insert;
21 local to_byte, to_char = string.byte, string.char;
22
23 local md5 = require "util.hashes".md5;
24 local log = require "util.logger".init("sasl");
25 local generate_uuid = require "util.uuid".generate;
26
27 module "digest-md5"
28
29 --=========================
30 --SASL DIGEST-MD5 according to RFC 2831
31
32 --[[
33 Supported Authentication Backends
34
35 digest_md5:
36         function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
37                                                                                                 -- implementations it's not
38                 return digesthash, state;
39         end
40
41 digest_md5_test:
42         function(username, domain, realm, encoding, digesthash)
43                 return true or false, state;
44         end
45 ]]
46
47 local function digest(self, message)
48         --TODO complete support for authzid
49
50         local function serialize(message)
51                 local data = ""
52
53                 -- testing all possible values
54                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
55                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
56                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
57                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
58                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
59                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
60                 data = data:gsub(",$", "")
61                 return data
62         end
63
64         local function utf8tolatin1ifpossible(passwd)
65                 local i = 1;
66                 while i <= #passwd do
67                         local passwd_i = to_byte(passwd:sub(i, i));
68                         if passwd_i > 0x7F then
69                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
70                                         return passwd;
71                                 end
72                                 i = i + 1;
73                                 passwd_i = to_byte(passwd:sub(i, i));
74                                 if passwd_i < 0x80 or passwd_i > 0xBF then
75                                         return passwd;
76                                 end
77                         end
78                         i = i + 1;
79                 end
80
81                 local p = {};
82                 local j = 0;
83                 i = 1;
84                 while (i <= #passwd) do
85                         local passwd_i = to_byte(passwd:sub(i, i));
86                         if passwd_i > 0x7F then
87                                 i = i + 1;
88                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
89                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
90                         else
91                                 t_insert(p, to_char(passwd_i));
92                         end
93                         i = i + 1;
94                 end
95                 return t_concat(p);
96         end
97         local function latin1toutf8(str)
98                 local p = {};
99                 for ch in s_gmatch(str, ".") do
100                         ch = to_byte(ch);
101                         if (ch < 0x80) then
102                                 t_insert(p, to_char(ch));
103                         elseif (ch < 0xC0) then
104                                 t_insert(p, to_char(0xC2, ch));
105                         else
106                                 t_insert(p, to_char(0xC3, ch - 64));
107                         end
108                 end
109                 return t_concat(p);
110         end
111         local function parse(data)
112                 local message = {}
113                 -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
114                 for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
115                         message[k] = v;
116                 end
117                 return message;
118         end
119
120         if not self.nonce then
121                 self.nonce = generate_uuid();
122                 self.step = 0;
123                 self.nonce_count = {};
124         end
125
126         self.step = self.step + 1;
127         if (self.step == 1) then
128                 local challenge = serialize({   nonce = self.nonce,
129                                                                                 qop = "auth",
130                                                                                 charset = "utf-8",
131                                                                                 algorithm = "md5-sess",
132                                                                                 realm = self.realm});
133                 return "challenge", challenge;
134         elseif (self.step == 2) then
135                 local response = parse(message);
136                 -- check for replay attack
137                 if response["nc"] then
138                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
139                 end
140
141                 -- check for username, it's REQUIRED by RFC 2831
142                 if not response["username"] then
143                         return "failure", "malformed-request";
144                 end
145                 self["username"] = response["username"];
146
147                 -- check for nonce, ...
148                 if not response["nonce"] then
149                         return "failure", "malformed-request";
150                 else
151                         -- check if it's the right nonce
152                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
153                 end
154
155                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
156                 if not response["qop"] then response["qop"] = "auth" end
157
158                 if response["realm"] == nil or response["realm"] == "" then
159                         response["realm"] = "";
160                 elseif response["realm"] ~= self.realm then
161                         return "failure", "not-authorized", "Incorrect realm value";
162                 end
163
164                 local decoder;
165                 if response["charset"] == nil then
166                         decoder = utf8tolatin1ifpossible;
167                 elseif response["charset"] ~= "utf-8" then
168                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
169                 end
170
171                 local domain = "";
172                 local protocol = "";
173                 if response["digest-uri"] then
174                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
175                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
176                 else
177                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
178                 end
179
180                 --TODO maybe realm support
181                 self.username = response["username"];
182                 local Y, state;
183                 if self.profile.plain then
184                         local password, state = self.profile.plain(response["username"], self.realm)
185                         if state == nil then return "failure", "not-authorized"
186                         elseif state == false then return "failure", "account-disabled" end
187                         Y = md5(response["username"]..":"..response["realm"]..":"..password);
188                 elseif self.profile["digest-md5"] then
189                         Y, state = self.profile["digest-md5"](response["username"], self.realm, response["realm"], response["charset"])
190                         if state == nil then return "failure", "not-authorized"
191                         elseif state == false then return "failure", "account-disabled" end
192                 elseif self.profile["digest-md5-test"] then
193                         -- TODO
194                 end
195                 --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
196                 --if Y == nil then return "failure", "not-authorized"
197                 --elseif Y == false then return "failure", "account-disabled" end
198                 local A1 = "";
199                 if response.authzid then
200                         if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
201                                 -- COMPAT
202                                 log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
203                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
204                         else
205                                 return "failure", "invalid-authzid";
206                         end
207                 else
208                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
209                 end
210                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
211
212                 local HA1 = md5(A1, true);
213                 local HA2 = md5(A2, true);
214
215                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
216                 local response_value = md5(KD, true);
217
218                 if response_value == response["response"] then
219                         -- calculate rspauth
220                         A2 = ":"..protocol.."/"..domain;
221
222                         HA1 = md5(A1, true);
223                         HA2 = md5(A2, true);
224
225                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
226                         local rspauth = md5(KD, true);
227                         self.authenticated = true;
228                         --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
229                         return "challenge", serialize({rspauth = rspauth});
230                 else
231                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
232                 end
233         elseif self.step == 3 then
234                 if self.authenticated ~= nil then return "success"
235                 else return "failure", "malformed-request" end
236         end
237 end
238
239 function init(registerMechanism)
240         registerMechanism("DIGEST-MD5", {"plain"}, digest);
241 end
242
243 return _M;