eb30da78633d61c61333a8fd8b80684f0c37785b
[prosody.git] / net / http / server.lua
1
2 local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
3 local parser_new = require "net.http.parser".new;
4 local events = require "util.events".new();
5 local addserver = require "net.server".addserver;
6 local log = require "util.logger".init("http.server");
7 local os_date = os.date;
8 local pairs = pairs;
9 local s_upper = string.upper;
10 local setmetatable = setmetatable;
11 local xpcall = xpcall;
12 local debug = debug;
13 local tostring = tostring;
14 local codes = require "net.http.codes";
15 local _G = _G;
16 local legacy_httpserver = require "net.httpserver";
17
18 local _M = {};
19
20 local sessions = {};
21 local listener = {};
22 local hosts = {};
23 local default_host;
24
25 local function is_wildcard_event(event)
26         return event:sub(-2, -1) == "/*";
27 end
28 local function is_wildcard_match(wildcard_event, event)
29         return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
30 end
31
32 local event_map = events._event_map;
33 setmetatable(events._handlers, {
34         __index = function (handlers, curr_event)
35                 if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
36                 -- Find all handlers that could match this event, sort them
37                 -- and then put the array into handlers[curr_event] (and return it)
38                 local matching_handlers_set = {};
39                 local handlers_array = {};
40                 for event, handlers_set in pairs(event_map) do
41                         if event == curr_event or
42                         is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
43                                 for handler, priority in pairs(handlers_set) do
44                                         matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
45                                         table.insert(handlers_array, handler);
46                                 end
47                         end
48                 end
49                 if #handlers_array > 0 then
50                         table.sort(handlers_array, function(b, a)
51                                 local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
52                                 for i = 1, #a_score do
53                                         if a_score[i] ~= b_score[i] then -- If equal, compare next score value
54                                                 return a_score[i] < b_score[i];
55                                         end
56                                 end
57                                 return false;
58                         end);
59                 else
60                         handlers_array = false;
61                 end
62                 rawset(handlers, curr_event, handlers_array);
63                 return handlers_array;
64         end;
65         __newindex = function (handlers, curr_event, handlers_array)
66                 if handlers_array == nil
67                 and is_wildcard_event(curr_event) then
68                         -- Invalidate the indexes of all matching events
69                         for event in pairs(handlers) do
70                                 if is_wildcard_match(curr_event, event) then
71                                         handlers[event] = nil;
72                                 end
73                         end
74                 end
75                 rawset(handlers, curr_event, handlers_array);
76         end;
77 });
78
79 local handle_request;
80 local _1, _2, _3;
81 local function _handle_request() return handle_request(_1, _2, _3); end
82
83 local last_err;
84 local function _traceback_handler(err) last_err = err; log("error", "Traceback[http]: %s: %s", tostring(err), debug.traceback()); end
85 events.add_handler("http-error", function (error)
86         return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
87 end, -1);
88
89 function listener.onconnect(conn)
90         local secure = conn:ssl() and true or nil;
91         local pending = {};
92         local waiting = false;
93         local function process_next(last_response)
94                 --if waiting then log("debug", "can't process_next, waiting"); return; end
95                 if sessions[conn] and #pending > 0 then
96                         local request = t_remove(pending);
97                         --log("debug", "process_next: %s", request.path);
98                         waiting = true;
99                         --handle_request(conn, request, process_next);
100                         _1, _2, _3 = conn, request, process_next;
101                         if not xpcall(_handle_request, _traceback_handler) then
102                                 conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
103                                 conn:close();
104                         end
105                 else
106                         --log("debug", "ready for more");
107                         waiting = false;
108                 end
109         end
110         local function success_cb(request)
111                 --log("debug", "success_cb: %s", request.path);
112                 request.secure = secure;
113                 t_insert(pending, request);
114                 if not waiting then
115                         process_next();
116                 end
117         end
118         local function error_cb(err)
119                 log("debug", "error_cb: %s", err or "<nil>");
120                 -- FIXME don't close immediately, wait until we process current stuff
121                 -- FIXME if err, send off a bad-request response
122                 sessions[conn] = nil;
123                 conn:close();
124         end
125         sessions[conn] = parser_new(success_cb, error_cb);
126 end
127
128 function listener.ondisconnect(conn)
129         local open_response = conn._http_open_response;
130         if open_response and open_response.on_destroy then
131                 open_response.finished = true;
132                 open_response:on_destroy();
133         end
134         sessions[conn] = nil;
135 end
136
137 function listener.onincoming(conn, data)
138         sessions[conn]:feed(data);
139 end
140
141 local headerfix = setmetatable({}, {
142         __index = function(t, k)
143                 local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": ";
144                 t[k] = v;
145                 return v;
146         end
147 });
148
149 function _M.hijack_response(response, listener)
150         error("TODO");
151 end
152 function handle_request(conn, request, finish_cb)
153         --log("debug", "handler: %s", request.path);
154         local headers = {};
155         for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end
156         request.headers = headers;
157         request.conn = conn;
158
159         local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
160         local conn_header = request.headers.connection;
161         local keep_alive = conn_header == "Keep-Alive" or (request.httpversion == "1.1" and conn_header ~= "close");
162
163         local response = {
164                 request = request;
165                 status_code = 200;
166                 headers = { date = date_header, connection = (keep_alive and "Keep-Alive" or "close") };
167                 conn = conn;
168                 send = _M.send_response;
169                 finish_cb = finish_cb;
170         };
171         conn._http_open_response = response;
172
173         local host = (request.headers.host or ""):match("[^:]+");
174
175         -- Some sanity checking
176         local err_code, err;
177         if not host then
178                 err_code, err = 400, "Missing or invalid 'Host' header";
179         elseif not request.path then
180                 err_code, err = 400, "Invalid path";
181         elseif not hosts[host] then
182                 if hosts[default_host] then
183                         host = default_host;
184                 else
185                         err_code, err = 404, "Unknown host: "..host;
186                 end
187         end
188         
189         if err then
190                 response.status_code = err_code;
191                 response:send(events.fire_event("http-error", { code = err_code, message = err }));
192                 return;
193         end
194
195         local event = request.method.." "..host..request.path:match("[^?]*");
196         local payload = { request = request, response = response };
197         --log("debug", "Firing event: %s", event);
198         local result = events.fire_event(event, payload);
199         if result ~= nil then
200                 if result ~= true then
201                         local code, body = 200, "";
202                         local result_type = type(result);
203                         if result_type == "number" then
204                                 response.status_code = result;
205                                 if result >= 400 then
206                                         body = events.fire_event("http-error", { code = result });
207                                 end
208                         elseif result_type == "string" then
209                                 body = result;
210                         elseif result_type == "table" then
211                                 body = result.body;
212                                 result.body = nil;
213                                 for k, v in pairs(result) do
214                                         response[k] = v;
215                                 end
216                         end
217                         response:send(body);
218                 end
219                 return;
220         end
221
222         -- if handler not called, return 404
223         response.status_code = 404;
224         response:send(events.fire_event("http-error", { code = 404 }));
225 end
226 function _M.send_response(response, body)
227         if response.finished then return; end
228         response.finished = true;
229         response.conn._http_open_response = nil;
230         
231         local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
232         local headers = response.headers;
233         body = body or "";
234         headers.content_length = #body;
235
236         local output = { status_line };
237         for k,v in pairs(headers) do
238                 t_insert(output, headerfix[k]..v);
239         end
240         t_insert(output, "\r\n\r\n");
241         t_insert(output, body);
242
243         response.conn:write(t_concat(output));
244         if response.on_destroy then
245                 response:on_destroy();
246                 response.on_destroy = nil;
247         end
248         if headers.connection == "Keep-Alive" then
249                 response:finish_cb();
250         else
251                 response.conn:close();
252         end
253 end
254 function _M.add_handler(event, handler, priority)
255         events.add_handler(event, handler, priority);
256 end
257 function _M.remove_handler(event, handler)
258         events.remove_handler(event, handler);
259 end
260
261 function _M.listen_on(port, interface, ssl)
262         addserver(interface or "*", port, listener, "*a", ssl);
263 end
264 function _M.add_host(host)
265         hosts[host] = true;
266 end
267 function _M.remove_host(host)
268         hosts[host] = nil;
269 end
270 function _M.set_default_host(host)
271         default_host = host;
272 end
273
274 _M.listener = listener;
275 _M.codes = codes;
276 _M._events = events;
277 return _M;