net.http.server: Add response method for reading response body from a file handle
[prosody.git] / net / http / server.lua
1
2 local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
3 local parser_new = require "net.http.parser".new;
4 local events = require "util.events".new();
5 local addserver = require "net.server".addserver;
6 local log = require "util.logger".init("http.server");
7 local os_date = os.date;
8 local pairs = pairs;
9 local s_upper = string.upper;
10 local setmetatable = setmetatable;
11 local xpcall = xpcall;
12 local traceback = debug.traceback;
13 local tostring = tostring;
14 local cache = require "util.cache";
15 local codes = require "net.http.codes";
16 local blocksize = require "socket".BLOCKSIZE or 2048;
17
18 local _M = {};
19
20 local sessions = {};
21 local incomplete = {};
22 local listener = {};
23 local hosts = {};
24 local default_host;
25
26 local function is_wildcard_event(event)
27         return event:sub(-2, -1) == "/*";
28 end
29 local function is_wildcard_match(wildcard_event, event)
30         return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
31 end
32
33 local _handlers = events._handlers;
34 local recent_wildcard_events = cache.new(10000, function (key, value)
35         rawset(_handlers, key, nil);
36 end);
37
38 local event_map = events._event_map;
39 setmetatable(events._handlers, {
40         -- Called when firing an event that doesn't exist (but may match a wildcard handler)
41         __index = function (handlers, curr_event)
42                 if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
43                 -- Find all handlers that could match this event, sort them
44                 -- and then put the array into handlers[curr_event] (and return it)
45                 local matching_handlers_set = {};
46                 local handlers_array = {};
47                 for event, handlers_set in pairs(event_map) do
48                         if event == curr_event or
49                         is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
50                                 for handler, priority in pairs(handlers_set) do
51                                         matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
52                                         table.insert(handlers_array, handler);
53                                 end
54                         end
55                 end
56                 if #handlers_array > 0 then
57                         table.sort(handlers_array, function(b, a)
58                                 local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
59                                 for i = 1, #a_score do
60                                         if a_score[i] ~= b_score[i] then -- If equal, compare next score value
61                                                 return a_score[i] < b_score[i];
62                                         end
63                                 end
64                                 return false;
65                         end);
66                 else
67                         handlers_array = false;
68                 end
69                 rawset(handlers, curr_event, handlers_array);
70                 if not event_map[curr_event] then -- Only wildcard handlers match, if any
71                         recent_wildcard_events:set(curr_event, true);
72                 end
73                 return handlers_array;
74         end;
75         __newindex = function (handlers, curr_event, handlers_array)
76                 if handlers_array == nil
77                 and is_wildcard_event(curr_event) then
78                         -- Invalidate the indexes of all matching events
79                         for event in pairs(handlers) do
80                                 if is_wildcard_match(curr_event, event) then
81                                         handlers[event] = nil;
82                                 end
83                         end
84                 end
85                 rawset(handlers, curr_event, handlers_array);
86         end;
87 });
88
89 local handle_request;
90 local _1, _2, _3;
91 local function _handle_request() return handle_request(_1, _2, _3); end
92
93 local last_err;
94 local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
95 events.add_handler("http-error", function (error)
96         return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
97 end, -1);
98
99 function listener.onconnect(conn)
100         local secure = conn:ssl() and true or nil;
101         local pending = {};
102         local waiting = false;
103         local function process_next()
104                 if waiting then return; end -- log("debug", "can't process_next, waiting");
105                 waiting = true;
106                 while sessions[conn] and #pending > 0 do
107                         local request = t_remove(pending);
108                         --log("debug", "process_next: %s", request.path);
109                         --handle_request(conn, request, process_next);
110                         _1, _2, _3 = conn, request, process_next;
111                         if not xpcall(_handle_request, _traceback_handler) then
112                                 conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
113                                 conn:close();
114                         end
115                 end
116                 --log("debug", "ready for more");
117                 waiting = false;
118         end
119         local function success_cb(request)
120                 --log("debug", "success_cb: %s", request.path);
121                 if waiting then
122                         log("error", "http connection handler is not reentrant: %s", request.path);
123                         assert(false, "http connection handler is not reentrant");
124                 end
125                 request.secure = secure;
126                 t_insert(pending, request);
127                 process_next();
128         end
129         local function error_cb(err)
130                 log("debug", "error_cb: %s", err or "<nil>");
131                 -- FIXME don't close immediately, wait until we process current stuff
132                 -- FIXME if err, send off a bad-request response
133                 sessions[conn] = nil;
134                 conn:close();
135         end
136         sessions[conn] = parser_new(success_cb, error_cb);
137 end
138
139 function listener.ondisconnect(conn)
140         local open_response = conn._http_open_response;
141         if open_response and open_response.on_destroy then
142                 open_response.finished = true;
143                 open_response:on_destroy();
144         end
145         incomplete[conn] = nil;
146         sessions[conn] = nil;
147 end
148
149 function listener.ondetach(conn)
150         sessions[conn] = nil;
151         incomplete[conn] = nil;
152 end
153
154 function listener.onincoming(conn, data)
155         sessions[conn]:feed(data);
156 end
157
158 function listener.ondrain(conn)
159         local response = incomplete[conn];
160         if response and response._send_more then
161                 response._send_more();
162         end
163 end
164
165 local headerfix = setmetatable({}, {
166         __index = function(t, k)
167                 local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": ";
168                 t[k] = v;
169                 return v;
170         end
171 });
172
173 function _M.hijack_response(response, listener)
174         error("TODO");
175 end
176 function handle_request(conn, request, finish_cb)
177         --log("debug", "handler: %s", request.path);
178         local headers = {};
179         for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end
180         request.headers = headers;
181         request.conn = conn;
182
183         local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
184         local conn_header = request.headers.connection;
185         conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or ""
186         local httpversion = request.httpversion
187         local persistent = conn_header:find(",keep-alive,", 1, true)
188                 or (httpversion == "1.1" and not conn_header:find(",close,", 1, true));
189
190         local response_conn_header;
191         if persistent then
192                 response_conn_header = "Keep-Alive";
193         else
194                 response_conn_header = httpversion == "1.1" and "close" or nil
195         end
196
197         local response = {
198                 request = request;
199                 status_code = 200;
200                 headers = { date = date_header, connection = response_conn_header };
201                 persistent = persistent;
202                 conn = conn;
203                 send = _M.send_response;
204                 send_file = _M.send_file;
205                 done = _M.finish_response;
206                 finish_cb = finish_cb;
207         };
208         conn._http_open_response = response;
209
210         local host = (request.headers.host or ""):match("[^:]+");
211
212         -- Some sanity checking
213         local err_code, err;
214         if not request.path then
215                 err_code, err = 400, "Invalid path";
216         elseif not hosts[host] then
217                 if hosts[default_host] then
218                         host = default_host;
219                 elseif host then
220                         err_code, err = 404, "Unknown host: "..host;
221                 else
222                         err_code, err = 400, "Missing or invalid 'Host' header";
223                 end
224         end
225
226         if err then
227                 response.status_code = err_code;
228                 response:send(events.fire_event("http-error", { code = err_code, message = err }));
229                 return;
230         end
231
232         local event = request.method.." "..host..request.path:match("[^?]*");
233         local payload = { request = request, response = response };
234         log("debug", "Firing event: %s", event);
235         local result = events.fire_event(event, payload);
236         if result ~= nil then
237                 if result ~= true then
238                         local body;
239                         local result_type = type(result);
240                         if result_type == "number" then
241                                 response.status_code = result;
242                                 if result >= 400 then
243                                         body = events.fire_event("http-error", { code = result });
244                                 end
245                         elseif result_type == "string" then
246                                 body = result;
247                         elseif result_type == "table" then
248                                 for k, v in pairs(result) do
249                                         if k ~= "headers" then
250                                                 response[k] = v;
251                                         else
252                                                 for header_name, header_value in pairs(v) do
253                                                         response.headers[header_name] = header_value;
254                                                 end
255                                         end
256                                 end
257                         end
258                         response:send(body);
259                 end
260                 return;
261         end
262
263         -- if handler not called, return 404
264         response.status_code = 404;
265         response:send(events.fire_event("http-error", { code = 404 }));
266 end
267 local function prepare_header(response)
268         local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
269         local headers = response.headers;
270         local output = { status_line };
271         for k,v in pairs(headers) do
272                 t_insert(output, headerfix[k]..v);
273         end
274         t_insert(output, "\r\n\r\n");
275         return output;
276 end
277 _M.prepare_header = prepare_header;
278 function _M.send_response(response, body)
279         if response.finished then return; end
280         body = body or response.body or "";
281         response.headers.content_length = #body;
282         local output = prepare_header(response);
283         t_insert(output, body);
284         response.conn:write(t_concat(output));
285         response:done();
286 end
287 function _M.send_file(response, f)
288         if response.finished then return; end
289         local chunked = not response.headers.content_length;
290         if chunked then response.headers.transfer_encoding = "chunked"; end
291         incomplete[response.conn] = response;
292         response._send_more = function ()
293                 if response.finished then
294                         incomplete[response.conn] = nil;
295                         return;
296                 end
297                 local chunk = f:read(blocksize);
298                 if chunk then
299                         if chunked then
300                                 chunk = ("%x\r\n%s\r\n"):format(#chunk, chunk);
301                         end
302                         -- io.write("."); io.flush();
303                         response.conn:write(chunk);
304                 else
305                         if chunked then
306                                 response.conn:write("0\r\n\r\n");
307                         end
308                         -- io.write("\n");
309                         if f.close then f:close(); end
310                         incomplete[response.conn] = nil;
311                         return response:done();
312                 end
313         end
314         response.conn:write(t_concat(prepare_header(response)));
315         return true;
316 end
317 function _M.finish_response(response)
318         if response.finished then return; end
319         response.finished = true;
320         response.conn._http_open_response = nil;
321         if response.on_destroy then
322                 response:on_destroy();
323                 response.on_destroy = nil;
324         end
325         if response.persistent then
326                 response:finish_cb();
327         else
328                 response.conn:close();
329         end
330 end
331 function _M.add_handler(event, handler, priority)
332         events.add_handler(event, handler, priority);
333 end
334 function _M.remove_handler(event, handler)
335         events.remove_handler(event, handler);
336 end
337
338 function _M.listen_on(port, interface, ssl)
339         return addserver(interface or "*", port, listener, "*a", ssl);
340 end
341 function _M.add_host(host)
342         hosts[host] = true;
343 end
344 function _M.remove_host(host)
345         hosts[host] = nil;
346 end
347 function _M.set_default_host(host)
348         default_host = host;
349 end
350 function _M.fire_event(event, ...)
351         return events.fire_event(event, ...);
352 end
353
354 _M.listener = listener;
355 _M.codes = codes;
356 _M._events = events;
357 return _M;