net.server_select: Don't remove the socket from sendlist when we might have data...
[prosody.git] / net / server_select.lua
1 -- 
2 -- server.lua by blastbeat of the luadch project
3 -- Re-used here under the MIT/X Consortium License
4 -- 
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
6 --
7
8 -- // wrapping luadch stuff // --
9
10 local use = function( what )
11         return _G[ what ]
12 end
13
14 local log, table_concat = require ("util.logger").init("socket"), table.concat;
15 local out_put = function (...) return log("debug", table_concat{...}); end
16 local out_error = function (...) return log("warn", table_concat{...}); end
17
18 ----------------------------------// DECLARATION //--
19
20 --// constants //--
21
22 local STAT_UNIT = 1 -- byte
23
24 --// lua functions //--
25
26 local type = use "type"
27 local pairs = use "pairs"
28 local ipairs = use "ipairs"
29 local tonumber = use "tonumber"
30 local tostring = use "tostring"
31
32 --// lua libs //--
33
34 local os = use "os"
35 local table = use "table"
36 local string = use "string"
37 local coroutine = use "coroutine"
38
39 --// lua lib methods //--
40
41 local os_difftime = os.difftime
42 local math_min = math.min
43 local math_huge = math.huge
44 local table_concat = table.concat
45 local string_sub = string.sub
46 local coroutine_wrap = coroutine.wrap
47 local coroutine_yield = coroutine.yield
48
49 --// extern libs //--
50
51 local luasec = use "ssl"
52 local luasocket = use "socket" or require "socket"
53 local luasocket_gettime = luasocket.gettime
54
55 --// extern lib methods //--
56
57 local ssl_wrap = ( luasec and luasec.wrap )
58 local socket_bind = luasocket.bind
59 local socket_sleep = luasocket.sleep
60 local socket_select = luasocket.select
61
62 --// functions //--
63
64 local id
65 local loop
66 local stats
67 local idfalse
68 local closeall
69 local addsocket
70 local addserver
71 local addtimer
72 local getserver
73 local wrapserver
74 local getsettings
75 local closesocket
76 local removesocket
77 local removeserver
78 local wrapconnection
79 local changesettings
80
81 --// tables //--
82
83 local _server
84 local _readlist
85 local _timerlist
86 local _sendlist
87 local _socketlist
88 local _closelist
89 local _readtimes
90 local _writetimes
91
92 --// simple data types //--
93
94 local _
95 local _readlistlen
96 local _sendlistlen
97 local _timerlistlen
98
99 local _sendtraffic
100 local _readtraffic
101
102 local _selecttimeout
103 local _sleeptime
104 local _tcpbacklog
105
106 local _starttime
107 local _currenttime
108
109 local _maxsendlen
110 local _maxreadlen
111
112 local _checkinterval
113 local _sendtimeout
114 local _readtimeout
115
116 local _timer
117
118 local _maxselectlen
119 local _maxfd
120
121 local _maxsslhandshake
122
123 ----------------------------------// DEFINITION //--
124
125 _server = { } -- key = port, value = table; list of listening servers
126 _readlist = { } -- array with sockets to read from
127 _sendlist = { } -- arrary with sockets to write to
128 _timerlist = { } -- array of timer functions
129 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
130 _readtimes = { } -- key = handler, value = timestamp of last data reading
131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
132 _closelist = { } -- handlers to close
133
134 _readlistlen = 0 -- length of readlist
135 _sendlistlen = 0 -- length of sendlist
136 _timerlistlen = 0 -- lenght of timerlist
137
138 _sendtraffic = 0 -- some stats
139 _readtraffic = 0
140
141 _selecttimeout = 1 -- timeout of socket.select
142 _sleeptime = 0 -- time to wait at the end of every loop
143 _tcpbacklog = 128 -- some kind of hint to the OS
144
145 _maxsendlen = 51000 * 1024 -- max len of send buffer
146 _maxreadlen = 25000 * 1024 -- max len of read buffer
147
148 _checkinterval = 1200000 -- interval in secs to check idle clients
149 _sendtimeout = 60000 -- allowed send idle time in secs
150 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
151
152 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
153 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
154 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
155
156 _maxsslhandshake = 30 -- max handshake round-trips
157
158 ----------------------------------// PRIVATE //--
159
160 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
161
162         if socket:getfd() >= _maxfd then
163                 out_error("server.lua: Disallowed FD number: "..socket:getfd())
164                 socket:close()
165                 return nil, "fd-too-large"
166         end
167
168         local connections = 0
169
170         local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
171
172         local accept = socket.accept
173
174         --// public methods of the object //--
175
176         local handler = { }
177
178         handler.shutdown = function( ) end
179
180         handler.ssl = function( )
181                 return sslctx ~= nil
182         end
183         handler.sslctx = function( )
184                 return sslctx
185         end
186         handler.remove = function( )
187                 connections = connections - 1
188                 if handler then
189                         handler.resume( )
190                 end
191         end
192         handler.close = function()
193                 socket:close( )
194                 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
195                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
196                 _server[ip..":"..serverport] = nil;
197                 _socketlist[ socket ] = nil
198                 handler = nil
199                 socket = nil
200                 --mem_free( )
201                 out_put "server.lua: closed server handler and removed sockets from list"
202         end
203         handler.pause = function( hard )
204                 if not handler.paused then
205                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
206                         if hard then
207                                 _socketlist[ socket ] = nil
208                                 socket:close( )
209                                 socket = nil;
210                         end
211                         handler.paused = true;
212                 end
213         end
214         handler.resume = function( )
215                 if handler.paused then
216                         if not socket then
217                                 socket = socket_bind( ip, serverport, _tcpbacklog );
218                                 socket:settimeout( 0 )
219                         end
220                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
221                         _socketlist[ socket ] = handler
222                         handler.paused = false;
223                 end
224         end
225         handler.ip = function( )
226                 return ip
227         end
228         handler.serverport = function( )
229                 return serverport
230         end
231         handler.socket = function( )
232                 return socket
233         end
234         handler.readbuffer = function( )
235                 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
236                         handler.pause( )
237                         out_put( "server.lua: refused new client connection: server full" )
238                         return false
239                 end
240                 local client, err = accept( socket )    -- try to accept
241                 if client then
242                         local ip, clientport = client:getpeername( )
243                         local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
244                         if err then -- error while wrapping ssl socket
245                                 return false
246                         end
247                         connections = connections + 1
248                         out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
249                         if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
250                                 return dispatch( handler );
251                         end
252                         return;
253                 elseif err then -- maybe timeout or something else
254                         out_put( "server.lua: error with new client connection: ", tostring(err) )
255                         return false
256                 end
257         end
258         return handler
259 end
260
261 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
262
263         if socket:getfd() >= _maxfd then
264                 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
265                 socket:close( ) -- Should we send some kind of error here?
266                 if server then
267                         server.pause( )
268                 end
269                 return nil, nil, "fd-too-large"
270         end
271         socket:settimeout( 0 )
272
273         --// local import of socket methods //--
274
275         local send
276         local receive
277         local shutdown
278
279         --// private closures of the object //--
280
281         local ssl
282
283         local dispatch = listeners.onincoming
284         local status = listeners.onstatus
285         local disconnect = listeners.ondisconnect
286         local drain = listeners.ondrain
287
288         local bufferqueue = { } -- buffer array
289         local bufferqueuelen = 0        -- end of buffer array
290
291         local toclose
292         local fatalerror
293         local needtls
294
295         local bufferlen = 0
296
297         local noread = false
298         local nosend = false
299
300         local sendtraffic, readtraffic = 0, 0
301
302         local maxsendlen = _maxsendlen
303         local maxreadlen = _maxreadlen
304
305         --// public methods of the object //--
306
307         local handler = bufferqueue -- saves a table ^_^
308
309         handler.dispatch = function( )
310                 return dispatch
311         end
312         handler.disconnect = function( )
313                 return disconnect
314         end
315         handler.setlistener = function( self, listeners )
316                 dispatch = listeners.onincoming
317                 disconnect = listeners.ondisconnect
318                 status = listeners.onstatus
319                 drain = listeners.ondrain
320         end
321         handler.getstats = function( )
322                 return readtraffic, sendtraffic
323         end
324         handler.ssl = function( )
325                 return ssl
326         end
327         handler.sslctx = function ( )
328                 return sslctx
329         end
330         handler.send = function( _, data, i, j )
331                 return send( socket, data, i, j )
332         end
333         handler.receive = function( pattern, prefix )
334                 return receive( socket, pattern, prefix )
335         end
336         handler.shutdown = function( pattern )
337                 return shutdown( socket, pattern )
338         end
339         handler.setoption = function (self, option, value)
340                 if socket.setoption then
341                         return socket:setoption(option, value);
342                 end
343                 return false, "setoption not implemented";
344         end
345         handler.force_close = function ( self, err )
346                 if bufferqueuelen ~= 0 then
347                         out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
348                         bufferqueuelen = 0;
349                 end
350                 return self:close(err);
351         end
352         handler.close = function( self, err )
353                 if not handler then return true; end
354                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
355                 _readtimes[ handler ] = nil
356                 if bufferqueuelen ~= 0 then
357                         handler.sendbuffer() -- Try now to send any outstanding data
358                         if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
359                                 if handler then
360                                         handler.write = nil -- ... but no further writing allowed
361                                 end
362                                 toclose = true
363                                 return false
364                         end
365                 end
366                 if socket then
367                         _ = shutdown and shutdown( socket )
368                         socket:close( )
369                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
370                         _socketlist[ socket ] = nil
371                         socket = nil
372                 else
373                         out_put "server.lua: socket already closed"
374                 end
375                 if handler then
376                         _writetimes[ handler ] = nil
377                         _closelist[ handler ] = nil
378                         local _handler = handler;
379                         handler = nil
380                         if disconnect then
381                                 disconnect(_handler, err or false);
382                                 disconnect = nil
383                         end
384                 end
385                 if server then
386                         server.remove( )
387                 end
388                 out_put "server.lua: closed client handler and removed socket from list"
389                 return true
390         end
391         handler.ip = function( )
392                 return ip
393         end
394         handler.serverport = function( )
395                 return serverport
396         end
397         handler.clientport = function( )
398                 return clientport
399         end
400         local write = function( self, data )
401                 bufferlen = bufferlen + #data
402                 if bufferlen > maxsendlen then
403                         _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle
404                         handler.write = idfalse -- dont write anymore
405                         return false
406                 elseif socket and not _sendlist[ socket ] then
407                         _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
408                 end
409                 bufferqueuelen = bufferqueuelen + 1
410                 bufferqueue[ bufferqueuelen ] = data
411                 if handler then
412                         _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
413                 end
414                 return true
415         end
416         handler.write = write
417         handler.bufferqueue = function( self )
418                 return bufferqueue
419         end
420         handler.socket = function( self )
421                 return socket
422         end
423         handler.set_mode = function( self, new )
424                 pattern = new or pattern
425                 return pattern
426         end
427         handler.set_send = function ( self, newsend )
428                 send = newsend or send
429                 return send
430         end
431         handler.bufferlen = function( self, readlen, sendlen )
432                 maxsendlen = sendlen or maxsendlen
433                 maxreadlen = readlen or maxreadlen
434                 return bufferlen, maxreadlen, maxsendlen
435         end
436         --TODO: Deprecate
437         handler.lock_read = function (self, switch)
438                 if switch == true then
439                         local tmp = _readlistlen
440                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
441                         _readtimes[ handler ] = nil
442                         if _readlistlen ~= tmp then
443                                 noread = true
444                         end
445                 elseif switch == false then
446                         if noread then
447                                 noread = false
448                                 _readlistlen = addsocket(_readlist, socket, _readlistlen)
449                                 _readtimes[ handler ] = _currenttime
450                         end
451                 end
452                 return noread
453         end
454         handler.pause = function (self)
455                 return self:lock_read(true);
456         end
457         handler.resume = function (self)
458                 return self:lock_read(false);
459         end
460         handler.lock = function( self, switch )
461                 handler.lock_read (switch)
462                 if switch == true then
463                         handler.write = idfalse
464                         local tmp = _sendlistlen
465                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
466                         _writetimes[ handler ] = nil
467                         if _sendlistlen ~= tmp then
468                                 nosend = true
469                         end
470                 elseif switch == false then
471                         handler.write = write
472                         if nosend then
473                                 nosend = false
474                                 write( "" )
475                         end
476                 end
477                 return noread, nosend
478         end
479         local _readbuffer = function( ) -- this function reads data
480                 local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
481                 if not err or (err == "wantread" or err == "timeout") then -- received something
482                         local buffer = buffer or part or ""
483                         local len = #buffer
484                         if len > maxreadlen then
485                                 handler:close( "receive buffer exceeded" )
486                                 return false
487                         end
488                         local count = len * STAT_UNIT
489                         readtraffic = readtraffic + count
490                         _readtraffic = _readtraffic + count
491                         _readtimes[ handler ] = _currenttime
492                         --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
493                         return dispatch( handler, buffer, err )
494                 else    -- connections was closed or fatal error
495                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
496                         fatalerror = true
497                         _ = handler and handler:force_close( err )
498                         return false
499                 end
500         end
501         local _sendbuffer = function( ) -- this function sends data
502                 local succ, err, byte, buffer, count;
503                 if socket then
504                         buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
505                         succ, err, byte = send( socket, buffer, 1, bufferlen )
506                         count = ( succ or byte or 0 ) * STAT_UNIT
507                         sendtraffic = sendtraffic + count
508                         _sendtraffic = _sendtraffic + count
509                         for i = bufferqueuelen,1,-1 do
510                                 bufferqueue[ i ] = nil
511                         end
512                         --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
513                 else
514                         succ, err, count = false, "unexpected close", 0;
515                 end
516                 if succ then    -- sending succesful
517                         bufferqueuelen = 0
518                         bufferlen = 0
519                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
520                         _writetimes[ handler ] = nil
521                         if drain then
522                                 drain(handler)
523                         end
524                         _ = needtls and handler:starttls(nil)
525                         _ = toclose and handler:force_close( )
526                         return true
527                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
528                         buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
529                         bufferqueue[ 1 ] = buffer        -- insert new buffer in queue
530                         bufferqueuelen = 1
531                         bufferlen = bufferlen - byte
532                         _writetimes[ handler ] = _currenttime
533                         return true
534                 else    -- connection was closed during sending or fatal error
535                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
536                         fatalerror = true
537                         _ = handler and handler:force_close( err )
538                         return false
539                 end
540         end
541
542         -- Set the sslctx
543         local handshake;
544         function handler.set_sslctx(self, new_sslctx)
545                 sslctx = new_sslctx;
546                 local read, wrote
547                 handshake = coroutine_wrap( function( client ) -- create handshake coroutine
548                                 local err
549                                 for i = 1, _maxsslhandshake do
550                                         _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
551                                         _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
552                                         read, wrote = nil, nil
553                                         _, err = client:dohandshake( )
554                                         if not err then
555                                                 out_put( "server.lua: ssl handshake done" )
556                                                 handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
557                                                 handler.sendbuffer = _sendbuffer
558                                                 _ = status and status( handler, "ssl-handshake-complete" )
559                                                 if self.autostart_ssl and listeners.onconnect then
560                                                         listeners.onconnect(self);
561                                                 end
562                                                 _readlistlen = addsocket(_readlist, client, _readlistlen)
563                                                 return true
564                                         else
565                                                 if err == "wantwrite" then
566                                                         _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
567                                                         wrote = true
568                                                 elseif err == "wantread" then
569                                                         _readlistlen = addsocket(_readlist, client, _readlistlen)
570                                                         read = true
571                                                 else
572                                                         break;
573                                                 end
574                                                 err = nil;
575                                                 coroutine_yield( ) -- handshake not finished
576                                         end
577                                 end
578                                 out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
579                                 _ = handler and handler:force_close("ssl handshake failed")
580                                 return false, err -- handshake failed
581                         end
582                 )
583         end
584         if luasec then
585                 handler.starttls = function( self, _sslctx)
586                         if _sslctx then
587                                 handler:set_sslctx(_sslctx);
588                         end
589                         if bufferqueuelen > 0 then
590                                 out_put "server.lua: we need to do tls, but delaying until send buffer empty"
591                                 needtls = true
592                                 return
593                         end
594                         out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
595                         local oldsocket, err = socket
596                         socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
597                         if not socket then
598                                 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
599                                 return nil, err -- fatal error
600                         end
601
602                         socket:settimeout( 0 )
603
604                         -- add the new socket to our system
605                         send = socket.send
606                         receive = socket.receive
607                         shutdown = id
608                         _socketlist[ socket ] = handler
609                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
610                         
611                         -- remove traces of the old socket
612                         _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
613                         _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
614                         _socketlist[ oldsocket ] = nil
615
616                         handler.starttls = nil
617                         needtls = nil
618
619                         -- Secure now (if handshake fails connection will close)
620                         ssl = true
621
622                         handler.readbuffer = handshake
623                         handler.sendbuffer = handshake
624                         return handshake( socket ) -- do handshake
625                 end
626         end
627
628         handler.readbuffer = _readbuffer
629         handler.sendbuffer = _sendbuffer
630         send = socket.send
631         receive = socket.receive
632         shutdown = ( ssl and id ) or socket.shutdown
633
634         _socketlist[ socket ] = handler
635         _readlistlen = addsocket(_readlist, socket, _readlistlen)
636
637         if sslctx and luasec then
638                 out_put "server.lua: auto-starting ssl negotiation..."
639                 handler.autostart_ssl = true;
640                 local ok, err = handler:starttls(sslctx);
641                 if ok == false then
642                         return nil, nil, err
643                 end
644         end
645
646         return handler, socket
647 end
648
649 id = function( )
650 end
651
652 idfalse = function( )
653         return false
654 end
655
656 addsocket = function( list, socket, len )
657         if not list[ socket ] then
658                 len = len + 1
659                 list[ len ] = socket
660                 list[ socket ] = len
661         end
662         return len;
663 end
664
665 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )
666         local pos = list[ socket ]
667         if pos then
668                 list[ socket ] = nil
669                 local last = list[ len ]
670                 list[ len ] = nil
671                 if last ~= socket then
672                         list[ last ] = pos
673                         list[ pos ] = last
674                 end
675                 return len - 1
676         end
677         return len
678 end
679
680 closesocket = function( socket )
681         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
682         _readlistlen = removesocket( _readlist, socket, _readlistlen )
683         _socketlist[ socket ] = nil
684         socket:close( )
685         --mem_free( )
686 end
687
688 local function link(sender, receiver, buffersize)
689         local sender_locked;
690         local _sendbuffer = receiver.sendbuffer;
691         function receiver.sendbuffer()
692                 _sendbuffer();
693                 if sender_locked and receiver.bufferlen() < buffersize then
694                         sender:lock_read(false); -- Unlock now
695                         sender_locked = nil;
696                 end
697         end
698         
699         local _readbuffer = sender.readbuffer;
700         function sender.readbuffer()
701                 _readbuffer();
702                 if not sender_locked and receiver.bufferlen() >= buffersize then
703                         sender_locked = true;
704                         sender:lock_read(true);
705                 end
706         end
707 end
708
709 ----------------------------------// PUBLIC //--
710
711 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
712         local err
713         if type( listeners ) ~= "table" then
714                 err = "invalid listener table"
715         end
716         if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
717                 err = "invalid port"
718         elseif _server[ addr..":"..port ] then
719                 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
720         elseif sslctx and not luasec then
721                 err = "luasec not found"
722         end
723         if err then
724                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
725                 return nil, err
726         end
727         addr = addr or "*"
728         local server, err = socket_bind( addr, port, _tcpbacklog )
729         if err then
730                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
731                 return nil, err
732         end
733         local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
734         if not handler then
735                 server:close( )
736                 return nil, err
737         end
738         server:settimeout( 0 )
739         _readlistlen = addsocket(_readlist, server, _readlistlen)
740         _server[ addr..":"..port ] = handler
741         _socketlist[ server ] = handler
742         out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
743         return handler
744 end
745
746 getserver = function ( addr, port )
747         return _server[ addr..":"..port ];
748 end
749
750 removeserver = function( addr, port )
751         local handler = _server[ addr..":"..port ]
752         if not handler then
753                 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
754         end
755         handler:close( )
756         _server[ addr..":"..port ] = nil
757         return true
758 end
759
760 closeall = function( )
761         for _, handler in pairs( _socketlist ) do
762                 handler:close( )
763                 _socketlist[ _ ] = nil
764         end
765         _readlistlen = 0
766         _sendlistlen = 0
767         _timerlistlen = 0
768         _server = { }
769         _readlist = { }
770         _sendlist = { }
771         _timerlist = { }
772         _socketlist = { }
773         --mem_free( )
774 end
775
776 getsettings = function( )
777         return {
778                 select_timeout = _selecttimeout;
779                 select_sleep_time = _sleeptime;
780                 tcp_backlog = _tcpbacklog;
781                 max_send_buffer_size = _maxsendlen;
782                 max_receive_buffer_size = _maxreadlen;
783                 select_idle_check_interval = _checkinterval;
784                 send_timeout = _sendtimeout;
785                 read_timeout = _readtimeout;
786                 max_connections = _maxselectlen;
787                 max_ssl_handshake_roundtrips = _maxsslhandshake;
788                 highest_allowed_fd = _maxfd;
789         }
790 end
791
792 changesettings = function( new )
793         if type( new ) ~= "table" then
794                 return nil, "invalid settings table"
795         end
796         _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
797         _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
798         _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
799         _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
800         _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
801         _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
802         _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
803         _readtimeout = tonumber( new.read_timeout ) or _readtimeout
804         _maxselectlen = new.max_connections or _maxselectlen
805         _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
806         _maxfd = new.highest_allowed_fd or _maxfd
807         return true
808 end
809
810 addtimer = function( listener )
811         if type( listener ) ~= "function" then
812                 return nil, "invalid listener function"
813         end
814         _timerlistlen = _timerlistlen + 1
815         _timerlist[ _timerlistlen ] = listener
816         return true
817 end
818
819 stats = function( )
820         return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
821 end
822
823 local quitting;
824
825 local function setquitting(quit)
826         quitting = not not quit;
827 end
828
829 loop = function(once) -- this is the main loop of the program
830         if quitting then return "quitting"; end
831         if once then quitting = "once"; end
832         local next_timer_time = math_huge;
833         repeat
834                 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
835                 for i, socket in ipairs( write ) do -- send data waiting in writequeues
836                         local handler = _socketlist[ socket ]
837                         if handler then
838                                 handler.sendbuffer( )
839                         else
840                                 closesocket( socket )
841                                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen
842                         end
843                 end
844                 for i, socket in ipairs( read ) do -- receive data
845                         local handler = _socketlist[ socket ]
846                         if handler then
847                                 handler.readbuffer( )
848                         else
849                                 closesocket( socket )
850                                 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
851                         end
852                 end
853                 for handler, err in pairs( _closelist ) do
854                         handler.disconnect( )( handler, err )
855                         handler:force_close()    -- forced disconnect
856                         _closelist[ handler ] = nil;
857                 end
858                 _currenttime = luasocket_gettime( )
859
860                 -- Check for socket timeouts
861                 local difftime = os_difftime( _currenttime - _starttime )
862                 if difftime > _checkinterval then
863                         _starttime = _currenttime
864                         for handler, timestamp in pairs( _writetimes ) do
865                                 if os_difftime( _currenttime - timestamp ) > _sendtimeout then
866                                         --_writetimes[ handler ] = nil
867                                         handler.disconnect( )( handler, "send timeout" )
868                                         handler:force_close()    -- forced disconnect
869                                 end
870                         end
871                         for handler, timestamp in pairs( _readtimes ) do
872                                 if os_difftime( _currenttime - timestamp ) > _readtimeout then
873                                         --_readtimes[ handler ] = nil
874                                         handler.disconnect( )( handler, "read timeout" )
875                                         handler:close( )        -- forced disconnect?
876                                 end
877                         end
878                 end
879
880                 -- Fire timers
881                 if _currenttime - _timer >= math_min(next_timer_time, 1) then
882                         next_timer_time = math_huge;
883                         for i = 1, _timerlistlen do
884                                 local t = _timerlist[ i ]( _currenttime ) -- fire timers
885                                 if t then next_timer_time = math_min(next_timer_time, t); end
886                         end
887                         _timer = _currenttime
888                 else
889                         next_timer_time = next_timer_time - (_currenttime - _timer);
890                 end
891
892                 -- wait some time (0 by default)
893                 socket_sleep( _sleeptime )
894         until quitting;
895         if once and quitting == "once" then quitting = nil; return; end
896         return "quitting"
897 end
898
899 local function step()
900         return loop(true);
901 end
902
903 local function get_backend()
904         return "select";
905 end
906
907 --// EXPERIMENTAL //--
908
909 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
910         local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
911         if not handler then return nil, err end
912         _socketlist[ socket ] = handler
913         if not sslctx then
914                 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
915                 if listeners.onconnect then
916                         -- When socket is writeable, call onconnect
917                         local _sendbuffer = handler.sendbuffer;
918                         handler.sendbuffer = function ()
919                                 handler.sendbuffer = _sendbuffer;
920                                 listeners.onconnect(handler);
921                                 return _sendbuffer(); -- Send any queued outgoing data
922                         end
923                 end
924         end
925         return handler, socket
926 end
927
928 local addclient = function( address, port, listeners, pattern, sslctx )
929         local client, err = luasocket.tcp( )
930         if err then
931                 return nil, err
932         end
933         client:settimeout( 0 )
934         _, err = client:connect( address, port )
935         if err then -- try again
936                 local handler = wrapclient( client, address, port, listeners )
937         else
938                 wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
939         end
940 end
941
942 --// EXPERIMENTAL //--
943
944 ----------------------------------// BEGIN //--
945
946 use "setmetatable" ( _socketlist, { __mode = "k" } )
947 use "setmetatable" ( _readtimes, { __mode = "k" } )
948 use "setmetatable" ( _writetimes, { __mode = "k" } )
949
950 _timer = luasocket_gettime( )
951 _starttime = luasocket_gettime( )
952
953 local function setlogger(new_logger)
954         local old_logger = log;
955         if new_logger then
956                 log = new_logger;
957         end
958         return old_logger;
959 end
960
961 ----------------------------------// PUBLIC INTERFACE //--
962
963 return {
964         _addtimer = addtimer,
965
966         addclient = addclient,
967         wrapclient = wrapclient,
968         
969         loop = loop,
970         link = link,
971         step = step,
972         stats = stats,
973         closeall = closeall,
974         addserver = addserver,
975         getserver = getserver,
976         setlogger = setlogger,
977         getsettings = getsettings,
978         setquitting = setquitting,
979         removeserver = removeserver,
980         get_backend = get_backend,
981         changesettings = changesettings,
982 }