7a0e47b85b5d831f3785eb38b36750ae88b72d0c
[prosody.git] / util / sasl.lua
1
2 local base64 = require "base64"
3 local md5 = require "md5"
4 local crypto = require "crypto"
5 local log = require "util.logger".init("sasl");
6 local tostring = tostring;
7 local st = require "util.stanza";
8 local generate_uuid = require "util.uuid".generate;
9 local s_match = string.match;
10 local gmatch = string.gmatch
11 local math = require "math"
12 local type = type
13 local error = error
14 local print = print
15
16 module "sasl"
17
18 local function new_plain(onAuth, onSuccess, onFail, onWrite)
19         local object = { mechanism = "PLAIN", onAuth = onAuth, onSuccess = onSuccess, onFail = onFail,
20                                         onWrite = onWrite}
21         local challenge = base64.encode("");
22         --onWrite(st.stanza("challenge", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):text(challenge))
23         object.feed =   function(self, stanza)
24                                                 if stanza.name ~= "response" and stanza.name ~= "auth" then self.onFail("invalid-stanza-tag") end
25                                                 if stanza.attr.xmlns ~= "urn:ietf:params:xml:ns:xmpp-sasl" then self.onFail("invalid-stanza-namespace") end
26                                                 local response = base64.decode(stanza[1])
27                                                 local authorization = s_match(response, "([^&%z]+)")
28                                                 local authentication = s_match(response, "%z([^&%z]+)%z")
29                                                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
30                                                 if self.onAuth(authentication, password) == true then
31                                                         self.onWrite(st.stanza("success", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}))
32                                                         self.onSuccess(authentication)
33                                                 else
34                                                         self.onWrite(st.stanza("failure", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):tag("temporary-auth-failure"));
35                                                 end
36                                         end
37         return object
38 end
39
40
41 --[[
42 SERVER:
43 nonce="3145176401",qop="auth",charset=utf-8,algorithm=md5-sess
44
45 CLIENT: username="tobiasfar",nonce="3145176401",cnonce="pJiW7hzeZLvOSAf7gBzwTzLWe4obYOVDlnNESzQCzGg=",nc=00000001,digest-uri="xmpp/jabber.org",qop=auth,response=99a93ba75235136e6403c3a2ba37089d,charset=utf-8 
46
47 username="tobias",nonce="4406697386",cnonce="wUnT7vYrOB0V8D/lKd5bhpaNCk+hLJwc8T4CBCqp7WM=",nc=00000001,digest-uri="xmpp/luaetta.ath.cx",qop=auth,response=d202b8a1bdf8204816fb23c5f87b6b63,charset=utf-8
48
49 SERVER:
50 rspauth=ab66d28c260e97da577ce3aac46a8991
51 ]]--
52 local function new_digest_md5(onAuth, onSuccess, onFail, onWrite)
53         local function H(s)
54                 return md5.sum(s)
55         end
56         
57         local function KD(k, s)
58                 return H(k..":"..s)
59         end
60         
61         local function HEX(n)
62                 return md5.sumhexa(n)
63         end
64
65         local function HMAC(k, s)
66                 return crypto.hmac.digest("md5", s, k, true)
67         end
68
69         local function serialize(message)
70                 local data = ""
71                 
72                 if type(message) ~= "table" then error("serialize needs an argument of type table.") end
73                 
74                 -- testing all possible values
75                 if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
76                 if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
77                 if message["charset"] then data = data..[[charset=]]..message.charset.."," end
78                 if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
79                 if message["rspauth"] then data = data..[[rspauth=]]..message.algorith.."," end
80                 data = data:gsub(",$", "")
81                 return data
82         end
83         
84         local function parse(data)
85                 message = {}
86                 for k, v in gmatch(data, [[([%w%-]+)="?([%w%-%/%.]+)"?,?]]) do
87                         message[k] = v
88                 end
89                 return message
90         end
91
92         local object = { mechanism = "DIGEST-MD5", onAuth = onAuth, onSuccess = onSuccess, onFail = onFail,
93                                         onWrite = onWrite }
94         
95         --TODO: something better than math.random would be nice, maybe OpenSSL's random number generator
96         object.nonce = math.random(0, 9)
97         for i = 1, 9 do object.nonce = object.nonce..math.random(0, 9) end
98         object.step = 1
99         object.nonce_count = {}
100         local challenge = base64.encode(serialize({     nonce = object.nonce, 
101                                                                                                 qop = "auth",
102                                                                                                 charset = "utf-8",
103                                                                                                 algorithm = "md5-sess"} ));
104         object.onWrite(st.stanza("challenge", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):text(challenge))
105         object.feed =   function(self, stanza)
106                                                 if stanza.name ~= "response" and stanza.name ~= "auth" then self.onFail("invalid-stanza-tag") end
107                                                 if stanza.attr.xmlns ~= "urn:ietf:params:xml:ns:xmpp-sasl" then self.onFail("invalid-stanza-namespace") end
108                                                 if stanza.name == "auth" then return end
109                                                 self.step = self.step + 1
110                                                 if (self.step == 2) then
111                                                         local response = parse(base64.decode(stanza[1]))
112                                                         -- check for replay attack
113                                                         if response["nonce-count"] then
114                                                                 if self.nonce_count[response["nonce-count"]] then self.onFail("not-authorized") end
115                                                         end
116                                                         
117                                                         -- check for username, it's REQUIRED by RFC 2831
118                                                         if not response["username"] then
119                                                                 self.onFail("malformed-request")
120                                                         end
121                                                         
122                                                         -- check for nonce, ...
123                                                         if not response["nonce"] then
124                                                                 self.onFail("malformed-request")
125                                                         else
126                                                                 -- check if it's the right nonce
127                                                                 if response["nonce"] ~= self.nonce then self.onFail("malformed-request") end
128                                                         end
129                                                         
130                                                         if not response["cnonce"] then self.onFail("malformed-request") end
131                                                         if not response["qop"] then response["qop"] = "auth" end
132                                                         
133                                                         local hostname = ""
134                                                         local protocol = ""
135                                                         if response["digest-uri"] then
136                                                                 protocol, hostname = response["digest-uri"]:match("(%w+)/(.*)$")
137                                                         else
138                                                                 error("No digest-uri")
139                                                         end
140                                                                                                                 
141                                                         -- compare response_value with own calculation
142                                                         local A1-- = H(response["username"]..":"..realm-value, ":", passwd } ),
143                                                                 --   ":", nonce-value, ":", cnonce-value)
144                                                         local A2
145                                                         
146                                                         --local response_value = HEX(KD(HEX(H(A1)), response["nonce"]..":"..response["nonce-count"]..":"..response["cnonce-value"]..":"..response["qop"]..":"..HEX(H(A2))))
147                                                         
148                                                         if response["qop"] == "auth" then
149                                                         
150                                                         else
151                                                         
152                                                         end
153                                                         
154                                                         --local response_value = HEX(KD(HEX(H(A1)), response["nonce"]..":"..response["nonce-count"]..":"..response["cnonce-value"]..":"..response["qop"]..":"..HEX(H(A2))))
155                                                         
156                                                 end
157                                                 --[[
158                                                 local authorization = s_match(response, "([^&%z]+)")
159                                                 local authentication = s_match(response, "%z([^&%z]+)%z")
160                                                 local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
161                                                 if self.onAuth(authentication, password) == true then
162                                                         self.onWrite(st.stanza("success", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}))
163                                                         self.onSuccess(authentication)
164                                                 else
165                                                         self.onWrite(st.stanza("failure", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):tag("temporary-auth-failure"));
166                                                 end]]--
167                                         end
168         return object
169 end
170
171 function new(mechanism, onAuth, onSuccess, onFail, onWrite)
172         local object
173         if mechanism == "PLAIN" then object = new_plain(onAuth, onSuccess, onFail, onWrite)
174         elseif mechanism == "DIGEST-MD5" then object = new_digest_md5(onAuth, onSuccess, onFail, onWrite)
175         else
176                 log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
177                 onFail("unsupported-mechanism")
178         end
179         return object
180 end
181
182 return _M;