Merge from waqas
[prosody.git] / util / stanza.lua
1 local t_insert      =  table.insert;
2 local t_remove      =  table.remove;
3 local s_format      = string.format;
4 local tostring      =      tostring;
5 local setmetatable  =  setmetatable;
6 local pairs         =         pairs;
7 local ipairs        =        ipairs;
8 local type          =          type;
9 local next          =          next;
10 local print          =          print;
11 local unpack        =        unpack;
12 local s_gsub        =   string.gsub;
13 local os = os;
14
15 local do_pretty_printing = not os.getenv("WINDIR");
16 local getstyle, getstring = require "util.termcolours".getstyle, require "util.termcolours".getstring;
17
18 local log = require "util.logger".init("stanza");
19
20 module "stanza"
21
22 stanza_mt = {};
23 stanza_mt.__index = stanza_mt;
24
25 function stanza(name, attr)
26         local stanza = { name = name, attr = attr or {}, tags = {}, last_add = {}};
27         return setmetatable(stanza, stanza_mt);
28 end
29
30 function stanza_mt:query(xmlns)
31         return self:tag("query", { xmlns = xmlns });
32 end
33 function stanza_mt:tag(name, attrs)
34         local s = stanza(name, attrs);
35         (self.last_add[#self.last_add] or self):add_direct_child(s);
36         t_insert(self.last_add, s);
37         return self;
38 end
39
40 function stanza_mt:text(text)
41         (self.last_add[#self.last_add] or self):add_direct_child(text);
42         return self; 
43 end
44
45 function stanza_mt:up()
46         t_remove(self.last_add);
47         return self;
48 end
49
50 function stanza_mt:add_direct_child(child)
51         if type(child) == "table" then
52                 t_insert(self.tags, child);
53         end
54         t_insert(self, child);
55 end
56
57 function stanza_mt:add_child(child)
58         (self.last_add[#self.last_add] or self):add_direct_child(child);
59         return self;
60 end
61
62 function stanza_mt:child_with_name(name)
63         for _, child in ipairs(self) do 
64                 if child.name == name then return child; end
65         end
66 end
67
68 function stanza_mt:children()
69         local i = 0;
70         return function (a)
71                         i = i + 1
72                         local v = a[i]
73                         if v then return v; end
74                 end, self, i;
75                                             
76 end
77 function stanza_mt:childtags()
78         local i = 0;
79         return function (a)
80                         i = i + 1
81                         local v = self.tags[i]
82                         if v then return v; end
83                 end, self.tags[1], i;
84                                             
85 end
86
87 do
88         local xml_entities = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
89         function xml_escape(s) return s_gsub(s, "['&<>\"]", xml_entities); end
90 end
91
92 local xml_escape = xml_escape;
93
94 function stanza_mt.__tostring(t)
95         local children_text = "";
96         for n, child in ipairs(t) do
97                 if type(child) == "string" then 
98                         children_text = children_text .. xml_escape(child);
99                 else
100                         children_text = children_text .. tostring(child);
101                 end
102         end
103
104         local attr_string = "";
105         if t.attr then
106                 for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, tostring(v)); end end
107         end
108         return s_format("<%s%s>%s</%s>", t.name, attr_string, children_text, t.name);
109 end
110
111 function stanza_mt.top_tag(t)
112         local attr_string = "";
113         if t.attr then
114                 for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, tostring(v)); end end
115         end
116         return s_format("<%s%s>", t.name, attr_string);
117 end
118
119 function stanza_mt.__add(s1, s2)
120         return s1:add_direct_child(s2);
121 end
122
123
124 do
125         local id = 0;
126         function new_id()
127                 id = id + 1;
128                 return "lx"..id;
129         end
130 end
131
132 function preserialize(stanza)
133         local s = { name = stanza.name, attr = stanza.attr };
134         for _, child in ipairs(stanza) do
135                 if type(child) == "table" then
136                         t_insert(s, preserialize(child));
137                 else
138                         t_insert(s, child);
139                 end
140         end
141         return s;
142 end
143
144 function deserialize(stanza)
145         -- Set metatable
146         if stanza then
147                 setmetatable(stanza, stanza_mt);
148                 for _, child in ipairs(stanza) do
149                         if type(child) == "table" then
150                                 deserialize(child);
151                         end
152                 end
153                 if not stanza.tags then
154                         -- Rebuild tags
155                         local tags = {};
156                         for _, child in ipairs(stanza) do
157                                 if type(child) == "table" then
158                                         t_insert(tags, child);
159                                 end
160                         end
161                         stanza.tags = tags;
162                 end
163         end
164         
165         return stanza;
166 end
167
168 function message(attr, body)
169         if not body then
170                 return stanza("message", attr);
171         else
172                 return stanza("message", attr):tag("body"):text(body);
173         end
174 end
175 function iq(attr)
176         if attr and not attr.id then attr.id = new_id(); end
177         return stanza("iq", attr or { id = new_id() });
178 end
179
180 function reply(orig)
181         return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
182 end
183
184 function error_reply(orig, type, condition, message, clone)
185         local t = reply(orig);
186         t.attr.type = "error";
187         -- TODO use clone
188         t:tag("error", {type = type})
189                 :tag(condition, {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up();
190         if (message) then t:tag("text"):text(message):up(); end
191         return t; -- stanza ready for adding app-specific errors
192 end
193
194 function presence(attr)
195         return stanza("presence", attr);
196 end
197
198 if do_pretty_printing then
199         local style_attrk = getstyle("yellow");
200         local style_attrv = getstyle("red");
201         local style_tagname = getstyle("red");
202         local style_punc = getstyle("magenta");
203         
204         local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
205         local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
206         --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
207         local tag_format = top_tag_format.."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
208         function stanza_mt.pretty_print(t)
209                 local children_text = "";
210                 for n, child in ipairs(t) do
211                         if type(child) == "string" then 
212                                 children_text = children_text .. xml_escape(child);
213                         else
214                                 children_text = children_text .. child:pretty_print();
215                         end
216                 end
217
218                 local attr_string = "";
219                 if t.attr then
220                         for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end
221                 end
222                 return s_format(tag_format, t.name, attr_string, children_text, t.name);
223         end
224         
225         function stanza_mt.pretty_top_tag(t)
226                 local attr_string = "";
227                 if t.attr then
228                         for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end
229                 end
230                 return s_format(top_tag_format, t.name, attr_string);
231         end
232 else
233         -- Sorry, fresh out of colours for you guys ;)
234         stanza_mt.pretty_print = stanza_mt.__tostring;
235         stanza_mt.pretty_top_tag = stanza_mt.top_tag;
236 end
237
238 return _M;