Merge 0.8->trunk
[prosody.git] / util / set.lua
index 08361d03332044e929bd01ed998f4391c83f8c5c..e4cc2dffbd428b6be84c2ec749ba0cd2156b67b3 100644 (file)
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+-- 
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ipairs, pairs, setmetatable, next, tostring =
+      ipairs, pairs, setmetatable, next, tostring;
+local t_concat = table.concat;
 
 module "set"
 
+local set_mt = {};
+function set_mt.__call(set, _, k)
+       return next(set._items, k);
+end
+function set_mt.__add(set1, set2)
+       return _M.union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+       return _M.difference(set1, set2);
+end
+function set_mt.__div(set, func)
+       local new_set, new_items = _M.new();
+       local items, new_items = set._items, new_set._items;
+       for item in pairs(items) do
+               if func(item) then
+                       new_items[item] = true;
+               end
+       end
+       return new_set;
+end
+function set_mt.__eq(set1, set2)
+       local set1, set2 = set1._items, set2._items;
+       for item in pairs(set1) do
+               if not set2[item] then
+                       return false;
+               end
+       end
+       
+       for item in pairs(set2) do
+               if not set1[item] then
+                       return false;
+               end
+       end
+       
+       return true;
+end
+function set_mt.__tostring(set)
+       local s, items = { }, set._items;
+       for item in pairs(items) do
+               s[#s+1] = tostring(item);
+       end
+       return t_concat(s, ", ");
+end
+
+local items_mt = {};
+function items_mt.__call(items, _, k)
+       return next(items, k);
+end
+
 function new(list)
-       local items = {};
-       local set = { items = items };
+       local items = setmetatable({}, items_mt);
+       local set = { _items = items };
        
-       function set:add(set, item)
+       function set:add(item)
                items[item] = true;
        end
        
-       function set:contains(set, item)
-               return items[item]
+       function set:contains(item)
+               return items[item];
        end
        
-       function set:items(set)
+       function set:items()
                return items;
        end
        
-       function set:remove(set, item)
+       function set:remove(item)
                items[item] = nil;
        end
        
-       function set:add_list(set, list)
+       function set:add_list(list)
                for _, item in ipairs(list) do
                        items[item] = true;
                end
        end
        
-       function set:include(set, otherset)
+       function set:include(otherset)
                for item in pairs(otherset) do
                        items[item] = true;
                end
        end
 
-       function set:exclude(set, otherset)
+       function set:exclude(otherset)
                for item in pairs(otherset) do
                        items[item] = nil;
                end
        end
        
-       return set;
+       function set:empty()
+               return not next(items);
+       end
+       
+       if list then
+               set:add_list(list);
+       end
+       
+       return setmetatable(set, set_mt);
 end
 
 function union(set1, set2)
        local set = new();
-       local items = set.items;
+       local items = set._items;
        
-       for item in pairs(set1.items) do
+       for item in pairs(set1._items) do
                items[item] = true;
        end
 
-       for item in pairs(set2.items) do
+       for item in pairs(set2._items) do
                items[item] = true;
        end
        
@@ -59,17 +127,30 @@ end
 
 function difference(set1, set2)
        local set = new();
-       local items = set.items;
+       local items = set._items;
        
-       for item in pairs(set1.items) do
-               items[item] = true;
+       for item in pairs(set1._items) do
+               items[item] = (not set2._items[item]) or nil;
        end
 
-       for item in pairs(set2.items) do
-               items[item] = nil;
+       return set;
+end
+
+function intersection(set1, set2)
+       local set = new();
+       local items = set._items;
+       
+       set1, set2 = set1._items, set2._items;
+       
+       for item in pairs(set1) do
+               items[item] = (not not set2[item]) or nil;
        end
        
        return set;
 end
 
+function xor(set1, set2)
+       return union(set1, set2) - intersection(set1, set2);
+end
+
 return _M;