Merge 0.9->0.10
[prosody.git] / net / websocket / frames.lua
1 -- Prosody IM
2 -- Copyright (C) 2012 Florian Zeitz
3 -- Copyright (C) 2014 Daurnimator
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local softreq = require "util.dependencies".softreq;
10 local log = require "util.logger".init "websocket.frames";
11 local random_bytes = require "util.random".bytes;
12
13 local bit = assert(softreq"bit" or softreq"bit32",
14         "No bit module found. See https://prosody.im/doc/depends#bitop");
15 local band = bit.band;
16 local bor = bit.bor;
17 local bxor = bit.bxor;
18 local lshift = bit.lshift;
19 local rshift = bit.rshift;
20
21 local t_concat = table.concat;
22 local s_byte = string.byte;
23 local s_char= string.char;
24 local s_sub = string.sub;
25 local s_pack = string.pack;
26 local s_unpack = string.unpack;
27
28 if not s_pack and softreq"struct" then
29         s_pack = softreq"struct".pack;
30         s_unpack = softreq"struct".unpack;
31 end
32
33 local function read_uint16be(str, pos)
34         local l1, l2 = s_byte(str, pos, pos+1);
35         return l1*256 + l2;
36 end
37 -- FIXME: this may lose precision
38 local function read_uint64be(str, pos)
39         local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
40         local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4;
41         local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
42         return h * 2^32 + l;
43 end
44 local function pack_uint16be(x)
45         return s_char(rshift(x, 8), band(x, 0xFF));
46 end
47 local function get_byte(x, n)
48         return band(rshift(x, n), 0xFF);
49 end
50 local function pack_uint64be(x)
51         local h = band(x / 2^32, 2^32-1);
52         return s_char(get_byte(h, 24), get_byte(h, 16), get_byte(h, 8), band(h, 0xFF),
53                 get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
54 end
55
56 if s_pack then
57         function pack_uint16be(x)
58                 return s_pack(">I2", x);
59         end
60         function pack_uint64be(x)
61                 return s_pack(">I8", x);
62         end
63 end
64
65 if s_unpack then
66         function read_uint16be(str, pos)
67                 return s_unpack(">I2", str, pos);
68         end
69         function read_uint64be(str, pos)
70                 return s_unpack(">I8", str, pos);
71         end
72 end
73
74 local function parse_frame_header(frame)
75         if #frame < 2 then return; end
76
77         local byte1, byte2 = s_byte(frame, 1, 2);
78         local result = {
79                 FIN = band(byte1, 0x80) > 0;
80                 RSV1 = band(byte1, 0x40) > 0;
81                 RSV2 = band(byte1, 0x20) > 0;
82                 RSV3 = band(byte1, 0x10) > 0;
83                 opcode = band(byte1, 0x0F);
84
85                 MASK = band(byte2, 0x80) > 0;
86                 length = band(byte2, 0x7F);
87         };
88
89         local length_bytes = 0;
90         if result.length == 126 then
91                 length_bytes = 2;
92         elseif result.length == 127 then
93                 length_bytes = 8;
94         end
95
96         local header_length = 2 + length_bytes + (result.MASK and 4 or 0);
97         if #frame < header_length then return; end
98
99         if length_bytes == 2 then
100                 result.length = read_uint16be(frame, 3);
101         elseif length_bytes == 8 then
102                 result.length = read_uint64be(frame, 3);
103         end
104
105         if result.MASK then
106                 result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
107         end
108
109         return result, header_length;
110 end
111
112 -- XORs the string `str` with the array of bytes `key`
113 -- TODO: optimize
114 local function apply_mask(str, key, from, to)
115         from = from or 1
116         if from < 0 then from = #str + from + 1 end -- negative indicies
117         to = to or #str
118         if to < 0 then to = #str + to + 1 end -- negative indicies
119         local key_len = #key
120         local counter = 0;
121         local data = {};
122         for i = from, to do
123                 local key_index = counter%key_len + 1;
124                 counter = counter + 1;
125                 data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
126         end
127         return t_concat(data);
128 end
129
130 local function parse_frame_body(frame, header, pos)
131         if header.MASK then
132                 return apply_mask(frame, header.key, pos, pos + header.length - 1);
133         else
134                 return frame:sub(pos, pos + header.length - 1);
135         end
136 end
137
138 local function parse_frame(frame)
139         local result, pos = parse_frame_header(frame);
140         if result == nil or #frame < (pos + result.length) then return; end
141         result.data = parse_frame_body(frame, result, pos+1);
142         return result, pos + result.length;
143 end
144
145 local function build_frame(desc)
146         local data = desc.data or "";
147
148         assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode");
149         if desc.opcode >= 0x8 then
150                 -- RFC 6455 5.5
151                 assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
152         end
153
154         local b1 = bor(desc.opcode,
155                 desc.FIN and 0x80 or 0,
156                 desc.RSV1 and 0x40 or 0,
157                 desc.RSV2 and 0x20 or 0,
158                 desc.RSV3 and 0x10 or 0);
159
160         local b2 = #data;
161         local length_extra;
162         if b2 <= 125 then -- 7-bit length
163                 length_extra = "";
164         elseif b2 <= 0xFFFF then -- 2-byte length
165                 b2 = 126;
166                 length_extra = pack_uint16be(#data);
167         else -- 8-byte length
168                 b2 = 127;
169                 length_extra = pack_uint64be(#data);
170         end
171
172         local key = ""
173         if desc.MASK then
174                 local key_a = desc.key
175                 if key_a then
176                         key = s_char(unpack(key_a, 1, 4));
177                 else
178                         key = random_bytes(4);
179                         key_a = {key:byte(1,4)};
180                 end
181                 b2 = bor(b2, 0x80);
182                 data = apply_mask(data, key_a);
183         end
184
185         return s_char(b1, b2) .. length_extra .. key .. data
186 end
187
188 local function parse_close(data)
189         local code, message
190         if #data >= 2 then
191                 code = read_uint16be(data, 1);
192                 if #data > 2 then
193                         message = s_sub(data, 3);
194                 end
195         end
196         return code, message
197 end
198
199 local function build_close(code, message, mask)
200         local data = pack_uint16be(code);
201         if message then
202                 assert(#message<=123, "Close reason must be <=123 bytes");
203                 data = data .. message;
204         end
205         return build_frame({
206                 opcode = 0x8;
207                 FIN = true;
208                 MASK = mask;
209                 data = data;
210         });
211 end
212
213 return {
214         parse_header = parse_frame_header;
215         parse_body = parse_frame_body;
216         parse = parse_frame;
217         build = build_frame;
218         parse_close = parse_close;
219         build_close = build_close;
220 };