Merge 0.9->0.10
[prosody.git] / util / set.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local ipairs, pairs, setmetatable, next, tostring =
10       ipairs, pairs, setmetatable, next, tostring;
11 local t_concat = table.concat;
12
13 module "set"
14
15 local set_mt = {};
16 function set_mt.__call(set, _, k)
17         return next(set._items, k);
18 end
19 function set_mt.__add(set1, set2)
20         return _M.union(set1, set2);
21 end
22 function set_mt.__sub(set1, set2)
23         return _M.difference(set1, set2);
24 end
25 function set_mt.__div(set, func)
26         local new_set = _M.new();
27         local items, new_items = set._items, new_set._items;
28         for item in pairs(items) do
29                 local new_item = func(item);
30                 if new_item ~= nil then
31                         new_items[new_item] = true;
32                 end
33         end
34         return new_set;
35 end
36 function set_mt.__eq(set1, set2)
37         local set1, set2 = set1._items, set2._items;
38         for item in pairs(set1) do
39                 if not set2[item] then
40                         return false;
41                 end
42         end
43
44         for item in pairs(set2) do
45                 if not set1[item] then
46                         return false;
47                 end
48         end
49
50         return true;
51 end
52 function set_mt.__tostring(set)
53         local s, items = { }, set._items;
54         for item in pairs(items) do
55                 s[#s+1] = tostring(item);
56         end
57         return t_concat(s, ", ");
58 end
59
60 local items_mt = {};
61 function items_mt.__call(items, _, k)
62         return next(items, k);
63 end
64
65 function new(list)
66         local items = setmetatable({}, items_mt);
67         local set = { _items = items };
68
69         function set:add(item)
70                 items[item] = true;
71         end
72
73         function set:contains(item)
74                 return items[item];
75         end
76
77         function set:items()
78                 return next, items;
79         end
80
81         function set:remove(item)
82                 items[item] = nil;
83         end
84
85         function set:add_list(list)
86                 if list then
87                         for _, item in ipairs(list) do
88                                 items[item] = true;
89                         end
90                 end
91         end
92
93         function set:include(otherset)
94                 for item in otherset do
95                         items[item] = true;
96                 end
97         end
98
99         function set:exclude(otherset)
100                 for item in otherset do
101                         items[item] = nil;
102                 end
103         end
104
105         function set:empty()
106                 return not next(items);
107         end
108
109         if list then
110                 set:add_list(list);
111         end
112
113         return setmetatable(set, set_mt);
114 end
115
116 function union(set1, set2)
117         local set = new();
118         local items = set._items;
119
120         for item in pairs(set1._items) do
121                 items[item] = true;
122         end
123
124         for item in pairs(set2._items) do
125                 items[item] = true;
126         end
127
128         return set;
129 end
130
131 function difference(set1, set2)
132         local set = new();
133         local items = set._items;
134
135         for item in pairs(set1._items) do
136                 items[item] = (not set2._items[item]) or nil;
137         end
138
139         return set;
140 end
141
142 function intersection(set1, set2)
143         local set = new();
144         local items = set._items;
145
146         set1, set2 = set1._items, set2._items;
147
148         for item in pairs(set1) do
149                 items[item] = (not not set2[item]) or nil;
150         end
151
152         return set;
153 end
154
155 function xor(set1, set2)
156         return union(set1, set2) - intersection(set1, set2);
157 end
158
159 return _M;