net.websocket.frames: Simplify import of bitlib
[prosody.git] / net / websocket / frames.lua
1 -- Prosody IM
2 -- Copyright (C) 2012 Florian Zeitz
3 -- Copyright (C) 2014 Daurnimator
4 --
5 -- This project is MIT/X11 licensed. Please see the
6 -- COPYING file in the source package for more information.
7 --
8
9 local softreq = require "util.dependencies".softreq;
10 local log = require "util.logger".init "websocket.frames";
11 local random_bytes = require "util.random".bytes;
12
13 local bit = softreq"bit" or softreq"bit32";
14 if not bit then log("error", "No bit module found. Either LuaJIT 2, lua-bitop or Lua 5.2 is required"); end
15 local band = bit.band;
16 local bor = bit.bor;
17 local bxor = bit.bxor;
18 local lshift = bit.lshift;
19 local rshift = bit.rshift;
20
21 local t_concat = table.concat;
22 local s_byte = string.byte;
23 local s_char= string.char;
24 local s_sub = string.sub;
25
26 local function read_uint16be(str, pos)
27         local l1, l2 = s_byte(str, pos, pos+1);
28         return l1*256 + l2;
29 end
30 -- FIXME: this may lose precision
31 local function read_uint64be(str, pos)
32         local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
33         return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32)
34                 + lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
35 end
36 local function pack_uint16be(x)
37         return s_char(rshift(x, 8), band(x, 0xFF));
38 end
39 local function get_byte(x, n)
40         return band(rshift(x, n), 0xFF);
41 end
42 local function pack_uint64be(x)
43         return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32),
44                 get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
45 end
46
47 local function parse_frame_header(frame)
48         if #frame < 2 then return; end
49
50         local byte1, byte2 = s_byte(frame, 1, 2);
51         local result = {
52                 FIN = band(byte1, 0x80) > 0;
53                 RSV1 = band(byte1, 0x40) > 0;
54                 RSV2 = band(byte1, 0x20) > 0;
55                 RSV3 = band(byte1, 0x10) > 0;
56                 opcode = band(byte1, 0x0F);
57
58                 MASK = band(byte2, 0x80) > 0;
59                 length = band(byte2, 0x7F);
60         };
61
62         local length_bytes = 0;
63         if result.length == 126 then
64                 length_bytes = 2;
65         elseif result.length == 127 then
66                 length_bytes = 8;
67         end
68
69         local header_length = 2 + length_bytes + (result.MASK and 4 or 0);
70         if #frame < header_length then return; end
71
72         if length_bytes == 2 then
73                 result.length = read_uint16be(frame, 3);
74         elseif length_bytes == 8 then
75                 result.length = read_uint64be(frame, 3);
76         end
77
78         if result.MASK then
79                 result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
80         end
81
82         return result, header_length;
83 end
84
85 -- XORs the string `str` with the array of bytes `key`
86 -- TODO: optimize
87 local function apply_mask(str, key, from, to)
88         from = from or 1
89         if from < 0 then from = #str + from + 1 end -- negative indicies
90         to = to or #str
91         if to < 0 then to = #str + to + 1 end -- negative indicies
92         local key_len = #key
93         local counter = 0;
94         local data = {};
95         for i = from, to do
96                 local key_index = counter%key_len + 1;
97                 counter = counter + 1;
98                 data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
99         end
100         return t_concat(data);
101 end
102
103 local function parse_frame_body(frame, header, pos)
104         if header.MASK then
105                 return apply_mask(frame, header.key, pos, pos + header.length - 1);
106         else
107                 return frame:sub(pos, pos + header.length - 1);
108         end
109 end
110
111 local function parse_frame(frame)
112         local result, pos = parse_frame_header(frame);
113         if result == nil or #frame < (pos + result.length) then return; end
114         result.data = parse_frame_body(frame, result, pos+1);
115         return result, pos + result.length;
116 end
117
118 local function build_frame(desc)
119         local data = desc.data or "";
120
121         assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode");
122         if desc.opcode >= 0x8 then
123                 -- RFC 6455 5.5
124                 assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
125         end
126
127         local b1 = bor(desc.opcode,
128                 desc.FIN and 0x80 or 0,
129                 desc.RSV1 and 0x40 or 0,
130                 desc.RSV2 and 0x20 or 0,
131                 desc.RSV3 and 0x10 or 0);
132
133         local b2 = #data;
134         local length_extra;
135         if b2 <= 125 then -- 7-bit length
136                 length_extra = "";
137         elseif b2 <= 0xFFFF then -- 2-byte length
138                 b2 = 126;
139                 length_extra = pack_uint16be(#data);
140         else -- 8-byte length
141                 b2 = 127;
142                 length_extra = pack_uint64be(#data);
143         end
144
145         local key = ""
146         if desc.MASK then
147                 local key_a = desc.key
148                 if key_a then
149                         key = s_char(unpack(key_a, 1, 4));
150                 else
151                         key = random_bytes(4);
152                         key_a = {key:byte(1,4)};
153                 end
154                 b2 = bor(b2, 0x80);
155                 data = apply_mask(data, key_a);
156         end
157
158         return s_char(b1, b2) .. length_extra .. key .. data
159 end
160
161 local function parse_close(data)
162         local code, message
163         if #data >= 2 then
164                 code = read_uint16be(data, 1);
165                 if #data > 2 then
166                         message = s_sub(data, 3);
167                 end
168         end
169         return code, message
170 end
171
172 local function build_close(code, message, mask)
173         local data = pack_uint16be(code);
174         if message then
175                 assert(#message<=123, "Close reason must be <=123 bytes");
176                 data = data .. message;
177         end
178         return build_frame({
179                 opcode = 0x8;
180                 FIN = true;
181                 MASK = mask;
182                 data = data;
183         });
184 end
185
186 return {
187         parse_header = parse_frame_header;
188         parse_body = parse_frame_body;
189         parse = parse_frame;
190         build = build_frame;
191         parse_close = parse_close;
192         build_close = build_close;
193 };