1429a5c63e7b19d7dbc57dcca980340a55ad36cf
[prosody.git] / util / sasl / digest-md5.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local tostring = tostring;
15 local type = type;
16
17 local s_gmatch = string.gmatch;
18 local s_match = string.match;
19 local t_concat = table.concat;
20 local t_insert = table.insert;
21 local to_byte, to_char = string.byte, string.char;
22
23 local md5 = require "util.hashes".md5;
24 local log = require "util.logger".init("sasl");
25 local generate_uuid = require "util.uuid".generate;
26
27 module "digest-md5"
28
29 --=========================
30 --SASL DIGEST-MD5 according to RFC 2831
31 local function digest_response()
32         
33         return response, A1, A2
34 end
35
36 local function digest(self, message)
37         --TODO complete support for authzid
38
39         local function serialize(message)
40                 local data = ""
41
42                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
43
44                 -- testing all possible values
45                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
46                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
47                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
48                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
49                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
50                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
51                 data = data:gsub(",$", "")
52                 return data
53         end
54
55         local function utf8tolatin1ifpossible(passwd)
56                 local i = 1;
57                 while i <= #passwd do
58                         local passwd_i = to_byte(passwd:sub(i, i));
59                         if passwd_i > 0x7F then
60                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
61                                         return passwd;
62                                 end
63                                 i = i + 1;
64                                 passwd_i = to_byte(passwd:sub(i, i));
65                                 if passwd_i < 0x80 or passwd_i > 0xBF then
66                                         return passwd;
67                                 end
68                         end
69                         i = i + 1;
70                 end
71
72                 local p = {};
73                 local j = 0;
74                 i = 1;
75                 while (i <= #passwd) do
76                         local passwd_i = to_byte(passwd:sub(i, i));
77                         if passwd_i > 0x7F then
78                                 i = i + 1;
79                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
80                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
81                         else
82                                 t_insert(p, to_char(passwd_i));
83                         end
84                         i = i + 1;
85                 end
86                 return t_concat(p);
87         end
88         local function latin1toutf8(str)
89                 local p = {};
90                 for ch in s_gmatch(str, ".") do
91                         ch = to_byte(ch);
92                         if (ch < 0x80) then
93                                 t_insert(p, to_char(ch));
94                         elseif (ch < 0xC0) then
95                                 t_insert(p, to_char(0xC2, ch));
96                         else
97                                 t_insert(p, to_char(0xC3, ch - 64));
98                         end
99                 end
100                 return t_concat(p);
101         end
102         local function parse(data)
103                 local message = {}
104                 for k, v in s_gmatch(data, [[([%w%-]+)="?([^",]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
105                         message[k] = v;
106                 end
107                 return message;
108         end
109
110         if not self.nonce then
111                 self.nonce = generate_uuid();
112                 self.step = 0;
113                 self.nonce_count = {};
114         end
115
116         self.step = self.step + 1;
117         if (self.step == 1) then
118                 local challenge = serialize({   nonce = self.nonce,
119                                                                                 qop = "auth",
120                                                                                 charset = "utf-8",
121                                                                                 algorithm = "md5-sess",
122                                                                                 realm = self.realm});
123                 return "challenge", challenge;
124         elseif (self.step == 2) then
125                 local response = parse(message);
126                 -- check for replay attack
127                 if response["nc"] then
128                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
129                 end
130
131                 -- check for username, it's REQUIRED by RFC 2831
132                 if not response["username"] then
133                         return "failure", "malformed-request";
134                 end
135                 self["username"] = response["username"];
136
137                 -- check for nonce, ...
138                 if not response["nonce"] then
139                         return "failure", "malformed-request";
140                 else
141                         -- check if it's the right nonce
142                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
143                 end
144
145                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
146                 if not response["qop"] then response["qop"] = "auth" end
147
148                 if response["realm"] == nil or response["realm"] == "" then
149                         response["realm"] = "";
150                 elseif response["realm"] ~= self.realm then
151                         return "failure", "not-authorized", "Incorrect realm value";
152                 end
153
154                 local decoder;
155                 if response["charset"] == nil then
156                         decoder = utf8tolatin1ifpossible;
157                 elseif response["charset"] ~= "utf-8" then
158                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
159                 end
160
161                 local domain = "";
162                 local protocol = "";
163                 if response["digest-uri"] then
164                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
165                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
166                 else
167                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
168                 end
169
170                 --TODO maybe realm support
171                 self.username = response["username"];
172                 if self.profile.plain then
173                         local password, state = self.profile.plain(response["username"], self.realm)
174                         if state == nil then return "failure", "not-authorized"
175                         elseif state == false then return "failure", "account-disabled" end
176                         Y = md5(response["username"]..":"..response["realm"]..":"..password);
177                 elseif self.profile["digest-md5"] then
178                         local Y, state = self.profile["digest-md5"](response["username"], self.realm, response["realm"], response["charset"])
179                         if state == nil then return "failure", "not-authorized"
180                         elseif state == false then return "failure", "account-disabled" end
181                 elseif self.profile["digest-md5-test"] then
182                         -- TODO
183                 end
184                 --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
185                 --if Y == nil then return "failure", "not-authorized"
186                 --elseif Y == false then return "failure", "account-disabled" end
187                 local A1 = "";
188                 if response.authzid then
189                         if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
190                                 -- COMPAT
191                                 log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
192                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
193                         else
194                                 return "failure", "invalid-authzid";
195                         end
196                 else
197                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
198                 end
199                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
200
201                 local HA1 = md5(A1, true);
202                 local HA2 = md5(A2, true);
203
204                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
205                 local response_value = md5(KD, true);
206
207                 if response_value == response["response"] then
208                         -- calculate rspauth
209                         A2 = ":"..protocol.."/"..domain;
210
211                         HA1 = md5(A1, true);
212                         HA2 = md5(A2, true);
213
214                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
215                         local rspauth = md5(KD, true);
216                         self.authenticated = true;
217                         --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
218                         return "challenge", serialize({rspauth = rspauth});
219                 else
220                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
221                 end
222         elseif self.step == 3 then
223                 if self.authenticated ~= nil then return "success"
224                 else return "failure", "malformed-request" end
225         end
226 end
227
228 function init(registerMechanism)
229         registerMechanism("DIGEST-MD5", {"plain"}, digest);
230 end
231
232 return _M;