mod_dialback: Correctly check if a connection was destroyed (thanks iron)
[prosody.git] / net / http / server.lua
index 978a5a820076c8d68c792e26279be059dc504510..69908e4e3c4ab8cfa795d1f86f2869b9ef1032cc 100644 (file)
@@ -12,14 +12,13 @@ local xpcall = xpcall;
 local debug = debug;
 local tostring = tostring;
 local codes = require "net.http.codes";
-local _G = _G;
-local legacy_httpserver = require "net.httpserver";
 
 local _M = {};
 
 local sessions = {};
-
 local listener = {};
+local hosts = {};
+local default_host;
 
 local function is_wildcard_event(event)
        return event:sub(-2, -1) == "/*";
@@ -89,7 +88,7 @@ function listener.onconnect(conn)
        local secure = conn:ssl() and true or nil;
        local pending = {};
        local waiting = false;
-       local function process_next(last_response)
+       local function process_next()
                --if waiting then log("debug", "can't process_next, waiting"); return; end
                if sessions[conn] and #pending > 0 then
                        local request = t_remove(pending);
@@ -169,54 +168,56 @@ function handle_request(conn, request, finish_cb)
        };
        conn._http_open_response = response;
 
-       local err;
-       if not request.headers.host then
-               err = "No 'Host' header";
-       elseif not request.path then
-               err = "Invalid path";
+       local host = (request.headers.host or ""):match("[^:]+");
+
+       -- Some sanity checking
+       local err_code, err;
+       if not request.path then
+               err_code, err = 400, "Invalid path";
+       elseif not hosts[host] then
+               if hosts[default_host] then
+                       host = default_host;
+               elseif host then
+                       err_code, err = 404, "Unknown host: "..host;
+               else
+                       err_code, err = 400, "Missing or invalid 'Host' header";
+               end
        end
        
        if err then
-               response.status_code = 400;
-               response.headers.content_type = "text/html";
-               response:send(events.fire_event("http-error", { code = 400, message = err }));
-       else
-               local host = request.headers.host;
-               if host then
-                       host = host:match("[^:]*"):lower();
-                       local event = request.method.." "..host..request.path:match("[^?]*");
-                       local payload = { request = request, response = response };
-                       --log("debug", "Firing event: %s", event);
-                       local result = events.fire_event(event, payload);
-                       if result ~= nil then
-                               if result ~= true then
-                                       local code, body = 200, "";
-                                       local result_type = type(result);
-                                       if result_type == "number" then
-                                               response.status_code = result;
-                                               if result >= 400 then
-                                                       body = events.fire_event("http-error", { code = result });
-                                               end
-                                       elseif result_type == "string" then
-                                               body = result;
-                                       elseif result_type == "table" then
-                                               body = result.body;
-                                               result.body = nil;
-                                               for k, v in pairs(result) do
-                                                       response[k] = v;
-                                               end
-                                       end
-                                       response:send(body);
+               response.status_code = err_code;
+               response:send(events.fire_event("http-error", { code = err_code, message = err }));
+               return;
+       end
+
+       local event = request.method.." "..host..request.path:match("[^?]*");
+       local payload = { request = request, response = response };
+       --log("debug", "Firing event: %s", event);
+       local result = events.fire_event(event, payload);
+       if result ~= nil then
+               if result ~= true then
+                       local body;
+                       local result_type = type(result);
+                       if result_type == "number" then
+                               response.status_code = result;
+                               if result >= 400 then
+                                       body = events.fire_event("http-error", { code = result });
+                               end
+                       elseif result_type == "string" then
+                               body = result;
+                       elseif result_type == "table" then
+                               for k, v in pairs(result) do
+                                       response[k] = v;
                                end
-                               return;
                        end
+                       response:send(body);
                end
-
-               -- if handler not called, return 404
-               response.status_code = 404;
-               response.headers.content_type = "text/html";
-               response:send(events.fire_event("http-error", { code = 404 }));
+               return;
        end
+
+       -- if handler not called, return 404
+       response.status_code = 404;
+       response:send(events.fire_event("http-error", { code = 404 }));
 end
 function _M.send_response(response, body)
        if response.finished then return; end
@@ -225,7 +226,7 @@ function _M.send_response(response, body)
        
        local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
        local headers = response.headers;
-       body = body or "";
+       body = body or response.body or "";
        headers.content_length = #body;
 
        local output = { status_line };
@@ -256,6 +257,15 @@ end
 function _M.listen_on(port, interface, ssl)
        addserver(interface or "*", port, listener, "*a", ssl);
 end
+function _M.add_host(host)
+       hosts[host] = true;
+end
+function _M.remove_host(host)
+       hosts[host] = nil;
+end
+function _M.set_default_host(host)
+       default_host = host;
+end
 
 _M.listener = listener;
 _M.codes = codes;