util.set: Add metatable to sets to allow +, -, /, ==, tostring and to double as iterators
[prosody.git] / util / set.lua
1 local ipairs, pairs, setmetatable, next, tostring = 
2       ipairs, pairs, setmetatable, next, tostring;
3 local t_concat = table.concat;
4
5 module "set"
6
7 local set_mt = {};
8 function set_mt.__call(set, _, k)
9         return next(set._items, k);
10 end
11 function set_mt.__add(set1, set2)
12         return _M.union(set1, set2);
13 end
14 function set_mt.__sub(set1, set2)
15         return _M.difference(set1, set2);
16 end
17 function set_mt.__div(set, func)
18         local new_set, new_items = _M.new();
19         local items, new_items = set._items, new_set._items;
20         for item in pairs(items) do
21                 if func(item) then
22                         new_items[item] = true;
23                 end
24         end
25         return new_set;
26 end
27 function set_mt.__eq(set1, set2)
28         local set1, set2 = set1._items, set2._items;
29         for item in pairs(set1) do
30                 if not set2[item] then
31                         return false;
32                 end
33         end
34         
35         for item in pairs(set2) do
36                 if not set1[item] then
37                         return false;
38                 end
39         end
40         
41         return true;
42 end
43 function set_mt.__tostring(set)
44         local s, items = { }, set._items;
45         for item in pairs(items) do
46                 s[#s+1] = tostring(item);
47         end
48         return t_concat(s, ", ");
49 end
50
51 local items_mt = {};
52 function items_mt.__call(items, _, k)
53         return next(items, k);
54 end
55
56 function new(list)
57         local items = setmetatable({}, items_mt);
58         local set = { _items = items };
59         
60         function set:add(item)
61                 items[item] = true;
62         end
63         
64         function set:contains(item)
65                 return items[item];
66         end
67         
68         function set:items()
69                 return items;
70         end
71         
72         function set:remove(item)
73                 items[item] = nil;
74         end
75         
76         function set:add_list(list)
77                 for _, item in ipairs(list) do
78                         items[item] = true;
79                 end
80         end
81         
82         function set:include(otherset)
83                 for item in pairs(otherset) do
84                         items[item] = true;
85                 end
86         end
87
88         function set:exclude(otherset)
89                 for item in pairs(otherset) do
90                         items[item] = nil;
91                 end
92         end
93         
94         if list then
95                 set:add_list(list);
96         end
97         
98         return setmetatable(set, set_mt);
99 end
100
101 function union(set1, set2)
102         local set = new();
103         local items = set._items;
104         
105         for item in pairs(set1._items) do
106                 items[item] = true;
107         end
108
109         for item in pairs(set2._items) do
110                 items[item] = true;
111         end
112         
113         return set;
114 end
115
116 function difference(set1, set2)
117         local set = new();
118         local items = set._items;
119         
120         for item in pairs(set1._items) do
121                 items[item] = (not set2._items[item]) or nil;
122         end
123
124         return set;
125 end
126
127 function intersection(set1, set2)
128         local set = new();
129         local items = set._items;
130         
131         set1, set2 = set1._items, set2._items;
132         
133         for item in pairs(set1) do
134                 items[item] = (not not set2[item]) or nil;
135         end
136         
137         return set;
138 end
139
140 return _M;