net.dns: Make sure argument to math.randomseed does not overflow a 32 bit unsigned...
[prosody.git] / net / http / server.lua
index feb8f7667fca2e83f976eb4594fbe368098d2714..87d82418550489b2e9955abd3a98572ca7008e43 100644 (file)
@@ -12,48 +12,107 @@ local xpcall = xpcall;
 local debug = debug;
 local tostring = tostring;
 local codes = require "net.http.codes";
-local _G = _G;
 
 local _M = {};
 
 local sessions = {};
-local handlers = {};
-
 local listener = {};
+local hosts = {};
+local default_host;
+
+local function is_wildcard_event(event)
+       return event:sub(-2, -1) == "/*";
+end
+local function is_wildcard_match(wildcard_event, event)
+       return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
+end
+
+local event_map = events._event_map;
+setmetatable(events._handlers, {
+       __index = function (handlers, curr_event)
+               if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
+               -- Find all handlers that could match this event, sort them
+               -- and then put the array into handlers[curr_event] (and return it)
+               local matching_handlers_set = {};
+               local handlers_array = {};
+               for event, handlers_set in pairs(event_map) do
+                       if event == curr_event or
+                       is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
+                               for handler, priority in pairs(handlers_set) do
+                                       matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
+                                       table.insert(handlers_array, handler);
+                               end
+                       end
+               end
+               if #handlers_array > 0 then
+                       table.sort(handlers_array, function(b, a)
+                               local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
+                               for i = 1, #a_score do
+                                       if a_score[i] ~= b_score[i] then -- If equal, compare next score value
+                                               return a_score[i] < b_score[i];
+                                       end
+                               end
+                               return false;
+                       end);
+               else
+                       handlers_array = false;
+               end
+               rawset(handlers, curr_event, handlers_array);
+               return handlers_array;
+       end;
+       __newindex = function (handlers, curr_event, handlers_array)
+               if handlers_array == nil
+               and is_wildcard_event(curr_event) then
+                       -- Invalidate the indexes of all matching events
+                       for event in pairs(handlers) do
+                               if is_wildcard_match(curr_event, event) then
+                                       handlers[event] = nil;
+                               end
+                       end
+               end
+               rawset(handlers, curr_event, handlers_array);
+       end;
+});
 
 local handle_request;
 local _1, _2, _3;
 local function _handle_request() return handle_request(_1, _2, _3); end
-local function _traceback_handler(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug.traceback()); end
+
+local last_err;
+local function _traceback_handler(err) last_err = err; log("error", "Traceback[http]: %s: %s", tostring(err), debug.traceback()); end
+events.add_handler("http-error", function (error)
+       return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
+end, -1);
 
 function listener.onconnect(conn)
        local secure = conn:ssl() and true or nil;
        local pending = {};
        local waiting = false;
-       local function process_next(last_response)
-               --if waiting then log("debug", "can't process_next, waiting"); return; end
-               if sessions[conn] and #pending > 0 then
+       local function process_next()
+               if waiting then log("debug", "can't process_next, waiting"); return; end
+               waiting = true;
+               while sessions[conn] and #pending > 0 do
                        local request = t_remove(pending);
                        --log("debug", "process_next: %s", request.path);
-                       waiting = true;
                        --handle_request(conn, request, process_next);
                        _1, _2, _3 = conn, request, process_next;
                        if not xpcall(_handle_request, _traceback_handler) then
-                               conn:write("HTTP/1.0 503 Internal Server Error\r\n\r\nAn error occured during the processing of this request.");
+                               conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
                                conn:close();
                        end
-               else
-                       --log("debug", "ready for more");
-                       waiting = false;
                end
+               --log("debug", "ready for more");
+               waiting = false;
        end
        local function success_cb(request)
                --log("debug", "success_cb: %s", request.path);
+               if waiting then
+                       log("error", "http connection handler is not reentrant: %s", request.path);
+                       assert(false, "http connection handler is not reentrant");
+               end
                request.secure = secure;
                t_insert(pending, request);
-               if not waiting then
-                       process_next();
-               end
+               process_next();
        end
        local function error_cb(err)
                log("debug", "error_cb: %s", err or "<nil>");
@@ -66,6 +125,11 @@ function listener.onconnect(conn)
 end
 
 function listener.ondisconnect(conn)
+       local open_response = conn._http_open_response;
+       if open_response and open_response.on_destroy then
+               open_response.finished = true;
+               open_response:on_destroy();
+       end
        sessions[conn] = nil;
 end
 
@@ -93,52 +157,88 @@ function handle_request(conn, request, finish_cb)
 
        local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
        local conn_header = request.headers.connection;
-       local keep_alive = conn_header == "Keep-Alive" or (request.httpversion == "1.1" and conn_header ~= "close");
+       conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or ""
+       local httpversion = request.httpversion
+       local persistent = conn_header:find(",keep-alive,", 1, true)
+               or (httpversion == "1.1" and not conn_header:find(",close,", 1, true));
+
+       local response_conn_header;
+       if persistent then
+               response_conn_header = "Keep-Alive";
+       else
+               response_conn_header = httpversion == "1.1" and "close" or nil
+       end
 
        local response = {
                request = request;
                status_code = 200;
-               headers = { date = date_header, connection = (keep_alive and "Keep-Alive" or "close") };
+               headers = { date = date_header, connection = response_conn_header };
+               persistent = persistent;
                conn = conn;
                send = _M.send_response;
                finish_cb = finish_cb;
        };
+       conn._http_open_response = response;
 
