portmanager, s2smanager, sessionmanager, stanza_router, storagemanager, usermanager...
[prosody.git] / util / set.lua
index 892f1c9d3f70e14bc75a3d36df9ebdcd98531078..04f5f0f4f297d3d5b9d8bf3d167edebf4cf6d092 100644 (file)
@@ -1,4 +1,12 @@
-local ipairs, pairs, setmetatable, next, tostring = 
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ipairs, pairs, setmetatable, next, tostring =
       ipairs, pairs, setmetatable, next, tostring;
 local t_concat = table.concat;
 
@@ -15,11 +23,12 @@ function set_mt.__sub(set1, set2)
        return _M.difference(set1, set2);
 end
 function set_mt.__div(set, func)
-       local new_set, new_items = _M.new();
+       local new_set = _M.new();
        local items, new_items = set._items, new_set._items;
        for item in pairs(items) do
-               if func(item) then
-                       new_items[item] = true;
+               local new_item = func(item);
+               if new_item ~= nil then
+                       new_items[new_item] = true;
                end
        end
        return new_set;
@@ -31,13 +40,13 @@ function set_mt.__eq(set1, set2)
                        return false;
                end
        end
-       
+
        for item in pairs(set2) do
                if not set1[item] then
                        return false;
                end
        end
-       
+
        return true;
 end
 function set_mt.__tostring(set)
@@ -56,52 +65,58 @@ end
 function new(list)
        local items = setmetatable({}, items_mt);
        local set = { _items = items };
-       
+
        function set:add(item)
                items[item] = true;
        end
-       
+
        function set:contains(item)
                return items[item];
        end
-       
+
        function set:items()
-               return items;
+               return next, items;
        end
-       
+
        function set:remove(item)
                items[item] = nil;
        end
-       
+
        function set:add_list(list)
-               for _, item in ipairs(list) do
-                       items[item] = true;
+               if list then
+                       for _, item in ipairs(list) do
+                               items[item] = true;
+                       end
                end
        end
-       
+
        function set:include(otherset)
-               for item in pairs(otherset) do
+               for item in otherset do
                        items[item] = true;
                end
        end
 
        function set:exclude(otherset)
-               for item in pairs(otherset) do
+               for item in otherset do
                        items[item] = nil;
                end
        end
-       
+
+       function set:empty()
+               return not next(items);
+       end
+
        if list then
                set:add_list(list);
        end
-       
+
        return setmetatable(set, set_mt);
 end
 
 function union(set1, set2)
        local set = new();
        local items = set._items;
-       
+
        for item in pairs(set1._items) do
                items[item] = true;
        end
@@ -109,14 +124,14 @@ function union(set1, set2)
        for item in pairs(set2._items) do
                items[item] = true;
        end
-       
+
        return set;
 end
 
 function difference(set1, set2)
        local set = new();
        local items = set._items;
-       
+
        for item in pairs(set1._items) do
                items[item] = (not set2._items[item]) or nil;
        end
@@ -127,14 +142,18 @@ end
 function intersection(set1, set2)
        local set = new();
        local items = set._items;
-       
+
        set1, set2 = set1._items, set2._items;
-       
+
        for item in pairs(set1) do
                items[item] = (not not set2[item]) or nil;
        end
-       
+
        return set;
 end
 
+function xor(set1, set2)
+       return union(set1, set2) - intersection(set1, set2);
+end
+
 return _M;