aeaa7416065f1bf646d5c13b6ce9d511b2317d84
[prosody.git] / net / http / server.lua
1
2 local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
3 local parser_new = require "net.http.parser".new;
4 local events = require "util.events".new();
5 local addserver = require "net.server".addserver;
6 local log = require "util.logger".init("http.server");
7 local os_date = os.date;
8 local pairs = pairs;
9 local s_upper = string.upper;
10 local setmetatable = setmetatable;
11 local xpcall = xpcall;
12 local traceback = debug.traceback;
13 local tostring = tostring;
14 local cache = require "util.cache";
15 local codes = require "net.http.codes";
16
17 local _M = {};
18
19 local sessions = {};
20 local listener = {};
21 local hosts = {};
22 local default_host;
23
24 local function is_wildcard_event(event)
25         return event:sub(-2, -1) == "/*";
26 end
27 local function is_wildcard_match(wildcard_event, event)
28         return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
29 end
30
31 local _handlers = events._handlers;
32 local recent_wildcard_events = cache.new(10000, function (key, value)
33         rawset(_handlers, key, nil);
34 end);
35
36 local event_map = events._event_map;
37 setmetatable(events._handlers, {
38         -- Called when firing an event that doesn't exist (but may match a wildcard handler)
39         __index = function (handlers, curr_event)
40                 if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
41                 -- Find all handlers that could match this event, sort them
42                 -- and then put the array into handlers[curr_event] (and return it)
43                 local matching_handlers_set = {};
44                 local handlers_array = {};
45                 for event, handlers_set in pairs(event_map) do
46                         if event == curr_event or
47                         is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
48                                 for handler, priority in pairs(handlers_set) do
49                                         matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
50                                         table.insert(handlers_array, handler);
51                                 end
52                         end
53                 end
54                 if #handlers_array > 0 then
55                         table.sort(handlers_array, function(b, a)
56                                 local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
57                                 for i = 1, #a_score do
58                                         if a_score[i] ~= b_score[i] then -- If equal, compare next score value
59                                                 return a_score[i] < b_score[i];
60                                         end
61                                 end
62                                 return false;
63                         end);
64                 else
65                         handlers_array = false;
66                 end
67                 rawset(handlers, curr_event, handlers_array);
68                 if not event_map[curr_event] then -- Only wildcard handlers match, if any
69                         recent_wildcard_events:set(curr_event, true);
70                 end
71                 return handlers_array;
72         end;
73         __newindex = function (handlers, curr_event, handlers_array)
74                 if handlers_array == nil
75                 and is_wildcard_event(curr_event) then
76                         -- Invalidate the indexes of all matching events
77                         for event in pairs(handlers) do
78                                 if is_wildcard_match(curr_event, event) then
79                                         handlers[event] = nil;
80                                 end
81                         end
82                 end
83                 rawset(handlers, curr_event, handlers_array);
84         end;
85 });
86
87 local handle_request;
88 local _1, _2, _3;
89 local function _handle_request() return handle_request(_1, _2, _3); end
90
91 local last_err;
92 local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
93 events.add_handler("http-error", function (error)
94         return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
95 end, -1);
96
97 function listener.onconnect(conn)
98         local secure = conn:ssl() and true or nil;
99         local pending = {};
100         local waiting = false;
101         local function process_next()
102                 if waiting then return; end -- log("debug", "can't process_next, waiting");
103                 waiting = true;
104                 while sessions[conn] and #pending > 0 do
105                         local request = t_remove(pending);
106                         --log("debug", "process_next: %s", request.path);
107                         --handle_request(conn, request, process_next);
108                         _1, _2, _3 = conn, request, process_next;
109                         if not xpcall(_handle_request, _traceback_handler) then
110                                 conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
111                                 conn:close();
112                         end
113                 end
114                 --log("debug", "ready for more");
115                 waiting = false;
116         end
117         local function success_cb(request)
118                 --log("debug", "success_cb: %s", request.path);
119                 if waiting then
120                         log("error", "http connection handler is not reentrant: %s", request.path);
121                         assert(false, "http connection handler is not reentrant");
122                 end
123                 request.secure = secure;
124                 t_insert(pending, request);
125                 process_next();
126         end
127         local function error_cb(err)
128                 log("debug", "error_cb: %s", err or "<nil>");
129                 -- FIXME don't close immediately, wait until we process current stuff
130                 -- FIXME if err, send off a bad-request response
131                 sessions[conn] = nil;
132                 conn:close();
133         end
134         sessions[conn] = parser_new(success_cb, error_cb);
135 end
136
137 function listener.ondisconnect(conn)
138         local open_response = conn._http_open_response;
139         if open_response and open_response.on_destroy then
140                 open_response.finished = true;
141                 open_response:on_destroy();
142         end
143         sessions[conn] = nil;
144 end
145
146 function listener.ondetach(conn)
147         sessions[conn] = nil;
148 end
149
150 function listener.onincoming(conn, data)
151         sessions[conn]:feed(data);
152 end
153
154 local headerfix = setmetatable({}, {
155         __index = function(t, k)
156                 local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": ";
157                 t[k] = v;
158                 return v;
159         end
160 });
161
162 function _M.hijack_response(response, listener)
163         error("TODO");
164 end
165 function handle_request(conn, request, finish_cb)
166         --log("debug", "handler: %s", request.path);
167         local headers = {};
168         for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end
169         request.headers = headers;
170         request.conn = conn;
171
172         local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
173         local conn_header = request.headers.connection;
174         conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or ""
175         local httpversion = request.httpversion
176         local persistent = conn_header:find(",keep-alive,", 1, true)
177                 or (httpversion == "1.1" and not conn_header:find(",close,", 1, true));
178
179         local response_conn_header;
180         if persistent then
181                 response_conn_header = "Keep-Alive";
182         else
183                 response_conn_header = httpversion == "1.1" and "close" or nil
184         end
185
186         local response = {
187                 request = request;
188                 status_code = 200;
189                 headers = { date = date_header, connection = response_conn_header };
190                 persistent = persistent;
191                 conn = conn;
192                 send = _M.send_response;
193                 done = _M.finish_response;
194                 finish_cb = finish_cb;
195         };
196         conn._http_open_response = response;
197
198         local host = (request.headers.host or ""):match("[^:]+");
199
200         -- Some sanity checking
201         local err_code, err;
202         if not request.path then
203                 err_code, err = 400, "Invalid path";
204         elseif not hosts[host] then
205                 if hosts[default_host] then
206                         host = default_host;
207                 elseif host then
208                         err_code, err = 404, "Unknown host: "..host;
209                 else
210                         err_code, err = 400, "Missing or invalid 'Host' header";
211                 end
212         end
213
214         if err then
215                 response.status_code = err_code;
216                 response:send(events.fire_event("http-error", { code = err_code, message = err }));
217                 return;
218         end
219
220         local event = request.method.." "..host..request.path:match("[^?]*");
221         local payload = { request = request, response = response };
222         log("debug", "Firing event: %s", event);
223         local result = events.fire_event(event, payload);
224         if result ~= nil then
225                 if result ~= true then
226                         local body;
227                         local result_type = type(result);
228                         if result_type == "number" then
229                                 response.status_code = result;
230                                 if result >= 400 then
231                                         body = events.fire_event("http-error", { code = result });
232                                 end
233                         elseif result_type == "string" then
234                                 body = result;
235                         elseif result_type == "table" then
236                                 for k, v in pairs(result) do
237                                         if k ~= "headers" then
238                                                 response[k] = v;
239                                         else
240                                                 for header_name, header_value in pairs(v) do
241                                                         response.headers[header_name] = header_value;
242                                                 end
243                                         end
244                                 end
245                         end
246                         response:send(body);
247                 end
248                 return;
249         end
250
251         -- if handler not called, return 404
252         response.status_code = 404;
253         response:send(events.fire_event("http-error", { code = 404 }));
254 end
255 local function prepare_header(response)
256         local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
257         local headers = response.headers;
258         local output = { status_line };
259         for k,v in pairs(headers) do
260                 t_insert(output, headerfix[k]..v);
261         end
262         t_insert(output, "\r\n\r\n");
263         return output;
264 end
265 _M.prepare_header = prepare_header;
266 function _M.send_response(response, body)
267         if response.finished then return; end
268         body = body or response.body or "";
269         response.headers.content_length = #body;
270         local output = prepare_header(response);
271         t_insert(output, body);
272         response.conn:write(t_concat(output));
273         response:done();
274 end
275 function _M.finish_response(response)
276         if response.finished then return; end
277         response.finished = true;
278         response.conn._http_open_response = nil;
279         if response.on_destroy then
280                 response:on_destroy();
281                 response.on_destroy = nil;
282         end
283         if response.persistent then
284                 response:finish_cb();
285         else
286                 response.conn:close();
287         end
288 end
289 function _M.add_handler(event, handler, priority)
290         events.add_handler(event, handler, priority);
291 end
292 function _M.remove_handler(event, handler)
293         events.remove_handler(event, handler);
294 end
295
296 function _M.listen_on(port, interface, ssl)
297         return addserver(interface or "*", port, listener, "*a", ssl);
298 end
299 function _M.add_host(host)
300         hosts[host] = true;
301 end
302 function _M.remove_host(host)
303         hosts[host] = nil;
304 end
305 function _M.set_default_host(host)
306         default_host = host;
307 end
308 function _M.fire_event(event, ...)
309         return events.fire_event(event, ...);
310 end
311
312 _M.listener = listener;
313 _M.codes = codes;
314 _M._events = events;
315 return _M;