Merge with trunk
[prosody.git] / net / server_select.lua
1 -- \r
2 -- server.lua by blastbeat of the luadch project\r
3 -- Re-used here under the MIT/X Consortium License\r
4 -- \r
5 -- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain\r
6 --\r
7 \r
8 -- // wrapping luadch stuff // --\r
9 \r
10 local use = function( what )\r
11     return _G[ what ]\r
12 end\r
13 local clean = function( tbl )\r
14     for i, k in pairs( tbl ) do\r
15         tbl[ i ] = nil\r
16     end\r
17 end\r
18 \r
19 local log, table_concat = require ("util.logger").init("socket"), table.concat;\r
20 local out_put = function (...) return log("debug", table_concat{...}); end\r
21 local out_error = function (...) return log("warn", table_concat{...}); end\r
22 local mem_free = collectgarbage\r
23 \r
24 ----------------------------------// DECLARATION //--\r
25 \r
26 --// constants //--\r
27 \r
28 local STAT_UNIT = 1    -- byte\r
29 \r
30 --// lua functions //--\r
31 \r
32 local type = use "type"\r
33 local pairs = use "pairs"\r
34 local ipairs = use "ipairs"\r
35 local tostring = use "tostring"\r
36 local collectgarbage = use "collectgarbage"\r
37 \r
38 --// lua libs //--\r
39 \r
40 local os = use "os"\r
41 local table = use "table"\r
42 local string = use "string"\r
43 local coroutine = use "coroutine"\r
44 \r
45 --// lua lib methods //--\r
46 \r
47 local os_time = os.time\r
48 local os_difftime = os.difftime\r
49 local table_concat = table.concat\r
50 local table_remove = table.remove\r
51 local string_len = string.len\r
52 local string_sub = string.sub\r
53 local coroutine_wrap = coroutine.wrap\r
54 local coroutine_yield = coroutine.yield\r
55 \r
56 --// extern libs //--\r
57 \r
58 local luasec = select( 2, pcall( require, "ssl" ) )\r
59 local luasocket = require "socket"\r
60 \r
61 --// extern lib methods //--\r
62 \r
63 local ssl_wrap = ( luasec and luasec.wrap )\r
64 local socket_bind = luasocket.bind\r
65 local socket_sleep = luasocket.sleep\r
66 local socket_select = luasocket.select\r
67 local ssl_newcontext = ( luasec and luasec.newcontext )\r
68 \r
69 --// functions //--\r
70 \r
71 local id\r
72 local loop\r
73 local stats\r
74 local idfalse\r
75 local addtimer\r
76 local closeall\r
77 local addserver\r
78 local getserver\r
79 local wrapserver\r
80 local getsettings\r
81 local closesocket\r
82 local removesocket\r
83 local removeserver\r
84 local changetimeout\r
85 local wrapconnection\r
86 local changesettings\r
87 \r
88 --// tables //--\r
89 \r
90 local _server\r
91 local _readlist\r
92 local _timerlist\r
93 local _sendlist\r
94 local _socketlist\r
95 local _closelist\r
96 local _readtimes\r
97 local _writetimes\r
98 \r
99 --// simple data types //--\r
100 \r
101 local _\r
102 local _readlistlen\r
103 local _sendlistlen\r
104 local _timerlistlen\r
105 \r
106 local _sendtraffic\r
107 local _readtraffic\r
108 \r
109 local _selecttimeout\r
110 local _sleeptime\r
111 \r
112 local _starttime\r
113 local _currenttime\r
114 \r
115 local _maxsendlen\r
116 local _maxreadlen\r
117 \r
118 local _checkinterval\r
119 local _sendtimeout\r
120 local _readtimeout\r
121 \r
122 local _cleanqueue\r
123 \r
124 local _timer\r
125 \r
126 local _maxclientsperserver\r
127 \r
128 ----------------------------------// DEFINITION //--\r
129 \r
130 _server = { }    -- key = port, value = table; list of listening servers\r
131 _readlist = { }    -- array with sockets to read from\r
132 _sendlist = { }    -- arrary with sockets to write to\r
133 _timerlist = { }    -- array of timer functions\r
134 _socketlist = { }    -- key = socket, value = wrapped socket (handlers)\r
135 _readtimes = { }   -- key = handler, value = timestamp of last data reading\r
136 _writetimes = { }   -- key = handler, value = timestamp of last data writing/sending\r
137 _closelist = { }    -- handlers to close\r
138 \r
139 _readlistlen = 0    -- length of readlist\r
140 _sendlistlen = 0    -- length of sendlist\r
141 _timerlistlen = 0    -- lenght of timerlist\r
142 \r
143 _sendtraffic = 0    -- some stats\r
144 _readtraffic = 0\r
145 \r
146 _selecttimeout = 1    -- timeout of socket.select\r
147 _sleeptime = 0    -- time to wait at the end of every loop\r
148 \r
149 _maxsendlen = 51000 * 1024    -- max len of send buffer\r
150 _maxreadlen = 25000 * 1024    -- max len of read buffer\r
151 \r
152 _checkinterval = 1200000    -- interval in secs to check idle clients\r
153 _sendtimeout = 60000    -- allowed send idle time in secs\r
154 _readtimeout = 6 * 60 * 60    -- allowed read idle time in secs\r
155 \r
156 _cleanqueue = false    -- clean bufferqueue after using\r
157 \r
158 _maxclientsperserver = 1000\r
159 \r
160 _maxsslhandshake = 30 -- max handshake round-trips\r
161 ----------------------------------// PRIVATE //--\r
162 \r
163 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl )    -- this function wraps a server\r
164 \r
165     maxconnections = maxconnections or _maxclientsperserver\r
166 \r
167     local connections = 0\r
168 \r
169     local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect\r
170 \r
171     local err\r
172 \r
173     local ssl = false\r
174 \r
175     if sslctx then\r
176         ssl = true\r
177         if not ssl_newcontext then\r
178             out_error "luasec not found"\r
179             ssl = false\r
180         end\r
181         if type( sslctx ) ~= "table" then\r
182             out_error "server.lua: wrong server sslctx"\r
183             ssl = false\r
184         end\r
185         local ctx;\r
186         ctx, err = ssl_newcontext( sslctx )\r
187         if not ctx then\r
188             err = err or "wrong sslctx parameters"\r
189             local file;\r
190             file = err:match("^error loading (.-) %(");\r
191             if file then\r
192                 if file == "private key" then\r
193                         file = sslctx.key or "your private key";\r
194                 elseif file == "certificate" then\r
195                         file = sslctx.certificate or "your certificate file";\r
196                 end\r
197                 local reason = err:match("%((.+)%)$") or "some reason";\r
198                 if reason == "Permission denied" then\r
199                         reason = "Check that the permissions allow Prosody to read this file.";\r
200                 elseif reason == "No such file or directory" then\r
201                         reason = "Check that the path is correct, and the file exists.";\r
202                 elseif reason == "system lib" then\r
203                         reason = "Previous error (see logs), or other system error.";\r
204                 else\r
205                         reason = "Reason: "..tostring(reason or "unknown"):lower();\r
206                 end\r
207                 log("error", "SSL/TLS: Failed to load %s: %s", file, reason);\r
208             else\r
209                 log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );\r
210             end\r
211             ssl = false\r
212         end\r
213         sslctx = ctx;\r
214     end\r
215     if not ssl then\r
216       sslctx = false;\r
217       if startssl then\r
218          log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )\r
219          return nil, "Cannot start ssl,  see log for details"\r
220        end\r
221     end\r
222 \r
223     local accept = socket.accept\r
224 \r
225     --// public methods of the object //--\r
226 \r
227     local handler = { }\r
228 \r
229     handler.shutdown = function( ) end\r
230 \r
231     handler.ssl = function( )\r
232         return ssl\r
233     end\r
234     handler.sslctx = function( )\r
235         return sslctx\r
236     end\r
237     handler.remove = function( )\r
238         connections = connections - 1\r
239     end\r
240     handler.close = function( )\r
241         for _, handler in pairs( _socketlist ) do\r
242             if handler.serverport == serverport then\r
243                 handler.disconnect( handler, "server closed" )\r
244                 handler:close( true )\r
245             end\r
246         end\r
247         socket:close( )\r
248         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
249         _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
250         _socketlist[ socket ] = nil\r
251         handler = nil\r
252         socket = nil\r
253         --mem_free( )\r
254         out_put "server.lua: closed server handler and removed sockets from list"\r
255     end\r
256     handler.ip = function( )\r
257         return ip\r
258     end\r
259     handler.serverport = function( )\r
260         return serverport\r
261     end\r
262     handler.socket = function( )\r
263         return socket\r
264     end\r
265     handler.readbuffer = function( )\r
266         if connections > maxconnections then\r
267             out_put( "server.lua: refused new client connection: server full" )\r
268             return false\r
269         end\r
270         local client, err = accept( socket )    -- try to accept\r
271         if client then\r
272             local ip, clientport = client:getpeername( )\r
273             client:settimeout( 0 )\r
274             local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl )    -- wrap new client socket\r
275             if err then    -- error while wrapping ssl socket\r
276                 return false\r
277             end\r
278             connections = connections + 1\r
279             out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))\r
280             return dispatch( handler )\r
281         elseif err then    -- maybe timeout or something else\r
282             out_put( "server.lua: error with new client connection: ", tostring(err) )\r
283             return false\r
284         end\r
285     end\r
286     return handler\r
287 end\r
288 \r
289 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl )    -- this function wraps a client to a handler object\r
290 \r
291     socket:settimeout( 0 )\r
292 \r
293     --// local import of socket methods //--\r
294 \r
295     local send\r
296     local receive\r
297     local shutdown\r
298 \r
299     --// private closures of the object //--\r
300 \r
301     local ssl\r
302 \r
303     local dispatch = listeners.onincoming\r
304     local status = listeners.status\r
305     local disconnect = listeners.ondisconnect\r
306 \r
307     local bufferqueue = { }    -- buffer array\r
308     local bufferqueuelen = 0    -- end of buffer array\r
309 \r
310     local toclose\r
311     local fatalerror\r
312     local needtls\r
313 \r
314     local bufferlen = 0\r
315 \r
316     local noread = false\r
317     local nosend = false\r
318 \r
319     local sendtraffic, readtraffic = 0, 0\r
320 \r
321     local maxsendlen = _maxsendlen\r
322     local maxreadlen = _maxreadlen\r
323 \r
324     --// public methods of the object //--\r
325 \r
326     local handler = bufferqueue    -- saves a table ^_^\r
327 \r
328     handler.dispatch = function( )\r
329         return dispatch\r
330     end\r
331     handler.disconnect = function( )\r
332         return disconnect\r
333     end\r
334     handler.setlistener = function( self, listeners )\r
335         dispatch = listeners.onincoming\r
336         disconnect = listeners.ondisconnect\r
337     end\r
338     handler.getstats = function( )\r
339         return readtraffic, sendtraffic\r
340     end\r
341     handler.ssl = function( )\r
342         return ssl\r
343     end\r
344     handler.sslctx = function ( )\r
345         return sslctx\r
346     end\r
347     handler.send = function( _, data, i, j )\r
348         return send( socket, data, i, j )\r
349     end\r
350     handler.receive = function( pattern, prefix )\r
351         return receive( socket, pattern, prefix )\r
352     end\r
353     handler.shutdown = function( pattern )\r
354         return shutdown( socket, pattern )\r
355     end\r
356     handler.close = function( forced )\r
357         if not handler then return true; end\r
358         _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
359         _readtimes[ handler ] = nil\r
360         if bufferqueuelen ~= 0 then\r
361             if not ( forced or fatalerror ) then\r
362                 handler.sendbuffer( )\r
363                 if bufferqueuelen ~= 0 then   -- try again...\r
364                     if handler then\r
365                         handler.write = nil    -- ... but no further writing allowed\r
366                     end\r
367                     toclose = true\r
368                     return false\r
369                 end\r
370             else\r
371                 send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen )    -- forced send\r
372             end\r
373         end\r
374         if socket then\r
375           _ = shutdown and shutdown( socket )\r
376           socket:close( )\r
377           _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
378           _socketlist[ socket ] = nil\r
379           socket = nil\r
380         else\r
381           out_put "server.lua: socket already closed"\r
382         end\r
383         if handler then\r
384             _writetimes[ handler ] = nil\r
385             _closelist[ handler ] = nil\r
386             handler = nil\r
387         end\r
388         if server then\r
389                 server.remove( )\r
390         end\r
391         out_put "server.lua: closed client handler and removed socket from list"\r
392         return true\r
393     end\r
394     handler.ip = function( )\r
395         return ip\r
396     end\r
397     handler.serverport = function( )\r
398         return serverport\r
399     end\r
400     handler.clientport = function( )\r
401         return clientport\r
402     end\r
403     local write = function( self, data )\r
404         bufferlen = bufferlen + string_len( data )\r
405         if bufferlen > maxsendlen then\r
406             _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle\r
407             handler.write = idfalse    -- dont write anymore\r
408             return false\r
409         elseif socket and not _sendlist[ socket ] then\r
410             _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)\r
411         end\r
412         bufferqueuelen = bufferqueuelen + 1\r
413         bufferqueue[ bufferqueuelen ] = data\r
414         if handler then\r
415                 _writetimes[ handler ] = _writetimes[ handler ] or _currenttime\r
416         end\r
417         return true\r
418     end\r
419     handler.write = write\r
420     handler.bufferqueue = function( self )\r
421         return bufferqueue\r
422     end\r
423     handler.socket = function( self )\r
424         return socket\r
425     end\r
426     handler.pattern = function( self, new )\r
427         pattern = new or pattern\r
428         return pattern\r
429     end\r
430     handler.setsend = function ( self, newsend )\r
431         send = newsend or send\r
432         return send\r
433     end\r
434     handler.bufferlen = function( self, readlen, sendlen )\r
435         maxsendlen = sendlen or maxsendlen\r
436         maxreadlen = readlen or maxreadlen\r
437         return maxreadlen, maxsendlen\r
438     end\r
439     handler.lock = function( self, switch )\r
440         if switch == true then\r
441             handler.write = idfalse\r
442             local tmp = _sendlistlen\r
443             _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
444             _writetimes[ handler ] = nil\r
445             if _sendlistlen ~= tmp then\r
446                 nosend = true\r
447             end\r
448             tmp = _readlistlen\r
449             _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
450             _readtimes[ handler ] = nil\r
451             if _readlistlen ~= tmp then\r
452                 noread = true\r
453             end\r
454         elseif switch == false then\r
455             handler.write = write\r
456             if noread then\r
457                 noread = false\r
458                 _readlistlen = addsocket(_readlist, socket, _readlistlen)\r
459                 _readtimes[ handler ] = _currenttime\r
460             end\r
461             if nosend then\r
462                 nosend = false\r
463                 write( "" )\r
464             end\r
465         end\r
466         return noread, nosend\r
467     end\r
468     local _readbuffer = function( )    -- this function reads data\r
469         local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"\r
470         if not err or ( err == "timeout" or err == "wantread" ) then    -- received something\r
471             local buffer = buffer or part or ""\r
472             local len = string_len( buffer )\r
473             if len > maxreadlen then\r
474                 disconnect( handler, "receive buffer exceeded" )\r
475                 handler.close( true )\r
476                 return false\r
477             end\r
478             local count = len * STAT_UNIT\r
479             readtraffic = readtraffic + count\r
480             _readtraffic = _readtraffic + count\r
481             _readtimes[ handler ] = _currenttime\r
482             --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )\r
483             return dispatch( handler, buffer, err )\r
484         else    -- connections was closed or fatal error\r
485             out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )\r
486             fatalerror = true\r
487             disconnect( handler, err )\r
488             _ = handler and handler.close( )\r
489             return false\r
490         end\r
491     end\r
492     local _sendbuffer = function( )    -- this function sends data\r
493         local succ, err, byte, buffer, count;\r
494         local count;\r
495         if socket then\r
496             buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
497             succ, err, byte = send( socket, buffer, 1, bufferlen )\r
498             count = ( succ or byte or 0 ) * STAT_UNIT\r
499             sendtraffic = sendtraffic + count\r
500             _sendtraffic = _sendtraffic + count\r
501             _ = _cleanqueue and clean( bufferqueue )\r
502             --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )\r
503         else\r
504             succ, err, count = false, "closed", 0;\r
505         end\r
506         if succ then    -- sending succesful\r
507             bufferqueuelen = 0\r
508             bufferlen = 0\r
509             _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )    -- delete socket from writelist\r
510             _ = needtls and handler:starttls(true)\r
511             _writetimes[ handler ] = nil\r
512             _ = toclose and handler.close( )\r
513             return true\r
514         elseif byte and ( err == "timeout" or err == "wantwrite" ) then    -- want write\r
515             buffer = string_sub( buffer, byte + 1, bufferlen )    -- new buffer\r
516             bufferqueue[ 1 ] = buffer    -- insert new buffer in queue\r
517             bufferqueuelen = 1\r
518             bufferlen = bufferlen - byte\r
519             _writetimes[ handler ] = _currenttime\r
520             return true\r
521         else    -- connection was closed during sending or fatal error\r
522             out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )\r
523             fatalerror = true\r
524             disconnect( handler, err )\r
525             _ = handler and handler.close( )\r
526             return false\r
527         end\r
528     end\r
529 \r
530     -- Set the sslctx\r
531     local handshake;\r
532     function handler.set_sslctx(self, new_sslctx)\r
533         ssl = true\r
534         sslctx = new_sslctx;\r
535         local wrote\r
536         local read\r
537         handshake = coroutine_wrap( function( client )    -- create handshake coroutine\r
538                 local err\r
539                 for i = 1, _maxsslhandshake do\r
540                     _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen\r
541                     _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen\r
542                     read, wrote = nil, nil\r
543                     _, err = client:dohandshake( )\r
544                     if not err then\r
545                         out_put( "server.lua: ssl handshake done" )\r
546                         handler.readbuffer = _readbuffer    -- when handshake is done, replace the handshake function with regular functions\r
547                         handler.sendbuffer = _sendbuffer\r
548                         _ = status and status( handler, "ssl-handshake-complete" )\r
549                         _readlistlen = addsocket(_readlist, client, _readlistlen)\r
550                         return true\r
551                     else\r
552                        out_put( "server.lua: error during ssl handshake: ", tostring(err) )\r
553                        if err == "wantwrite" and not wrote then\r
554                            _sendlistlen = addsocket(_sendlist, client, _sendlistlen)\r
555                            wrote = true\r
556                        elseif err == "wantread" and not read then\r
557                            _readlistlen = addsocket(_readlist, client, _readlistlen)\r
558                            read = true\r
559                        else\r
560                            break;\r
561                        end\r
562                        --coroutine_yield( handler, nil, err )    -- handshake not finished\r
563                        coroutine_yield( )\r
564                     end\r
565                 end\r
566                 disconnect( handler, "ssl handshake failed" )\r
567                 _ = handler and handler:close( true )    -- forced disconnect\r
568                 return false    -- handshake failed\r
569             end\r
570         )\r
571     end\r
572     if sslctx then    -- ssl?\r
573         handler:set_sslctx(sslctx);\r
574         if startssl then    -- ssl now?\r
575             --out_put("server.lua: ", "starting ssl handshake")\r
576             local err\r
577             socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
578             if err then\r
579                 out_put( "server.lua: ssl error: ", tostring(err) )\r
580                 --mem_free( )\r
581                 return nil, nil, err    -- fatal error\r
582             end\r
583             socket:settimeout( 0 )\r
584             handler.readbuffer = handshake\r
585             handler.sendbuffer = handshake\r
586             handshake( socket ) -- do handshake\r
587             if not socket then\r
588                 return nil, nil, "ssl handshake failed";\r
589             end\r
590         else\r
591             -- We're not automatically doing SSL, so we're not secure (yet)\r
592             ssl = false\r
593             handler.starttls = function( self, now )\r
594                 if not now then\r
595                     --out_put "server.lua: we need to do tls, but delaying until later"\r
596                     needtls = true\r
597                     return\r
598                 end\r
599                 --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )\r
600                 local oldsocket, err = socket\r
601                 socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
602                 --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )\r
603                 if err then\r
604                     out_put( "server.lua: error while starting tls on client: ", tostring(err) )\r
605                     return nil, err    -- fatal error\r
606                 end\r
607 \r
608                 socket:settimeout( 0 )\r
609 \r
610                 -- add the new socket to our system\r
611 \r
612                 send = socket.send\r
613                 receive = socket.receive\r
614                 shutdown = id\r
615 \r
616                 _socketlist[ socket ] = handler\r
617                 _readlistlen = addsocket(_readlist, socket, _readlistlen)\r
618 \r
619                 -- remove traces of the old socket\r
620 \r
621                 _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )\r
622                 _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )\r
623                 _socketlist[ oldsocket ] = nil\r
624 \r
625                 handler.starttls = nil\r
626                 needtls = nil\r
627                 \r
628                 -- Secure now\r
629                 ssl = true\r
630 \r
631                 handler.readbuffer = handshake\r
632                 handler.sendbuffer = handshake\r
633                 handshake( socket )    -- do handshake\r
634             end\r
635             handler.readbuffer = _readbuffer\r
636             handler.sendbuffer = _sendbuffer\r
637         end\r
638     else    -- normal connection\r
639         ssl = false\r
640         handler.readbuffer = _readbuffer\r
641         handler.sendbuffer = _sendbuffer\r
642     end\r
643 \r
644     send = socket.send\r
645     receive = socket.receive\r
646     shutdown = ( ssl and id ) or socket.shutdown\r
647 \r
648     _socketlist[ socket ] = handler\r
649     _readlistlen = addsocket(_readlist, socket, _readlistlen)\r
650 \r
651     return handler, socket\r
652 end\r
653 \r
654 id = function( )\r
655 end\r
656 \r
657 idfalse = function( )\r
658     return false\r
659 end\r
660 \r
661 addsocket = function( list, socket, len )\r
662     if not list[ socket ] then\r
663       len = len + 1\r
664       list[ len ] = socket\r
665       list[ socket ] = len\r
666     end\r
667     return len;\r
668 end\r
669 \r
670 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )\r
671     local pos = list[ socket ]\r
672     if pos then\r
673         list[ socket ] = nil\r
674         local last = list[ len ]\r
675         list[ len ] = nil\r
676         if last ~= socket then\r
677             list[ last ] = pos\r
678             list[ pos ] = last\r
679         end\r
680         return len - 1\r
681     end\r
682     return len\r
683 end\r
684 \r
685 closesocket = function( socket )\r
686     _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
687     _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
688     _socketlist[ socket ] = nil\r
689     socket:close( )\r
690     --mem_free( )\r
691 end\r
692 \r
693 ----------------------------------// PUBLIC //--\r
694 \r
695 addserver = function( addr, port, listeners, pattern, sslctx, startssl )    -- this function provides a way for other scripts to reg a server\r
696     local err\r
697     --out_put("server.lua: autossl on ", port, " is ", startssl)\r
698     if type( listeners ) ~= "table" then\r
699         err = "invalid listener table"\r
700     end\r
701     if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then\r
702         err = "invalid port"\r
703     elseif _server[ port ] then\r
704         err =  "listeners on port '" .. port .. "' already exist"\r
705     elseif sslctx and not luasec then\r
706         err = "luasec not found"\r
707     end\r
708     if err then\r
709         out_error( "server.lua, port ", port, ": ", err )\r
710         return nil, err\r
711     end\r
712     addr = addr or "*"\r
713     local server, err = socket_bind( addr, port )\r
714     if err then\r
715         out_error( "server.lua, port ", port, ": ", err )\r
716         return nil, err\r
717     end\r
718     local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver, startssl )    -- wrap new server socket\r
719     if not handler then\r
720         server:close( )\r
721         return nil, err\r
722     end\r
723     server:settimeout( 0 )\r
724     _readlistlen = addsocket(_readlist, server, _readlistlen)\r
725     _server[ port ] = handler\r
726     _socketlist[ server ] = handler\r
727     out_put( "server.lua: new server listener on '", addr, ":", port, "'" )\r
728     return handler\r
729 end\r
730 \r
731 getserver = function ( port )\r
732         return _server[ port ];\r
733 end\r
734 \r
735 removeserver = function( port )\r
736     local handler = _server[ port ]\r
737     if not handler then\r
738         return nil, "no server found on port '" .. tostring( port ) .. "'"\r
739     end\r
740     handler:close( )\r
741     _server[ port ] = nil\r
742     return true\r
743 end\r
744 \r
745 closeall = function( )\r
746     for _, handler in pairs( _socketlist ) do\r
747         handler:close( )\r
748         _socketlist[ _ ] = nil\r
749     end\r
750     _readlistlen = 0\r
751     _sendlistlen = 0\r
752     _timerlistlen = 0\r
753     _server = { }\r
754     _readlist = { }\r
755     _sendlist = { }\r
756     _timerlist = { }\r
757     _socketlist = { }\r
758     --mem_free( )\r
759 end\r
760 \r
761 getsettings = function( )\r
762     return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake\r
763 end\r
764 \r
765 changesettings = function( new )\r
766     if type( new ) ~= "table" then\r
767         return nil, "invalid settings table"\r
768     end\r
769     _selecttimeout = tonumber( new.timeout ) or _selecttimeout\r
770     _sleeptime = tonumber( new.sleeptime ) or _sleeptime\r
771     _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen\r
772     _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen\r
773     _checkinterval = tonumber( new.checkinterval ) or _checkinterval\r
774     _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout\r
775     _readtimeout = tonumber( new.readtimeout ) or _readtimeout\r
776     _cleanqueue = new.cleanqueue\r
777     _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver\r
778     _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake\r
779     return true\r
780 end\r
781 \r
782 addtimer = function( listener )\r
783     if type( listener ) ~= "function" then\r
784         return nil, "invalid listener function"\r
785     end\r
786     _timerlistlen = _timerlistlen + 1\r
787     _timerlist[ _timerlistlen ] = listener\r
788     return true\r
789 end\r
790 \r
791 stats = function( )\r
792     return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen\r
793 end\r
794 \r
795 local dontstop = true; -- thinking about tomorrow, ...\r
796 \r
797 setquitting = function (quit)\r
798         dontstop = not quit;\r
799         return;\r
800 end\r
801 \r
802 loop = function( )    -- this is the main loop of the program\r
803     while dontstop do\r
804         local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )\r
805         for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
806             local handler = _socketlist[ socket ]\r
807             if handler then\r
808                 handler.sendbuffer( )\r
809             else\r
810                 closesocket( socket )\r
811                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen\r
812             end\r
813         end\r
814         for i, socket in ipairs( read ) do    -- receive data\r
815             local handler = _socketlist[ socket ]\r
816             if handler then\r
817                 handler.readbuffer( )\r
818             else\r
819                 closesocket( socket )\r
820                 out_put "server.lua: found no handler and closed socket (readlist)"    -- this can happen\r
821             end\r
822         end\r
823         for handler, err in pairs( _closelist ) do\r
824             handler.disconnect( )( handler, err )\r
825             handler:close( true )    -- forced disconnect\r
826         end\r
827         clean( _closelist )\r
828         _currenttime = os_time( )\r
829         if os_difftime( _currenttime - _timer ) >= 1 then\r
830             for i = 1, _timerlistlen do\r
831                 _timerlist[ i ]( _currenttime )    -- fire timers\r
832             end\r
833             _timer = _currenttime\r
834         end\r
835         socket_sleep( _sleeptime )    -- wait some time\r
836         --collectgarbage( )\r
837     end\r
838     return "quitting"\r
839 end\r
840 \r
841 --// EXPERIMENTAL //--\r
842 \r
843 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )\r
844     local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )\r
845     _socketlist[ socket ] = handler\r
846     _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)\r
847     return handler, socket\r
848 end\r
849 \r
850 local addclient = function( address, port, listeners, pattern, sslctx, startssl )\r
851     local client, err = luasocket.tcp( )\r
852     if err then\r
853         return nil, err\r
854     end\r
855     client:settimeout( 0 )\r
856     _, err = client:connect( address, port )\r
857     if err then    -- try again\r
858         local handler = wrapclient( client, address, port, listeners )\r
859     else\r
860         wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )\r
861     end\r
862 end\r
863 \r
864 --// EXPERIMENTAL //--\r
865 \r
866 ----------------------------------// BEGIN //--\r
867 \r
868 use "setmetatable" ( _socketlist, { __mode = "k" } )\r
869 use "setmetatable" ( _readtimes, { __mode = "k" } )\r
870 use "setmetatable" ( _writetimes, { __mode = "k" } )\r
871 \r
872 _timer = os_time( )\r
873 _starttime = os_time( )\r
874 \r
875 addtimer( function( )\r
876         local difftime = os_difftime( _currenttime - _starttime )\r
877         if difftime > _checkinterval then\r
878             _starttime = _currenttime\r
879             for handler, timestamp in pairs( _writetimes ) do\r
880                 if os_difftime( _currenttime - timestamp ) > _sendtimeout then\r
881                     --_writetimes[ handler ] = nil\r
882                     handler.disconnect( )( handler, "send timeout" )\r
883                     handler:close( true )    -- forced disconnect\r
884                 end\r
885             end\r
886             for handler, timestamp in pairs( _readtimes ) do\r
887                 if os_difftime( _currenttime - timestamp ) > _readtimeout then\r
888                     --_readtimes[ handler ] = nil\r
889                     handler.disconnect( )( handler, "read timeout" )\r
890                     handler:close( )    -- forced disconnect?\r
891                 end\r
892             end\r
893         end\r
894     end\r
895 )\r
896 \r
897 ----------------------------------// PUBLIC INTERFACE //--\r
898 \r
899 return {\r
900 \r
901     addclient = addclient,\r
902     wrapclient = wrapclient,\r
903     \r
904     loop = loop,\r
905     stats = stats,\r
906     closeall = closeall,\r
907     addtimer = addtimer,\r
908     addserver = addserver,\r
909     getserver = getserver,\r
910     getsettings = getsettings,\r
911     setquitting = setquitting,\r
912     removeserver = removeserver,\r
913     changesettings = changesettings,\r
914 }\r