net.server_select/event: Switch sender mode to *a when reading, to make sure we get...
[prosody.git] / net / server_select.lua
1 -- 
2 -- server.lua by blastbeat of the luadch project
3 -- Re-used here under the MIT/X Consortium License
4 -- 
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
6 --
7
8 -- // wrapping luadch stuff // --
9
10 local use = function( what )
11         return _G[ what ]
12 end
13
14 local log, table_concat = require ("util.logger").init("socket"), table.concat;
15 local out_put = function (...) return log("debug", table_concat{...}); end
16 local out_error = function (...) return log("warn", table_concat{...}); end
17
18 ----------------------------------// DECLARATION //--
19
20 --// constants //--
21
22 local STAT_UNIT = 1 -- byte
23
24 --// lua functions //--
25
26 local type = use "type"
27 local pairs = use "pairs"
28 local ipairs = use "ipairs"
29 local tonumber = use "tonumber"
30 local tostring = use "tostring"
31
32 --// lua libs //--
33
34 local os = use "os"
35 local table = use "table"
36 local string = use "string"
37 local coroutine = use "coroutine"
38
39 --// lua lib methods //--
40
41 local os_difftime = os.difftime
42 local math_min = math.min
43 local math_huge = math.huge
44 local table_concat = table.concat
45 local string_sub = string.sub
46 local coroutine_wrap = coroutine.wrap
47 local coroutine_yield = coroutine.yield
48
49 --// extern libs //--
50
51 local luasec = use "ssl"
52 local luasocket = use "socket" or require "socket"
53 local luasocket_gettime = luasocket.gettime
54
55 --// extern lib methods //--
56
57 local ssl_wrap = ( luasec and luasec.wrap )
58 local socket_bind = luasocket.bind
59 local socket_sleep = luasocket.sleep
60 local socket_select = luasocket.select
61
62 --// functions //--
63
64 local id
65 local loop
66 local stats
67 local idfalse
68 local closeall
69 local addsocket
70 local addserver
71 local addtimer
72 local getserver
73 local wrapserver
74 local getsettings
75 local closesocket
76 local removesocket
77 local removeserver
78 local wrapconnection
79 local changesettings
80
81 --// tables //--
82
83 local _server
84 local _readlist
85 local _timerlist
86 local _sendlist
87 local _socketlist
88 local _closelist
89 local _readtimes
90 local _writetimes
91
92 --// simple data types //--
93
94 local _
95 local _readlistlen
96 local _sendlistlen
97 local _timerlistlen
98
99 local _sendtraffic
100 local _readtraffic
101
102 local _selecttimeout
103 local _sleeptime
104 local _tcpbacklog
105
106 local _starttime
107 local _currenttime
108
109 local _maxsendlen
110 local _maxreadlen
111
112 local _checkinterval
113 local _sendtimeout
114 local _readtimeout
115
116 local _timer
117
118 local _maxselectlen
119 local _maxfd
120
121 local _maxsslhandshake
122
123 ----------------------------------// DEFINITION //--
124
125 _server = { } -- key = port, value = table; list of listening servers
126 _readlist = { } -- array with sockets to read from
127 _sendlist = { } -- arrary with sockets to write to
128 _timerlist = { } -- array of timer functions
129 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
130 _readtimes = { } -- key = handler, value = timestamp of last data reading
131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
132 _closelist = { } -- handlers to close
133
134 _readlistlen = 0 -- length of readlist
135 _sendlistlen = 0 -- length of sendlist
136 _timerlistlen = 0 -- lenght of timerlist
137
138 _sendtraffic = 0 -- some stats
139 _readtraffic = 0
140
141 _selecttimeout = 1 -- timeout of socket.select
142 _sleeptime = 0 -- time to wait at the end of every loop
143 _tcpbacklog = 128 -- some kind of hint to the OS
144
145 _maxsendlen = 51000 * 1024 -- max len of send buffer
146 _maxreadlen = 25000 * 1024 -- max len of read buffer
147
148 _checkinterval = 1200000 -- interval in secs to check idle clients
149 _sendtimeout = 60000 -- allowed send idle time in secs
150 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
151
152 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
153 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
154 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
155
156 _maxsslhandshake = 30 -- max handshake round-trips
157
158 ----------------------------------// PRIVATE //--
159
160 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
161
162         if socket:getfd() >= _maxfd then
163                 out_error("server.lua: Disallowed FD number: "..socket:getfd())
164                 socket:close()
165                 return nil, "fd-too-large"
166         end
167
168         local connections = 0
169
170         local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
171
172         local accept = socket.accept
173
174         --// public methods of the object //--
175
176         local handler = { }
177
178         handler.shutdown = function( ) end
179
180         handler.ssl = function( )
181                 return sslctx ~= nil
182         end
183         handler.sslctx = function( )
184                 return sslctx
185         end
186         handler.remove = function( )
187                 connections = connections - 1
188                 if handler then
189                         handler.resume( )
190                 end
191         end
192         handler.close = function()
193                 socket:close( )
194                 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
195                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
196                 _server[ip..":"..serverport] = nil;
197                 _socketlist[ socket ] = nil
198                 handler = nil
199                 socket = nil
200                 --mem_free( )
201                 out_put "server.lua: closed server handler and removed sockets from list"
202         end
203         handler.pause = function( hard )
204                 if not handler.paused then
205                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
206                         if hard then
207                                 _socketlist[ socket ] = nil
208                                 socket:close( )
209                                 socket = nil;
210                         end
211                         handler.paused = true;
212                 end
213         end
214         handler.resume = function( )
215                 if handler.paused then
216                         if not socket then
217                                 socket = socket_bind( ip, serverport, _tcpbacklog );
218                                 socket:settimeout( 0 )
219                         end
220                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
221                         _socketlist[ socket ] = handler
222                         handler.paused = false;
223                 end
224         end
225         handler.ip = function( )
226                 return ip
227         end
228         handler.serverport = function( )
229                 return serverport
230         end
231         handler.socket = function( )
232                 return socket
233         end
234         handler.readbuffer = function( )
235                 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
236                         handler.pause( )
237                         out_put( "server.lua: refused new client connection: server full" )
238                         return false
239                 end
240                 local client, err = accept( socket )    -- try to accept
241                 if client then
242                         local ip, clientport = client:getpeername( )
243                         local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
244                         if err then -- error while wrapping ssl socket
245                                 return false
246                         end
247                         connections = connections + 1
248                         out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
249                         if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
250                                 return dispatch( handler );
251                         end
252                         return;
253                 elseif err then -- maybe timeout or something else
254                         out_put( "server.lua: error with new client connection: ", tostring(err) )
255                         return false
256                 end
257         end
258         return handler
259 end
260
261 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
262
263         if socket:getfd() >= _maxfd then
264                 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
265                 socket:close( ) -- Should we send some kind of error here?
266                 if server then
267                         server.pause( )
268                 end
269                 return nil, nil, "fd-too-large"
270         end
271         socket:settimeout( 0 )
272
273         --// local import of socket methods //--
274
275         local send
276         local receive
277         local shutdown
278
279         --// private closures of the object //--
280
281         local ssl
282
283         local dispatch = listeners.onincoming
284         local status = listeners.onstatus
285         local disconnect = listeners.ondisconnect
286         local drain = listeners.ondrain
287
288         local bufferqueue = { } -- buffer array
289         local bufferqueuelen = 0        -- end of buffer array
290
291         local toclose
292         local fatalerror
293         local needtls
294
295         local bufferlen = 0
296
297         local noread = false
298         local nosend = false
299
300         local sendtraffic, readtraffic = 0, 0
301
302         local maxsendlen = _maxsendlen
303         local maxreadlen = _maxreadlen
304
305         --// public methods of the object //--
306
307         local handler = bufferqueue -- saves a table ^_^
308
309         handler.dispatch = function( )
310                 return dispatch
311         end
312         handler.disconnect = function( )
313                 return disconnect
314         end
315         handler.setlistener = function( self, listeners )
316                 dispatch = listeners.onincoming
317                 disconnect = listeners.ondisconnect
318                 status = listeners.onstatus
319                 drain = listeners.ondrain
320         end
321         handler.getstats = function( )
322                 return readtraffic, sendtraffic
323         end
324         handler.ssl = function( )
325                 return ssl
326         end
327         handler.sslctx = function ( )
328                 return sslctx
329         end
330         handler.send = function( _, data, i, j )
331                 return send( socket, data, i, j )
332         end
333         handler.receive = function( pattern, prefix )
334                 return receive( socket, pattern, prefix )
335         end
336         handler.shutdown = function( pattern )
337                 return shutdown( socket, pattern )
338         end
339         handler.setoption = function (self, option, value)
340                 if socket.setoption then
341                         return socket:setoption(option, value);
342                 end
343                 return false, "setoption not implemented";
344         end
345         handler.force_close = function ( self, err )
346                 if bufferqueuelen ~= 0 then
347                         out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
348                         bufferqueuelen = 0;
349                 end
350                 return self:close(err);
351         end
352         handler.close = function( self, err )
353                 if not handler then return true; end
354                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
355                 _readtimes[ handler ] = nil
356                 if bufferqueuelen ~= 0 then
357                         handler.sendbuffer() -- Try now to send any outstanding data
358                         if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
359                                 if handler then
360                                         handler.write = nil -- ... but no further writing allowed
361                                 end
362                                 toclose = true
363                                 return false
364                         end
365                 end
366                 if socket then
367                         _ = shutdown and shutdown( socket )
368                         socket:close( )
369                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
370                         _socketlist[ socket ] = nil
371                         socket = nil
372                 else
373                         out_put "server.lua: socket already closed"
374                 end
375                 if handler then
376                         _writetimes[ handler ] = nil
377                         _closelist[ handler ] = nil
378                         local _handler = handler;
379                         handler = nil
380                         if disconnect then
381                                 disconnect(_handler, err or false);
382                                 disconnect = nil
383                         end
384                 end
385                 if server then
386                         server.remove( )
387                 end
388                 out_put "server.lua: closed client handler and removed socket from list"
389                 return true
390         end
391         handler.ip = function( )
392                 return ip
393         end
394         handler.serverport = function( )
395                 return serverport
396         end
397         handler.clientport = function( )
398                 return clientport
399         end
400         handler.port = handler.clientport -- COMPAT server_event
401         local write = function( self, data )
402                 bufferlen = bufferlen + #data
403                 if bufferlen > maxsendlen then
404                         _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle
405                         handler.write = idfalse -- dont write anymore
406                         return false
407                 elseif socket and not _sendlist[ socket ] then
408                         _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
409                 end
410                 bufferqueuelen = bufferqueuelen + 1
411                 bufferqueue[ bufferqueuelen ] = data
412                 if handler then
413                         _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
414                 end
415                 return true
416         end
417         handler.write = write
418         handler.bufferqueue = function( self )
419                 return bufferqueue
420         end
421         handler.socket = function( self )
422                 return socket
423         end
424         handler.set_mode = function( self, new )
425                 pattern = new or pattern
426                 return pattern
427         end
428         handler.set_send = function ( self, newsend )
429                 send = newsend or send
430                 return send
431         end
432         handler.bufferlen = function( self, readlen, sendlen )
433                 maxsendlen = sendlen or maxsendlen
434                 maxreadlen = readlen or maxreadlen
435                 return bufferlen, maxreadlen, maxsendlen
436         end
437         --TODO: Deprecate
438         handler.lock_read = function (self, switch)
439                 if switch == true then
440                         local tmp = _readlistlen
441                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
442                         _readtimes[ handler ] = nil
443                         if _readlistlen ~= tmp then
444                                 noread = true
445                         end
446                 elseif switch == false then
447                         if noread then
448                                 noread = false
449                                 _readlistlen = addsocket(_readlist, socket, _readlistlen)
450                                 _readtimes[ handler ] = _currenttime
451                         end
452                 end
453                 return noread
454         end
455         handler.pause = function (self)
456                 return self:lock_read(true);
457         end
458         handler.resume = function (self)
459                 return self:lock_read(false);
460         end
461         handler.lock = function( self, switch )
462                 handler.lock_read (switch)
463                 if switch == true then
464                         handler.write = idfalse
465                         local tmp = _sendlistlen
466                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
467                         _writetimes[ handler ] = nil
468                         if _sendlistlen ~= tmp then
469                                 nosend = true
470                         end
471                 elseif switch == false then
472                         handler.write = write
473                         if nosend then
474                                 nosend = false
475                                 write( "" )
476                         end
477                 end
478                 return noread, nosend
479         end
480         local _readbuffer = function( ) -- this function reads data
481                 local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
482                 if not err or (err == "wantread" or err == "timeout") then -- received something
483                         local buffer = buffer or part or ""
484                         local len = #buffer
485                         if len > maxreadlen then
486                                 handler:close( "receive buffer exceeded" )
487                                 return false
488                         end
489                         local count = len * STAT_UNIT
490                         readtraffic = readtraffic + count
491                         _readtraffic = _readtraffic + count
492                         _readtimes[ handler ] = _currenttime
493                         --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
494                         return dispatch( handler, buffer, err )
495                 else    -- connections was closed or fatal error
496                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
497                         fatalerror = true
498                         _ = handler and handler:force_close( err )
499                         return false
500                 end
501         end
502         local _sendbuffer = function( ) -- this function sends data
503                 local succ, err, byte, buffer, count;
504                 if socket then
505                         buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
506                         succ, err, byte = send( socket, buffer, 1, bufferlen )
507                         count = ( succ or byte or 0 ) * STAT_UNIT
508                         sendtraffic = sendtraffic + count
509                         _sendtraffic = _sendtraffic + count
510                         for i = bufferqueuelen,1,-1 do
511                                 bufferqueue[ i ] = nil
512                         end
513                         --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
514                 else
515                         succ, err, count = false, "unexpected close", 0;
516                 end
517                 if succ then    -- sending succesful
518                         bufferqueuelen = 0
519                         bufferlen = 0
520                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
521                         _writetimes[ handler ] = nil
522                         if drain then
523                                 drain(handler)
524                         end
525                         _ = needtls and handler:starttls(nil)
526                         _ = toclose and handler:force_close( )
527                         return true
528                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
529                         buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
530                         bufferqueue[ 1 ] = buffer        -- insert new buffer in queue
531                         bufferqueuelen = 1
532                         bufferlen = bufferlen - byte
533                         _writetimes[ handler ] = _currenttime
534                         return true
535                 else    -- connection was closed during sending or fatal error
536                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
537                         fatalerror = true
538                         _ = handler and handler:force_close( err )
539                         return false
540                 end
541         end
542
543         -- Set the sslctx
544         local handshake;
545         function handler.set_sslctx(self, new_sslctx)
546                 sslctx = new_sslctx;
547                 local read, wrote
548                 handshake = coroutine_wrap( function( client ) -- create handshake coroutine
549                                 local err
550                                 for i = 1, _maxsslhandshake do
551                                         _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
552                                         _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
553                                         read, wrote = nil, nil
554                                         _, err = client:dohandshake( )
555                                         if not err then
556                                                 out_put( "server.lua: ssl handshake done" )
557                                                 handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
558                                                 handler.sendbuffer = _sendbuffer
559                                                 _ = status and status( handler, "ssl-handshake-complete" )
560                                                 if self.autostart_ssl and listeners.onconnect then
561                                                         listeners.onconnect(self);
562                                                 end
563                                                 _readlistlen = addsocket(_readlist, client, _readlistlen)
564                                                 return true
565                                         else
566                                                 if err == "wantwrite" then
567                                                         _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
568                                                         wrote = true
569                                                 elseif err == "wantread" then
570                                                         _readlistlen = addsocket(_readlist, client, _readlistlen)
571                                                         read = true
572                                                 else
573                                                         break;
574                                                 end
575                                                 err = nil;
576                                                 coroutine_yield( ) -- handshake not finished
577                                         end
578                                 end
579                                 out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
580                                 _ = handler and handler:force_close("ssl handshake failed")
581                                 return false, err -- handshake failed
582                         end
583                 )
584         end
585         if luasec then
586                 handler.starttls = function( self, _sslctx)
587                         if _sslctx then
588                                 handler:set_sslctx(_sslctx);
589                         end
590                         if bufferqueuelen > 0 then
591                                 out_put "server.lua: we need to do tls, but delaying until send buffer empty"
592                                 needtls = true
593                                 return
594                         end
595                         out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
596                         local oldsocket, err = socket
597                         socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
598                         if not socket then
599                                 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
600                                 return nil, err -- fatal error
601                         end
602
603                         socket:settimeout( 0 )
604
605                         -- add the new socket to our system
606                         send = socket.send
607                         receive = socket.receive
608                         shutdown = id
609                         _socketlist[ socket ] = handler
610                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
611                         
612                         -- remove traces of the old socket
613                         _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
614                         _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
615                         _socketlist[ oldsocket ] = nil
616
617                         handler.starttls = nil
618                         needtls = nil
619
620                         -- Secure now (if handshake fails connection will close)
621                         ssl = true
622
623                         handler.readbuffer = handshake
624                         handler.sendbuffer = handshake
625                         return handshake( socket ) -- do handshake
626                 end
627         end
628
629         handler.readbuffer = _readbuffer
630         handler.sendbuffer = _sendbuffer
631         send = socket.send
632         receive = socket.receive
633         shutdown = ( ssl and id ) or socket.shutdown
634
635         _socketlist[ socket ] = handler
636         _readlistlen = addsocket(_readlist, socket, _readlistlen)
637
638         if sslctx and luasec then
639                 out_put "server.lua: auto-starting ssl negotiation..."
640                 handler.autostart_ssl = true;
641                 local ok, err = handler:starttls(sslctx);
642                 if ok == false then
643                         return nil, nil, err
644                 end
645         end
646
647         return handler, socket
648 end
649
650 id = function( )
651 end
652
653 idfalse = function( )
654         return false
655 end
656
657 addsocket = function( list, socket, len )
658         if not list[ socket ] then
659                 len = len + 1
660                 list[ len ] = socket
661                 list[ socket ] = len
662         end
663         return len;
664 end
665
666 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )
667         local pos = list[ socket ]
668         if pos then
669                 list[ socket ] = nil
670                 local last = list[ len ]
671                 list[ len ] = nil
672                 if last ~= socket then
673                         list[ last ] = pos
674                         list[ pos ] = last
675                 end
676                 return len - 1
677         end
678         return len
679 end
680
681 closesocket = function( socket )
682         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
683         _readlistlen = removesocket( _readlist, socket, _readlistlen )
684         _socketlist[ socket ] = nil
685         socket:close( )
686         --mem_free( )
687 end
688
689 local function link(sender, receiver, buffersize)
690         local sender_locked;
691         local _sendbuffer = receiver.sendbuffer;
692         function receiver.sendbuffer()
693                 _sendbuffer();
694                 if sender_locked and receiver.bufferlen() < buffersize then
695                         sender:lock_read(false); -- Unlock now
696                         sender_locked = nil;
697                 end
698         end
699         
700         local _readbuffer = sender.readbuffer;
701         function sender.readbuffer()
702                 _readbuffer();
703                 if not sender_locked and receiver.bufferlen() >= buffersize then
704                         sender_locked = true;
705                         sender:lock_read(true);
706                 end
707         end
708         sender:set_mode("*a");
709 end
710
711 ----------------------------------// PUBLIC //--
712
713 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
714         local err
715         if type( listeners ) ~= "table" then
716                 err = "invalid listener table"
717         end
718         if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
719                 err = "invalid port"
720         elseif _server[ addr..":"..port ] then
721                 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
722         elseif sslctx and not luasec then
723                 err = "luasec not found"
724         end
725         if err then
726                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
727                 return nil, err
728         end
729         addr = addr or "*"
730         local server, err = socket_bind( addr, port, _tcpbacklog )
731         if err then
732                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
733                 return nil, err
734         end
735         local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
736         if not handler then
737                 server:close( )
738                 return nil, err
739         end
740         server:settimeout( 0 )
741         _readlistlen = addsocket(_readlist, server, _readlistlen)
742         _server[ addr..":"..port ] = handler
743         _socketlist[ server ] = handler
744         out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
745         return handler
746 end
747
748 getserver = function ( addr, port )
749         return _server[ addr..":"..port ];
750 end
751
752 removeserver = function( addr, port )
753         local handler = _server[ addr..":"..port ]
754         if not handler then
755                 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
756         end
757         handler:close( )
758         _server[ addr..":"..port ] = nil
759         return true
760 end
761
762 closeall = function( )
763         for _, handler in pairs( _socketlist ) do
764                 handler:close( )
765                 _socketlist[ _ ] = nil
766         end
767         _readlistlen = 0
768         _sendlistlen = 0
769         _timerlistlen = 0
770         _server = { }
771         _readlist = { }
772         _sendlist = { }
773         _timerlist = { }
774         _socketlist = { }
775         --mem_free( )
776 end
777
778 getsettings = function( )
779         return {
780                 select_timeout = _selecttimeout;
781                 select_sleep_time = _sleeptime;
782                 tcp_backlog = _tcpbacklog;
783                 max_send_buffer_size = _maxsendlen;
784                 max_receive_buffer_size = _maxreadlen;
785                 select_idle_check_interval = _checkinterval;
786                 send_timeout = _sendtimeout;
787                 read_timeout = _readtimeout;
788                 max_connections = _maxselectlen;
789                 max_ssl_handshake_roundtrips = _maxsslhandshake;
790                 highest_allowed_fd = _maxfd;
791         }
792 end
793
794 changesettings = function( new )
795         if type( new ) ~= "table" then
796                 return nil, "invalid settings table"
797         end
798         _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
799         _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
800         _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
801         _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
802         _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
803         _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
804         _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
805         _readtimeout = tonumber( new.read_timeout ) or _readtimeout
806         _maxselectlen = new.max_connections or _maxselectlen
807         _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
808         _maxfd = new.highest_allowed_fd or _maxfd
809         return true
810 end
811
812 addtimer = function( listener )
813         if type( listener ) ~= "function" then
814                 return nil, "invalid listener function"
815         end
816         _timerlistlen = _timerlistlen + 1
817         _timerlist[ _timerlistlen ] = listener
818         return true
819 end
820
821 stats = function( )
822         return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
823 end
824
825 local quitting;
826
827 local function setquitting(quit)
828         quitting = not not quit;
829 end
830
831 loop = function(once) -- this is the main loop of the program
832         if quitting then return "quitting"; end
833         if once then quitting = "once"; end
834         local next_timer_time = math_huge;
835         repeat
836                 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
837                 for i, socket in ipairs( write ) do -- send data waiting in writequeues
838                         local handler = _socketlist[ socket ]
839                         if handler then
840                                 handler.sendbuffer( )
841                         else
842                                 closesocket( socket )
843                                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen
844                         end
845                 end
846                 for i, socket in ipairs( read ) do -- receive data
847                         local handler = _socketlist[ socket ]
848                         if handler then
849                                 handler.readbuffer( )
850                         else
851                                 closesocket( socket )
852                                 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
853                         end
854                 end
855                 for handler, err in pairs( _closelist ) do
856                         handler.disconnect( )( handler, err )
857                         handler:force_close()    -- forced disconnect
858                         _closelist[ handler ] = nil;
859                 end
860                 _currenttime = luasocket_gettime( )
861
862                 -- Check for socket timeouts
863                 local difftime = os_difftime( _currenttime - _starttime )
864                 if difftime > _checkinterval then
865                         _starttime = _currenttime
866                         for handler, timestamp in pairs( _writetimes ) do
867                                 if os_difftime( _currenttime - timestamp ) > _sendtimeout then
868                                         --_writetimes[ handler ] = nil
869                                         handler.disconnect( )( handler, "send timeout" )
870                                         handler:force_close()    -- forced disconnect
871                                 end
872                         end
873                         for handler, timestamp in pairs( _readtimes ) do
874                                 if os_difftime( _currenttime - timestamp ) > _readtimeout then
875                                         --_readtimes[ handler ] = nil
876                                         handler.disconnect( )( handler, "read timeout" )
877                                         handler:close( )        -- forced disconnect?
878                                 end
879                         end
880                 end
881
882                 -- Fire timers
883                 if _currenttime - _timer >= math_min(next_timer_time, 1) then
884                         next_timer_time = math_huge;
885                         for i = 1, _timerlistlen do
886                                 local t = _timerlist[ i ]( _currenttime ) -- fire timers
887                                 if t then next_timer_time = math_min(next_timer_time, t); end
888                         end
889                         _timer = _currenttime
890                 else
891                         next_timer_time = next_timer_time - (_currenttime - _timer);
892                 end
893
894                 -- wait some time (0 by default)
895                 socket_sleep( _sleeptime )
896         until quitting;
897         if once and quitting == "once" then quitting = nil; return; end
898         return "quitting"
899 end
900
901 local function step()
902         return loop(true);
903 end
904
905 local function get_backend()
906         return "select";
907 end
908
909 --// EXPERIMENTAL //--
910
911 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
912         local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
913         if not handler then return nil, err end
914         _socketlist[ socket ] = handler
915         if not sslctx then
916                 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
917                 if listeners.onconnect then
918                         -- When socket is writeable, call onconnect
919                         local _sendbuffer = handler.sendbuffer;
920                         handler.sendbuffer = function ()
921                                 handler.sendbuffer = _sendbuffer;
922                                 listeners.onconnect(handler);
923                                 return _sendbuffer(); -- Send any queued outgoing data
924                         end
925                 end
926         end
927         return handler, socket
928 end
929
930 local addclient = function( address, port, listeners, pattern, sslctx )
931         local client, err = luasocket.tcp( )
932         if err then
933                 return nil, err
934         end
935         client:settimeout( 0 )
936         _, err = client:connect( address, port )
937         if err then -- try again
938                 local handler = wrapclient( client, address, port, listeners )
939         else
940                 wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
941         end
942 end
943
944 --// EXPERIMENTAL //--
945
946 ----------------------------------// BEGIN //--
947
948 use "setmetatable" ( _socketlist, { __mode = "k" } )
949 use "setmetatable" ( _readtimes, { __mode = "k" } )
950 use "setmetatable" ( _writetimes, { __mode = "k" } )
951
952 _timer = luasocket_gettime( )
953 _starttime = luasocket_gettime( )
954
955 local function setlogger(new_logger)
956         local old_logger = log;
957         if new_logger then
958                 log = new_logger;
959         end
960         return old_logger;
961 end
962
963 ----------------------------------// PUBLIC INTERFACE //--
964
965 return {
966         _addtimer = addtimer,
967
968         addclient = addclient,
969         wrapclient = wrapclient,
970         
971         loop = loop,
972         link = link,
973         step = step,
974         stats = stats,
975         closeall = closeall,
976         addserver = addserver,
977         getserver = getserver,
978         setlogger = setlogger,
979         getsettings = getsettings,
980         setquitting = setquitting,
981         removeserver = removeserver,
982         get_backend = get_backend,
983         changesettings = changesettings,
984 }