Switched to new connection framework, courtesy of the luadch project
[prosody.git] / server.lua
1 --[[\r
2 \r
3         server.lua by blastbeat\r
4 \r
5         - this script contains the server loop of the program\r
6         - other scripts can reg a server here\r
7 \r
8 ]]--\r
9 \r
10 ----------------------------------// DECLARATION //--\r
11 \r
12 --// constants //--\r
13 \r
14 local STAT_UNIT = 1 / ( 1024 * 1024 )    -- mb\r
15 \r
16 --// lua functions //--\r
17 \r
18 local function use( what ) return _G[ what ] end\r
19 \r
20 local type = use "type"\r
21 local pairs = use "pairs"\r
22 local ipairs = use "ipairs"\r
23 local tostring = use "tostring"\r
24 local collectgarbage = use "collectgarbage"\r
25 \r
26 --// lua libs //--\r
27 \r
28 local table = use "table"\r
29 local coroutine = use "coroutine"\r
30 \r
31 --// lua lib methods //--\r
32 \r
33 local table_concat = table.concat\r
34 local table_remove = table.remove\r
35 local string_sub = use'string'.sub\r
36 local coroutine_wrap = coroutine.wrap\r
37 local coroutine_yield = coroutine.yield\r
38 local print = print;\r
39 local out_put = function () end --print;\r
40 local out_error = print;\r
41 \r
42 --// extern libs //--\r
43 \r
44 local luasec = require "ssl"\r
45 local luasocket = require "socket"\r
46 \r
47 --// extern lib methods //--\r
48 \r
49 local ssl_wrap = ( luasec and luasec.wrap )\r
50 local socket_bind = luasocket.bind\r
51 local socket_select = luasocket.select\r
52 local ssl_newcontext = ( luasec and luasec.newcontext )\r
53 \r
54 --// functions //--\r
55 \r
56 local loop\r
57 local stats\r
58 local addtimer\r
59 local closeall\r
60 local addserver\r
61 local firetimer\r
62 local closesocket\r
63 local removesocket\r
64 local wrapserver\r
65 local wraptcpclient\r
66 local wrapsslclient\r
67 \r
68 --// tables //--\r
69 \r
70 local listener\r
71 local readlist\r
72 local writelist\r
73 local socketlist\r
74 local timelistener\r
75 \r
76 --// simple data types //--\r
77 \r
78 local _\r
79 local readlen = 0    -- length of readlist\r
80 local writelen = 0    -- lenght of writelist\r
81 \r
82 local sendstat= 0\r
83 local receivestat = 0\r
84 \r
85 ----------------------------------// DEFINITION //--\r
86 \r
87 listener = { }    -- key = port, value = table\r
88 readlist = { }    -- array with sockets to read from\r
89 writelist = { }    -- arrary with sockets to write to\r
90 socketlist = { }    -- key = socket, value = wrapped socket\r
91 timelistener = { }\r
92 \r
93 stats = function( )\r
94     return receivestat, sendstat\r
95 end\r
96 \r
97 wrapserver = function( listener, socket, ip, serverport, mode, sslctx )    -- this function wraps a server\r
98 \r
99     local dispatch, disconnect = listener.listener, listener.disconnect    -- dangerous\r
100 \r
101     local wrapclient, err\r
102 \r
103     if sslctx then\r
104         if not ssl_newcontext then\r
105             return nil, "luasec not found"\r
106 --        elseif not cfg_get "use_ssl" then\r
107 --            return nil, "ssl is deactivated"\r
108         end\r
109         if type( sslctx ) ~= "table" then\r
110             out_error "server.lua: wrong server sslctx"\r
111             return nil, "wrong server sslctx"\r
112         end\r
113         sslctx, err = ssl_newcontext( sslctx )\r
114         if not sslctx then\r
115             err = err or "wrong sslctx parameters"\r
116             out_error( "server.lua: ", err )\r
117             return nil, err\r
118         end\r
119         wrapclient = wrapsslclient\r
120     else\r
121         wrapclient = wraptcpclient\r
122     end\r
123 \r
124     local accept = socket.accept\r
125     local close = socket.close\r
126 \r
127     --// public methods of the object //--    \r
128 \r
129     local handler = { }\r
130 \r
131     handler.shutdown = function( ) end\r
132 \r
133     --[[handler.listener = function( data, err )\r
134         return ondata( handler, data, err )\r
135     end]]\r
136     handler.ssl = function( )\r
137         return sslctx and true or false\r
138     end\r
139     handler.close = function( closed )\r
140         _ = not closed and close( socket )\r
141         writelen = removesocket( writelist, socket, writelen )\r
142         readlen = removesocket( readlist, socket, readlen )\r
143         socketlist[ socket ] = nil\r
144         handler = nil\r
145     end\r
146     handler.ip = function( )\r
147         return ip\r
148     end\r
149     handler.serverport = function( )\r
150         return serverport\r
151     end\r
152     handler.socket = function( )\r
153         return socket\r
154     end\r
155     handler.receivedata = function( )\r
156         local client, err = accept( socket )    -- try to accept\r
157         if client then\r
158             local ip, clientport = client:getpeername( )\r
159             client:settimeout( 0 )\r
160             local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx )    -- wrap new client socket\r
161             if err then    -- error while wrapping ssl socket\r
162                 return false\r
163             end\r
164             out_put( "server.lua: accepted new client connection from ", ip, ":", clientport )\r
165             return dispatch( handler )\r
166         elseif err then    -- maybe timeout or something else\r
167             out_put( "server.lua: error with new client connection: ", err )\r
168             return false\r
169         end\r
170     end\r
171     return handler\r
172 end\r
173 \r
174 wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx )    -- this function wraps a ssl cleint\r
175 \r
176     local dispatch, disconnect = listener.listener, listener.disconnect\r
177 \r
178     --// transform socket to ssl object //--\r
179 \r
180     local err\r
181     socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
182     if err then\r
183         out_put( "server.lua: ssl error: ", err )\r
184         return nil, nil, err    -- fatal error\r
185     end\r
186     socket:settimeout( 0 )\r
187 \r
188     --// private closures of the object //--\r
189 \r
190     local writequeue = { }    -- buffer for messages to send\r
191 \r
192     local eol   -- end of buffer\r
193 \r
194     local sstat, rstat = 0, 0\r
195 \r
196     --// local import of socket methods //--\r
197 \r
198     local send = socket.send\r
199     local receive = socket.receive\r
200     local close = socket.close\r
201     --local shutdown = socket.shutdown\r
202 \r
203     --// public methods of the object //--\r
204 \r
205     local handler = { }\r
206 \r
207     handler.getstats = function( )\r
208         return rstat, sstat\r
209     end\r
210 \r
211     handler.listener = function( data, err )\r
212         return listener( handler, data, err )\r
213     end\r
214     handler.ssl = function( )\r
215         return true\r
216     end\r
217     handler.send = function( _, data, i, j )\r
218             return send( socket, data, i, j )\r
219     end\r
220     handler.receive = function( pattern, prefix )\r
221             return receive( socket, pattern, prefix )\r
222     end\r
223     handler.shutdown = function( pattern )\r
224         --return shutdown( socket, pattern )\r
225     end\r
226     handler.close = function( closed )\r
227         close( socket )\r
228         writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
229         readlen = removesocket( readlist, socket, readlen )\r
230         socketlist[ socket ] = nil\r
231         out_put "server.lua: closed handler and removed socket from list"\r
232     end\r
233     handler.ip = function( )\r
234         return ip\r
235     end\r
236     handler.serverport = function( )\r
237         return serverport\r
238     end\r
239     handler.clientport = function( ) \r
240         return clientport\r
241     end\r
242 \r
243     handler.write = function( data )\r
244         if not eol then\r
245             writelen = writelen + 1\r
246             writelist[ writelen ] = socket\r
247             eol = 0\r
248         end\r
249         eol = eol + 1\r
250         writequeue[ eol ] = data\r
251     end\r
252     handler.writequeue = function( )\r
253         return writequeue\r
254     end\r
255     handler.socket = function( )\r
256         return socket\r
257     end\r
258     handler.mode = function( )\r
259         return mode\r
260     end\r
261     handler._receivedata = function( )\r
262         local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
263         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
264             local data = data or part or ""\r
265             local count = #data * STAT_UNIT\r
266             rstat = rstat + count\r
267             receivestat = receivestat + count\r
268             out_put( "server.lua: read data '", data, "', error: ", err )\r
269             return dispatch( handler, data, err )\r
270         else    -- connections was closed or fatal error\r
271             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
272             handler.close( )\r
273             disconnect( handler, err )\r
274             writequeue = nil\r
275             handler = nil\r
276             return false\r
277         end\r
278     end\r
279     handler._dispatchdata = function( )    -- this function writes data to handlers\r
280         local buffer = table_concat( writequeue, "", 1, eol )\r
281         local succ, err, byte = send( socket, buffer )\r
282         local count = ( succ or 0 ) * STAT_UNIT\r
283         sstat = sstat + count\r
284         sendstat = sendstat + count\r
285         out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
286         if succ then    -- sending succesful\r
287             --writequeue = { }\r
288             eol = nil\r
289             writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
290             return true\r
291         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
292             buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
293             writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
294             eol = 1\r
295             return true\r
296         else    -- connection was closed during sending or fatal error\r
297             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
298             handler.close( )\r
299             disconnect( handler, err )\r
300             writequeue = nil\r
301             handler = nil\r
302             return false\r
303         end\r
304     end\r
305 \r
306     -- // COMPAT // --\r
307 \r
308     handler.getIp = handler.ip\r
309     handler.getPort = handler.clientport\r
310 \r
311     --// handshake //--\r
312 \r
313     local wrote\r
314 \r
315     handler.handshake = coroutine_wrap( function( client )\r
316             local err\r
317             for i = 1, 10 do    -- 10 handshake attemps\r
318                 _, err = client:dohandshake( )\r
319                 if not err then\r
320                     out_put( "server.lua: ssl handshake done" )\r
321                     writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen\r
322                     handler.receivedata = handler._receivedata    -- when handshake is done, replace the handshake function with regular functions\r
323                     handler.dispatchdata = handler._dispatchdata\r
324                     return dispatch( handler )\r
325                 else\r
326                     out_put( "server.lua: error during ssl handshake: ", err )\r
327                     if err == "wantwrite" then\r
328                         if wrote == nil then\r
329                             writelen = writelen + 1\r
330                             writelist[ writelen ] = client\r
331                             wrote = true\r
332                         end\r
333                     end\r
334                     coroutine_yield( handler, nil, err )    -- handshake not finished\r
335                 end\r
336             end\r
337             _ = err ~= "closed" and close( socket )\r
338             handler.close( )\r
339             disconnect( handler, err )\r
340             writequeue = nil\r
341             handler = nil\r
342             return false    -- handshake failed\r
343         end\r
344     )\r
345     handler.receivedata = handler.handshake\r
346     handler.dispatchdata = handler.handshake\r
347 \r
348     handler.handshake( socket )    -- do handshake\r
349 \r
350     socketlist[ socket ] = handler\r
351     readlen = readlen + 1\r
352     readlist[ readlen ] = socket\r
353 \r
354     return handler, socket\r
355 end\r
356 \r
357 wraptcpclient = function( listener, socket, ip, serverport, clientport, mode )    -- this function wraps a socket\r
358 \r
359     local dispatch, disconnect = listener.listener, listener.disconnect\r
360 \r
361     --// private closures of the object //--\r
362 \r
363     local writequeue = { }    -- list for messages to send\r
364 \r
365     local eol\r
366 \r
367     local rstat, sstat = 0, 0\r
368 \r
369     --// local import of socket methods //--\r
370 \r
371     local send = socket.send\r
372     local receive = socket.receive\r
373     local close = socket.close\r
374     local shutdown = socket.shutdown\r
375 \r
376     --// public methods of the object //--\r
377 \r
378     local handler = { }\r
379 \r
380     handler.getstats = function( )\r
381         return rstat, sstat\r
382     end\r
383 \r
384     handler.listener = function( data, err )\r
385         return listener( handler, data, err )\r
386     end\r
387     handler.ssl = function( )\r
388         return false\r
389     end\r
390     handler.send = function( _, data, i, j )\r
391             return send( socket, data, i, j )\r
392     end\r
393     handler.receive = function( pattern, prefix )\r
394             return receive( socket, pattern, prefix )\r
395     end\r
396     handler.shutdown = function( pattern )\r
397         return shutdown( socket, pattern )\r
398     end\r
399     handler.close = function( closed )\r
400         _ = not closed and shutdown( socket )\r
401         _ = not closed and close( socket )\r
402         writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
403         readlen = removesocket( readlist, socket, readlen )\r
404         socketlist[ socket ] = nil\r
405         out_put "server.lua: closed handler and removed socket from list"\r
406     end\r
407     handler.ip = function( )\r
408         return ip\r
409     end\r
410     handler.serverport = function( )\r
411         return serverport\r
412     end\r
413     handler.clientport = function( ) \r
414         return clientport\r
415     end\r
416     handler.write = function( data )\r
417         if not eol then\r
418             writelen = writelen + 1\r
419             writelist[ writelen ] = socket\r
420             eol = 0\r
421         end\r
422         eol = eol + 1\r
423         writequeue[ eol ] = data\r
424     end\r
425     handler.writequeue = function( )\r
426         return writequeue\r
427     end\r
428     handler.socket = function( )\r
429         return socket\r
430     end\r
431     handler.mode = function( )\r
432         return mode\r
433     end\r
434     handler.receivedata = function( )\r
435         local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
436         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
437             local data = data or part or ""\r
438             local count = #data * STAT_UNIT\r
439             rstat = rstat + count\r
440             receivestat = receivestat + count\r
441             out_put( "server.lua: read data '", data, "', error: ", err )\r
442             return dispatch( handler, data, err )\r
443         else    -- connections was closed or fatal error\r
444             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
445             handler.close( )\r
446             disconnect( handler, err )\r
447             writequeue = nil\r
448             handler = nil\r
449             return false\r
450         end\r
451     end\r
452     handler.dispatchdata = function( )    -- this function writes data to handlers\r
453         local buffer = table_concat( writequeue, "", 1, eol )\r
454         local succ, err, byte = send( socket, buffer )\r
455         local count = ( succ or 0 ) * STAT_UNIT\r
456         sstat = sstat + count\r
457         sendstat = sendstat + count\r
458         out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
459         if succ then    -- sending succesful\r
460             --writequeue = { }\r
461             eol = nil\r
462             writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
463             return true\r
464         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
465             buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
466             writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
467             eol = 1\r
468             return true\r
469         else    -- connection was closed during sending or fatal error\r
470             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
471             handler.close( )\r
472             disconnect( handler, err )\r
473             writequeue = nil\r
474             handler = nil\r
475             return false\r
476         end\r
477     end\r
478 \r
479     -- // COMPAT // --\r
480 \r
481     handler.getIp = handler.ip\r
482     handler.getPort = handler.clientport\r
483 \r
484     socketlist[ socket ] = handler\r
485     readlen = readlen + 1\r
486     readlist[ readlen ] = socket\r
487 \r
488     return handler, socket\r
489 end\r
490 \r
491 addtimer = function( listener )\r
492     timelistener[ #timelistener + 1 ] = listener\r
493 end\r
494 \r
495 firetimer = function( listener )\r
496     for i, listener in ipairs( timelistener ) do\r
497         listener( )\r
498     end\r
499 end\r
500 \r
501 addserver = function( listeners, port, addr, mode, sslctx )    -- this function provides a way for other scripts to reg a server\r
502     local err\r
503     if type( listeners ) ~= "table" then\r
504         err = "invalid listener table"\r
505     else\r
506         for name, func in pairs( listeners ) do\r
507             if type( func ) ~= "function" then\r
508                 err = "invalid listener function"\r
509                 break\r
510             end\r
511         end\r
512     end\r
513     if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then\r
514         err = "invalid port"\r
515     elseif listener[ port ] then\r
516         err=  "listeners on port '" .. port .. "' already exist"\r
517     elseif sslctx and not luasec then\r
518         err = "luasec not found"\r
519     end\r
520     if err then\r
521         out_error( "server.lua: ", err )\r
522         return nil, err\r
523     end\r
524     addr = addr or "*"\r
525     local server, err = socket_bind( addr, port )\r
526     if err then\r
527         out_error( "server.lua: ", err )\r
528         return nil, err\r
529     end\r
530     local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx )    -- wrap new server socket\r
531     if not handler then\r
532         server:close( )\r
533         return nil, err\r
534     end\r
535     server:settimeout( 0 )\r
536     readlen = readlen + 1\r
537     readlist[ readlen ] = server\r
538     listener[ port ] = listeners\r
539     socketlist[ server ] = handler\r
540     out_put( "server.lua: new server listener on ", addr, ":", port )\r
541     return true\r
542 end\r
543 \r
544 removesocket = function( tbl, socket, len )    -- this function removes sockets from a list\r
545     for i, target in ipairs( tbl ) do\r
546         if target == socket then\r
547             len = len - 1\r
548             table_remove( tbl, i )\r
549             return len\r
550         end\r
551     end\r
552     return len\r
553 end\r
554 \r
555 closeall = function( )\r
556     for _, handler in pairs( socketlist ) do\r
557         handler.shutdown( )\r
558         handler.close( )\r
559         socketlist[ _ ] = nil\r
560     end\r
561     writelist, readlist, socketlist = { }, { }, { }\r
562 end\r
563 \r
564 closesocket = function( socket )\r
565     writelen = removesocket( writelist, socket, writelen )\r
566     readlen = removesocket( readlist, socket, readlen )\r
567     socketlist[ socket ] = nil\r
568     socket:close( )\r
569 end\r
570 \r
571 loop = function( )    -- this is the main loop of the program\r
572     --signal_set( "hub", "run" )\r
573     repeat\r
574         local read, write, err = socket_select( readlist, writelist, 1 )    -- 1 sec timeout, nice for timers\r
575         for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
576             local handler = socketlist[ socket ]\r
577             if handler then\r
578                 handler.dispatchdata( )\r
579             else\r
580                 closesocket( socket )\r
581                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen\r
582             end\r
583         end\r
584         for i, socket in ipairs( read ) do    -- receive data\r
585             local handler = socketlist[ socket ]\r
586             if handler then\r
587                 handler.receivedata( )\r
588             else\r
589                 closesocket( socket )\r
590                 out_put "server.lua: found no handler and closed socket (readlist)"    -- this can happen\r
591             end\r
592         end\r
593         firetimer( )\r
594         --collectgarbage "collect"\r
595     until false --signal_get "hub" ~= "run"\r
596     return --signal_get "hub"\r
597 end\r
598 \r
599 ----------------------------------// BEGIN //--\r
600 \r
601 ----------------------------------// PUBLIC INTERFACE //--\r
602 \r
603 return {\r
604 \r
605     add = addserver,\r
606     loop = loop,\r
607     stats = stats,\r
608     closeall = closeall,\r
609     addtimer = addtimer,\r
610 \r
611 }\r