Merge with trunk
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local st = require "util.stanza";
18 local set = require "util.set";
19 local array = require "util.array";
20 local to_unicode = require "util.encodings".idna.to_unicode;
21
22 local tostring = tostring;
23 local pairs, ipairs = pairs, ipairs;
24 local t_insert, t_concat = table.insert, table.concat;
25 local s_match = string.match;
26 local type = type
27 local error = error
28 local setmetatable = setmetatable;
29 local assert = assert;
30 local require = require;
31
32 require "util.iterators"
33 local keys = keys
34
35 local array = require "util.array"
36 module "sasl"
37
38 --[[
39 Authentication Backend Prototypes:
40
41 state = false : disabled
42 state = true : enabled
43 state = nil : non-existant
44
45 plain:
46         function(username, realm)
47                 return password, state;
48         end
49
50 plain-test:
51         function(username, realm, password)
52                 return true or false, state;
53         end
54
55 digest-md5:
56         function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
57                                                                                                 -- implementations it's not
58                 return digesthash, state;
59         end
60
61 digest-md5-test:
62         function(username, domain, realm, encoding, digesthash)
63                 return true or false, state;
64         end
65 ]]
66
67 local method = {};
68 method.__index = method;
69 local mechanisms = {};
70 local backend_mechanism = {};
71
72 -- register a new SASL mechanims
73 local function registerMechanism(name, backends, f)
74         assert(type(name) == "string", "Parameter name MUST be a string.");
75         assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
76         assert(type(f) == "function", "Parameter f MUST be a function.");
77         mechanisms[name] = f
78         for _, backend_name in ipairs(backends) do
79                 if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
80                 t_insert(backend_mechanism[backend_name], name);
81         end
82 end
83
84 -- create a new SASL object which can be used to authenticate clients
85 function new(realm, profile, forbidden)
86         sasl_i = {profile = profile};
87         sasl_i.realm = realm;
88         s = setmetatable(sasl_i, method);
89         s:forbidden(sasl_i, forbidden)
90         return s;
91 end
92
93 -- set the forbidden mechanisms
94 function method:forbidden( restrict )
95         if restrict then
96                 -- set forbidden
97                 self.restrict = set.new(restrict);
98         else
99                 -- get forbidden
100                 return array.collect(self.restrict:items());
101         end
102 end
103
104 -- get a list of possible SASL mechanims to use
105 function method:mechanisms()
106         local mechanisms = {}
107         for backend, f in pairs(self.profile) do
108                 if backend_mechanism[backend] then
109                         for _, mechanism in ipairs(backend_mechanism[backend]) do
110                                 if not sasl_i.restrict:contains(mechanism) then
111                                         mechanisms[mechanism] = true;
112                                 end
113                         end
114                 end
115         end
116         self["possible_mechanisms"] = mechanisms;
117         return array.collect(keys(mechanisms));
118 end
119
120 -- select a mechanism to use
121 function method:select(mechanism)
122         if self.mech_i then
123                 return false;
124         end
125         
126         self.mech_i = mechanisms[mechanism]
127         if self.mech_i == nil then 
128                 return false;
129         end
130         return true;
131 end
132
133 -- feed new messages to process into the library
134 function method:process(message)
135         --if message == "" or message == nil then return "failure", "malformed-request" end
136         return self.mech_i(self, message);
137 end
138
139 -- load the mechanisms
140 load_mechs = {"plain", "digest-md5", "anonymous", "scram"}
141 for _, mech in ipairs(load_mechs) do
142         local name = "util.sasl."..mech;
143         local m = require(name);
144         m.init(registerMechanism)
145 end
146
147 return _M;