Merge 0.9->0.10 (third time lucky)
[prosody.git] / util / json.lua
1 -- Prosody IM
2 -- Copyright (C) 2008-2010 Matthew Wild
3 -- Copyright (C) 2008-2010 Waqas Hussain
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local type = type;
10 local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
11 local s_char = string.char;
12 local tostring, tonumber = tostring, tonumber;
13 local pairs, ipairs = pairs, ipairs;
14 local next = next;
15 local error = error;
16 local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable;
17 local print = print;
18
19 local has_array, array = pcall(require, "util.array");
20 local array_mt = has_array and getmetatable(array()) or {};
21
22 --module("json")
23 local json = {};
24
25 local null = newproxy and newproxy(true) or {};
26 if getmetatable and getmetatable(null) then
27         getmetatable(null).__tostring = function() return "null"; end;
28 end
29 json.null = null;
30
31 local escapes = {
32         ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
33         ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
34 local unescapes = {
35         ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
36         b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
37 for i=0,31 do
38         local ch = s_char(i);
39         if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
40 end
41
42 local function codepoint_to_utf8(code)
43         if code < 0x80 then return s_char(code); end
44         local bits0_6 = code % 64;
45         if code < 0x800 then
46                 local bits6_5 = (code - bits0_6) / 64;
47                 return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
48         end
49         local bits0_12 = code % 4096;
50         local bits6_6 = (bits0_12 - bits0_6) / 64;
51         local bits12_4 = (code - bits0_12) / 4096;
52         return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
53 end
54
55 local valid_types = {
56         number  = true,
57         string  = true,
58         table   = true,
59         boolean = true
60 };
61 local special_keys = {
62         __array = true;
63         __hash  = true;
64 };
65
66 local simplesave, tablesave, arraysave, stringsave;
67
68 function stringsave(o, buffer)
69         -- FIXME do proper utf-8 and binary data detection
70         t_insert(buffer, "\""..(o:gsub(".", escapes)).."\"");
71 end
72
73 function arraysave(o, buffer)
74         t_insert(buffer, "[");
75         if next(o) then
76                 for i,v in ipairs(o) do
77                         simplesave(v, buffer);
78                         t_insert(buffer, ",");
79                 end
80                 t_remove(buffer);
81         end
82         t_insert(buffer, "]");
83 end
84
85 function tablesave(o, buffer)
86         local __array = {};
87         local __hash = {};
88         local hash = {};
89         for i,v in ipairs(o) do
90                 __array[i] = v;
91         end
92         for k,v in pairs(o) do
93                 local ktype, vtype = type(k), type(v);
94                 if valid_types[vtype] or v == null then
95                         if ktype == "string" and not special_keys[k] then
96                                 hash[k] = v;
97                         elseif (valid_types[ktype] or k == null) and __array[k] == nil then
98                                 __hash[k] = v;
99                         end
100                 end
101         end
102         if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then
103                 t_insert(buffer, "{");
104                 local mark = #buffer;
105                 if buffer.ordered then
106                         local keys = {};
107                         for k in pairs(hash) do
108                                 t_insert(keys, k);
109                         end
110                         t_sort(keys);
111                         for _,k in ipairs(keys) do
112                                 stringsave(k, buffer);
113                                 t_insert(buffer, ":");
114                                 simplesave(hash[k], buffer);
115                                 t_insert(buffer, ",");
116                         end
117                 else
118                         for k,v in pairs(hash) do
119                                 stringsave(k, buffer);
120                                 t_insert(buffer, ":");
121                                 simplesave(v, buffer);
122                                 t_insert(buffer, ",");
123                         end
124                 end
125                 if next(__hash) ~= nil then
126                         t_insert(buffer, "\"__hash\":[");
127                         for k,v in pairs(__hash) do
128                                 simplesave(k, buffer);
129                                 t_insert(buffer, ",");
130                                 simplesave(v, buffer);
131                                 t_insert(buffer, ",");
132                         end
133                         t_remove(buffer);
134                         t_insert(buffer, "]");
135                         t_insert(buffer, ",");
136                 end
137                 if next(__array) then
138                         t_insert(buffer, "\"__array\":");
139                         arraysave(__array, buffer);
140                         t_insert(buffer, ",");
141                 end
142                 if mark ~= #buffer then t_remove(buffer); end
143                 t_insert(buffer, "}");
144         else
145                 arraysave(__array, buffer);
146         end
147 end
148
149 function simplesave(o, buffer)
150         local t = type(o);
151         if t == "number" then
152                 t_insert(buffer, tostring(o));
153         elseif t == "string" then
154                 stringsave(o, buffer);
155         elseif t == "table" then
156                 local mt = getmetatable(o);
157                 if mt == array_mt then
158                         arraysave(o, buffer);
159                 else
160                         tablesave(o, buffer);
161                 end
162         elseif t == "boolean" then
163                 t_insert(buffer, (o and "true" or "false"));
164         else
165                 t_insert(buffer, "null");
166         end
167 end
168
169 function json.encode(obj)
170         local t = {};
171         simplesave(obj, t);
172         return t_concat(t);
173 end
174 function json.encode_ordered(obj)
175         local t = { ordered = true };
176         simplesave(obj, t);
177         return t_concat(t);
178 end
179 function json.encode_array(obj)
180         local t = {};
181         arraysave(obj, t);
182         return t_concat(t);
183 end
184
185 -----------------------------------
186
187
188 local function _skip_whitespace(json, index)
189         return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t
190 end
191 local function _fixobject(obj)
192         local __array = obj.__array;
193         if __array then
194                 obj.__array = nil;
195                 for i,v in ipairs(__array) do
196                         t_insert(obj, v);
197                 end
198         end
199         local __hash = obj.__hash;
200         if __hash then
201                 obj.__hash = nil;
202                 local k;
203                 for i,v in ipairs(__hash) do
204                         if k ~= nil then
205                                 obj[k] = v; k = nil;
206                         else
207                                 k = v;
208                         end
209                 end
210         end
211         return obj;
212 end
213 local _readvalue, _readstring;
214 local function _readobject(json, index)
215         local o = {};
216         while true do
217                 local key, val;
218                 index = _skip_whitespace(json, index + 1);
219                 if json:byte(index) ~= 0x22 then -- "\""
220                         if json:byte(index) == 0x7d then return o, index + 1; end -- "}"
221                         return nil, "key expected";
222                 end
223                 key, index = _readstring(json, index);
224                 if key == nil then return nil, index; end
225                 index = _skip_whitespace(json, index);
226                 if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":"
227                 val, index = _readvalue(json, index + 1);
228                 if val == nil then return nil, index; end
229                 o[key] = val;
230                 index = _skip_whitespace(json, index);
231                 local b = json:byte(index);
232                 if b == 0x7d then return _fixobject(o), index + 1; end -- "}"
233                 if b ~= 0x2c then return nil, "object eof"; end -- ","
234         end
235 end
236 local function _readarray(json, index)
237         local a = {};
238         local oindex = index;
239         while true do
240                 local val;
241                 val, index = _readvalue(json, index + 1);
242                 if val == nil then
243                         if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]"
244                         return val, index;
245                 end
246                 t_insert(a, val);
247                 index = _skip_whitespace(json, index);
248                 local b = json:byte(index);
249                 if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]"
250                 if b ~= 0x2c then return nil, "array eof"; end -- ","
251         end
252 end
253 local _unescape_error;
254 local function _unescape_surrogate_func(x)
255         local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16);
256         local codepoint = lead * 0x400 + trail - 0x35FDC00;
257         local a = codepoint % 64;
258         codepoint = (codepoint - a) / 64;
259         local b = codepoint % 64;
260         codepoint = (codepoint - b) / 64;
261         local c = codepoint % 64;
262         codepoint = (codepoint - c) / 64;
263         return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a);
264 end
265 local function _unescape_func(x)
266         x = x:match("%x%x%x%x", 3);
267         if x then
268                 --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
269                 return codepoint_to_utf8(tonumber(x, 16));
270         end
271         _unescape_error = true;
272 end
273 function _readstring(json, index)
274         index = index + 1;
275         local endindex = json:find("\"", index, true);
276         if endindex then
277                 local s = json:sub(index, endindex - 1);
278                 --if s:find("[%z-\31]") then return nil, "control char in string"; end
279                 -- FIXME handle control characters
280                 _unescape_error = nil;
281                 --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
282                 -- FIXME handle escapes beyond BMP
283                 s = s:gsub("\\u.?.?.?.?", _unescape_func);
284                 if _unescape_error then return nil, "invalid escape"; end
285                 return s, endindex + 1;
286         end
287         return nil, "string eof";
288 end
289 local function _readnumber(json, index)
290         local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking
291         return tonumber(m), index + #m;
292 end
293 local function _readnull(json, index)
294         local a, b, c = json:byte(index + 1, index + 3);
295         if a == 0x75 and b == 0x6c and c == 0x6c then
296                 return null, index + 4;
297         end
298         return nil, "null parse failed";
299 end
300 local function _readtrue(json, index)
301         local a, b, c = json:byte(index + 1, index + 3);
302         if a == 0x72 and b == 0x75 and c == 0x65 then
303                 return true, index + 4;
304         end
305         return nil, "true parse failed";
306 end
307 local function _readfalse(json, index)
308         local a, b, c, d = json:byte(index + 1, index + 4);
309         if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then
310                 return false, index + 5;
311         end
312         return nil, "false parse failed";
313 end
314 function _readvalue(json, index)
315         index = _skip_whitespace(json, index);
316         local b = json:byte(index);
317         -- TODO try table lookup instead of if-else?
318         if b == 0x7B then -- "{"
319                 return _readobject(json, index);
320         elseif b == 0x5B then -- "["
321                 return _readarray(json, index);
322         elseif b == 0x22 then -- "\""
323                 return _readstring(json, index);
324         elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-"
325                 return _readnumber(json, index);
326         elseif b == 0x6e then -- "n"
327                 return _readnull(json, index);
328         elseif b == 0x74 then -- "t"
329                 return _readtrue(json, index);
330         elseif b == 0x66 then -- "f"
331                 return _readfalse(json, index);
332         else
333                 return nil, "value expected";
334         end
335 end
336 local first_escape = {
337         ["\\\""] = "\\u0022";
338         ["\\\\"] = "\\u005c";
339         ["\\/" ] = "\\u002f";
340         ["\\b" ] = "\\u0008";
341         ["\\f" ] = "\\u000C";
342         ["\\n" ] = "\\u000A";
343         ["\\r" ] = "\\u000D";
344         ["\\t" ] = "\\u0009";
345         ["\\u" ] = "\\u";
346 };
347
348 function json.decode(json)
349         json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
350                 --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
351
352         -- TODO do encoding verification
353
354         local val, index = _readvalue(json, 1);
355         if val == nil then return val, index; end
356         if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
357
358         return val;
359 end
360
361 function json.test(object)
362         local encoded = json.encode(object);
363         local decoded = json.decode(encoded);
364         local recoded = json.encode(decoded);
365         if encoded ~= recoded then
366                 print("FAILED");
367                 print("encoded:", encoded);
368                 print("recoded:", recoded);
369         else
370                 print(encoded);
371         end
372         return encoded == recoded;
373 end
374
375 return json;