util.sasl.plain: Adding plain_hashed authentication backend support.
[prosody.git] / util / sasl / digest-md5.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14 local tostring = tostring;
15 local type = type;
16
17 local s_gmatch = string.gmatch;
18 local s_match = string.match;
19 local t_concat = table.concat;
20 local t_insert = table.insert;
21 local to_byte, to_char = string.byte, string.char;
22
23 local md5 = require "util.hashes".md5;
24 local log = require "util.logger".init("sasl");
25 local generate_uuid = require "util.uuid".generate;
26
27 module "digest-md5"
28
29 --=========================
30 --SASL DIGEST-MD5 according to RFC 2831
31
32 --[[
33 Supported Authentication Backends
34
35 digest-md5:
36         function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
37                                                                                                 -- implementations it's not
38                 return digesthash, state;
39         end
40
41 digest-md5-test:
42         function(username, domain, realm, encoding, digesthash)
43                 return true or false, state;
44         end
45 ]]
46
47 local function digest(self, message)
48         --TODO complete support for authzid
49
50         local function serialize(message)
51                 local data = ""
52
53                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
54
55                 -- testing all possible values
56                 if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
57                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
58                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
59                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
60                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
61                 if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
62                 data = data:gsub(",$", "")
63                 return data
64         end
65
66         local function utf8tolatin1ifpossible(passwd)
67                 local i = 1;
68                 while i <= #passwd do
69                         local passwd_i = to_byte(passwd:sub(i, i));
70                         if passwd_i > 0x7F then
71                                 if passwd_i < 0xC0 or passwd_i > 0xC3 then
72                                         return passwd;
73                                 end
74                                 i = i + 1;
75                                 passwd_i = to_byte(passwd:sub(i, i));
76                                 if passwd_i < 0x80 or passwd_i > 0xBF then
77                                         return passwd;
78                                 end
79                         end
80                         i = i + 1;
81                 end
82
83                 local p = {};
84                 local j = 0;
85                 i = 1;
86                 while (i <= #passwd) do
87                         local passwd_i = to_byte(passwd:sub(i, i));
88                         if passwd_i > 0x7F then
89                                 i = i + 1;
90                                 local passwd_i_1 = to_byte(passwd:sub(i, i));
91                                 t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
92                         else
93                                 t_insert(p, to_char(passwd_i));
94                         end
95                         i = i + 1;
96                 end
97                 return t_concat(p);
98         end
99         local function latin1toutf8(str)
100                 local p = {};
101                 for ch in s_gmatch(str, ".") do
102                         ch = to_byte(ch);
103                         if (ch < 0x80) then
104                                 t_insert(p, to_char(ch));
105                         elseif (ch < 0xC0) then
106                                 t_insert(p, to_char(0xC2, ch));
107                         else
108                                 t_insert(p, to_char(0xC3, ch - 64));
109                         end
110                 end
111                 return t_concat(p);
112         end
113         local function parse(data)
114                 local message = {}
115                 -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
116                 for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
117                         message[k] = v;
118                 end
119                 return message;
120         end
121
122         if not self.nonce then
123                 self.nonce = generate_uuid();
124                 self.step = 0;
125                 self.nonce_count = {};
126         end
127
128         self.step = self.step + 1;
129         if (self.step == 1) then
130                 local challenge = serialize({   nonce = self.nonce,
131                                                                                 qop = "auth",
132                                                                                 charset = "utf-8",
133                                                                                 algorithm = "md5-sess",
134                                                                                 realm = self.realm});
135                 return "challenge", challenge;
136         elseif (self.step == 2) then
137                 local response = parse(message);
138                 -- check for replay attack
139                 if response["nc"] then
140                         if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
141                 end
142
143                 -- check for username, it's REQUIRED by RFC 2831
144                 if not response["username"] then
145                         return "failure", "malformed-request";
146                 end
147                 self["username"] = response["username"];
148
149                 -- check for nonce, ...
150                 if not response["nonce"] then
151                         return "failure", "malformed-request";
152                 else
153                         -- check if it's the right nonce
154                         if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
155                 end
156
157                 if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
158                 if not response["qop"] then response["qop"] = "auth" end
159
160                 if response["realm"] == nil or response["realm"] == "" then
161                         response["realm"] = "";
162                 elseif response["realm"] ~= self.realm then
163                         return "failure", "not-authorized", "Incorrect realm value";
164                 end
165
166                 local decoder;
167                 if response["charset"] == nil then
168                         decoder = utf8tolatin1ifpossible;
169                 elseif response["charset"] ~= "utf-8" then
170                         return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
171                 end
172
173                 local domain = "";
174                 local protocol = "";
175                 if response["digest-uri"] then
176                         protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
177                         if protocol == nil or domain == nil then return "failure", "malformed-request" end
178                 else
179                         return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
180                 end
181
182                 --TODO maybe realm support
183                 self.username = response["username"];
184                 local Y, state;
185                 if self.profile.plain then
186                         local password, state = self.profile.plain(response["username"], self.realm)
187                         if state == nil then return "failure", "not-authorized"
188                         elseif state == false then return "failure", "account-disabled" end
189                         Y = md5(response["username"]..":"..response["realm"]..":"..password);
190                 elseif self.profile["digest-md5"] then
191                         Y, state = self.profile["digest-md5"](response["username"], self.realm, response["realm"], response["charset"])
192                         if state == nil then return "failure", "not-authorized"
193                         elseif state == false then return "failure", "account-disabled" end
194                 elseif self.profile["digest-md5-test"] then
195                         -- TODO
196                 end
197                 --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
198                 --if Y == nil then return "failure", "not-authorized"
199                 --elseif Y == false then return "failure", "account-disabled" end
200                 local A1 = "";
201                 if response.authzid then
202                         if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
203                                 -- COMPAT
204                                 log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
205                                 A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
206                         else
207                                 return "failure", "invalid-authzid";
208                         end
209                 else
210                         A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
211                 end
212                 local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
213
214                 local HA1 = md5(A1, true);
215                 local HA2 = md5(A2, true);
216
217                 local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
218                 local response_value = md5(KD, true);
219
220                 if response_value == response["response"] then
221                         -- calculate rspauth
222                         A2 = ":"..protocol.."/"..domain;
223
224                         HA1 = md5(A1, true);
225                         HA2 = md5(A2, true);
226
227                         KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
228                         local rspauth = md5(KD, true);
229                         self.authenticated = true;
230                         --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
231                         return "challenge", serialize({rspauth = rspauth});
232                 else
233                         return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
234                 end
235         elseif self.step == 3 then
236                 if self.authenticated ~= nil then return "success"
237                 else return "failure", "malformed-request" end
238         end
239 end
240
241 function init(registerMechanism)
242         registerMechanism("DIGEST-MD5", {"plain"}, digest);
243 end
244
245 return _M;