Merge 0.10->trunk
[prosody.git] / util / set.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local ipairs, pairs, setmetatable, next, tostring =
10       ipairs, pairs, setmetatable, next, tostring;
11 local t_concat = table.concat;
12
13 local _ENV = nil;
14
15 local set_mt = {};
16 function set_mt.__call(set, _, k)
17         return next(set._items, k);
18 end
19
20 local items_mt = {};
21 function items_mt.__call(items, _, k)
22         return next(items, k);
23 end
24
25 local function new(list)
26         local items = setmetatable({}, items_mt);
27         local set = { _items = items };
28
29         -- We access the set through an upvalue in these methods, so ignore 'self' being unused
30         --luacheck: ignore 212/self
31
32         function set:add(item)
33                 items[item] = true;
34         end
35
36         function set:contains(item)
37                 return items[item];
38         end
39
40         function set:items()
41                 return next, items;
42         end
43
44         function set:remove(item)
45                 items[item] = nil;
46         end
47
48         function set:add_list(item_list)
49                 if item_list then
50                         for _, item in ipairs(item_list) do
51                                 items[item] = true;
52                         end
53                 end
54         end
55
56         function set:include(otherset)
57                 for item in otherset do
58                         items[item] = true;
59                 end
60         end
61
62         function set:exclude(otherset)
63                 for item in otherset do
64                         items[item] = nil;
65                 end
66         end
67
68         function set:empty()
69                 return not next(items);
70         end
71
72         if list then
73                 set:add_list(list);
74         end
75
76         return setmetatable(set, set_mt);
77 end
78
79 local function union(set1, set2)
80         local set = new();
81         local items = set._items;
82
83         for item in pairs(set1._items) do
84                 items[item] = true;
85         end
86
87         for item in pairs(set2._items) do
88                 items[item] = true;
89         end
90
91         return set;
92 end
93
94 local function difference(set1, set2)
95         local set = new();
96         local items = set._items;
97
98         for item in pairs(set1._items) do
99                 items[item] = (not set2._items[item]) or nil;
100         end
101
102         return set;
103 end
104
105 local function intersection(set1, set2)
106         local set = new();
107         local items = set._items;
108
109         set1, set2 = set1._items, set2._items;
110
111         for item in pairs(set1) do
112                 items[item] = (not not set2[item]) or nil;
113         end
114
115         return set;
116 end
117
118 local function xor(set1, set2)
119         return union(set1, set2) - intersection(set1, set2);
120 end
121
122 function set_mt.__add(set1, set2)
123         return union(set1, set2);
124 end
125 function set_mt.__sub(set1, set2)
126         return difference(set1, set2);
127 end
128 function set_mt.__div(set, func)
129         local new_set = new();
130         local items, new_items = set._items, new_set._items;
131         for item in pairs(items) do
132                 local new_item = func(item);
133                 if new_item ~= nil then
134                         new_items[new_item] = true;
135                 end
136         end
137         return new_set;
138 end
139 function set_mt.__eq(set1, set2)
140         set1, set2 = set1._items, set2._items;
141         for item in pairs(set1) do
142                 if not set2[item] then
143                         return false;
144                 end
145         end
146
147         for item in pairs(set2) do
148                 if not set1[item] then
149                         return false;
150                 end
151         end
152
153         return true;
154 end
155 function set_mt.__tostring(set)
156         local s, items = { }, set._items;
157         for item in pairs(items) do
158                 s[#s+1] = tostring(item);
159         end
160         return t_concat(s, ", ");
161 end
162
163 return {
164         new = new;
165         union = union;
166         difference = difference;
167         intersection = intersection;
168         xor = xor;
169 };