net.server: Much improve SSL/TLS error reporting, do our best to understand and hide...
[prosody.git] / net / server.lua
1 -- \r
2 -- server.lua by blastbeat of the luadch project\r
3 -- Re-used here under the MIT/X Consortium License\r
4 -- \r
5 -- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain\r
6 --\r
7 \r
8 -- // wrapping luadch stuff // --\r
9 \r
10 local use = function( what )\r
11     return _G[ what ]\r
12 end\r
13 local clean = function( tbl )\r
14     for i, k in pairs( tbl ) do\r
15         tbl[ i ] = nil\r
16     end\r
17 end\r
18 \r
19 local log, table_concat = require ("util.logger").init("socket"), table.concat;\r
20 local out_put = function (...) return log("debug", table_concat{...}); end\r
21 local out_error = function (...) return log("warn", table_concat{...}); end\r
22 local mem_free = collectgarbage\r
23 \r
24 ----------------------------------// DECLARATION //--\r
25 \r
26 --// constants //--\r
27 \r
28 local STAT_UNIT = 1    -- byte\r
29 \r
30 --// lua functions //--\r
31 \r
32 local type = use "type"\r
33 local pairs = use "pairs"\r
34 local ipairs = use "ipairs"\r
35 local tostring = use "tostring"\r
36 local collectgarbage = use "collectgarbage"\r
37 \r
38 --// lua libs //--\r
39 \r
40 local os = use "os"\r
41 local table = use "table"\r
42 local string = use "string"\r
43 local coroutine = use "coroutine"\r
44 \r
45 --// lua lib methods //--\r
46 \r
47 local os_time = os.time\r
48 local os_difftime = os.difftime\r
49 local table_concat = table.concat\r
50 local table_remove = table.remove\r
51 local string_len = string.len\r
52 local string_sub = string.sub\r
53 local coroutine_wrap = coroutine.wrap\r
54 local coroutine_yield = coroutine.yield\r
55 \r
56 --// extern libs //--\r
57 \r
58 local luasec = select( 2, pcall( require, "ssl" ) )\r
59 local luasocket = require "socket"\r
60 \r
61 --// extern lib methods //--\r
62 \r
63 local ssl_wrap = ( luasec and luasec.wrap )\r
64 local socket_bind = luasocket.bind\r
65 local socket_sleep = luasocket.sleep\r
66 local socket_select = luasocket.select\r
67 local ssl_newcontext = ( luasec and luasec.newcontext )\r
68 \r
69 --// functions //--\r
70 \r
71 local id\r
72 local loop\r
73 local stats\r
74 local idfalse\r
75 local addtimer\r
76 local closeall\r
77 local addserver\r
78 local getserver\r
79 local wrapserver\r
80 local getsettings\r
81 local closesocket\r
82 local removesocket\r
83 local removeserver\r
84 local changetimeout\r
85 local wrapconnection\r
86 local changesettings\r
87 \r
88 --// tables //--\r
89 \r
90 local _server\r
91 local _readlist\r
92 local _timerlist\r
93 local _sendlist\r
94 local _socketlist\r
95 local _closelist\r
96 local _readtimes\r
97 local _writetimes\r
98 \r
99 --// simple data types //--\r
100 \r
101 local _\r
102 local _readlistlen\r
103 local _sendlistlen\r
104 local _timerlistlen\r
105 \r
106 local _sendtraffic\r
107 local _readtraffic\r
108 \r
109 local _selecttimeout\r
110 local _sleeptime\r
111 \r
112 local _starttime\r
113 local _currenttime\r
114 \r
115 local _maxsendlen\r
116 local _maxreadlen\r
117 \r
118 local _checkinterval\r
119 local _sendtimeout\r
120 local _readtimeout\r
121 \r
122 local _cleanqueue\r
123 \r
124 local _timer\r
125 \r
126 local _maxclientsperserver\r
127 \r
128 ----------------------------------// DEFINITION //--\r
129 \r
130 _server = { }    -- key = port, value = table; list of listening servers\r
131 _readlist = { }    -- array with sockets to read from\r
132 _sendlist = { }    -- arrary with sockets to write to\r
133 _timerlist = { }    -- array of timer functions\r
134 _socketlist = { }    -- key = socket, value = wrapped socket (handlers)\r
135 _readtimes = { }   -- key = handler, value = timestamp of last data reading\r
136 _writetimes = { }   -- key = handler, value = timestamp of last data writing/sending\r
137 _closelist = { }    -- handlers to close\r
138 \r
139 _readlistlen = 0    -- length of readlist\r
140 _sendlistlen = 0    -- length of sendlist\r
141 _timerlistlen = 0    -- lenght of timerlist\r
142 \r
143 _sendtraffic = 0    -- some stats\r
144 _readtraffic = 0\r
145 \r
146 _selecttimeout = 1    -- timeout of socket.select\r
147 _sleeptime = 0    -- time to wait at the end of every loop\r
148 \r
149 _maxsendlen = 51000 * 1024    -- max len of send buffer\r
150 _maxreadlen = 25000 * 1024    -- max len of read buffer\r
151 \r
152 _checkinterval = 1200000    -- interval in secs to check idle clients\r
153 _sendtimeout = 60000    -- allowed send idle time in secs\r
154 _readtimeout = 6 * 60 * 60    -- allowed read idle time in secs\r
155 \r
156 _cleanqueue = false    -- clean bufferqueue after using\r
157 \r
158 _maxclientsperserver = 1000\r
159 \r
160 ----------------------------------// PRIVATE //--\r
161 \r
162 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl )    -- this function wraps a server\r
163 \r
164     maxconnections = maxconnections or _maxclientsperserver\r
165 \r
166     local connections = 0\r
167 \r
168     local dispatch, disconnect = listeners.incoming or listeners.listener, listeners.disconnect\r
169 \r
170     local err\r
171 \r
172     local ssl = false\r
173 \r
174     if sslctx then\r
175         ssl = true\r
176         if not ssl_newcontext then\r
177             out_error "luasec not found"\r
178             ssl = false\r
179         end\r
180         if type( sslctx ) ~= "table" then\r
181             out_error "server.lua: wrong server sslctx"\r
182             ssl = false\r
183         end\r
184         local ctx;\r
185         ctx, err = ssl_newcontext( sslctx )\r
186         if not ctx then\r
187             err = err or "wrong sslctx parameters"\r
188             local file;\r
189             file = err:match("^error loading (.-) %(");\r
190             if file then\r
191                 if file == "private key" then\r
192                         file = sslctx.key or "your private key";\r
193                 elseif file == "certificate" then\r
194                         file = sslctx.certificate or "your certificate file";\r
195                 end\r
196                 local reason = err:match("%((.+)%)$") or "some reason";\r
197                 if reason == "Permission denied" then\r
198                         reason = "Check that the permissions allow Prosody to read this file.";\r
199                 elseif reason == "No such file or directory" then\r
200                         reason = "Check that the path is correct, and the file exists.";\r
201                 elseif reason == "system lib" then\r
202                         reason = "Previous error (see logs), or other system error.";\r
203                 else\r
204                         reason = "Reason: "..tostring(reason or "unknown"):lower();\r
205                 end\r
206                 log("error", "SSL/TLS: Failed to load %s: %s", file, reason);\r
207             else\r
208                 log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );\r
209             end\r
210             ssl = false\r
211         end\r
212         sslctx = ctx;\r
213     end\r
214     if not ssl then\r
215       sslctx = false;\r
216       if startssl then\r
217          log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )\r
218          return nil, "Cannot start ssl,  see log for details"\r
219        end\r
220     end\r
221 \r
222     local accept = socket.accept\r
223 \r
224     --// public methods of the object //--\r
225 \r
226     local handler = { }\r
227 \r
228     handler.shutdown = function( ) end\r
229 \r
230     handler.ssl = function( )\r
231         return ssl\r
232     end\r
233     handler.remove = function( )\r
234         connections = connections - 1\r
235     end\r
236     handler.close = function( )\r
237         for _, handler in pairs( _socketlist ) do\r
238             if handler.serverport == serverport then\r
239                 handler.disconnect( handler, "server closed" )\r
240                 handler.close( true )\r
241             end\r
242         end\r
243         socket:close( )\r
244         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
245         _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
246         _socketlist[ socket ] = nil\r
247         handler = nil\r
248         socket = nil\r
249         mem_free( )\r
250         out_put "server.lua: closed server handler and removed sockets from list"\r
251     end\r
252     handler.ip = function( )\r
253         return ip\r
254     end\r
255     handler.serverport = function( )\r
256         return serverport\r
257     end\r
258     handler.socket = function( )\r
259         return socket\r
260     end\r
261     handler.readbuffer = function( )\r
262         if connections > maxconnections then\r
263             out_put( "server.lua: refused new client connection: server full" )\r
264             return false\r
265         end\r
266         local client, err = accept( socket )    -- try to accept\r
267         if client then\r
268             local ip, clientport = client:getpeername( )\r
269             client:settimeout( 0 )\r
270             local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl )    -- wrap new client socket\r
271             if err then    -- error while wrapping ssl socket\r
272                 return false\r
273             end\r
274             connections = connections + 1\r
275             out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))\r
276             return dispatch( handler )\r
277         elseif err then    -- maybe timeout or something else\r
278             out_put( "server.lua: error with new client connection: ", tostring(err) )\r
279             return false\r
280         end\r
281     end\r
282     return handler\r
283 end\r
284 \r
285 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl )    -- this function wraps a client to a handler object\r
286 \r
287     socket:settimeout( 0 )\r
288 \r
289     --// local import of socket methods //--\r
290 \r
291     local send\r
292     local receive\r
293     local shutdown\r
294 \r
295     --// private closures of the object //--\r
296 \r
297     local ssl\r
298 \r
299     local dispatch = listeners.incoming or listeners.listener\r
300     local disconnect = listeners.disconnect\r
301 \r
302     local bufferqueue = { }    -- buffer array\r
303     local bufferqueuelen = 0    -- end of buffer array\r
304 \r
305     local toclose\r
306     local fatalerror\r
307     local needtls\r
308 \r
309     local bufferlen = 0\r
310 \r
311     local noread = false\r
312     local nosend = false\r
313 \r
314     local sendtraffic, readtraffic = 0, 0\r
315 \r
316     local maxsendlen = _maxsendlen\r
317     local maxreadlen = _maxreadlen\r
318 \r
319     --// public methods of the object //--\r
320 \r
321     local handler = bufferqueue    -- saves a table ^_^\r
322 \r
323     handler.dispatch = function( )\r
324         return dispatch\r
325     end\r
326     handler.disconnect = function( )\r
327         return disconnect\r
328     end\r
329     handler.setlistener = function( listeners )\r
330         dispatch = listeners.incoming\r
331         disconnect = listeners.disconnect\r
332     end\r
333     handler.getstats = function( )\r
334         return readtraffic, sendtraffic\r
335     end\r
336     handler.ssl = function( )\r
337         return ssl\r
338     end\r
339     handler.send = function( _, data, i, j )\r
340         return send( socket, data, i, j )\r
341     end\r
342     handler.receive = function( pattern, prefix )\r
343         return receive( socket, pattern, prefix )\r
344     end\r
345     handler.shutdown = function( pattern )\r
346         return shutdown( socket, pattern )\r
347     end\r
348     handler.close = function( forced )\r
349         if not handler then return true; end\r
350         _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
351         _readtimes[ handler ] = nil\r
352         if bufferqueuelen ~= 0 then\r
353             if not ( forced or fatalerror ) then\r
354                 handler.sendbuffer( )\r
355                 if bufferqueuelen ~= 0 then   -- try again...\r
356                     if handler then\r
357                         handler.write = nil    -- ... but no further writing allowed\r
358                     end\r
359                     toclose = true\r
360                     return false\r
361                 end\r
362             else\r
363                 send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen )    -- forced send\r
364             end\r
365         end\r
366         _ = shutdown and shutdown( socket )\r
367         socket:close( )\r
368         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
369         _socketlist[ socket ] = nil\r
370         if handler then\r
371             _writetimes[ handler ] = nil\r
372             _closelist[ handler ] = nil\r
373             handler = nil\r
374         end\r
375         socket = nil\r
376         mem_free( )\r
377         if server then\r
378                 server.remove( )\r
379         end\r
380         out_put "server.lua: closed client handler and removed socket from list"\r
381         return true\r
382     end\r
383     handler.ip = function( )\r
384         return ip\r
385     end\r
386     handler.serverport = function( )\r
387         return serverport\r
388     end\r
389     handler.clientport = function( )\r
390         return clientport\r
391     end\r
392     local write = function( data )\r
393         bufferlen = bufferlen + string_len( data )\r
394         if bufferlen > maxsendlen then\r
395             _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle\r
396             handler.write = idfalse    -- dont write anymore\r
397             return false\r
398         elseif socket and not _sendlist[ socket ] then\r
399             _sendlistlen = _sendlistlen + 1\r
400             _sendlist[ _sendlistlen ] = socket\r
401             _sendlist[ socket ] = _sendlistlen\r
402         end\r
403         bufferqueuelen = bufferqueuelen + 1\r
404         bufferqueue[ bufferqueuelen ] = data\r
405         if handler then\r
406                 _writetimes[ handler ] = _writetimes[ handler ] or _currenttime\r
407         end\r
408         return true\r
409     end\r
410     handler.write = write\r
411     handler.bufferqueue = function( )\r
412         return bufferqueue\r
413     end\r
414     handler.socket = function( )\r
415         return socket\r
416     end\r
417     handler.pattern = function( new )\r
418         pattern = new or pattern\r
419         return pattern\r
420     end\r
421     handler.setsend = function ( newsend )\r
422         send = newsend or send\r
423         return send\r
424     end\r
425     handler.bufferlen = function( readlen, sendlen )\r
426         maxsendlen = sendlen or maxsendlen\r
427         maxreadlen = readlen or maxreadlen\r
428         return maxreadlen, maxsendlen\r
429     end\r
430     handler.lock = function( switch )\r
431         if switch == true then\r
432             handler.write = idfalse\r
433             local tmp = _sendlistlen\r
434             _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
435             _writetimes[ handler ] = nil\r
436             if _sendlistlen ~= tmp then\r
437                 nosend = true\r
438             end\r
439             tmp = _readlistlen\r
440             _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
441             _readtimes[ handler ] = nil\r
442             if _readlistlen ~= tmp then\r
443                 noread = true\r
444             end\r
445         elseif switch == false then\r
446             handler.write = write\r
447             if noread then\r
448                 noread = false\r
449                 _readlistlen = _readlistlen + 1\r
450                 _readlist[ socket ] = _readlistlen\r
451                 _readlist[ _readlistlen ] = socket\r
452                 _readtimes[ handler ] = _currenttime\r
453             end\r
454             if nosend then\r
455                 nosend = false\r
456                 write( "" )\r
457             end\r
458         end\r
459         return noread, nosend\r
460     end\r
461     local _readbuffer = function( )    -- this function reads data\r
462         local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"\r
463         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
464             local buffer = buffer or part or ""\r
465             local len = string_len( buffer )\r
466             if len > maxreadlen then\r
467                 disconnect( handler, "receive buffer exceeded" )\r
468                 handler.close( true )\r
469                 return false\r
470             end\r
471             local count = len * STAT_UNIT\r
472             readtraffic = readtraffic + count\r
473             _readtraffic = _readtraffic + count\r
474             _readtimes[ handler ] = _currenttime\r
475             --out_put( "server.lua: read data '", buffer, "', error: ", err )\r
476             return dispatch( handler, buffer, err )\r
477         else    -- connections was closed or fatal error\r
478             out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
479             fatalerror = true\r
480             disconnect( handler, err )\r
481             _ = handler and handler.close( )\r
482             return false\r
483         end\r
484     end\r
485     local _sendbuffer = function( )    -- this function sends data\r
486         local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
487         local succ, err, byte = send( socket, buffer, 1, bufferlen )\r
488         local count = ( succ or byte or 0 ) * STAT_UNIT\r
489         sendtraffic = sendtraffic + count\r
490         _sendtraffic = _sendtraffic + count\r
491         _ = _cleanqueue and clean( bufferqueue )\r
492         --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )\r
493         if succ then    -- sending succesful\r
494             bufferqueuelen = 0\r
495             bufferlen = 0\r
496             _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )    -- delete socket from writelist\r
497             _ = needtls and handler.starttls(true)\r
498             _writetimes[ handler ] = nil\r
499             _ = toclose and handler.close( )\r
500             return true\r
501         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
502             buffer = string_sub( buffer, byte + 1, bufferlen )    -- new buffer\r
503             bufferqueue[ 1 ] = buffer    -- insert new buffer in queue\r
504             bufferqueuelen = 1\r
505             bufferlen = bufferlen - byte\r
506             _writetimes[ handler ] = _currenttime\r
507             return true\r
508         else    -- connection was closed during sending or fatal error\r
509             out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
510             fatalerror = true\r
511             disconnect( handler, err )\r
512             _ = handler and handler.close( )\r
513             return false\r
514         end\r
515     end\r
516 \r
517     if sslctx then    -- ssl?\r
518         ssl = true\r
519         local wrote\r
520         local read\r
521         local handshake = coroutine_wrap( function( client )    -- create handshake coroutine\r
522                 local err\r
523                 for i = 1, 10 do    -- 10 handshake attemps\r
524                     _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen\r
525                     _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen\r
526                     read, wrote = nil, nil\r
527                     _, err = client:dohandshake( )\r
528                     if not err then\r
529                         out_put( "server.lua: ssl handshake done" )\r
530                         handler.readbuffer = _readbuffer    -- when handshake is done, replace the handshake function with regular functions\r
531                         handler.sendbuffer = _sendbuffer\r
532                         -- return dispatch( handler )\r
533                         return true\r
534                     else\r
535                         out_put( "server.lua: error during ssl handshake: ", tostring(err) )\r
536                         if err == "wantwrite" and not wrote then\r
537                             _sendlistlen = _sendlistlen + 1\r
538                             _sendlist[ _sendlistlen ] = client\r
539                             wrote = true\r
540                         elseif err == "wantread" and not read then\r
541                                 _readlistlen = _readlistlen + 1\r
542                                 _readlist [ _readlistlen ] = client\r
543                                 read = true\r
544                         else\r
545                                 break;\r
546                         end\r
547                         --coroutine_yield( handler, nil, err )    -- handshake not finished\r
548                         coroutine_yield( )\r
549                     end\r
550                 end\r
551                 disconnect( handler, "ssl handshake failed" )\r
552                 _ = handler and handler.close( true )    -- forced disconnect\r
553                 return false    -- handshake failed\r
554             end\r
555         )\r
556         if startssl then    -- ssl now?\r
557             --out_put("server.lua: ", "starting ssl handshake")\r
558             local err\r
559             socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
560             if err then\r
561                 out_put( "server.lua: ssl error: ", tostring(err) )\r
562                 mem_free( )\r
563                 return nil, nil, err    -- fatal error\r
564             end\r
565             socket:settimeout( 0 )\r
566             handler.readbuffer = handshake\r
567             handler.sendbuffer = handshake\r
568             handshake( socket ) -- do handshake\r
569             if not socket then\r
570                 return nil, nil, "ssl handshake failed";\r
571             end\r
572         else\r
573             -- We're not automatically doing SSL, so we're not secure (yet)\r
574             ssl = false\r
575             handler.starttls = function( now )\r
576                 if not now then\r
577                     --out_put "server.lua: we need to do tls, but delaying until later"\r
578                     needtls = true\r
579                     return\r
580                 end\r
581                 --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )\r
582                 local oldsocket, err = socket\r
583                 socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
584                 --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )\r
585                 if err then\r
586                     out_put( "server.lua: error while starting tls on client: ", tostring(err) )\r
587                     return nil, err    -- fatal error\r
588                 end\r
589 \r
590                 socket:settimeout( 0 )\r
591 \r
592                 -- add the new socket to our system\r
593 \r
594                 send = socket.send\r
595                 receive = socket.receive\r
596                 shutdown = id\r
597 \r
598                 _socketlist[ socket ] = handler\r
599                 _readlistlen = _readlistlen + 1\r
600                 _readlist[ _readlistlen ] = socket\r
601                 _readlist[ socket ] = _readlistlen\r
602 \r
603                 -- remove traces of the old socket\r
604 \r
605                 _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )\r
606                 _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )\r
607                 _socketlist[ oldsocket ] = nil\r
608 \r
609                 handler.starttls = nil\r
610                 needtls = nil\r
611                 \r
612                 -- Secure now\r
613                 ssl = true\r
614 \r
615                 handler.readbuffer = handshake\r
616                 handler.sendbuffer = handshake\r
617                 handshake( socket )    -- do handshake\r
618             end\r
619             handler.readbuffer = _readbuffer\r
620             handler.sendbuffer = _sendbuffer\r
621         end\r
622     else    -- normal connection\r
623         ssl = false\r
624         handler.readbuffer = _readbuffer\r
625         handler.sendbuffer = _sendbuffer\r
626     end\r
627 \r
628     send = socket.send\r
629     receive = socket.receive\r
630     shutdown = ( ssl and id ) or socket.shutdown\r
631 \r
632     _socketlist[ socket ] = handler\r
633     _readlistlen = _readlistlen + 1\r
634     _readlist[ _readlistlen ] = socket\r
635     _readlist[ socket ] = _readlistlen\r
636 \r
637     return handler, socket\r
638 end\r
639 \r
640 id = function( )\r
641 end\r
642 \r
643 idfalse = function( )\r
644     return false\r
645 end\r
646 \r
647 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )\r
648     local pos = list[ socket ]\r
649     if pos then\r
650         list[ socket ] = nil\r
651         local last = list[ len ]\r
652         list[ len ] = nil\r
653         if last ~= socket then\r
654             list[ last ] = pos\r
655             list[ pos ] = last\r
656         end\r
657         return len - 1\r
658     end\r
659     return len\r
660 end\r
661 \r
662 closesocket = function( socket )\r
663     _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
664     _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
665     _socketlist[ socket ] = nil\r
666     socket:close( )\r
667     mem_free( )\r
668 end\r
669 \r
670 ----------------------------------// PUBLIC //--\r
671 \r
672 addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, startssl )    -- this function provides a way for other scripts to reg a server\r
673     local err\r
674     --out_put("server.lua: autossl on ", port, " is ", startssl)\r
675     if type( listeners ) ~= "table" then\r
676         err = "invalid listener table"\r
677     end\r
678     if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then\r
679         err = "invalid port"\r
680     elseif _server[ port ] then\r
681         err =  "listeners on port '" .. port .. "' already exist"\r
682     elseif sslctx and not luasec then\r
683         err = "luasec not found"\r
684     end\r
685     if err then\r
686         out_error( "server.lua, port ", port, ": ", err )\r
687         return nil, err\r
688     end\r
689     addr = addr or "*"\r
690     local server, err = socket_bind( addr, port )\r
691     if err then\r
692         out_error( "server.lua, port ", port, ": ", err )\r
693         return nil, err\r
694     end\r
695     local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, maxconnections, startssl )    -- wrap new server socket\r
696     if not handler then\r
697         server:close( )\r
698         return nil, err\r
699     end\r
700     server:settimeout( 0 )\r
701     _readlistlen = _readlistlen + 1\r
702     _readlist[ _readlistlen ] = server\r
703     _server[ port ] = handler\r
704     _socketlist[ server ] = handler\r
705     out_put( "server.lua: new server listener on '", addr, ":", port, "'" )\r
706     return handler\r
707 end\r
708 \r
709 getserver = function ( port )\r
710         return _server[ port ];\r
711 end\r
712 \r
713 removeserver = function( port )\r
714     local handler = _server[ port ]\r
715     if not handler then\r
716         return nil, "no server found on port '" .. tostring( port ) "'"\r
717     end\r
718     handler.close( )\r
719     _server[ port ] = nil\r
720     return true\r
721 end\r
722 \r
723 closeall = function( )\r
724     for _, handler in pairs( _socketlist ) do\r
725         handler.close( )\r
726         _socketlist[ _ ] = nil\r
727     end\r
728     _readlistlen = 0\r
729     _sendlistlen = 0\r
730     _timerlistlen = 0\r
731     _server = { }\r
732     _readlist = { }\r
733     _sendlist = { }\r
734     _timerlist = { }\r
735     _socketlist = { }\r
736     mem_free( )\r
737 end\r
738 \r
739 getsettings = function( )\r
740     return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver\r
741 end\r
742 \r
743 changesettings = function( new )\r
744     if type( new ) ~= "table" then\r
745         return nil, "invalid settings table"\r
746     end\r
747     _selecttimeout = tonumber( new.timeout ) or _selecttimeout\r
748     _sleeptime = tonumber( new.sleeptime ) or _sleeptime\r
749     _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen\r
750     _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen\r
751     _checkinterval = tonumber( new.checkinterval ) or _checkinterval\r
752     _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout\r
753     _readtimeout = tonumber( new.readtimeout ) or _readtimeout\r
754     _cleanqueue = new.cleanqueue\r
755     _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver\r
756     return true\r
757 end\r
758 \r
759 addtimer = function( listener )\r
760     if type( listener ) ~= "function" then\r
761         return nil, "invalid listener function"\r
762     end\r
763     _timerlistlen = _timerlistlen + 1\r
764     _timerlist[ _timerlistlen ] = listener\r
765     return true\r
766 end\r
767 \r
768 stats = function( )\r
769     return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen\r
770 end\r
771 \r
772 local dontstop = true; -- thinking about tomorrow, ...\r
773 \r
774 setquitting = function (quit)\r
775         dontstop = not quit;\r
776         return;\r
777 end\r
778 \r
779 loop = function( )    -- this is the main loop of the program\r
780     while dontstop do\r
781         local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )\r
782         for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
783             local handler = _socketlist[ socket ]\r
784             if handler then\r
785                 handler.sendbuffer( )\r
786             else\r
787                 closesocket( socket )\r
788                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen\r
789             end\r
790         end\r
791         for i, socket in ipairs( read ) do    -- receive data\r
792             local handler = _socketlist[ socket ]\r
793             if handler then\r
794                 handler.readbuffer( )\r
795             else\r
796                 closesocket( socket )\r
797                 out_put "server.lua: found no handler and closed socket (readlist)"    -- this can happen\r
798             end\r
799         end\r
800         for handler, err in pairs( _closelist ) do\r
801             handler.disconnect( )( handler, err )\r
802             handler.close( true )    -- forced disconnect\r
803         end\r
804         clean( _closelist )\r
805         _currenttime = os_time( )\r
806         if os_difftime( _currenttime - _timer ) >= 1 then\r
807             for i = 1, _timerlistlen do\r
808                 _timerlist[ i ]( )    -- fire timers\r
809             end\r
810             _timer = _currenttime\r
811         end\r
812         socket_sleep( _sleeptime )    -- wait some time\r
813         --collectgarbage( )\r
814     end\r
815     return "quitting"\r
816 end\r
817 \r
818 --// EXPERIMENTAL //--\r
819 \r
820 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )\r
821     local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )\r
822     _socketlist[ socket ] = handler\r
823     _sendlistlen = _sendlistlen + 1\r
824     _sendlist[ _sendlistlen ] = socket\r
825     _sendlist[ socket ] = _sendlistlen\r
826     return handler, socket\r
827 end\r
828 \r
829 local addclient = function( address, port, listeners, pattern, sslctx, startssl )\r
830     local client, err = luasocket.tcp( )\r
831     if err then\r
832         return nil, err\r
833     end\r
834     client:settimeout( 0 )\r
835     _, err = client:connect( address, port )\r
836     if err then    -- try again\r
837         local handler = wrapclient( client, address, port, listeners )\r
838     else\r
839         wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )\r
840     end\r
841 end\r
842 \r
843 --// EXPERIMENTAL //--\r
844 \r
845 ----------------------------------// BEGIN //--\r
846 \r
847 use "setmetatable" ( _socketlist, { __mode = "k" } )\r
848 use "setmetatable" ( _readtimes, { __mode = "k" } )\r
849 use "setmetatable" ( _writetimes, { __mode = "k" } )\r
850 \r
851 _timer = os_time( )\r
852 _starttime = os_time( )\r
853 \r
854 addtimer( function( )\r
855         local difftime = os_difftime( _currenttime - _starttime )\r
856         if difftime > _checkinterval then\r
857             _starttime = _currenttime\r
858             for handler, timestamp in pairs( _writetimes ) do\r
859                 if os_difftime( _currenttime - timestamp ) > _sendtimeout then\r
860                     --_writetimes[ handler ] = nil\r
861                     handler.disconnect( )( handler, "send timeout" )\r
862                     handler.close( true )    -- forced disconnect\r
863                 end\r
864             end\r
865             for handler, timestamp in pairs( _readtimes ) do\r
866                 if os_difftime( _currenttime - timestamp ) > _readtimeout then\r
867                     --_readtimes[ handler ] = nil\r
868                     handler.disconnect( )( handler, "read timeout" )\r
869                     handler.close( )    -- forced disconnect?\r
870                 end\r
871             end\r
872         end\r
873     end\r
874 )\r
875 \r
876 ----------------------------------// PUBLIC INTERFACE //--\r
877 \r
878 return {\r
879 \r
880     addclient = addclient,\r
881     wrapclient = wrapclient,\r
882     \r
883     loop = loop,\r
884     stats = stats,\r
885     closeall = closeall,\r
886     addtimer = addtimer,\r
887     addserver = addserver,\r
888     getserver = getserver,\r
889     getsettings = getsettings,\r
890     setquitting = setquitting,\r
891     removeserver = removeserver,\r
892     changesettings = changesettings,\r
893 }\r