Merge 0.10->trunk
[prosody.git] / util / json.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local type = type;
10 local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
11 local s_char = string.char;
12 local tostring, tonumber = tostring, tonumber;
13 local pairs, ipairs = pairs, ipairs;
14 local next = next;
15 local error = error;
16 local getmetatable, setmetatable = getmetatable, setmetatable;
17 local print = print;
18
19 local has_array, array = pcall(require, "util.array");
20 local array_mt = has_array and getmetatable(array()) or {};
21
22 --module("json")
23 local json = {};
24
25 local null = setmetatable({}, { __tostring = function() return "null"; end; });
26 json.null = null;
27
28 local escapes = {
29         ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
30         ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
31 local unescapes = {
32         ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
33         b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
34 for i=0,31 do
35         local ch = s_char(i);
36         if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
37 end
38
39 local function codepoint_to_utf8(code)
40         if code < 0x80 then return s_char(code); end
41         local bits0_6 = code % 64;
42         if code < 0x800 then
43                 local bits6_5 = (code - bits0_6) / 64;
44                 return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
45         end
46         local bits0_12 = code % 4096;
47         local bits6_6 = (bits0_12 - bits0_6) / 64;
48         local bits12_4 = (code - bits0_12) / 4096;
49         return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
50 end
51
52 local valid_types = {
53         number  = true,
54         string  = true,
55         table   = true,
56         boolean = true
57 };
58 local special_keys = {
59         __array = true;
60         __hash  = true;
61 };
62
63 local simplesave, tablesave, arraysave, stringsave;
64
65 function stringsave(o, buffer)
66         -- FIXME do proper utf-8 and binary data detection
67         t_insert(buffer, "\""..(o:gsub(".", escapes)).."\"");
68 end
69
70 function arraysave(o, buffer)
71         t_insert(buffer, "[");
72         if next(o) then
73                 for i,v in ipairs(o) do
74                         simplesave(v, buffer);
75                         t_insert(buffer, ",");
76                 end
77                 t_remove(buffer);
78         end
79         t_insert(buffer, "]");
80 end
81
82 function tablesave(o, buffer)
83         local __array = {};
84         local __hash = {};
85         local hash = {};
86         for i,v in ipairs(o) do
87                 __array[i] = v;
88         end
89         for k,v in pairs(o) do
90                 local ktype, vtype = type(k), type(v);
91                 if valid_types[vtype] or v == null then
92                         if ktype == "string" and not special_keys[k] then
93                                 hash[k] = v;
94                         elseif (valid_types[ktype] or k == null) and __array[k] == nil then
95                                 __hash[k] = v;
96                         end
97                 end
98         end
99         if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then
100                 t_insert(buffer, "{");
101                 local mark = #buffer;
102                 if buffer.ordered then
103                         local keys = {};
104                         for k in pairs(hash) do
105                                 t_insert(keys, k);
106                         end
107                         t_sort(keys);
108                         for _,k in ipairs(keys) do
109                                 stringsave(k, buffer);
110                                 t_insert(buffer, ":");
111                                 simplesave(hash[k], buffer);
112                                 t_insert(buffer, ",");
113                         end
114                 else
115                         for k,v in pairs(hash) do
116                                 stringsave(k, buffer);
117                                 t_insert(buffer, ":");
118                                 simplesave(v, buffer);
119                                 t_insert(buffer, ",");
120                         end
121                 end
122                 if next(__hash) ~= nil then
123                         t_insert(buffer, "\"__hash\":[");
124                         for k,v in pairs(__hash) do
125                                 simplesave(k, buffer);
126                                 t_insert(buffer, ",");
127                                 simplesave(v, buffer);
128                                 t_insert(buffer, ",");
129                         end
130                         t_remove(buffer);
131                         t_insert(buffer, "]");
132                         t_insert(buffer, ",");
133                 end
134                 if next(__array) then
135                         t_insert(buffer, "\"__array\":");
136                         arraysave(__array, buffer);
137                         t_insert(buffer, ",");
138                 end
139                 if mark ~= #buffer then t_remove(buffer); end
140                 t_insert(buffer, "}");
141         else
142                 arraysave(__array, buffer);
143         end
144 end
145
146 function simplesave(o, buffer)
147         local t = type(o);
148         if o == null then
149                 t_insert(buffer, "null");
150         elseif t == "number" then
151                 t_insert(buffer, tostring(o));
152         elseif t == "string" then
153                 stringsave(o, buffer);
154         elseif t == "table" then
155                 local mt = getmetatable(o);
156                 if mt == array_mt then
157                         arraysave(o, buffer);
158                 else
159                         tablesave(o, buffer);
160                 end
161         elseif t == "boolean" then
162                 t_insert(buffer, (o and "true" or "false"));
163         else
164                 t_insert(buffer, "null");
165         end
166 end
167
168 function json.encode(obj)
169         local t = {};
170         simplesave(obj, t);
171         return t_concat(t);
172 end
173 function json.encode_ordered(obj)
174         local t = { ordered = true };
175         simplesave(obj, t);
176         return t_concat(t);
177 end
178 function json.encode_array(obj)
179         local t = {};
180         arraysave(obj, t);
181         return t_concat(t);
182 end
183
184 -----------------------------------
185
186
187 local function _skip_whitespace(json, index)
188         return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t
189 end
190 local function _fixobject(obj)
191         local __array = obj.__array;
192         if __array then
193                 obj.__array = nil;
194                 for i,v in ipairs(__array) do
195                         t_insert(obj, v);
196                 end
197         end
198         local __hash = obj.__hash;
199         if __hash then
200                 obj.__hash = nil;
201                 local k;
202                 for i,v in ipairs(__hash) do
203                         if k ~= nil then
204                                 obj[k] = v; k = nil;
205                         else
206                                 k = v;
207                         end
208                 end
209         end
210         return obj;
211 end
212 local _readvalue, _readstring;
213 local function _readobject(json, index)
214         local o = {};
215         while true do
216                 local key, val;
217                 index = _skip_whitespace(json, index + 1);
218                 if json:byte(index) ~= 0x22 then -- "\""
219                         if json:byte(index) == 0x7d then return o, index + 1; end -- "}"
220                         return nil, "key expected";
221                 end
222                 key, index = _readstring(json, index);
223                 if key == nil then return nil, index; end
224                 index = _skip_whitespace(json, index);
225                 if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":"
226                 val, index = _readvalue(json, index + 1);
227                 if val == nil then return nil, index; end
228                 o[key] = val;
229                 index = _skip_whitespace(json, index);
230                 local b = json:byte(index);
231                 if b == 0x7d then return _fixobject(o), index + 1; end -- "}"
232                 if b ~= 0x2c then return nil, "object eof"; end -- ","
233         end
234 end
235 local function _readarray(json, index)
236         local a = {};
237         local oindex = index;
238         while true do
239                 local val;
240                 val, index = _readvalue(json, index + 1);
241                 if val == nil then
242                         if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]"
243                         return val, index;
244                 end
245                 t_insert(a, val);
246                 index = _skip_whitespace(json, index);
247                 local b = json:byte(index);
248                 if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]"
249                 if b ~= 0x2c then return nil, "array eof"; end -- ","
250         end
251 end
252 local _unescape_error;
253 local function _unescape_surrogate_func(x)
254         local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16);
255         local codepoint = lead * 0x400 + trail - 0x35FDC00;
256         local a = codepoint % 64;
257         codepoint = (codepoint - a) / 64;
258         local b = codepoint % 64;
259         codepoint = (codepoint - b) / 64;
260         local c = codepoint % 64;
261         codepoint = (codepoint - c) / 64;
262         return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a);
263 end
264 local function _unescape_func(x)
265         x = x:match("%x%x%x%x", 3);
266         if x then
267                 --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
268                 return codepoint_to_utf8(tonumber(x, 16));
269         end
270         _unescape_error = true;
271 end
272 function _readstring(json, index)
273         index = index + 1;
274         local endindex = json:find("\"", index, true);
275         if endindex then
276                 local s = json:sub(index, endindex - 1);
277                 --if s:find("[%z-\31]") then return nil, "control char in string"; end
278                 -- FIXME handle control characters
279                 _unescape_error = nil;
280                 --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
281                 -- FIXME handle escapes beyond BMP
282                 s = s:gsub("\\u.?.?.?.?", _unescape_func);
283                 if _unescape_error then return nil, "invalid escape"; end
284                 return s, endindex + 1;
285         end
286         return nil, "string eof";
287 end
288 local function _readnumber(json, index)
289         local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking
290         return tonumber(m), index + #m;
291 end
292 local function _readnull(json, index)
293         local a, b, c = json:byte(index + 1, index + 3);
294         if a == 0x75 and b == 0x6c and c == 0x6c then
295                 return null, index + 4;
296         end
297         return nil, "null parse failed";
298 end
299 local function _readtrue(json, index)
300         local a, b, c = json:byte(index + 1, index + 3);
301         if a == 0x72 and b == 0x75 and c == 0x65 then
302                 return true, index + 4;
303         end
304         return nil, "true parse failed";
305 end
306 local function _readfalse(json, index)
307         local a, b, c, d = json:byte(index + 1, index + 4);
308         if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then
309                 return false, index + 5;
310         end
311         return nil, "false parse failed";
312 end
313 function _readvalue(json, index)
314         index = _skip_whitespace(json, index);
315         local b = json:byte(index);
316         -- TODO try table lookup instead of if-else?
317         if b == 0x7B then -- "{"
318                 return _readobject(json, index);
319         elseif b == 0x5B then -- "["
320                 return _readarray(json, index);
321         elseif b == 0x22 then -- "\""
322                 return _readstring(json, index);
323         elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-"
324                 return _readnumber(json, index);
325         elseif b == 0x6e then -- "n"
326                 return _readnull(json, index);
327         elseif b == 0x74 then -- "t"
328                 return _readtrue(json, index);
329         elseif b == 0x66 then -- "f"
330                 return _readfalse(json, index);
331         else
332                 return nil, "value expected";
333         end
334 end
335 local first_escape = {
336         ["\\\""] = "\\u0022";
337         ["\\\\"] = "\\u005c";
338         ["\\/" ] = "\\u002f";
339         ["\\b" ] = "\\u0008";
340         ["\\f" ] = "\\u000C";
341         ["\\n" ] = "\\u000A";
342         ["\\r" ] = "\\u000D";
343         ["\\t" ] = "\\u0009";
344         ["\\u" ] = "\\u";
345 };
346
347 function json.decode(json)
348         json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
349                 --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
350
351         -- TODO do encoding verification
352
353         local val, index = _readvalue(json, 1);
354         if val == nil then return val, index; end
355         if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
356
357         return val;
358 end
359
360 function json.test(object)
361         local encoded = json.encode(object);
362         local decoded = json.decode(encoded);
363         local recoded = json.encode(decoded);
364         if encoded ~= recoded then
365                 print("FAILED");
366                 print("encoded:", encoded);
367                 print("recoded:", recoded);
368         else
369                 print(encoded);
370         end
371         return encoded == recoded;
372 end
373
374 return json;