Merge
[prosody.git] / net / server_select.lua
1 -- 
2 -- server.lua by blastbeat of the luadch project
3 -- Re-used here under the MIT/X Consortium License
4 -- 
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
6 --
7
8 -- // wrapping luadch stuff // --
9
10 local use = function( what )
11         return _G[ what ]
12 end
13
14 local log, table_concat = require ("util.logger").init("socket"), table.concat;
15 local out_put = function (...) return log("debug", table_concat{...}); end
16 local out_error = function (...) return log("warn", table_concat{...}); end
17
18 ----------------------------------// DECLARATION //--
19
20 --// constants //--
21
22 local STAT_UNIT = 1 -- byte
23
24 --// lua functions //--
25
26 local type = use "type"
27 local pairs = use "pairs"
28 local ipairs = use "ipairs"
29 local tonumber = use "tonumber"
30 local tostring = use "tostring"
31
32 --// lua libs //--
33
34 local os = use "os"
35 local table = use "table"
36 local string = use "string"
37 local coroutine = use "coroutine"
38
39 --// lua lib methods //--
40
41 local os_difftime = os.difftime
42 local math_min = math.min
43 local math_huge = math.huge
44 local table_concat = table.concat
45 local string_sub = string.sub
46 local coroutine_wrap = coroutine.wrap
47 local coroutine_yield = coroutine.yield
48
49 --// extern libs //--
50
51 local luasec = use "ssl"
52 local luasocket = use "socket" or require "socket"
53 local luasocket_gettime = luasocket.gettime
54
55 --// extern lib methods //--
56
57 local ssl_wrap = ( luasec and luasec.wrap )
58 local socket_bind = luasocket.bind
59 local socket_sleep = luasocket.sleep
60 local socket_select = luasocket.select
61
62 --// functions //--
63
64 local id
65 local loop
66 local stats
67 local idfalse
68 local closeall
69 local addsocket
70 local addserver
71 local addtimer
72 local getserver
73 local wrapserver
74 local getsettings
75 local closesocket
76 local removesocket
77 local removeserver
78 local wrapconnection
79 local changesettings
80
81 --// tables //--
82
83 local _server
84 local _readlist
85 local _timerlist
86 local _sendlist
87 local _socketlist
88 local _closelist
89 local _readtimes
90 local _writetimes
91
92 --// simple data types //--
93
94 local _
95 local _readlistlen
96 local _sendlistlen
97 local _timerlistlen
98
99 local _sendtraffic
100 local _readtraffic
101
102 local _selecttimeout
103 local _sleeptime
104 local _tcpbacklog
105
106 local _starttime
107 local _currenttime
108
109 local _maxsendlen
110 local _maxreadlen
111
112 local _checkinterval
113 local _sendtimeout
114 local _readtimeout
115
116 local _timer
117
118 local _maxselectlen
119 local _maxfd
120
121 local _maxsslhandshake
122
123 ----------------------------------// DEFINITION //--
124
125 _server = { } -- key = port, value = table; list of listening servers
126 _readlist = { } -- array with sockets to read from
127 _sendlist = { } -- arrary with sockets to write to
128 _timerlist = { } -- array of timer functions
129 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
130 _readtimes = { } -- key = handler, value = timestamp of last data reading
131 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
132 _closelist = { } -- handlers to close
133
134 _readlistlen = 0 -- length of readlist
135 _sendlistlen = 0 -- length of sendlist
136 _timerlistlen = 0 -- lenght of timerlist
137
138 _sendtraffic = 0 -- some stats
139 _readtraffic = 0
140
141 _selecttimeout = 1 -- timeout of socket.select
142 _sleeptime = 0 -- time to wait at the end of every loop
143 _tcpbacklog = 128 -- some kind of hint to the OS
144
145 _maxsendlen = 51000 * 1024 -- max len of send buffer
146 _maxreadlen = 25000 * 1024 -- max len of read buffer
147
148 _checkinterval = 1200000 -- interval in secs to check idle clients
149 _sendtimeout = 60000 -- allowed send idle time in secs
150 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
151
152 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
153 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
154 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
155
156 _maxsslhandshake = 30 -- max handshake round-trips
157
158 ----------------------------------// PRIVATE //--
159
160 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
161
162         if socket:getfd() >= _maxfd then
163                 out_error("server.lua: Disallowed FD number: "..socket:getfd())
164                 socket:close()
165                 return nil, "fd-too-large"
166         end
167
168         local connections = 0
169
170         local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
171
172         local accept = socket.accept
173
174         --// public methods of the object //--
175
176         local handler = { }
177
178         handler.shutdown = function( ) end
179
180         handler.ssl = function( )
181                 return sslctx ~= nil
182         end
183         handler.sslctx = function( )
184                 return sslctx
185         end
186         handler.remove = function( )
187                 connections = connections - 1
188                 if handler then
189                         handler.resume( )
190                 end
191         end
192         handler.close = function()
193                 socket:close( )
194                 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
195                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
196                 _server[ip..":"..serverport] = nil;
197                 _socketlist[ socket ] = nil
198                 handler = nil
199                 socket = nil
200                 --mem_free( )
201                 out_put "server.lua: closed server handler and removed sockets from list"
202         end
203         handler.pause = function( hard )
204                 if not handler.paused then
205                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
206                         if hard then
207                                 _socketlist[ socket ] = nil
208                                 socket:close( )
209                                 socket = nil;
210                         end
211                         handler.paused = true;
212                 end
213         end
214         handler.resume = function( )
215                 if handler.paused then
216                         if not socket then
217                                 socket = socket_bind( ip, serverport, _tcpbacklog );
218                                 socket:settimeout( 0 )
219                         end
220                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
221                         _socketlist[ socket ] = handler
222                         handler.paused = false;
223                 end
224         end
225         handler.ip = function( )
226                 return ip
227         end
228         handler.serverport = function( )
229                 return serverport
230         end
231         handler.socket = function( )
232                 return socket
233         end
234         handler.readbuffer = function( )
235                 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
236                         handler.pause( )
237                         out_put( "server.lua: refused new client connection: server full" )
238                         return false
239                 end
240                 local client, err = accept( socket )    -- try to accept
241                 if client then
242                         local ip, clientport = client:getpeername( )
243                         local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
244                         if err then -- error while wrapping ssl socket
245                                 return false
246                         end
247                         connections = connections + 1
248                         out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
249                         if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
250                                 return dispatch( handler );
251                         end
252                         return;
253                 elseif err then -- maybe timeout or something else
254                         out_put( "server.lua: error with new client connection: ", tostring(err) )
255                         return false
256                 end
257         end
258         return handler
259 end
260
261 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
262
263         if socket:getfd() >= _maxfd then
264                 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
265                 socket:close( ) -- Should we send some kind of error here?
266                 if server then
267                         server.pause( )
268                 end
269                 return nil, nil, "fd-too-large"
270         end
271         socket:settimeout( 0 )
272
273         --// local import of socket methods //--
274
275         local send
276         local receive
277         local shutdown
278
279         --// private closures of the object //--
280
281         local ssl
282
283         local dispatch = listeners.onincoming
284         local status = listeners.onstatus
285         local disconnect = listeners.ondisconnect
286         local drain = listeners.ondrain
287
288         local bufferqueue = { } -- buffer array
289         local bufferqueuelen = 0        -- end of buffer array
290
291         local toclose
292         local fatalerror
293         local needtls
294
295         local bufferlen = 0
296
297         local noread = false
298         local nosend = false
299
300         local sendtraffic, readtraffic = 0, 0
301
302         local maxsendlen = _maxsendlen
303         local maxreadlen = _maxreadlen
304
305         --// public methods of the object //--
306
307         local handler = bufferqueue -- saves a table ^_^
308
309         handler.dispatch = function( )
310                 return dispatch
311         end
312         handler.disconnect = function( )
313                 return disconnect
314         end
315         handler.setlistener = function( self, listeners )
316                 dispatch = listeners.onincoming
317                 disconnect = listeners.ondisconnect
318                 status = listeners.onstatus
319                 drain = listeners.ondrain
320         end
321         handler.getstats = function( )
322                 return readtraffic, sendtraffic
323         end
324         handler.ssl = function( )
325                 return ssl
326         end
327         handler.sslctx = function ( )
328                 return sslctx
329         end
330         handler.send = function( _, data, i, j )
331                 return send( socket, data, i, j )
332         end
333         handler.receive = function( pattern, prefix )
334                 return receive( socket, pattern, prefix )
335         end
336         handler.shutdown = function( pattern )
337                 return shutdown( socket, pattern )
338         end
339         handler.setoption = function (self, option, value)
340                 if socket.setoption then
341                         return socket:setoption(option, value);
342                 end
343                 return false, "setoption not implemented";
344         end
345         handler.force_close = function ( self, err )
346                 if bufferqueuelen ~= 0 then
347                         out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
348                         bufferqueuelen = 0;
349                 end
350                 return self:close(err);
351         end
352         handler.close = function( self, err )
353                 if not handler then return true; end
354                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
355                 _readtimes[ handler ] = nil
356                 if bufferqueuelen ~= 0 then
357                         handler.sendbuffer() -- Try now to send any outstanding data
358                         if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
359                                 if handler then
360                                         handler.write = nil -- ... but no further writing allowed
361                                 end
362                                 toclose = true
363                                 return false
364                         end
365                 end
366                 if socket then
367                         _ = shutdown and shutdown( socket )
368                         socket:close( )
369                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
370                         _socketlist[ socket ] = nil
371                         socket = nil
372                 else
373                         out_put "server.lua: socket already closed"
374                 end
375                 if handler then
376                         _writetimes[ handler ] = nil
377                         _closelist[ handler ] = nil
378                         local _handler = handler;
379                         handler = nil
380                         if disconnect then
381                                 disconnect(_handler, err or false);
382                                 disconnect = nil
383                         end
384                 end
385                 if server then
386                         server.remove( )
387                 end
388                 out_put "server.lua: closed client handler and removed socket from list"
389                 return true
390         end
391         handler.ip = function( )
392                 return ip
393         end
394         handler.serverport = function( )
395                 return serverport
396         end
397         handler.clientport = function( )
398                 return clientport
399         end
400         handler.port = handler.clientport -- COMPAT server_event
401         local write = function( self, data )
402                 bufferlen = bufferlen + #data
403                 if bufferlen > maxsendlen then
404                         _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle
405                         handler.write = idfalse -- dont write anymore
406                         return false
407                 elseif socket and not _sendlist[ socket ] then
408                         _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
409                 end
410                 bufferqueuelen = bufferqueuelen + 1
411                 bufferqueue[ bufferqueuelen ] = data
412                 if handler then
413                         _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
414                 end
415                 return true
416         end
417         handler.write = write
418         handler.bufferqueue = function( self )
419                 return bufferqueue
420         end
421         handler.socket = function( self )
422                 return socket
423         end
424         handler.set_mode = function( self, new )
425                 pattern = new or pattern
426                 return pattern
427         end
428         handler.set_send = function ( self, newsend )
429                 send = newsend or send
430                 return send
431         end
432         handler.bufferlen = function( self, readlen, sendlen )
433                 maxsendlen = sendlen or maxsendlen
434                 maxreadlen = readlen or maxreadlen
435                 return bufferlen, maxreadlen, maxsendlen
436         end
437         --TODO: Deprecate
438         handler.lock_read = function (self, switch)
439                 if switch == true then
440                         local tmp = _readlistlen
441                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
442                         _readtimes[ handler ] = nil
443                         if _readlistlen ~= tmp then
444                                 noread = true
445                         end
446                 elseif switch == false then
447                         if noread then
448                                 noread = false
449                                 _readlistlen = addsocket(_readlist, socket, _readlistlen)
450                                 _readtimes[ handler ] = _currenttime
451                         end
452                 end
453                 return noread
454         end
455         handler.pause = function (self)
456                 return self:lock_read(true);
457         end
458         handler.resume = function (self)
459                 return self:lock_read(false);
460         end
461         handler.lock = function( self, switch )
462                 handler.lock_read (switch)
463                 if switch == true then
464                         handler.write = idfalse
465                         local tmp = _sendlistlen
466                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
467                         _writetimes[ handler ] = nil
468                         if _sendlistlen ~= tmp then
469                                 nosend = true
470                         end
471                 elseif switch == false then
472                         handler.write = write
473                         if nosend then
474                                 nosend = false
475                                 write( "" )
476                         end
477                 end
478                 return noread, nosend
479         end
480         local _readbuffer = function( ) -- this function reads data
481                 local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
482                 if not err or (err == "wantread" or err == "timeout") then -- received something
483                         local buffer = buffer or part or ""
484                         local len = #buffer
485                         if len > maxreadlen then
486                                 handler:close( "receive buffer exceeded" )
487                                 return false
488                         end
489                         local count = len * STAT_UNIT
490                         readtraffic = readtraffic + count
491                         _readtraffic = _readtraffic + count
492                         _readtimes[ handler ] = _currenttime
493                         --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
494                         return dispatch( handler, buffer, err )
495                 else    -- connections was closed or fatal error
496                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
497                         fatalerror = true
498                         _ = handler and handler:force_close( err )
499                         return false
500                 end
501         end
502         local _sendbuffer = function( ) -- this function sends data
503                 local succ, err, byte, buffer, count;
504                 if socket then
505                         buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
506                         succ, err, byte = send( socket, buffer, 1, bufferlen )
507                         count = ( succ or byte or 0 ) * STAT_UNIT
508                         sendtraffic = sendtraffic + count
509                         _sendtraffic = _sendtraffic + count
510                         for i = bufferqueuelen,1,-1 do
511                                 bufferqueue[ i ] = nil
512                         end
513                         --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
514                 else
515                         succ, err, count = false, "unexpected close", 0;
516                 end
517                 if succ then    -- sending succesful
518                         bufferqueuelen = 0
519                         bufferlen = 0
520                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
521                         _writetimes[ handler ] = nil
522                         if drain then
523                                 drain(handler)
524                         end
525                         _ = needtls and handler:starttls(nil)
526                         _ = toclose and handler:force_close( )
527                         return true
528                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
529                         buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
530                         bufferqueue[ 1 ] = buffer        -- insert new buffer in queue
531                         bufferqueuelen = 1
532                         bufferlen = bufferlen - byte
533                         _writetimes[ handler ] = _currenttime
534                         return true
535                 else    -- connection was closed during sending or fatal error
536                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
537                         fatalerror = true
538                         _ = handler and handler:force_close( err )
539                         return false
540                 end
541         end
542
543         -- Set the sslctx
544         local handshake;
545         function handler.set_sslctx(self, new_sslctx)
546                 sslctx = new_sslctx;
547                 local read, wrote
548                 handshake = coroutine_wrap( function( client ) -- create handshake coroutine
549                                 local err
550                                 for i = 1, _maxsslhandshake do
551                                         _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
552                                         _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
553                                         read, wrote = nil, nil
554                                         _, err = client:dohandshake( )
555                                         if not err then
556                                                 out_put( "server.lua: ssl handshake done" )
557                                                 handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
558                                                 handler.sendbuffer = _sendbuffer
559                                                 _ = status and status( handler, "ssl-handshake-complete" )
560                                                 if self.autostart_ssl and listeners.onconnect then
561                                                         listeners.onconnect(self);
562                                                 end
563                                                 _readlistlen = addsocket(_readlist, client, _readlistlen)
564                                                 return true
565                                         else
566                                                 if err == "wantwrite" then
567                                                         _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
568                                                         wrote = true
569                                                 elseif err == "wantread" then
570                                                         _readlistlen = addsocket(_readlist, client, _readlistlen)
571                                                         read = true
572                                                 else
573                                                         break;
574                                                 end
575                                                 err = nil;
576                                                 coroutine_yield( ) -- handshake not finished
577                                         end
578                                 end
579                                 out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
580                                 _ = handler and handler:force_close("ssl handshake failed")
581                                 return false, err -- handshake failed
582                         end
583                 )
584         end
585         if luasec then
586                 handler.starttls = function( self, _sslctx)
587                         if _sslctx then
588                                 handler:set_sslctx(_sslctx);
589                         end
590                         if bufferqueuelen > 0 then
591                                 out_put "server.lua: we need to do tls, but delaying until send buffer empty"
592                                 needtls = true
593                                 return
594                         end
595                         out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
596                         local oldsocket, err = socket
597                         socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
598                         if not socket then
599                                 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
600                                 return nil, err -- fatal error
601                         end
602
603                         socket:settimeout( 0 )
604
605                         -- add the new socket to our system
606                         send = socket.send
607                         receive = socket.receive
608                         shutdown = id
609                         _socketlist[ socket ] = handler
610                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
611                         
612                         -- remove traces of the old socket
613                         _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
614                         _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
615                         _socketlist[ oldsocket ] = nil
616
617                         handler.starttls = nil
618                         needtls = nil
619
620                         -- Secure now (if handshake fails connection will close)
621                         ssl = true
622
623                         handler.readbuffer = handshake
624                         handler.sendbuffer = handshake
625                         return handshake( socket ) -- do handshake
626                 end
627         end
628
629         handler.readbuffer = _readbuffer
630         handler.sendbuffer = _sendbuffer
631         send = socket.send
632         receive = socket.receive
633         shutdown = ( ssl and id ) or socket.shutdown
634
635         _socketlist[ socket ] = handler
636         _readlistlen = addsocket(_readlist, socket, _readlistlen)
637
638         if sslctx and luasec then
639                 out_put "server.lua: auto-starting ssl negotiation..."
640                 handler.autostart_ssl = true;
641                 local ok, err = handler:starttls(sslctx);
642                 if ok == false then
643                         return nil, nil, err
644                 end
645         end
646
647         return handler, socket
648 end
649
650 id = function( )
651 end
652
653 idfalse = function( )
654         return false
655 end
656
657 addsocket = function( list, socket, len )
658         if not list[ socket ] then
659                 len = len + 1
660                 list[ len ] = socket
661                 list[ socket ] = len
662         end
663         return len;
664 end
665
666 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )
667         local pos = list[ socket ]
668         if pos then
669                 list[ socket ] = nil
670                 local last = list[ len ]
671                 list[ len ] = nil
672                 if last ~= socket then
673                         list[ last ] = pos
674                         list[ pos ] = last
675                 end
676                 return len - 1
677         end
678         return len
679 end
680
681 closesocket = function( socket )
682         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
683         _readlistlen = removesocket( _readlist, socket, _readlistlen )
684         _socketlist[ socket ] = nil
685         socket:close( )
686         --mem_free( )
687 end
688
689 local function link(sender, receiver, buffersize)
690         local sender_locked;
691         local _sendbuffer = receiver.sendbuffer;
692         function receiver.sendbuffer()
693                 _sendbuffer();
694                 if sender_locked and receiver.bufferlen() < buffersize then
695                         sender:lock_read(false); -- Unlock now
696                         sender_locked = nil;
697                 end
698         end
699         
700         local _readbuffer = sender.readbuffer;
701         function sender.readbuffer()
702                 _readbuffer();
703                 if not sender_locked and receiver.bufferlen() >= buffersize then
704                         sender_locked = true;
705                         sender:lock_read(true);
706                 end
707         end
708 end
709
710 ----------------------------------// PUBLIC //--
711
712 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
713         local err
714         if type( listeners ) ~= "table" then
715                 err = "invalid listener table"
716         end
717         if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
718                 err = "invalid port"
719         elseif _server[ addr..":"..port ] then
720                 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
721         elseif sslctx and not luasec then
722                 err = "luasec not found"
723         end
724         if err then
725                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
726                 return nil, err
727         end
728         addr = addr or "*"
729         local server, err = socket_bind( addr, port, _tcpbacklog )
730         if err then
731                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
732                 return nil, err
733         end
734         local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
735         if not handler then
736                 server:close( )
737                 return nil, err
738         end
739         server:settimeout( 0 )
740         _readlistlen = addsocket(_readlist, server, _readlistlen)
741         _server[ addr..":"..port ] = handler
742         _socketlist[ server ] = handler
743         out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
744         return handler
745 end
746
747 getserver = function ( addr, port )
748         return _server[ addr..":"..port ];
749 end
750
751 removeserver = function( addr, port )
752         local handler = _server[ addr..":"..port ]
753         if not handler then
754                 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
755         end
756         handler:close( )
757         _server[ addr..":"..port ] = nil
758         return true
759 end
760
761 closeall = function( )
762         for _, handler in pairs( _socketlist ) do
763                 handler:close( )
764                 _socketlist[ _ ] = nil
765         end
766         _readlistlen = 0
767         _sendlistlen = 0
768         _timerlistlen = 0
769         _server = { }
770         _readlist = { }
771         _sendlist = { }
772         _timerlist = { }
773         _socketlist = { }
774         --mem_free( )
775 end
776
777 getsettings = function( )
778         return {
779                 select_timeout = _selecttimeout;
780                 select_sleep_time = _sleeptime;
781                 tcp_backlog = _tcpbacklog;
782                 max_send_buffer_size = _maxsendlen;
783                 max_receive_buffer_size = _maxreadlen;
784                 select_idle_check_interval = _checkinterval;
785                 send_timeout = _sendtimeout;
786                 read_timeout = _readtimeout;
787                 max_connections = _maxselectlen;
788                 max_ssl_handshake_roundtrips = _maxsslhandshake;
789                 highest_allowed_fd = _maxfd;
790         }
791 end
792
793 changesettings = function( new )
794         if type( new ) ~= "table" then
795                 return nil, "invalid settings table"
796         end
797         _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
798         _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
799         _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
800         _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
801         _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
802         _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
803         _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
804         _readtimeout = tonumber( new.read_timeout ) or _readtimeout
805         _maxselectlen = new.max_connections or _maxselectlen
806         _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
807         _maxfd = new.highest_allowed_fd or _maxfd
808         return true
809 end
810
811 addtimer = function( listener )
812         if type( listener ) ~= "function" then
813                 return nil, "invalid listener function"
814         end
815         _timerlistlen = _timerlistlen + 1
816         _timerlist[ _timerlistlen ] = listener
817         return true
818 end
819
820 stats = function( )
821         return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
822 end
823
824 local quitting;
825
826 local function setquitting(quit)
827         quitting = not not quit;
828 end
829
830 loop = function(once) -- this is the main loop of the program
831         if quitting then return "quitting"; end
832         if once then quitting = "once"; end
833         local next_timer_time = math_huge;
834         repeat
835                 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
836                 for i, socket in ipairs( write ) do -- send data waiting in writequeues
837                         local handler = _socketlist[ socket ]
838                         if handler then
839                                 handler.sendbuffer( )
840                         else
841                                 closesocket( socket )
842                                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen
843                         end
844                 end
845                 for i, socket in ipairs( read ) do -- receive data
846                         local handler = _socketlist[ socket ]
847                         if handler then
848                                 handler.readbuffer( )
849                         else
850                                 closesocket( socket )
851                                 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
852                         end
853                 end
854                 for handler, err in pairs( _closelist ) do
855                         handler.disconnect( )( handler, err )
856                         handler:force_close()    -- forced disconnect
857                         _closelist[ handler ] = nil;
858                 end
859                 _currenttime = luasocket_gettime( )
860
861                 -- Check for socket timeouts
862                 local difftime = os_difftime( _currenttime - _starttime )
863                 if difftime > _checkinterval then
864                         _starttime = _currenttime
865                         for handler, timestamp in pairs( _writetimes ) do
866                                 if os_difftime( _currenttime - timestamp ) > _sendtimeout then
867                                         --_writetimes[ handler ] = nil
868                                         handler.disconnect( )( handler, "send timeout" )
869                                         handler:force_close()    -- forced disconnect
870                                 end
871                         end
872                         for handler, timestamp in pairs( _readtimes ) do
873                                 if os_difftime( _currenttime - timestamp ) > _readtimeout then
874                                         --_readtimes[ handler ] = nil
875                                         handler.disconnect( )( handler, "read timeout" )
876                                         handler:close( )        -- forced disconnect?
877                                 end
878                         end
879                 end
880
881                 -- Fire timers
882                 if _currenttime - _timer >= math_min(next_timer_time, 1) then
883                         next_timer_time = math_huge;
884                         for i = 1, _timerlistlen do
885                                 local t = _timerlist[ i ]( _currenttime ) -- fire timers
886                                 if t then next_timer_time = math_min(next_timer_time, t); end
887                         end
888                         _timer = _currenttime
889                 else
890                         next_timer_time = next_timer_time - (_currenttime - _timer);
891                 end
892
893                 -- wait some time (0 by default)
894                 socket_sleep( _sleeptime )
895         until quitting;
896         if once and quitting == "once" then quitting = nil; return; end
897         return "quitting"
898 end
899
900 local function step()
901         return loop(true);
902 end
903
904 local function get_backend()
905         return "select";
906 end
907
908 --// EXPERIMENTAL //--
909
910 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
911         local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
912         if not handler then return nil, err end
913         _socketlist[ socket ] = handler
914         if not sslctx then
915                 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
916                 if listeners.onconnect then
917                         -- When socket is writeable, call onconnect
918                         local _sendbuffer = handler.sendbuffer;
919                         handler.sendbuffer = function ()
920                                 handler.sendbuffer = _sendbuffer;
921                                 listeners.onconnect(handler);
922                                 return _sendbuffer(); -- Send any queued outgoing data
923                         end
924                 end
925         end
926         return handler, socket
927 end
928
929 local addclient = function( address, port, listeners, pattern, sslctx )
930         local client, err = luasocket.tcp( )
931         if err then
932                 return nil, err
933         end
934         client:settimeout( 0 )
935         _, err = client:connect( address, port )
936         if err then -- try again
937                 local handler = wrapclient( client, address, port, listeners )
938         else
939                 wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
940         end
941 end
942
943 --// EXPERIMENTAL //--
944
945 ----------------------------------// BEGIN //--
946
947 use "setmetatable" ( _socketlist, { __mode = "k" } )
948 use "setmetatable" ( _readtimes, { __mode = "k" } )
949 use "setmetatable" ( _writetimes, { __mode = "k" } )
950
951 _timer = luasocket_gettime( )
952 _starttime = luasocket_gettime( )
953
954 local function setlogger(new_logger)
955         local old_logger = log;
956         if new_logger then
957                 log = new_logger;
958         end
959         return old_logger;
960 end
961
962 ----------------------------------// PUBLIC INTERFACE //--
963
964 return {
965         _addtimer = addtimer,
966
967         addclient = addclient,
968         wrapclient = wrapclient,
969         
970         loop = loop,
971         link = link,
972         step = step,
973         stats = stats,
974         closeall = closeall,
975         addserver = addserver,
976         getserver = getserver,
977         setlogger = setlogger,
978         getsettings = getsettings,
979         setquitting = setquitting,
980         removeserver = removeserver,
981         get_backend = get_backend,
982         changesettings = changesettings,
983 }