-       if not request.headers.host then
-               response.status_code = 400;
-               response.headers.content_type = "text/html";
-               response:send("<html><head>400 Bad Request</head><body>400 Bad Request: No Host header.</body></html>");
-       else
-               -- TODO call handler
-               --response.headers.content_type = "text/plain";
-               --response:send("host="..(request.headers.host or "").."\npath="..request.path.."\n"..(request.body or ""));
-               local host = request.headers.host;
-               if host then
-                       host = host:match("[^:]*"):lower();
-                       local event = request.method.." "..host..request.path:match("[^?]*");
-                       local payload = { request = request, response = response };
-                       --[[repeat
-                               if events.fire_event(event, payload) ~= nil then return; end
-                               event = (event:sub(-1) == "/") and event:sub(1, -1) or event:gsub("[^/]*$", "");
-                               if event:sub(-1) == "/" then
-                                       event = event:sub(1, -1);
-                               else
-                                       event = event:gsub("[^/]*$", "");
-                               end
-                       until not event:find("/", 1, true);]]
-                       --log("debug", "Event: %s", event);
-                       if events.fire_event(event, payload) ~= nil then return; end
-                       -- TODO try adding/stripping / at the end, but this needs to work via an HTTP redirect
+       local host = (request.headers.host or ""):match("[^:]+");
+
+       -- Some sanity checking
+       local err_code, err;
+       if not request.path then
+               err_code, err = 400, "Invalid path";
+       elseif not hosts[host] then
+               if hosts[default_host] then
+                       host = default_host;
+               elseif host then
+                       err_code, err = 404, "Unknown host: "..host;
+               else
+                       err_code, err = 400, "Missing or invalid 'Host' header";
                end
+       end
+       
+       if err then
+               response.status_code = err_code;
+               response:send(events.fire_event("http-error", { code = err_code, message = err }));
+               return;
+       end
 
-               -- if handler not called, fallback to legacy httpserver handlers
-               _M.legacy_handler(request, response);
+       local event = request.method.." "..host..request.path:match("[^?]*");
+       local payload = { request = request, response = response };
+       --log("debug", "Firing event: %s", event);
+       local result = events.fire_event(event, payload);
+       if result ~= nil then
+               if result ~= true then
+                       local body;
+                       local result_type = type(result);
+                       if result_type == "number" then
+                               response.status_code = result;
+                               if result >= 400 then
+                                       body = events.fire_event("http-error", { code = result });
+                               end
+                       elseif result_type == "string" then
+                               body = result;
+                       elseif result_type == "table" then
+                               for k, v in pairs(result) do
+                                       response[k] = v;
+                               end
+                       end
+                       response:send(body);
+               end
+               return;
        end
+
+       -- if handler not called, return 404
+       response.status_code = 404;
+       response:send(events.fire_event("http-error", { code = 404 }));
 end
 function _M.send_response(response, body)
+       if response.finished then return; end
+       response.finished = true;
+       response.conn._http_open_response = nil;
+       
        local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
        local headers = response.headers;
-       body = body or "";
+       body = body or response.body or "";
        headers.content_length = #body;
 
        local output = { status_line };
@@ -149,64 +249,16 @@ function _M.send_response(response, body)
        t_insert(output, body);
 
        response.conn:write(t_concat(output));
-       if headers.connection == "Keep-Alive" then
+       if response.on_destroy then
+               response:on_destroy();
+               response.on_destroy = nil;
+       end
+       if response.persistent then
                response:finish_cb();
        else
                response.conn:close();
        end
 end
-function _M.legacy_handler(request, response)
-       log("debug", "Invoking legacy handler");
-       local base = request.path:match("^/([^/?]+)");
-       local legacy_server = _G.httpserver and _G.httpserver.new.http_servers[5280];
-       local handler = legacy_server and legacy_server.handlers[base];
-       if not handler then handler = _G.httpserver and _G.httpserver.set_default_handler.default_handler; end
-       if handler then
-               -- add legacy properties to request object
-               request.url = { path = request.path };
-               request.handler = response.conn;
-               request.id = tostring{}:match("%x+$");
-               local headers = {};
-               for k,v in pairs(request.headers) do
-                       headers[k:gsub("_", "-")] = v;
-               end
-               request.headers = headers;
-               function request:send(resp)
-                       if self.destroyed then return; end
-                       if resp.body or resp.headers then
-                               if resp.headers then
-                                       for k,v in pairs(resp.headers) do response.headers[k] = v; end
-                               end
-                               response:send(resp.body)
-                       else
-                               response:send(resp)
-                       end
-                       self.sent = true;
-                       self:destroy();
-               end
-               function request:destroy()
-                       if self.destroyed then return; end
-                       if not self.sent then return self:send(""); end
-                       self.destroyed = true;
-                       if self.on_destroy then
-                               log("debug", "Request has destroy callback");
-                               self:on_destroy();
-                       else
-                               log("debug", "Request has no destroy callback");
-                       end
-               end
-               local r = handler(request.method, request.body, request);
-               if r ~= true then
-                       request:send(r);
-               end
-       else
-               log("debug", "No handler found");
-               response.status_code = 404;
-               response.headers.content_type = "text/html";
-               response:send("<html><head><title>404 Not Found</title></head><body>404 Not Found: No such page.</body></html>");
-       end
-end
-
 function _M.add_handler(event, handler, priority)
        events.add_handler(event, handler, priority);
 end
@@ -217,7 +269,17 @@ end
 function _M.listen_on(port, interface, ssl)
        addserver(interface or "*", port, listener, "*a", ssl);
 end
+function _M.add_host(host)
+       hosts[host] = true;
+end
+function _M.remove_host(host)
+       hosts[host] = nil;
+end
+function _M.set_default_host(host)
+       default_host = host;
+end
 
 _M.listener = listener;
 _M.codes = codes;
+_M._events = events;
 return _M;