050446ec3f3aca668bd70933e7d5802f6651b4cf
[prosody.git] / util / set.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 -- 
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local ipairs, pairs, setmetatable, next, tostring =
10       ipairs, pairs, setmetatable, next, tostring;
11 local t_concat = table.concat;
12
13 module "set"
14
15 local set_mt = {};
16 function set_mt.__call(set, _, k)
17         return next(set._items, k);
18 end
19 function set_mt.__add(set1, set2)
20         return _M.union(set1, set2);
21 end
22 function set_mt.__sub(set1, set2)
23         return _M.difference(set1, set2);
24 end
25 function set_mt.__div(set, func)
26         local new_set, new_items = _M.new();
27         local items, new_items = set._items, new_set._items;
28         for item in pairs(items) do
29                 if func(item) then
30                         new_items[item] = true;
31                 end
32         end
33         return new_set;
34 end
35 function set_mt.__eq(set1, set2)
36         local set1, set2 = set1._items, set2._items;
37         for item in pairs(set1) do
38                 if not set2[item] then
39                         return false;
40                 end
41         end
42         
43         for item in pairs(set2) do
44                 if not set1[item] then
45                         return false;
46                 end
47         end
48         
49         return true;
50 end
51 function set_mt.__tostring(set)
52         local s, items = { }, set._items;
53         for item in pairs(items) do
54                 s[#s+1] = tostring(item);
55         end
56         return t_concat(s, ", ");
57 end
58
59 local items_mt = {};
60 function items_mt.__call(items, _, k)
61         return next(items, k);
62 end
63
64 function new(list)
65         local items = setmetatable({}, items_mt);
66         local set = { _items = items };
67         
68         function set:add(item)
69                 items[item] = true;
70         end
71         
72         function set:contains(item)
73                 return items[item];
74         end
75         
76         function set:items()
77                 return items;
78         end
79         
80         function set:remove(item)
81                 items[item] = nil;
82         end
83         
84         function set:add_list(list)
85                 if list then
86                         for _, item in ipairs(list) do
87                                 items[item] = true;
88                         end
89                 end
90         end
91         
92         function set:include(otherset)
93                 for item in pairs(otherset) do
94                         items[item] = true;
95                 end
96         end
97
98         function set:exclude(otherset)
99                 for item in pairs(otherset) do
100                         items[item] = nil;
101                 end
102         end
103         
104         function set:empty()
105                 return not next(items);
106         end
107         
108         if list then
109                 set:add_list(list);
110         end
111         
112         return setmetatable(set, set_mt);
113 end
114
115 function union(set1, set2)
116         local set = new();
117         local items = set._items;
118         
119         for item in pairs(set1._items) do
120                 items[item] = true;
121         end
122
123         for item in pairs(set2._items) do
124                 items[item] = true;
125         end
126         
127         return set;
128 end
129
130 function difference(set1, set2)
131         local set = new();
132         local items = set._items;
133         
134         for item in pairs(set1._items) do
135                 items[item] = (not set2._items[item]) or nil;
136         end
137
138         return set;
139 end
140
141 function intersection(set1, set2)
142         local set = new();
143         local items = set._items;
144         
145         set1, set2 = set1._items, set2._items;
146         
147         for item in pairs(set1) do
148                 items[item] = (not not set2[item]) or nil;
149         end
150         
151         return set;
152 end
153
154 function xor(set1, set2)
155         return union(set1, set2) - intersection(set1, set2);
156 end
157
158 return _M;