X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=util%2Fset.lua;h=c136a522d3a272e1b7a08ec851708ad539a3ad89;hb=c4e5c21a7cb6a4082ac4e15079f021fed4c928c0;hp=67f3dcaabfb37bd445bf99b85bb41b24cc2d095c;hpb=cbaa12fe71fafacc60596ba6fe21f54267ca525a;p=prosody.git diff --git a/util/set.lua b/util/set.lua index 67f3dcaa..c136a522 100644 --- a/util/set.lua +++ b/util/set.lua @@ -1,81 +1,169 @@ -local ipairs, pairs = - ipairs, pairs; +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- -module "set" +local ipairs, pairs, setmetatable, next, tostring = + ipairs, pairs, setmetatable, next, tostring; +local t_concat = table.concat; + +local _ENV = nil; + +local set_mt = {}; +function set_mt.__call(set, _, k) + return next(set._items, k); +end + +local items_mt = {}; +function items_mt.__call(items, _, k) + return next(items, k); +end + +local function new(list) + local items = setmetatable({}, items_mt); + local set = { _items = items }; + + -- We access the set through an upvalue in these methods, so ignore 'self' being unused + --luacheck: ignore 212/self -function new(list) - local items = {}; - local set = { items = items }; - function set:add(item) items[item] = true; end - + function set:contains(item) return items[item]; end - + function set:items() - return items; + return next, items; end - + function set:remove(item) items[item] = nil; end - - function set:add_list(list) - for _, item in ipairs(list) do - items[item] = true; + + function set:add_list(item_list) + if item_list then + for _, item in ipairs(item_list) do + items[item] = true; + end end end - + function set:include(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = true; end end function set:exclude(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = nil; end end - + + function set:empty() + return not next(items); + end + if list then set:add_list(list); end - - return set; + + return setmetatable(set, set_mt); end -function union(set1, set2) +local function union(set1, set2) local set = new(); - local items = set.items; - - for item in pairs(set1.items) do + local items = set._items; + + for item in pairs(set1._items) do items[item] = true; end - for item in pairs(set2.items) do + for item in pairs(set2._items) do items[item] = true; end - + return set; end -function difference(set1, set2) +local function difference(set1, set2) local set = new(); - local items = set.items; - - for item in pairs(set1.items) do - items[item] = true; + local items = set._items; + + for item in pairs(set1._items) do + items[item] = (not set2._items[item]) or nil; end - for item in pairs(set2.items) do - items[item] = nil; + return set; +end + +local function intersection(set1, set2) + local set = new(); + local items = set._items; + + set1, set2 = set1._items, set2._items; + + for item in pairs(set1) do + items[item] = (not not set2[item]) or nil; end - + return set; end -return _M; +local function xor(set1, set2) + return union(set1, set2) - intersection(set1, set2); +end + +function set_mt.__add(set1, set2) + return union(set1, set2); +end +function set_mt.__sub(set1, set2) + return difference(set1, set2); +end +function set_mt.__div(set, func) + local new_set = new(); + local items, new_items = set._items, new_set._items; + for item in pairs(items) do + local new_item = func(item); + if new_item ~= nil then + new_items[new_item] = true; + end + end + return new_set; +end +function set_mt.__eq(set1, set2) + set1, set2 = set1._items, set2._items; + for item in pairs(set1) do + if not set2[item] then + return false; + end + end + + for item in pairs(set2) do + if not set1[item] then + return false; + end + end + + return true; +end +function set_mt.__tostring(set) + local s, items = { }, set._items; + for item in pairs(items) do + s[#s+1] = tostring(item); + end + return t_concat(s, ", "); +end + +return { + new = new; + union = union; + difference = difference; + intersection = intersection; + xor = xor; +};