util.sasl: Adding clean_clone() method.
[prosody.git] / util / sasl.lua
1 -- sasl.lua v0.4
2 -- Copyright (C) 2008-2009 Tobias Markmann
3 --
4 --    All rights reserved.
5 --
6 --    Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 --
8 --        * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 --        * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
10 --        * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 --
12 --    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13
14
15 local md5 = require "util.hashes".md5;
16 local log = require "util.logger".init("sasl");
17 local st = require "util.stanza";
18 local set = require "util.set";
19 local array = require "util.array";
20 local to_unicode = require "util.encodings".idna.to_unicode;
21
22 local tostring = tostring;
23 local pairs, ipairs = pairs, ipairs;
24 local t_insert, t_concat = table.insert, table.concat;
25 local s_match = string.match;
26 local type = type
27 local error = error
28 local setmetatable = setmetatable;
29 local assert = assert;
30 local require = require;
31
32 require "util.iterators"
33 local keys = keys
34
35 local array = require "util.array"
36 module "sasl"
37
38 --[[
39 Authentication Backend Prototypes:
40
41 state = false : disabled
42 state = true : enabled
43 state = nil : non-existant
44
45 plain:
46         function(username, realm)
47                 return password, state;
48         end
49
50 plain-test:
51         function(username, realm, password)
52                 return true or false, state;
53         end
54
55 digest-md5:
56         function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
57                                                                                                 -- implementations it's not
58                 return digesthash, state;
59         end
60
61 digest-md5-test:
62         function(username, domain, realm, encoding, digesthash)
63                 return true or false, state;
64         end
65 ]]
66
67 local method = {};
68 method.__index = method;
69 local mechanisms = {};
70 local backend_mechanism = {};
71
72 -- register a new SASL mechanims
73 local function registerMechanism(name, backends, f)
74         assert(type(name) == "string", "Parameter name MUST be a string.");
75         assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
76         assert(type(f) == "function", "Parameter f MUST be a function.");
77         mechanisms[name] = f
78         for _, backend_name in ipairs(backends) do
79                 if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
80                 t_insert(backend_mechanism[backend_name], name);
81         end
82 end
83
84 -- create a new SASL object which can be used to authenticate clients
85 function new(realm, profile, forbidden)
86         sasl_i = {profile = profile};
87         sasl_i.realm = realm;
88         s = setmetatable(sasl_i, method);
89         s:forbidden(sasl_i, forbidden)
90         return s;
91 end
92
93 -- get a fresh clone with the same realm, profiles and forbidden mechanisms
94 function method:clean_clone()
95         return new(self.realm, self.profile, self:forbidden())
96 end
97
98 -- set the forbidden mechanisms
99 function method:forbidden( restrict )
100         if restrict then
101                 -- set forbidden
102                 self.restrict = set.new(restrict);
103         else
104                 -- get forbidden
105                 return array.collect(self.restrict:items());
106         end
107 end
108
109 -- get a list of possible SASL mechanims to use
110 function method:mechanisms()
111         local mechanisms = {}
112         for backend, f in pairs(self.profile) do
113                 if backend_mechanism[backend] then
114                         for _, mechanism in ipairs(backend_mechanism[backend]) do
115                                 if not sasl_i.restrict:contains(mechanism) then
116                                         mechanisms[mechanism] = true;
117                                 end
118                         end
119                 end
120         end
121         self["possible_mechanisms"] = mechanisms;
122         return array.collect(keys(mechanisms));
123 end
124
125 -- select a mechanism to use
126 function method:select(mechanism)
127         if self.mech_i then
128                 return false;
129         end
130         
131         self.mech_i = mechanisms[mechanism]
132         if self.mech_i == nil then 
133                 return false;
134         end
135         return true;
136 end
137
138 -- feed new messages to process into the library
139 function method:process(message)
140         --if message == "" or message == nil then return "failure", "malformed-request" end
141         return self.mech_i(self, message);
142 end
143
144 -- load the mechanisms
145 load_mechs = {"plain", "digest-md5", "anonymous", "scram"}
146 for _, mech in ipairs(load_mechs) do
147         local name = "util.sasl."..mech;
148         local m = require(name);
149         m.init(registerMechanism)
150 end
151
152 return _M;