moduleapi: Have modules internally store a reference to shared tables they use, to...
[prosody.git] / core / moduleapi.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2012 Matthew Wild
3 -- Copyright (C) 2008-2012 Waqas Hussain
4 -- 
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local config = require "core.configmanager";
10 local modulemanager = require "modulemanager";
11 local array = require "util.array";
12 local set = require "util.set";
13 local logger = require "util.logger";
14 local pluginloader = require "util.pluginloader";
15
16 local multitable_new = require "util.multitable".new;
17
18 local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
19 local error, setmetatable, setfenv, type = error, setmetatable, setfenv, type;
20 local ipairs, pairs, select, unpack = ipairs, pairs, select, unpack;
21 local tonumber, tostring = tonumber, tostring;
22
23 local prosody = prosody;
24 local hosts = prosody.hosts;
25 local core_post_stanza = prosody.core_post_stanza;
26
27 -- Registry of shared module data
28 local shared_data = setmetatable({}, { __mode = "v" });
29
30 local NULL = {};
31
32 local api = {};
33
34 -- Returns the name of the current module
35 function api:get_name()
36         return self.name;
37 end
38
39 -- Returns the host that the current module is serving
40 function api:get_host()
41         return self.host;
42 end
43
44 function api:get_host_type()
45         return hosts[self.host].type;
46 end
47
48 function api:set_global()
49         self.host = "*";
50         -- Update the logger
51         local _log = logger.init("mod_"..self.name);
52         self.log = function (self, ...) return _log(...); end;
53         self._log = _log;
54         self.global = true;
55 end
56
57 function api:add_feature(xmlns)
58         self:add_item("feature", xmlns);
59 end
60 function api:add_identity(category, type, name)
61         self:add_item("identity", {category = category, type = type, name = name});
62 end
63 function api:add_extension(data)
64         self:add_item("extension", data);
65 end
66
67 function api:fire_event(...)
68         return (hosts[self.host] or prosody).events.fire_event(...);
69 end
70
71 function api:hook_object_event(object, event, handler, priority)
72         self.event_handlers[handler] = { name = event, priority = priority, object = object };
73         return object.add_handler(event, handler, priority);
74 end
75
76 function api:hook(event, handler, priority)
77         return self:hook_object_event((hosts[self.host] or prosody).events, event, handler, priority);
78 end
79
80 function api:hook_global(event, handler, priority)
81         return self:hook_object_event(prosody.events, event, handler, priority);
82 end
83
84 function api:hook_stanza(xmlns, name, handler, priority)
85         if not handler and type(name) == "function" then
86                 -- If only 2 options then they specified no xmlns
87                 xmlns, name, handler, priority = nil, xmlns, name, handler;
88         elseif not (handler and name) then
89                 self:log("warn", "Error: Insufficient parameters to module:hook_stanza()");
90                 return;
91         end
92         return self:hook("stanza/"..(xmlns and (xmlns..":") or "")..name, function (data) return handler(data.origin, data.stanza, data); end, priority);
93 end
94
95 function api:require(lib)
96         local f, n = pluginloader.load_code(self.name, lib..".lib.lua");
97         if not f then
98                 f, n = pluginloader.load_code(lib, lib..".lib.lua");
99         end
100         if not f then error("Failed to load plugin library '"..lib.."', error: "..n); end -- FIXME better error message
101         setfenv(f, self.environment);
102         return f();
103 end
104
105 function api:depends(name)
106         if not self.dependencies then
107                 self.dependencies = {};
108                 self:hook("module-reloaded", function (event)
109                         if self.dependencies[event.module] then
110                                 self:log("info", "Auto-reloading due to reload of %s:%s", event.host, event.module);
111                                 modulemanager.reload(self.host, self.name);
112                                 return;
113                         end
114                 end);
115                 self:hook("module-unloaded", function (event)
116                         if self.dependencies[event.module] then
117                                 self:log("info", "Auto-unloading due to unload of %s:%s", event.host, event.module);
118                                 modulemanager.unload(self.host, self.name);
119                         end
120                 end);
121         end
122         local mod = modulemanager.get_module(self.host, name) or modulemanager.get_module("*", name);
123         if not mod then
124                 local err;
125                 mod, err = modulemanager.load(self.host, name);
126                 if not mod then
127                         return error(("Unable to load required module, mod_%s: %s"):format(name, ((err or "unknown error"):gsub("%-", " ")) ));
128                 end
129         end
130         self.dependencies[name] = true;
131         return mod;
132 end
133
134 -- Returns one or more shared tables at the specified virtual paths
135 -- Intentionally does not allow the table at a path to be _set_, it
136 -- is auto-created if it does not exist.
137 function api:shared(...)
138         if not self.shared_data then self.shared_data = {}; end
139         local paths = { n = select("#", ...), ... };
140         local data_array = {};
141         local default_path_components = { self.host, self.name };
142         for i = 1, paths.n do
143                 local path = paths[i];
144                 if path:sub(1,1) ~= "/" then -- Prepend default components
145                         local n_components = select(2, path:gsub("/", "%1"));
146                         path = (n_components<#default_path_components and "/" or "")..t_concat(default_path_components, "/", 1, #default_path_components-n_components).."/"..path;
147                 end
148                 local shared = shared_data[path];
149                 if not shared then
150                         shared = {};
151                         shared_data[path] = shared;
152                 end
153                 t_insert(data_array, shared);
154                 self.shared_data[path] = shared;
155         end
156         return unpack(data_array);
157 end
158
159 function api:get_option(name, default_value)
160         local value = config.get(self.host, self.name, name);
161         if value == nil then
162                 value = config.get(self.host, "core", name);
163                 if value == nil then
164                         value = default_value;
165                 end
166         end
167         return value;
168 end
169
170 function api:get_option_string(name, default_value)
171         local value = self:get_option(name, default_value);
172         if type(value) == "table" then
173                 if #value > 1 then
174                         self:log("error", "Config option '%s' does not take a list, using just the first item", name);
175                 end
176                 value = value[1];
177         end
178         if value == nil then
179                 return nil;
180         end
181         return tostring(value);
182 end
183
184 function api:get_option_number(name, ...)
185         local value = self:get_option(name, ...);
186         if type(value) == "table" then
187                 if #value > 1 then
188                         self:log("error", "Config option '%s' does not take a list, using just the first item", name);
189                 end
190                 value = value[1];
191         end
192         local ret = tonumber(value);
193         if value ~= nil and ret == nil then
194                 self:log("error", "Config option '%s' not understood, expecting a number", name);
195         end
196         return ret;
197 end
198
199 function api:get_option_boolean(name, ...)
200         local value = self:get_option(name, ...);
201         if type(value) == "table" then
202                 if #value > 1 then
203                         self:log("error", "Config option '%s' does not take a list, using just the first item", name);
204                 end
205                 value = value[1];
206         end
207         if value == nil then
208                 return nil;
209         end
210         local ret = value == true or value == "true" or value == 1 or nil;
211         if ret == nil then
212                 ret = (value == false or value == "false" or value == 0);
213                 if ret then
214                         ret = false;
215                 else
216                         ret = nil;
217                 end
218         end
219         if ret == nil then
220                 self:log("error", "Config option '%s' not understood, expecting true/false", name);
221         end
222         return ret;
223 end
224
225 function api:get_option_array(name, ...)
226         local value = self:get_option(name, ...);
227
228         if value == nil then
229                 return nil;
230         end
231         
232         if type(value) ~= "table" then
233                 return array{ value }; -- Assume any non-list is a single-item list
234         end
235         
236         return array():append(value); -- Clone
237 end
238
239 function api:get_option_set(name, ...)
240         local value = self:get_option_array(name, ...);
241         
242         if value == nil then
243                 return nil;
244         end
245         
246         return set.new(value);
247 end
248
249 local module_items = multitable_new();
250 function api:add_item(key, value)
251         self.items = self.items or {};
252         self.items[key] = self.items[key] or {};
253         t_insert(self.items[key], value);
254         self:fire_event("item-added/"..key, {source = self, item = value});
255 end
256 function api:remove_item(key, value)
257         local t = self.items and self.items[key] or NULL;
258         for i = #t,1,-1 do
259                 if t[i] == value then
260                         t_remove(self.items[key], i);
261                         self:fire_event("item-removed/"..key, {source = self, item = value});
262                         return value;
263                 end
264         end
265 end
266
267 function api:get_host_items(key)
268         local result = {};
269         for mod_name, module in pairs(modulemanager.get_modules(self.host)) do
270                 module = module.module;
271                 if module.items then
272                         for _, item in ipairs(module.items[key] or NULL) do
273                                 t_insert(result, item);
274                         end
275                 end
276         end
277         for mod_name, module in pairs(modulemanager.get_modules("*")) do
278                 module = module.module;
279                 if module.items then
280                         for _, item in ipairs(module.items[key] or NULL) do
281                                 t_insert(result, item);
282                         end
283                 end
284         end
285         return result;
286 end
287
288 function api:handle_items(type, added_cb, removed_cb, existing)
289         self:hook("item-added/"..type, added_cb);
290         self:hook("item-removed/"..type, removed_cb);
291         if existing ~= false then
292                 for _, item in ipairs(self:get_host_items(type)) do
293                         added_cb({ item = item });
294                 end
295         end
296 end
297
298 function api:provides(name, item)
299         if not item then item = self.environment; end
300         if not item.name then
301                 local item_name = module.name;
302                 -- Strip a provider prefix to find the item name
303                 -- (e.g. "auth_foo" -> "foo" for an auth provider)
304                 if item_name:find(name.."_", 1, true) == 1 then
305                         item_name = item_name:sub(#name+2);
306                 end
307                 item.name = item_name;
308         end
309         self:add_item(name, item);
310 end
311
312 function api:send(stanza)
313         return core_post_stanza(hosts[self.host], stanza);
314 end
315
316 return api;