Merge 0.9->0.10
[prosody.git] / net / http / server.lua
1
2 local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
3 local parser_new = require "net.http.parser".new;
4 local events = require "util.events".new();
5 local addserver = require "net.server".addserver;
6 local log = require "util.logger".init("http.server");
7 local os_date = os.date;
8 local pairs = pairs;
9 local s_upper = string.upper;
10 local setmetatable = setmetatable;
11 local xpcall = xpcall;
12 local traceback = debug.traceback;
13 local tostring = tostring;
14 local codes = require "net.http.codes";
15
16 local _M = {};
17
18 local sessions = {};
19 local listener = {};
20 local hosts = {};
21 local default_host;
22
23 local function is_wildcard_event(event)
24         return event:sub(-2, -1) == "/*";
25 end
26 local function is_wildcard_match(wildcard_event, event)
27         return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
28 end
29
30 local recent_wildcard_events, max_cached_wildcard_events = {}, 10000;
31
32 local event_map = events._event_map;
33 setmetatable(events._handlers, {
34         -- Called when firing an event that doesn't exist (but may match a wildcard handler)
35         __index = function (handlers, curr_event)
36                 if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
37                 -- Find all handlers that could match this event, sort them
38                 -- and then put the array into handlers[curr_event] (and return it)
39                 local matching_handlers_set = {};
40                 local handlers_array = {};
41                 for event, handlers_set in pairs(event_map) do
42                         if event == curr_event or
43                         is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
44                                 for handler, priority in pairs(handlers_set) do
45                                         matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
46                                         table.insert(handlers_array, handler);
47                                 end
48                         end
49                 end
50                 if #handlers_array > 0 then
51                         table.sort(handlers_array, function(b, a)
52                                 local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
53                                 for i = 1, #a_score do
54                                         if a_score[i] ~= b_score[i] then -- If equal, compare next score value
55                                                 return a_score[i] < b_score[i];
56                                         end
57                                 end
58                                 return false;
59                         end);
60                 else
61                         handlers_array = false;
62                 end
63                 rawset(handlers, curr_event, handlers_array);
64                 if not event_map[curr_event] then -- Only wildcard handlers match, if any
65                         table.insert(recent_wildcard_events, curr_event);
66                         if #recent_wildcard_events > max_cached_wildcard_events then
67                                 rawset(handlers, table.remove(recent_wildcard_events, 1), nil);
68                         end
69                 end
70                 return handlers_array;
71         end;
72         __newindex = function (handlers, curr_event, handlers_array)
73                 if handlers_array == nil
74                 and is_wildcard_event(curr_event) then
75                         -- Invalidate the indexes of all matching events
76                         for event in pairs(handlers) do
77                                 if is_wildcard_match(curr_event, event) then
78                                         handlers[event] = nil;
79                                 end
80                         end
81                 end
82                 rawset(handlers, curr_event, handlers_array);
83         end;
84 });
85
86 local handle_request;
87 local _1, _2, _3;
88 local function _handle_request() return handle_request(_1, _2, _3); end
89
90 local last_err;
91 local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
92 events.add_handler("http-error", function (error)
93         return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
94 end, -1);
95
96 function listener.onconnect(conn)
97         local secure = conn:ssl() and true or nil;
98         local pending = {};
99         local waiting = false;
100         local function process_next()
101                 if waiting then return; end -- log("debug", "can't process_next, waiting");
102                 waiting = true;
103                 while sessions[conn] and #pending > 0 do
104                         local request = t_remove(pending);
105                         --log("debug", "process_next: %s", request.path);
106                         --handle_request(conn, request, process_next);
107                         _1, _2, _3 = conn, request, process_next;
108                         if not xpcall(_handle_request, _traceback_handler) then
109                                 conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
110                                 conn:close();
111                         end
112                 end
113                 --log("debug", "ready for more");
114                 waiting = false;
115         end
116         local function success_cb(request)
117                 --log("debug", "success_cb: %s", request.path);
118                 if waiting then
119                         log("error", "http connection handler is not reentrant: %s", request.path);
120                         assert(false, "http connection handler is not reentrant");
121                 end
122                 request.secure = secure;
123                 t_insert(pending, request);
124                 process_next();
125         end
126         local function error_cb(err)
127                 log("debug", "error_cb: %s", err or "<nil>");
128                 -- FIXME don't close immediately, wait until we process current stuff
129                 -- FIXME if err, send off a bad-request response
130                 sessions[conn] = nil;
131                 conn:close();
132         end
133         sessions[conn] = parser_new(success_cb, error_cb);
134 end
135
136 function listener.ondisconnect(conn)
137         local open_response = conn._http_open_response;
138         if open_response and open_response.on_destroy then
139                 open_response.finished = true;
140                 open_response:on_destroy();
141         end
142         sessions[conn] = nil;
143 end
144
145 function listener.onincoming(conn, data)
146         sessions[conn]:feed(data);
147 end
148
149 local headerfix = setmetatable({}, {
150         __index = function(t, k)
151                 local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": ";
152                 t[k] = v;
153                 return v;
154         end
155 });
156
157 function _M.hijack_response(response, listener)
158         error("TODO");
159 end
160 function handle_request(conn, request, finish_cb)
161         --log("debug", "handler: %s", request.path);
162         local headers = {};
163         for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end
164         request.headers = headers;
165         request.conn = conn;
166
167         local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
168         local conn_header = request.headers.connection;
169         conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or ""
170         local httpversion = request.httpversion
171         local persistent = conn_header:find(",keep-alive,", 1, true)
172                 or (httpversion == "1.1" and not conn_header:find(",close,", 1, true));
173
174         local response_conn_header;
175         if persistent then
176                 response_conn_header = "Keep-Alive";
177         else
178                 response_conn_header = httpversion == "1.1" and "close" or nil
179         end
180
181         local response = {
182                 request = request;
183                 status_code = 200;
184                 headers = { date = date_header, connection = response_conn_header };
185                 persistent = persistent;
186                 conn = conn;
187                 send = _M.send_response;
188                 done = _M.finish_response;
189                 finish_cb = finish_cb;
190         };
191         conn._http_open_response = response;
192
193         local host = (request.headers.host or ""):match("[^:]+");
194
195         -- Some sanity checking
196         local err_code, err;
197         if not request.path then
198                 err_code, err = 400, "Invalid path";
199         elseif not hosts[host] then
200                 if hosts[default_host] then
201                         host = default_host;
202                 elseif host then
203                         err_code, err = 404, "Unknown host: "..host;
204                 else
205                         err_code, err = 400, "Missing or invalid 'Host' header";
206                 end
207         end
208
209         if err then
210                 response.status_code = err_code;
211                 response:send(events.fire_event("http-error", { code = err_code, message = err }));
212                 return;
213         end
214
215         local event = request.method.." "..host..request.path:match("[^?]*");
216         local payload = { request = request, response = response };
217         --log("debug", "Firing event: %s", event);
218         local result = events.fire_event(event, payload);
219         if result ~= nil then
220                 if result ~= true then
221                         local body;
222                         local result_type = type(result);
223                         if result_type == "number" then
224                                 response.status_code = result;
225                                 if result >= 400 then
226                                         body = events.fire_event("http-error", { code = result });
227                                 end
228                         elseif result_type == "string" then
229                                 body = result;
230                         elseif result_type == "table" then
231                                 for k, v in pairs(result) do
232                                         if k ~= "headers" then
233                                                 response[k] = v;
234                                         else
235                                                 for header_name, header_value in pairs(v) do
236                                                         response.headers[header_name] = header_value;
237                                                 end
238                                         end
239                                 end
240                         end
241                         response:send(body);
242                 end
243                 return;
244         end
245
246         -- if handler not called, return 404
247         response.status_code = 404;
248         response:send(events.fire_event("http-error", { code = 404 }));
249 end
250 local function prepare_header(response)
251         local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
252         local headers = response.headers;
253         local output = { status_line };
254         for k,v in pairs(headers) do
255                 t_insert(output, headerfix[k]..v);
256         end
257         t_insert(output, "\r\n\r\n");
258         return output;
259 end
260 _M.prepare_header = prepare_header;
261 function _M.send_response(response, body)
262         if response.finished then return; end
263         body = body or response.body or "";
264         response.headers.content_length = #body;
265         local output = prepare_header(response);
266         t_insert(output, body);
267         response.conn:write(t_concat(output));
268         response:done();
269 end
270 function _M.finish_response(response)
271         if response.finished then return; end
272         response.finished = true;
273         response.conn._http_open_response = nil;
274         if response.on_destroy then
275                 response:on_destroy();
276                 response.on_destroy = nil;
277         end
278         if response.persistent then
279                 response:finish_cb();
280         else
281                 response.conn:close();
282         end
283 end
284 function _M.add_handler(event, handler, priority)
285         events.add_handler(event, handler, priority);
286 end
287 function _M.remove_handler(event, handler)
288         events.remove_handler(event, handler);
289 end
290
291 function _M.listen_on(port, interface, ssl)
292         addserver(interface or "*", port, listener, "*a", ssl);
293 end
294 function _M.add_host(host)
295         hosts[host] = true;
296 end
297 function _M.remove_host(host)
298         hosts[host] = nil;
299 end
300 function _M.set_default_host(host)
301         default_host = host;
302 end
303 function _M.fire_event(event, ...)
304         return events.fire_event(event, ...);
305 end
306
307 _M.listener = listener;
308 _M.codes = codes;
309 _M._events = events;
310 return _M;