Convert spaces->tabs
[prosody.git] / net / server.lua
1 --[[\r
2 \r
3                 server.lua by blastbeat of the luadch project\r
4                 \r
5                 re-used here under the MIT/X Consortium License\r
6 \r
7                 - this script contains the server loop of the program\r
8                 - other scripts can reg a server here\r
9 \r
10 ]]--\r
11 \r
12 ----------------------------------// DECLARATION //--\r
13 \r
14 --// constants //--\r
15 \r
16 local STAT_UNIT = 1 / ( 1024 * 1024 )    -- mb\r
17 \r
18 --// lua functions //--\r
19 \r
20 local function use( what ) return _G[ what ] end\r
21 \r
22 local type = use "type"\r
23 local pairs = use "pairs"\r
24 local ipairs = use "ipairs"\r
25 local tostring = use "tostring"\r
26 local collectgarbage = use "collectgarbage"\r
27 \r
28 --// lua libs //--\r
29 \r
30 local table = use "table"\r
31 local coroutine = use "coroutine"\r
32 \r
33 --// lua lib methods //--\r
34 \r
35 local table_concat = table.concat\r
36 local table_remove = table.remove\r
37 local string_sub = use'string'.sub\r
38 local coroutine_wrap = coroutine.wrap\r
39 local coroutine_yield = coroutine.yield\r
40 local print = print;\r
41 local out_put = function () end --print;\r
42 local out_error = print;\r
43 \r
44 --// extern libs //--\r
45 \r
46 local luasec = require "ssl"\r
47 local luasocket = require "socket"\r
48 \r
49 --// extern lib methods //--\r
50 \r
51 local ssl_wrap = ( luasec and luasec.wrap )\r
52 local socket_bind = luasocket.bind\r
53 local socket_select = luasocket.select\r
54 local ssl_newcontext = ( luasec and luasec.newcontext )\r
55 \r
56 --// functions //--\r
57 \r
58 local loop\r
59 local stats\r
60 local addtimer\r
61 local closeall\r
62 local addserver\r
63 local firetimer\r
64 local closesocket\r
65 local removesocket\r
66 local wrapserver\r
67 local wraptcpclient\r
68 local wrapsslclient\r
69 \r
70 --// tables //--\r
71 \r
72 local listener\r
73 local readlist\r
74 local writelist\r
75 local socketlist\r
76 local timelistener\r
77 \r
78 --// simple data types //--\r
79 \r
80 local _\r
81 local readlen = 0    -- length of readlist\r
82 local writelen = 0    -- lenght of writelist\r
83 \r
84 local sendstat= 0\r
85 local receivestat = 0\r
86 \r
87 ----------------------------------// DEFINITION //--\r
88 \r
89 listener = { }    -- key = port, value = table\r
90 readlist = { }    -- array with sockets to read from\r
91 writelist = { }    -- arrary with sockets to write to\r
92 socketlist = { }    -- key = socket, value = wrapped socket\r
93 timelistener = { }\r
94 \r
95 stats = function( )\r
96         return receivestat, sendstat\r
97 end\r
98 \r
99 wrapserver = function( listener, socket, ip, serverport, mode, sslctx )    -- this function wraps a server\r
100 \r
101         local dispatch, disconnect = listener.listener, listener.disconnect    -- dangerous\r
102 \r
103         local wrapclient, err\r
104 \r
105         if sslctx then\r
106                 if not ssl_newcontext then\r
107                         return nil, "luasec not found"\r
108 --        elseif not cfg_get "use_ssl" then\r
109 --            return nil, "ssl is deactivated"\r
110                 end\r
111                 if type( sslctx ) ~= "table" then\r
112                         out_error "server.lua: wrong server sslctx"\r
113                         return nil, "wrong server sslctx"\r
114                 end\r
115                 sslctx, err = ssl_newcontext( sslctx )\r
116                 if not sslctx then\r
117                         err = err or "wrong sslctx parameters"\r
118                         out_error( "server.lua: ", err )\r
119                         return nil, err\r
120                 end\r
121                 wrapclient = wrapsslclient\r
122         else\r
123                 wrapclient = wraptcpclient\r
124         end\r
125 \r
126         local accept = socket.accept\r
127         local close = socket.close\r
128 \r
129         --// public methods of the object //--    \r
130 \r
131         local handler = { }\r
132 \r
133         handler.shutdown = function( ) end\r
134 \r
135         --[[handler.listener = function( data, err )\r
136                 return ondata( handler, data, err )\r
137         end]]\r
138         handler.ssl = function( )\r
139                 return sslctx and true or false\r
140         end\r
141         handler.close = function( closed )\r
142                 _ = not closed and close( socket )\r
143                 writelen = removesocket( writelist, socket, writelen )\r
144                 readlen = removesocket( readlist, socket, readlen )\r
145                 socketlist[ socket ] = nil\r
146                 handler = nil\r
147         end\r
148         handler.ip = function( )\r
149                 return ip\r
150         end\r
151         handler.serverport = function( )\r
152                 return serverport\r
153         end\r
154         handler.socket = function( )\r
155                 return socket\r
156         end\r
157         handler.receivedata = function( )\r
158                 local client, err = accept( socket )    -- try to accept\r
159                 if client then\r
160                         local ip, clientport = client:getpeername( )\r
161                         client:settimeout( 0 )\r
162                         local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx )    -- wrap new client socket\r
163                         if err then    -- error while wrapping ssl socket\r
164                                 return false\r
165                         end\r
166                         out_put( "server.lua: accepted new client connection from ", ip, ":", clientport )\r
167                         return dispatch( handler )\r
168                 elseif err then    -- maybe timeout or something else\r
169                         out_put( "server.lua: error with new client connection: ", err )\r
170                         return false\r
171                 end\r
172         end\r
173         return handler\r
174 end\r
175 \r
176 wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx )    -- this function wraps a ssl cleint\r
177 \r
178         local dispatch, disconnect = listener.listener, listener.disconnect\r
179 \r
180         --// transform socket to ssl object //--\r
181 \r
182         local err\r
183         socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
184         if err then\r
185                 out_put( "server.lua: ssl error: ", err )\r
186                 return nil, nil, err    -- fatal error\r
187         end\r
188         socket:settimeout( 0 )\r
189 \r
190         --// private closures of the object //--\r
191 \r
192         local writequeue = { }    -- buffer for messages to send\r
193 \r
194         local eol   -- end of buffer\r
195 \r
196         local sstat, rstat = 0, 0\r
197 \r
198         --// local import of socket methods //--\r
199 \r
200         local send = socket.send\r
201         local receive = socket.receive\r
202         local close = socket.close\r
203         --local shutdown = socket.shutdown\r
204 \r
205         --// public methods of the object //--\r
206 \r
207         local handler = { }\r
208 \r
209         handler.getstats = function( )\r
210                 return rstat, sstat\r
211         end\r
212 \r
213         handler.listener = function( data, err )\r
214                 return listener( handler, data, err )\r
215         end\r
216         handler.ssl = function( )\r
217                 return true\r
218         end\r
219         handler.send = function( _, data, i, j )\r
220                         return send( socket, data, i, j )\r
221         end\r
222         handler.receive = function( pattern, prefix )\r
223                         return receive( socket, pattern, prefix )\r
224         end\r
225         handler.shutdown = function( pattern )\r
226                 --return shutdown( socket, pattern )\r
227         end\r
228         handler.close = function( closed )\r
229                 close( socket )\r
230                 writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
231                 readlen = removesocket( readlist, socket, readlen )\r
232                 socketlist[ socket ] = nil\r
233                 out_put "server.lua: closed handler and removed socket from list"\r
234         end\r
235         handler.ip = function( )\r
236                 return ip\r
237         end\r
238         handler.serverport = function( )\r
239                 return serverport\r
240         end\r
241         handler.clientport = function( ) \r
242                 return clientport\r
243         end\r
244 \r
245         handler.write = function( data )\r
246                 if not eol then\r
247                         writelen = writelen + 1\r
248                         writelist[ writelen ] = socket\r
249                         eol = 0\r
250                 end\r
251                 eol = eol + 1\r
252                 writequeue[ eol ] = data\r
253         end\r
254         handler.writequeue = function( )\r
255                 return writequeue\r
256         end\r
257         handler.socket = function( )\r
258                 return socket\r
259         end\r
260         handler.mode = function( )\r
261                 return mode\r
262         end\r
263         handler._receivedata = function( )\r
264                 local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
265                 if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
266                         local data = data or part or ""\r
267                         local count = #data * STAT_UNIT\r
268                         rstat = rstat + count\r
269                         receivestat = receivestat + count\r
270                         out_put( "server.lua: read data '", data, "', error: ", err )\r
271                         return dispatch( handler, data, err )\r
272                 else    -- connections was closed or fatal error\r
273                         out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
274                         handler.close( )\r
275                         disconnect( handler, err )\r
276                         writequeue = nil\r
277                         handler = nil\r
278                         return false\r
279                 end\r
280         end\r
281         handler._dispatchdata = function( )    -- this function writes data to handlers\r
282                 local buffer = table_concat( writequeue, "", 1, eol )\r
283                 local succ, err, byte = send( socket, buffer )\r
284                 local count = ( succ or 0 ) * STAT_UNIT\r
285                 sstat = sstat + count\r
286                 sendstat = sendstat + count\r
287                 out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
288                 if succ then    -- sending succesful\r
289                         --writequeue = { }\r
290                         eol = nil\r
291                         writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
292                         return true\r
293                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
294                         buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
295                         writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
296                         eol = 1\r
297                         return true\r
298                 else    -- connection was closed during sending or fatal error\r
299                         out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
300                         handler.close( )\r
301                         disconnect( handler, err )\r
302                         writequeue = nil\r
303                         handler = nil\r
304                         return false\r
305                 end\r
306         end\r
307 \r
308         -- // COMPAT // --\r
309 \r
310         handler.getIp = handler.ip\r
311         handler.getPort = handler.clientport\r
312 \r
313         --// handshake //--\r
314 \r
315         local wrote\r
316 \r
317         handler.handshake = coroutine_wrap( function( client )\r
318                         local err\r
319                         for i = 1, 10 do    -- 10 handshake attemps\r
320                                 _, err = client:dohandshake( )\r
321                                 if not err then\r
322                                         out_put( "server.lua: ssl handshake done" )\r
323                                         writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen\r
324                                         handler.receivedata = handler._receivedata    -- when handshake is done, replace the handshake function with regular functions\r
325                                         handler.dispatchdata = handler._dispatchdata\r
326                                         return dispatch( handler )\r
327                                 else\r
328                                         out_put( "server.lua: error during ssl handshake: ", err )\r
329                                         if err == "wantwrite" then\r
330                                                 if wrote == nil then\r
331                                                         writelen = writelen + 1\r
332                                                         writelist[ writelen ] = client\r
333                                                         wrote = true\r
334                                                 end\r
335                                         end\r
336                                         coroutine_yield( handler, nil, err )    -- handshake not finished\r
337                                 end\r
338                         end\r
339                         _ = err ~= "closed" and close( socket )\r
340                         handler.close( )\r
341                         disconnect( handler, err )\r
342                         writequeue = nil\r
343                         handler = nil\r
344                         return false    -- handshake failed\r
345                 end\r
346         )\r
347         handler.receivedata = handler.handshake\r
348         handler.dispatchdata = handler.handshake\r
349 \r
350         handler.handshake( socket )    -- do handshake\r
351 \r
352         socketlist[ socket ] = handler\r
353         readlen = readlen + 1\r
354         readlist[ readlen ] = socket\r
355 \r
356         return handler, socket\r
357 end\r
358 \r
359 wraptcpclient = function( listener, socket, ip, serverport, clientport, mode )    -- this function wraps a socket\r
360 \r
361         local dispatch, disconnect = listener.listener, listener.disconnect\r
362 \r
363         --// private closures of the object //--\r
364 \r
365         local writequeue = { }    -- list for messages to send\r
366 \r
367         local eol\r
368 \r
369         local rstat, sstat = 0, 0\r
370 \r
371         --// local import of socket methods //--\r
372 \r
373         local send = socket.send\r
374         local receive = socket.receive\r
375         local close = socket.close\r
376         local shutdown = socket.shutdown\r
377 \r
378         --// public methods of the object //--\r
379 \r
380         local handler = { }\r
381 \r
382         handler.getstats = function( )\r
383                 return rstat, sstat\r
384         end\r
385 \r
386         handler.listener = function( data, err )\r
387                 return listener( handler, data, err )\r
388         end\r
389         handler.ssl = function( )\r
390                 return false\r
391         end\r
392         handler.send = function( _, data, i, j )\r
393                         return send( socket, data, i, j )\r
394         end\r
395         handler.receive = function( pattern, prefix )\r
396                         return receive( socket, pattern, prefix )\r
397         end\r
398         handler.shutdown = function( pattern )\r
399                 return shutdown( socket, pattern )\r
400         end\r
401         handler.close = function( closed )\r
402                 _ = not closed and shutdown( socket )\r
403                 _ = not closed and close( socket )\r
404                 writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen\r
405                 readlen = removesocket( readlist, socket, readlen )\r
406                 socketlist[ socket ] = nil\r
407                 out_put "server.lua: closed handler and removed socket from list"\r
408         end\r
409         handler.ip = function( )\r
410                 return ip\r
411         end\r
412         handler.serverport = function( )\r
413                 return serverport\r
414         end\r
415         handler.clientport = function( ) \r
416                 return clientport\r
417         end\r
418         handler.write = function( data )\r
419                 if not eol then\r
420                         writelen = writelen + 1\r
421                         writelist[ writelen ] = socket\r
422                         eol = 0\r
423                 end\r
424                 eol = eol + 1\r
425                 writequeue[ eol ] = data\r
426         end\r
427         handler.writequeue = function( )\r
428                 return writequeue\r
429         end\r
430         handler.socket = function( )\r
431                 return socket\r
432         end\r
433         handler.mode = function( )\r
434                 return mode\r
435         end\r
436         handler.receivedata = function( )\r
437                 local data, err, part = receive( socket, mode )    -- receive data in "mode"\r
438                 if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
439                         local data = data or part or ""\r
440                         local count = #data * STAT_UNIT\r
441                         rstat = rstat + count\r
442                         receivestat = receivestat + count\r
443                         out_put( "server.lua: read data '", data, "', error: ", err )\r
444                         return dispatch( handler, data, err )\r
445                 else    -- connections was closed or fatal error\r
446                         out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
447                         handler.close( )\r
448                         disconnect( handler, err )\r
449                         writequeue = nil\r
450                         handler = nil\r
451                         return false\r
452                 end\r
453         end\r
454         handler.dispatchdata = function( )    -- this function writes data to handlers\r
455                 local buffer = table_concat( writequeue, "", 1, eol )\r
456                 local succ, err, byte = send( socket, buffer )\r
457                 local count = ( succ or 0 ) * STAT_UNIT\r
458                 sstat = sstat + count\r
459                 sendstat = sendstat + count\r
460                 out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
461                 if succ then    -- sending succesful\r
462                         --writequeue = { }\r
463                         eol = nil\r
464                         writelen = removesocket( writelist, socket, writelen )    -- delete socket from writelist\r
465                         return true\r
466                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
467                         buffer = string_sub( buffer, byte + 1, -1 )    -- new buffer\r
468                         writequeue[ 1 ] = buffer    -- insert new buffer in queue\r
469                         eol = 1\r
470                         return true\r
471                 else    -- connection was closed during sending or fatal error\r
472                         out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
473                         handler.close( )\r
474                         disconnect( handler, err )\r
475                         writequeue = nil\r
476                         handler = nil\r
477                         return false\r
478                 end\r
479         end\r
480 \r
481         -- // COMPAT // --\r
482 \r
483         handler.getIp = handler.ip\r
484         handler.getPort = handler.clientport\r
485 \r
486         socketlist[ socket ] = handler\r
487         readlen = readlen + 1\r
488         readlist[ readlen ] = socket\r
489 \r
490         return handler, socket\r
491 end\r
492 \r
493 addtimer = function( listener )\r
494         timelistener[ #timelistener + 1 ] = listener\r
495 end\r
496 \r
497 firetimer = function( listener )\r
498         for i, listener in ipairs( timelistener ) do\r
499                 listener( )\r
500         end\r
501 end\r
502 \r
503 addserver = function( listeners, port, addr, mode, sslctx )    -- this function provides a way for other scripts to reg a server\r
504         local err\r
505         if type( listeners ) ~= "table" then\r
506                 err = "invalid listener table"\r
507         else\r
508                 for name, func in pairs( listeners ) do\r
509                         if type( func ) ~= "function" then\r
510                                 err = "invalid listener function"\r
511                                 break\r
512                         end\r
513                 end\r
514         end\r
515         if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then\r
516                 err = "invalid port"\r
517         elseif listener[ port ] then\r
518                 err=  "listeners on port '" .. port .. "' already exist"\r
519         elseif sslctx and not luasec then\r
520                 err = "luasec not found"\r
521         end\r
522         if err then\r
523                 out_error( "server.lua: ", err )\r
524                 return nil, err\r
525         end\r
526         addr = addr or "*"\r
527         local server, err = socket_bind( addr, port )\r
528         if err then\r
529                 out_error( "server.lua: ", err )\r
530                 return nil, err\r
531         end\r
532         local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx )    -- wrap new server socket\r
533         if not handler then\r
534                 server:close( )\r
535                 return nil, err\r
536         end\r
537         server:settimeout( 0 )\r
538         readlen = readlen + 1\r
539         readlist[ readlen ] = server\r
540         listener[ port ] = listeners\r
541         socketlist[ server ] = handler\r
542         out_put( "server.lua: new server listener on ", addr, ":", port )\r
543         return true\r
544 end\r
545 \r
546 removesocket = function( tbl, socket, len )    -- this function removes sockets from a list\r
547         for i, target in ipairs( tbl ) do\r
548                 if target == socket then\r
549                         len = len - 1\r
550                         table_remove( tbl, i )\r
551                         return len\r
552                 end\r
553         end\r
554         return len\r
555 end\r
556 \r
557 closeall = function( )\r
558         for sock, handler in pairs( socketlist ) do\r
559                 handler.shutdown( )\r
560                 handler.close( )\r
561                 socketlist[ sock ] = nil\r
562         end\r
563         writelist, readlist, socketlist = { }, { }, { }\r
564 end\r
565 \r
566 closesocket = function( socket )\r
567         writelen = removesocket( writelist, socket, writelen )\r
568         readlen = removesocket( readlist, socket, readlen )\r
569         socketlist[ socket ] = nil\r
570         socket:close( )\r
571 end\r
572 \r
573 loop = function( )    -- this is the main loop of the program\r
574         --signal_set( "hub", "run" )\r
575         repeat\r
576                 local read, write, err = socket_select( readlist, writelist, 1 )    -- 1 sec timeout, nice for timers\r
577                 for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
578                         local handler = socketlist[ socket ]\r
579                         if handler then\r
580                                 handler.dispatchdata( )\r
581                         else\r
582                                 closesocket( socket )\r
583                                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen\r
584                         end\r
585                 end\r
586                 for i, socket in ipairs( read ) do    -- receive data\r
587                         local handler = socketlist[ socket ]\r
588                         if handler then\r
589                                 handler.receivedata( )\r
590                         else\r
591                                 closesocket( socket )\r
592                                 out_put "server.lua: found no handler and closed socket (readlist)"    -- this can happen\r
593                         end\r
594                 end\r
595                 firetimer( )\r
596                 --collectgarbage "collect"\r
597         until false --signal_get "hub" ~= "run"\r
598         return --signal_get "hub"\r
599 end\r
600 \r
601 ----------------------------------// BEGIN //--\r
602 \r
603 ----------------------------------// PUBLIC INTERFACE //--\r
604 \r
605 return {\r
606 \r
607         add = addserver,\r
608         loop = loop,\r
609         stats = stats,\r
610         closeall = closeall,\r
611         addtimer = addtimer,\r
612 \r
613 }\r