Moved server module to net/
[prosody.git] / net / server.lua
1 --[[\r
2 \r
3         server.lua by blastbeat of the luadch project\r
4         \r
5         re-used here under the MIT/X Consortium License\r
6 \r
7         - this script contains the server loop of the program\r
8         - other scripts can reg a server here\r
9 \r
10 ]]--\r
11 \r
12 ----------------------------------// DECLARATION //--\r
13 \r
14 --// constants //--\r
15 \r
16 local STAT_UNIT = 1 / ( 1024 * 1024 )    -- mb\r
17 \r
18 --// lua functions //--\r
19 \r
20 local function use( what ) return _G[ what ] end\r
21 \r
22 local type = use "type"\r
23 local pairs = use "pairs"\r
24 local ipairs = use "ipairs"\r
25 local tostring = use "tostring"\r
26 local collectgarbage = use "collectgarbage"\r
27 \r
28 --// lua libs //--\r
29 \r
30 local table = use "table"\r
31 local coroutine = use "coroutine"\r
32 \r
33 --// lua lib methods //--\r
34 \r
35 local table_concat = table.concat\r
36 local table_remove = table.remove\r
37 local string_sub = use'string'.sub\r
38 local coroutine_wrap = coroutine.wrap\r
39 local coroutine_yield = coroutine.yield\r
40 local print = print;\r
41 local out_put = function () end --print;\r
42 local out_error = print;\r
43 \r
44 --// extern libs //--\r
45 \r
46 local luasec = require "ssl"\r
47 local luasocket = require "socket"\r
48 \r
49 --// extern lib methods //--\r
50 \r
51 local ssl_wrap = ( luasec and luasec.wrap )\r
52 local socket_bind = luasocket.bind\r
53 local socket_select = luasocket.select\r
54 local ssl_newcontext = ( luasec and luasec.newcontext )\r
55 \r
56 --// functions //--\r
57 \r
58 local loop\r
59 local stats\r
60 local addtimer\r
61 local closeall\r
62 local addserver\r
63 local firetimer\r
64 local closesocket\r
65 local removesocket\r
66 local wrapserver\r
67 local wraptcpclient\r
68 local wrapsslclient\r
69 \r
70 --// tables //--\r
71 \r
72 local listener\r
73 local readlist\r
74 local writelist\r
75 local socketlist\r
76 local timelistener\r
77 \r
78 --// simple data types //--\r
79 \r
80 local _\r
81 local readlen = 0    -- length of readlist\r
82 local writelen = 0    -- lenght of writelist\r
83 \r
84 local sendstat= 0\r
85 local receivestat = 0\r
86 \r
87 ----------------------------------// DEFINITION //--\r
88 \r
89 listener = { }    -- key = port, value = table\r
90 readlist = { }    -- array with sockets to read from\r
91 writelist = { }    -- arrary with sockets to write to\r
92 socketlist = { }    -- key = socket, value = wrapped socket\r
93 timelistener = { }\r
94 \r
95 stats = function( )\r
96     return receivestat, sendstat\r
97 end\r
98 \r
99 wrapserver = function( listener, socket, ip, serverport, mode, sslctx )    -- this function wraps a server\r
100 \r
101     local dispatch, disconnect = listener.listener, listener.disconnect    -- dangerous\r
102 \r
103     local wrapclient, err\r
104 \r
105     if sslctx then\r
106         if not ssl_newcontext then\r
107             return nil, "luasec not found"\r
108 --        elseif not cfg_get "use_ssl" then\r
109 --            return nil, "ssl is deactivated"\r
110         end\r
111         if type( sslctx ) ~= "table" then\r
112             out_error "server.lua: wrong server sslctx"\r
113             return nil, "wrong server sslctx"\r
114         end\r
115         sslctx, err = ssl_newcontext( sslctx )\r
116         if not sslctx then\r
117             err = err or "wrong sslctx parameters"\r
118             out_error( "server.lua: ", err )\r
119             return nil, err\r
120         end\r
121         wrapclient = wrapsslclient\r
122     else\r
123         wrapclient = wraptcpclient\r
124     end\r
125 \r
126     local accept = socket.accept\r
127     local close = socket.close\r
128 \r
129     --// public methods of the object //--    \r
130 \r
131     local handler = { }\r
132 \r
133     handler.shutdown = function( ) end\r
134 \r
135     --[[handler.listener = function( data, err )\r
136         return ondata( handler, data, err )\r
137     end]]\r
138     handler.ssl = function( )\r
139         return sslctx and true or false\r
140     end\r
141     handler.close = function( closed )\r
142         _ = not closed and close( socket )\r
143         writelen = removesocket( writelist, socket, writelen )\r
144         readlen = removesocket( readlist, socket, readlen )\r
145         socketlist[ socket ] = nil\r
146         handler = nil\r
147     end\r
148     handler.ip = function( )\r
149         return ip\r
150     end\r
151     handler.serverport = function( )\r
152         return serverport\r
153     end\r
154     handler.socket = function( )\r
155         return socket\r
156     end\r
157     handler.receivedata = function( )\r
158         local client, err = accept( socket )    -- try to accept\r
159         if client then\r
160             local ip, clientport = client:getpeername( )\r
161             client:settimeout( 0 )\r
162             local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx )    -- wrap new client socket\r
163             if err then    -- error while wrapping ssl socket\r
164                 return false\r
165             end\r
166             out_put( "server.lua: accepted new client connection from ", ip, ":", clientport )\r
167             return dispatch( handler )\r
168         elseif err then    -- maybe timeout or something else\r
169             out_put( "server.lua: error with new client connection: ", err )\r
170             return false\r
171         end\r
172     end\r
173     return handler\r
174 end\r
175 \r
176 wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx )    -- this function wraps a ssl cleint\r
177 \r
178     local dispatch, disconnect = listener.listener, listener.disconnect\r
179 \r
180     --// transform socket to ssl object //--\r
181 \r
182     local err\r
183     socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
184     if err then\r
185         out_put( "server.lua: ssl error: ", err )\r
186         return nil, nil, err    -- fatal error\r
187     end\r
188     socket:settimeout( 0 )\r
189 \r
190     --// private closures of the object //--\r
191 \r
192     local writequeue = { }    -- buffer for messages to send\r
193 \r
194     local eol   -- end of buffer\r
195 \r
196     local sstat, rstat = 0, 0\r
197 \r
198     --// local import of socket methods //--\r
199 \r
200     local send = socket.send\r
201     local receive = socket.receive\r
202     local close = socket.close\r
203     --local shutdown = socket.shutdown\r
204 \r
205     --// public methods of the object //--\r
206 \r
207     local handler = { }\r
208 \r
209     handler.getstats = function( )\r
210         return rstat, sstat\r
211     end\r
212 \r
213     handler.listener = function( data, err )\r
214         return listener( handler, data, err )\r
215     end\r
216     handler.ssl = function( )\r
217         return true\r
218     end\r
219     handler.send = function( _, data, i, j )\r
220             return send( socket, data, i, j )\r
221     end\r
222     handler.receive = function( pattern, prefix )\r
223             return receive( socket, pattern, prefix )\r
224     end\r
225     handler.shutdown = function( pattern )\r
226         --return shutdown( socket, pattern )\r
227     end\r
228     handler.close = function( closed )\r
229         close( socket )\r
230         writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
231         readlen = removesocket( readlist, socket, readlen )\r
232         socketlist[ socket ] = nil\r
233         out_put "server.lua: closed handler and removed socket from list"\r
234     end\r
235     handler.ip = function( )\r
236         return ip\r
237     end\r
238     handler.serverport = function( )\r
239         return serverport\r
240     end\r
241     handler.clientport = function( ) \r
242         return clientport\r
243     end\r
244 \r
245     handler.write = function( data )\r
246         if not eol then\r
247             writelen = writelen + 1\r
248             writelist[ writelen ] = socket\r
249             eol = 0\r
250         end\r
251         eol = eol + 1\r
252         writequeue[ eol ] = data\r
253     end\r
254     handler.writequeue = function( )\r
255         return writequeue\r
256     end\r
257     handler.socket = function( )\r
258         return socket\r
259     end\r
260     handler.mode = function( )\r
261         return mode\r
262     end\r
263     handler._receivedata = function( )\r
264         local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
265         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
266             local data = data or part or ""\r
267             local count = #data * STAT_UNIT\r
268             rstat = rstat + count\r
269             receivestat = receivestat + count\r
270             out_put( "server.lua: read data '", data, "', error: ", err )\r
271             return dispatch( handler, data, err )\r
272         else    -- connections was closed or fatal error\r
273             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
274             handler.close( )\r
275             disconnect( handler, err )\r
276             writequeue = nil\r
277             handler = nil\r
278             return false\r
279         end\r
280     end\r
281     handler._dispatchdata = function( )    -- this function writes data to handlers\r
282         local buffer = table_concat( writequeue, "", 1, eol )\r
283         local succ, err, byte = send( socket, buffer )\r
284         local count = ( succ or 0 ) * STAT_UNIT\r
285         sstat = sstat + count\r
286         sendstat = sendstat + count\r
287         out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
288         if succ then    -- sending succesful\r
289             --writequeue = { }\r
290             eol = nil\r
291             writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
292             return true\r
293         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
294             buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
295             writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
296             eol = 1\r
297             return true\r
298         else    -- connection was closed during sending or fatal error\r
299             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
300             handler.close( )\r
301             disconnect( handler, err )\r
302             writequeue = nil\r
303             handler = nil\r
304             return false\r
305         end\r
306     end\r
307 \r
308     -- // COMPAT // --\r
309 \r
310     handler.getIp = handler.ip\r
311     handler.getPort = handler.clientport\r
312 \r
313     --// handshake //--\r
314 \r
315     local wrote\r
316 \r
317     handler.handshake = coroutine_wrap( function( client )\r
318             local err\r
319             for i = 1, 10 do    -- 10 handshake attemps\r
320                 _, err = client:dohandshake( )\r
321                 if not err then\r
322                     out_put( "server.lua: ssl handshake done" )\r
323                     writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen\r
324                     handler.receivedata = handler._receivedata    -- when handshake is done, replace the handshake function with regular functions\r
325                     handler.dispatchdata = handler._dispatchdata\r
326                     return dispatch( handler )\r
327                 else\r
328                     out_put( "server.lua: error during ssl handshake: ", err )\r
329                     if err == "wantwrite" then\r
330                         if wrote == nil then\r
331                             writelen = writelen + 1\r
332                             writelist[ writelen ] = client\r
333                             wrote = true\r
334                         end\r
335                     end\r
336                     coroutine_yield( handler, nil, err )    -- handshake not finished\r
337                 end\r
338             end\r
339             _ = err ~= "closed" and close( socket )\r
340             handler.close( )\r
341             disconnect( handler, err )\r
342             writequeue = nil\r
343             handler = nil\r
344             return false    -- handshake failed\r
345         end\r
346     )\r
347     handler.receivedata = handler.handshake\r
348     handler.dispatchdata = handler.handshake\r
349 \r
350     handler.handshake( socket )    -- do handshake\r
351 \r
352     socketlist[ socket ] = handler\r
353     readlen = readlen + 1\r
354     readlist[ readlen ] = socket\r
355 \r
356     return handler, socket\r
357 end\r
358 \r
359 wraptcpclient = function( listener, socket, ip, serverport, clientport, mode )    -- this function wraps a socket\r
360 \r
361     local dispatch, disconnect = listener.listener, listener.disconnect\r
362 \r
363     --// private closures of the object //--\r
364 \r
365     local writequeue = { }    -- list for messages to send\r
366 \r
367     local eol\r
368 \r
369     local rstat, sstat = 0, 0\r
370 \r
371     --// local import of socket methods //--\r
372 \r
373     local send = socket.send\r
374     local receive = socket.receive\r
375     local close = socket.close\r
376     local shutdown = socket.shutdown\r
377 \r
378     --// public methods of the object //--\r
379 \r
380     local handler = { }\r
381 \r
382     handler.getstats = function( )\r
383         return rstat, sstat\r
384     end\r
385 \r
386     handler.listener = function( data, err )\r
387         return listener( handler, data, err )\r
388     end\r
389     handler.ssl = function( )\r
390         return false\r
391     end\r
392     handler.send = function( _, data, i, j )\r
393             return send( socket, data, i, j )\r
394     end\r
395     handler.receive = function( pattern, prefix )\r
396             return receive( socket, pattern, prefix )\r
397     end\r
398     handler.shutdown = function( pattern )\r
399         return shutdown( socket, pattern )\r
400     end\r
401     handler.close = function( closed )\r
402         _ = not closed and shutdown( socket )\r
403         _ = not closed and close( socket )\r
404         writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
405         readlen = removesocket( readlist, socket, readlen )\r
406         socketlist[ socket ] = nil\r
407         out_put "server.lua: closed handler and removed socket from list"\r
408     end\r
409     handler.ip = function( )\r
410         return ip\r
411     end\r
412     handler.serverport = function( )\r
413         return serverport\r
414     end\r
415     handler.clientport = function( ) \r
416         return clientport\r
417     end\r
418     handler.write = function( data )\r
419         if not eol then\r
420             writelen = writelen + 1\r
421             writelist[ writelen ] = socket\r
422             eol = 0\r
423         end\r
424         eol = eol + 1\r
425         writequeue[ eol ] = data\r
426     end\r
427     handler.writequeue = function( )\r
428         return writequeue\r
429     end\r
430     handler.socket = function( )\r
431         return socket\r
432     end\r
433     handler.mode = function( )\r
434         return mode\r
435     end\r
436     handler.receivedata = function( )\r
437         local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
438         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
439             local data = data or part or ""\r
440             local count = #data * STAT_UNIT\r
441             rstat = rstat + count\r
442             receivestat = receivestat + count\r
443             out_put( "server.lua: read data '", data, "', error: ", err )\r
444             return dispatch( handler, data, err )\r
445         else    -- connections was closed or fatal error\r
446             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
447             handler.close( )\r
448             disconnect( handler, err )\r
449             writequeue = nil\r
450             handler = nil\r
451             return false\r
452         end\r
453     end\r
454     handler.dispatchdata = function( )    -- this function writes data to handlers\r
455         local buffer = table_concat( writequeue, "", 1, eol )\r
456         local succ, err, byte = send( socket, buffer )\r
457         local count = ( succ or 0 ) * STAT_UNIT\r
458         sstat = sstat + count\r
459         sendstat = sendstat + count\r
460         out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
461         if succ then    -- sending succesful\r
462             --writequeue = { }\r
463             eol = nil\r
464             writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
465             return true\r
466         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
467             buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
468             writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
469             eol = 1\r
470             return true\r
471         else    -- connection was closed during sending or fatal error\r
472             out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
473             handler.close( )\r
474             disconnect( handler, err )\r
475             writequeue = nil\r
476             handler = nil\r
477             return false\r
478         end\r
479     end\r
480 \r
481     -- // COMPAT // --\r
482 \r
483     handler.getIp = handler.ip\r
484     handler.getPort = handler.clientport\r
485 \r
486     socketlist[ socket ] = handler\r
487     readlen = readlen + 1\r
488     readlist[ readlen ] = socket\r
489 \r
490     return handler, socket\r
491 end\r
492 \r
493 addtimer = function( listener )\r
494     timelistener[ #timelistener + 1 ] = listener\r
495 end\r
496 \r
497 firetimer = function( listener )\r
498     for i, listener in ipairs( timelistener ) do\r
499         listener( )\r
500     end\r
501 end\r
502 \r
503 addserver = function( listeners, port, addr, mode, sslctx )    -- this function provides a way for other scripts to reg a server\r
504     local err\r
505     if type( listeners ) ~= "table" then\r
506         err = "invalid listener table"\r
507     else\r
508         for name, func in pairs( listeners ) do\r
509             if type( func ) ~= "function" then\r
510                 err = "invalid listener function"\r
511                 break\r
512             end\r
513         end\r
514     end\r
515     if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then\r
516         err = "invalid port"\r
517     elseif listener[ port ] then\r
518         err=  "listeners on port '" .. port .. "' already exist"\r
519     elseif sslctx and not luasec then\r
520         err = "luasec not found"\r
521     end\r
522     if err then\r
523         out_error( "server.lua: ", err )\r
524         return nil, err\r
525     end\r
526     addr = addr or "*"\r
527     local server, err = socket_bind( addr, port )\r
528     if err then\r
529         out_error( "server.lua: ", err )\r
530         return nil, err\r
531     end\r
532     local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx )    -- wrap new server socket\r
533     if not handler then\r
534         server:close( )\r
535         return nil, err\r
536     end\r
537     server:settimeout( 0 )\r
538     readlen = readlen + 1\r
539     readlist[ readlen ] = server\r
540     listener[ port ] = listeners\r
541     socketlist[ server ] = handler\r
542     out_put( "server.lua: new server listener on ", addr, ":", port )\r
543     return true\r
544 end\r
545 \r
546 removesocket = function( tbl, socket, len )    -- this function removes sockets from a list\r
547     for i, target in ipairs( tbl ) do\r
548         if target == socket then\r
549             len = len - 1\r
550             table_remove( tbl, i )\r
551             return len\r
552         end\r
553     end\r
554     return len\r
555 end\r
556 \r
557 closeall = function( )\r
558     for sock, handler in pairs( socketlist ) do\r
559         handler.shutdown( )\r
560         handler.close( )\r
561         socketlist[ sock ] = nil\r
562     end\r
563     writelist, readlist, socketlist = { }, { }, { }\r
564 end\r
565 \r
566 closesocket = function( socket )\r
567     writelen = removesocket( writelist, socket, writelen )\r
568     readlen = removesocket( readlist, socket, readlen )\r
569     socketlist[ socket ] = nil\r
570     socket:close( )\r
571 end\r
572 \r
573 loop = function( )    -- this is the main loop of the program\r
574     --signal_set( "hub", "run" )\r
575     repeat\r
576         local read, write, err = socket_select( readlist, writelist, 1 )    -- 1 sec timeout, nice for timers\r
577         for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
578             local handler = socketlist[ socket ]\r
579             if handler then\r
580                 handler.dispatchdata( )\r
581             else\r
582                 closesocket( socket )\r
583                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen\r
584             end\r
585         end\r
586         for i, socket in ipairs( read ) do    -- receive data\r
587             local handler = socketlist[ socket ]\r
588             if handler then\r
589                 handler.receivedata( )\r
590             else\r
591                 closesocket( socket )\r
592                 out_put "server.lua: found no handler and closed socket (readlist)"    -- this can happen\r
593             end\r
594         end\r
595         firetimer( )\r
596         --collectgarbage "collect"\r
597     until false --signal_get "hub" ~= "run"\r
598     return --signal_get "hub"\r
599 end\r
600 \r
601 ----------------------------------// BEGIN //--\r
602 \r
603 ----------------------------------// PUBLIC INTERFACE //--\r
604 \r
605 return {\r
606 \r
607     add = addserver,\r
608     loop = loop,\r
609     stats = stats,\r
610     closeall = closeall,\r
611     addtimer = addtimer,\r
612 \r
613 }\r