Merge 0.10->trunk
[prosody.git] / util / json.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local type = type;
10 local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
11 local s_char = string.char;
12 local tostring, tonumber = tostring, tonumber;
13 local pairs, ipairs = pairs, ipairs;
14 local next = next;
15 local getmetatable, setmetatable = getmetatable, setmetatable;
16 local print = print;
17
18 local has_array, array = pcall(require, "util.array");
19 local array_mt = has_array and getmetatable(array()) or {};
20
21 --module("json")
22 local module = {};
23
24 local null = setmetatable({}, { __tostring = function() return "null"; end; });
25 module.null = null;
26
27 local escapes = {
28         ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
29         ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
30 local unescapes = {
31         ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
32         b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
33 for i=0,31 do
34         local ch = s_char(i);
35         if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
36 end
37
38 local function codepoint_to_utf8(code)
39         if code < 0x80 then return s_char(code); end
40         local bits0_6 = code % 64;
41         if code < 0x800 then
42                 local bits6_5 = (code - bits0_6) / 64;
43                 return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
44         end
45         local bits0_12 = code % 4096;
46         local bits6_6 = (bits0_12 - bits0_6) / 64;
47         local bits12_4 = (code - bits0_12) / 4096;
48         return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
49 end
50
51 local valid_types = {
52         number  = true,
53         string  = true,
54         table   = true,
55         boolean = true
56 };
57 local special_keys = {
58         __array = true;
59         __hash  = true;
60 };
61
62 local simplesave, tablesave, arraysave, stringsave;
63
64 function stringsave(o, buffer)
65         -- FIXME do proper utf-8 and binary data detection
66         t_insert(buffer, "\""..(o:gsub(".", escapes)).."\"");
67 end
68
69 function arraysave(o, buffer)
70         t_insert(buffer, "[");
71         if next(o) then
72                 for _, v in ipairs(o) do
73                         simplesave(v, buffer);
74                         t_insert(buffer, ",");
75                 end
76                 t_remove(buffer);
77         end
78         t_insert(buffer, "]");
79 end
80
81 function tablesave(o, buffer)
82         local __array = {};
83         local __hash = {};
84         local hash = {};
85         for i,v in ipairs(o) do
86                 __array[i] = v;
87         end
88         for k,v in pairs(o) do
89                 local ktype, vtype = type(k), type(v);
90                 if valid_types[vtype] or v == null then
91                         if ktype == "string" and not special_keys[k] then
92                                 hash[k] = v;
93                         elseif (valid_types[ktype] or k == null) and __array[k] == nil then
94                                 __hash[k] = v;
95                         end
96                 end
97         end
98         if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then
99                 t_insert(buffer, "{");
100                 local mark = #buffer;
101                 if buffer.ordered then
102                         local keys = {};
103                         for k in pairs(hash) do
104                                 t_insert(keys, k);
105                         end
106                         t_sort(keys);
107                         for _,k in ipairs(keys) do
108                                 stringsave(k, buffer);
109                                 t_insert(buffer, ":");
110                                 simplesave(hash[k], buffer);
111                                 t_insert(buffer, ",");
112                         end
113                 else
114                         for k,v in pairs(hash) do
115                                 stringsave(k, buffer);
116                                 t_insert(buffer, ":");
117                                 simplesave(v, buffer);
118                                 t_insert(buffer, ",");
119                         end
120                 end
121                 if next(__hash) ~= nil then
122                         t_insert(buffer, "\"__hash\":[");
123                         for k,v in pairs(__hash) do
124                                 simplesave(k, buffer);
125                                 t_insert(buffer, ",");
126                                 simplesave(v, buffer);
127                                 t_insert(buffer, ",");
128                         end
129                         t_remove(buffer);
130                         t_insert(buffer, "]");
131                         t_insert(buffer, ",");
132                 end
133                 if next(__array) then
134                         t_insert(buffer, "\"__array\":");
135                         arraysave(__array, buffer);
136                         t_insert(buffer, ",");
137                 end
138                 if mark ~= #buffer then t_remove(buffer); end
139                 t_insert(buffer, "}");
140         else
141                 arraysave(__array, buffer);
142         end
143 end
144
145 function simplesave(o, buffer)
146         local t = type(o);
147         if o == null then
148                 t_insert(buffer, "null");
149         elseif t == "number" then
150                 t_insert(buffer, tostring(o));
151         elseif t == "string" then
152                 stringsave(o, buffer);
153         elseif t == "table" then
154                 local mt = getmetatable(o);
155                 if mt == array_mt then
156                         arraysave(o, buffer);
157                 else
158                         tablesave(o, buffer);
159                 end
160         elseif t == "boolean" then
161                 t_insert(buffer, (o and "true" or "false"));
162         else
163                 t_insert(buffer, "null");
164         end
165 end
166
167 function module.encode(obj)
168         local t = {};
169         simplesave(obj, t);
170         return t_concat(t);
171 end
172 function module.encode_ordered(obj)
173         local t = { ordered = true };
174         simplesave(obj, t);
175         return t_concat(t);
176 end
177 function module.encode_array(obj)
178         local t = {};
179         arraysave(obj, t);
180         return t_concat(t);
181 end
182
183 -----------------------------------
184
185
186 local function _skip_whitespace(json, index)
187         return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t
188 end
189 local function _fixobject(obj)
190         local __array = obj.__array;
191         if __array then
192                 obj.__array = nil;
193                 for _, v in ipairs(__array) do
194                         t_insert(obj, v);
195                 end
196         end
197         local __hash = obj.__hash;
198         if __hash then
199                 obj.__hash = nil;
200                 local k;
201                 for _, v in ipairs(__hash) do
202                         if k ~= nil then
203                                 obj[k] = v; k = nil;
204                         else
205                                 k = v;
206                         end
207                 end
208         end
209         return obj;
210 end
211 local _readvalue, _readstring;
212 local function _readobject(json, index)
213         local o = {};
214         while true do
215                 local key, val;
216                 index = _skip_whitespace(json, index + 1);
217                 if json:byte(index) ~= 0x22 then -- "\""
218                         if json:byte(index) == 0x7d then return o, index + 1; end -- "}"
219                         return nil, "key expected";
220                 end
221                 key, index = _readstring(json, index);
222                 if key == nil then return nil, index; end
223                 index = _skip_whitespace(json, index);
224                 if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":"
225                 val, index = _readvalue(json, index + 1);
226                 if val == nil then return nil, index; end
227                 o[key] = val;
228                 index = _skip_whitespace(json, index);
229                 local b = json:byte(index);
230                 if b == 0x7d then return _fixobject(o), index + 1; end -- "}"
231                 if b ~= 0x2c then return nil, "object eof"; end -- ","
232         end
233 end
234 local function _readarray(json, index)
235         local a = {};
236         local oindex = index;
237         while true do
238                 local val;
239                 val, index = _readvalue(json, index + 1);
240                 if val == nil then
241                         if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]"
242                         return val, index;
243                 end
244                 t_insert(a, val);
245                 index = _skip_whitespace(json, index);
246                 local b = json:byte(index);
247                 if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]"
248                 if b ~= 0x2c then return nil, "array eof"; end -- ","
249         end
250 end
251 local _unescape_error;
252 local function _unescape_surrogate_func(x)
253         local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16);
254         local codepoint = lead * 0x400 + trail - 0x35FDC00;
255         local a = codepoint % 64;
256         codepoint = (codepoint - a) / 64;
257         local b = codepoint % 64;
258         codepoint = (codepoint - b) / 64;
259         local c = codepoint % 64;
260         codepoint = (codepoint - c) / 64;
261         return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a);
262 end
263 local function _unescape_func(x)
264         x = x:match("%x%x%x%x", 3);
265         if x then
266                 --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
267                 return codepoint_to_utf8(tonumber(x, 16));
268         end
269         _unescape_error = true;
270 end
271 function _readstring(json, index)
272         index = index + 1;
273         local endindex = json:find("\"", index, true);
274         if endindex then
275                 local s = json:sub(index, endindex - 1);
276                 --if s:find("[%z-\31]") then return nil, "control char in string"; end
277                 -- FIXME handle control characters
278                 _unescape_error = nil;
279                 --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
280                 -- FIXME handle escapes beyond BMP
281                 s = s:gsub("\\u.?.?.?.?", _unescape_func);
282                 if _unescape_error then return nil, "invalid escape"; end
283                 return s, endindex + 1;
284         end
285         return nil, "string eof";
286 end
287 local function _readnumber(json, index)
288         local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking
289         return tonumber(m), index + #m;
290 end
291 local function _readnull(json, index)
292         local a, b, c = json:byte(index + 1, index + 3);
293         if a == 0x75 and b == 0x6c and c == 0x6c then
294                 return null, index + 4;
295         end
296         return nil, "null parse failed";
297 end
298 local function _readtrue(json, index)
299         local a, b, c = json:byte(index + 1, index + 3);
300         if a == 0x72 and b == 0x75 and c == 0x65 then
301                 return true, index + 4;
302         end
303         return nil, "true parse failed";
304 end
305 local function _readfalse(json, index)
306         local a, b, c, d = json:byte(index + 1, index + 4);
307         if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then
308                 return false, index + 5;
309         end
310         return nil, "false parse failed";
311 end
312 function _readvalue(json, index)
313         index = _skip_whitespace(json, index);
314         local b = json:byte(index);
315         -- TODO try table lookup instead of if-else?
316         if b == 0x7B then -- "{"
317                 return _readobject(json, index);
318         elseif b == 0x5B then -- "["
319                 return _readarray(json, index);
320         elseif b == 0x22 then -- "\""
321                 return _readstring(json, index);
322         elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-"
323                 return _readnumber(json, index);
324         elseif b == 0x6e then -- "n"
325                 return _readnull(json, index);
326         elseif b == 0x74 then -- "t"
327                 return _readtrue(json, index);
328         elseif b == 0x66 then -- "f"
329                 return _readfalse(json, index);
330         else
331                 return nil, "value expected";
332         end
333 end
334 local first_escape = {
335         ["\\\""] = "\\u0022";
336         ["\\\\"] = "\\u005c";
337         ["\\/" ] = "\\u002f";
338         ["\\b" ] = "\\u0008";
339         ["\\f" ] = "\\u000C";
340         ["\\n" ] = "\\u000A";
341         ["\\r" ] = "\\u000D";
342         ["\\t" ] = "\\u0009";
343         ["\\u" ] = "\\u";
344 };
345
346 function module.decode(json)
347         json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
348                 --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
349
350         -- TODO do encoding verification
351
352         local val, index = _readvalue(json, 1);
353         if val == nil then return val, index; end
354         if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
355
356         return val;
357 end
358
359 function module.test(object)
360         local encoded = module.encode(object);
361         local decoded = module.decode(encoded);
362         local recoded = module.encode(decoded);
363         if encoded ~= recoded then
364                 print("FAILED");
365                 print("encoded:", encoded);
366                 print("recoded:", recoded);
367         else
368                 print(encoded);
369         end
370         return encoded == recoded;
371 end
372
373 return module;