Merge 0.10->trunk
[prosody.git] / net / server_select.lua
1 --
2 -- server.lua by blastbeat of the luadch project
3 -- Re-used here under the MIT/X Consortium License
4 --
5 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
6 --
7
8 -- // wrapping luadch stuff // --
9
10 local use = function( what )
11         return _G[ what ]
12 end
13
14 local log, table_concat = require ("util.logger").init("socket"), table.concat;
15 local out_put = function (...) return log("debug", table_concat{...}); end
16 local out_error = function (...) return log("warn", table_concat{...}); end
17
18 ----------------------------------// DECLARATION //--
19
20 --// constants //--
21
22 local STAT_UNIT = 1 -- byte
23
24 --// lua functions //--
25
26 local type = use "type"
27 local pairs = use "pairs"
28 local ipairs = use "ipairs"
29 local tonumber = use "tonumber"
30 local tostring = use "tostring"
31
32 --// lua libs //--
33
34 local table = use "table"
35 local string = use "string"
36 local coroutine = use "coroutine"
37
38 --// lua lib methods //--
39
40 local math_min = math.min
41 local math_huge = math.huge
42 local table_concat = table.concat
43 local table_insert = table.insert
44 local string_sub = string.sub
45 local coroutine_wrap = coroutine.wrap
46 local coroutine_yield = coroutine.yield
47
48 --// extern libs //--
49
50 local has_luasec, luasec = pcall ( require , "ssl" )
51 local luasocket = use "socket" or require "socket"
52 local luasocket_gettime = luasocket.gettime
53 local getaddrinfo = luasocket.dns.getaddrinfo
54
55 --// extern lib methods //--
56
57 local ssl_wrap = ( has_luasec and luasec.wrap )
58 local socket_bind = luasocket.bind
59 local socket_select = luasocket.select
60
61 --// functions //--
62
63 local id
64 local loop
65 local stats
66 local idfalse
67 local closeall
68 local addsocket
69 local addserver
70 local addtimer
71 local getserver
72 local wrapserver
73 local getsettings
74 local closesocket
75 local removesocket
76 local removeserver
77 local wrapconnection
78 local changesettings
79
80 --// tables //--
81
82 local _server
83 local _readlist
84 local _timerlist
85 local _sendlist
86 local _socketlist
87 local _closelist
88 local _readtimes
89 local _writetimes
90
91 --// simple data types //--
92
93 local _
94 local _readlistlen
95 local _sendlistlen
96 local _timerlistlen
97
98 local _sendtraffic
99 local _readtraffic
100
101 local _selecttimeout
102 local _tcpbacklog
103
104 local _starttime
105 local _currenttime
106
107 local _maxsendlen
108 local _maxreadlen
109
110 local _checkinterval
111 local _sendtimeout
112 local _readtimeout
113
114 local _maxselectlen
115 local _maxfd
116
117 local _maxsslhandshake
118
119 ----------------------------------// DEFINITION //--
120
121 _server = { } -- key = port, value = table; list of listening servers
122 _readlist = { } -- array with sockets to read from
123 _sendlist = { } -- arrary with sockets to write to
124 _timerlist = { } -- array of timer functions
125 _socketlist = { } -- key = socket, value = wrapped socket (handlers)
126 _readtimes = { } -- key = handler, value = timestamp of last data reading
127 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
128 _closelist = { } -- handlers to close
129
130 _readlistlen = 0 -- length of readlist
131 _sendlistlen = 0 -- length of sendlist
132 _timerlistlen = 0 -- lenght of timerlist
133
134 _sendtraffic = 0 -- some stats
135 _readtraffic = 0
136
137 _selecttimeout = 1 -- timeout of socket.select
138 _tcpbacklog = 128 -- some kind of hint to the OS
139
140 _maxsendlen = 51000 * 1024 -- max len of send buffer
141 _maxreadlen = 25000 * 1024 -- max len of read buffer
142
143 _checkinterval = 30 -- interval in secs to check idle clients
144 _sendtimeout = 60000 -- allowed send idle time in secs
145 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
146
147 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
148 _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
149 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
150
151 _maxsslhandshake = 30 -- max handshake round-trips
152
153 ----------------------------------// PRIVATE //--
154
155 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
156
157         if socket:getfd() >= _maxfd then
158                 out_error("server.lua: Disallowed FD number: "..socket:getfd())
159                 socket:close()
160                 return nil, "fd-too-large"
161         end
162
163         local connections = 0
164
165         local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
166
167         local accept = socket.accept
168
169         --// public methods of the object //--
170
171         local handler = { }
172
173         handler.shutdown = function( ) end
174
175         handler.ssl = function( )
176                 return sslctx ~= nil
177         end
178         handler.sslctx = function( )
179                 return sslctx
180         end
181         handler.remove = function( )
182                 connections = connections - 1
183                 if handler then
184                         handler.resume( )
185                 end
186         end
187         handler.close = function()
188                 socket:close( )
189                 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
190                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
191                 _server[ip..":"..serverport] = nil;
192                 _socketlist[ socket ] = nil
193                 handler = nil
194                 socket = nil
195                 --mem_free( )
196                 out_put "server.lua: closed server handler and removed sockets from list"
197         end
198         handler.pause = function( hard )
199                 if not handler.paused then
200                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
201                         if hard then
202                                 _socketlist[ socket ] = nil
203                                 socket:close( )
204                                 socket = nil;
205                         end
206                         handler.paused = true;
207                 end
208         end
209         handler.resume = function( )
210                 if handler.paused then
211                         if not socket then
212                                 socket = socket_bind( ip, serverport, _tcpbacklog );
213                                 socket:settimeout( 0 )
214                         end
215                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
216                         _socketlist[ socket ] = handler
217                         handler.paused = false;
218                 end
219         end
220         handler.ip = function( )
221                 return ip
222         end
223         handler.serverport = function( )
224                 return serverport
225         end
226         handler.socket = function( )
227                 return socket
228         end
229         handler.readbuffer = function( )
230                 if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
231                         handler.pause( )
232                         out_put( "server.lua: refused new client connection: server full" )
233                         return false
234                 end
235                 local client, err = accept( socket )    -- try to accept
236                 if client then
237                         local ip, clientport = client:getpeername( )
238                         local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
239                         if err then -- error while wrapping ssl socket
240                                 return false
241                         end
242                         connections = connections + 1
243                         out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
244                         if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
245                                 return dispatch( handler );
246                         end
247                         return;
248                 elseif err then -- maybe timeout or something else
249                         out_put( "server.lua: error with new client connection: ", tostring(err) )
250                         return false
251                 end
252         end
253         return handler
254 end
255
256 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
257
258         if socket:getfd() >= _maxfd then
259                 out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
260                 socket:close( ) -- Should we send some kind of error here?
261                 if server then
262                         server.pause( )
263                 end
264                 return nil, nil, "fd-too-large"
265         end
266         socket:settimeout( 0 )
267
268         --// local import of socket methods //--
269
270         local send
271         local receive
272         local shutdown
273
274         --// private closures of the object //--
275
276         local ssl
277
278         local dispatch = listeners.onincoming
279         local status = listeners.onstatus
280         local disconnect = listeners.ondisconnect
281         local drain = listeners.ondrain
282         local onreadtimeout = listeners.onreadtimeout;
283         local detach = listeners.ondetach
284
285         local bufferqueue = { } -- buffer array
286         local bufferqueuelen = 0        -- end of buffer array
287
288         local toclose
289         local needtls
290
291         local bufferlen = 0
292
293         local noread = false
294         local nosend = false
295
296         local sendtraffic, readtraffic = 0, 0
297
298         local maxsendlen = _maxsendlen
299         local maxreadlen = _maxreadlen
300
301         --// public methods of the object //--
302
303         local handler = bufferqueue -- saves a table ^_^
304
305         handler.dispatch = function( )
306                 return dispatch
307         end
308         handler.disconnect = function( )
309                 return disconnect
310         end
311         handler.onreadtimeout = onreadtimeout;
312
313         handler.setlistener = function( self, listeners )
314                 if detach then
315                         detach(self) -- Notify listener that it is no longer responsible for this connection
316                 end
317                 dispatch = listeners.onincoming
318                 disconnect = listeners.ondisconnect
319                 status = listeners.onstatus
320                 drain = listeners.ondrain
321                 handler.onreadtimeout = listeners.onreadtimeout
322                 detach = listeners.ondetach
323         end
324         handler.getstats = function( )
325                 return readtraffic, sendtraffic
326         end
327         handler.ssl = function( )
328                 return ssl
329         end
330         handler.sslctx = function ( )
331                 return sslctx
332         end
333         handler.send = function( _, data, i, j )
334                 return send( socket, data, i, j )
335         end
336         handler.receive = function( pattern, prefix )
337                 return receive( socket, pattern, prefix )
338         end
339         handler.shutdown = function( pattern )
340                 return shutdown( socket, pattern )
341         end
342         handler.setoption = function (self, option, value)
343                 if socket.setoption then
344                         return socket:setoption(option, value);
345                 end
346                 return false, "setoption not implemented";
347         end
348         handler.force_close = function ( self, err )
349                 if bufferqueuelen ~= 0 then
350                         out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
351                         bufferqueuelen = 0;
352                 end
353                 return self:close(err);
354         end
355         handler.close = function( self, err )
356                 if not handler then return true; end
357                 _readlistlen = removesocket( _readlist, socket, _readlistlen )
358                 _readtimes[ handler ] = nil
359                 if bufferqueuelen ~= 0 then
360                         handler.sendbuffer() -- Try now to send any outstanding data
361                         if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
362                                 if handler then
363                                         handler.write = nil -- ... but no further writing allowed
364                                 end
365                                 toclose = true
366                                 return false
367                         end
368                 end
369                 if socket then
370                         _ = shutdown and shutdown( socket )
371                         socket:close( )
372                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
373                         _socketlist[ socket ] = nil
374                         socket = nil
375                 else
376                         out_put "server.lua: socket already closed"
377                 end
378                 if handler then
379                         _writetimes[ handler ] = nil
380                         _closelist[ handler ] = nil
381                         local _handler = handler;
382                         handler = nil
383                         if disconnect then
384                                 disconnect(_handler, err or false);
385                                 disconnect = nil
386                         end
387                 end
388                 if server then
389                         server.remove( )
390                 end
391                 out_put "server.lua: closed client handler and removed socket from list"
392                 return true
393         end
394         handler.ip = function( )
395                 return ip
396         end
397         handler.serverport = function( )
398                 return serverport
399         end
400         handler.clientport = function( )
401                 return clientport
402         end
403         handler.port = handler.clientport -- COMPAT server_event
404         local write = function( self, data )
405                 bufferlen = bufferlen + #data
406                 if bufferlen > maxsendlen then
407                         _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle
408                         handler.write = idfalse -- dont write anymore
409                         return false
410                 elseif socket and not _sendlist[ socket ] then
411                         _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
412                 end
413                 bufferqueuelen = bufferqueuelen + 1
414                 bufferqueue[ bufferqueuelen ] = data
415                 if handler then
416                         _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
417                 end
418                 return true
419         end
420         handler.write = write
421         handler.bufferqueue = function( self )
422                 return bufferqueue
423         end
424         handler.socket = function( self )
425                 return socket
426         end
427         handler.set_mode = function( self, new )
428                 pattern = new or pattern
429                 return pattern
430         end
431         handler.set_send = function ( self, newsend )
432                 send = newsend or send
433                 return send
434         end
435         handler.bufferlen = function( self, readlen, sendlen )
436                 maxsendlen = sendlen or maxsendlen
437                 maxreadlen = readlen or maxreadlen
438                 return bufferlen, maxreadlen, maxsendlen
439         end
440         --TODO: Deprecate
441         handler.lock_read = function (self, switch)
442                 if switch == true then
443                         local tmp = _readlistlen
444                         _readlistlen = removesocket( _readlist, socket, _readlistlen )
445                         _readtimes[ handler ] = nil
446                         if _readlistlen ~= tmp then
447                                 noread = true
448                         end
449                 elseif switch == false then
450                         if noread then
451                                 noread = false
452                                 _readlistlen = addsocket(_readlist, socket, _readlistlen)
453                                 _readtimes[ handler ] = _currenttime
454                         end
455                 end
456                 return noread
457         end
458         handler.pause = function (self)
459                 return self:lock_read(true);
460         end
461         handler.resume = function (self)
462                 return self:lock_read(false);
463         end
464         handler.lock = function( self, switch )
465                 handler.lock_read (switch)
466                 if switch == true then
467                         handler.write = idfalse
468                         local tmp = _sendlistlen
469                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
470                         _writetimes[ handler ] = nil
471                         if _sendlistlen ~= tmp then
472                                 nosend = true
473                         end
474                 elseif switch == false then
475                         handler.write = write
476                         if nosend then
477                                 nosend = false
478                                 write( "" )
479                         end
480                 end
481                 return noread, nosend
482         end
483         local _readbuffer = function( ) -- this function reads data
484                 local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
485                 if not err or (err == "wantread" or err == "timeout") then -- received something
486                         local buffer = buffer or part or ""
487                         local len = #buffer
488                         if len > maxreadlen then
489                                 handler:close( "receive buffer exceeded" )
490                                 return false
491                         end
492                         local count = len * STAT_UNIT
493                         readtraffic = readtraffic + count
494                         _readtraffic = _readtraffic + count
495                         _readtimes[ handler ] = _currenttime
496                         --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
497                         return dispatch( handler, buffer, err )
498                 else    -- connections was closed or fatal error
499                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
500                         _ = handler and handler:force_close( err )
501                         return false
502                 end
503         end
504         local _sendbuffer = function( ) -- this function sends data
505                 local succ, err, byte, buffer, count;
506                 if socket then
507                         buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
508                         succ, err, byte = send( socket, buffer, 1, bufferlen )
509                         count = ( succ or byte or 0 ) * STAT_UNIT
510                         sendtraffic = sendtraffic + count
511                         _sendtraffic = _sendtraffic + count
512                         for i = bufferqueuelen,1,-1 do
513                                 bufferqueue[ i ] = nil
514                         end
515                         --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
516                 else
517                         succ, err, count = false, "unexpected close", 0;
518                 end
519                 if succ then    -- sending succesful
520                         bufferqueuelen = 0
521                         bufferlen = 0
522                         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
523                         _writetimes[ handler ] = nil
524                         if drain then
525                                 drain(handler)
526                         end
527                         _ = needtls and handler:starttls(nil)
528                         _ = toclose and handler:force_close( )
529                         return true
530                 elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
531                         buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
532                         bufferqueue[ 1 ] = buffer        -- insert new buffer in queue
533                         bufferqueuelen = 1
534                         bufferlen = bufferlen - byte
535                         _writetimes[ handler ] = _currenttime
536                         return true
537                 else    -- connection was closed during sending or fatal error
538                         out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
539                         _ = handler and handler:force_close( err )
540                         return false
541                 end
542         end
543
544         -- Set the sslctx
545         local handshake;
546         function handler.set_sslctx(self, new_sslctx)
547                 sslctx = new_sslctx;
548                 local read, wrote
549                 handshake = coroutine_wrap( function( client ) -- create handshake coroutine
550                                 local err
551                                 for i = 1, _maxsslhandshake do
552                                         _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
553                                         _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
554                                         read, wrote = nil, nil
555                                         _, err = client:dohandshake( )
556                                         if not err then
557                                                 out_put( "server.lua: ssl handshake done" )
558                                                 handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
559                                                 handler.sendbuffer = _sendbuffer
560                                                 _ = status and status( handler, "ssl-handshake-complete" )
561                                                 if self.autostart_ssl and listeners.onconnect then
562                                                         listeners.onconnect(self);
563                                                         if bufferqueuelen ~= 0 then
564                                                                 _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
565                                                         end
566                                                 end
567                                                 _readlistlen = addsocket(_readlist, client, _readlistlen)
568                                                 return true
569                                         else
570                                                 if err == "wantwrite" then
571                                                         _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
572                                                         wrote = true
573                                                 elseif err == "wantread" then
574                                                         _readlistlen = addsocket(_readlist, client, _readlistlen)
575                                                         read = true
576                                                 else
577                                                         break;
578                                                 end
579                                                 err = nil;
580                                                 coroutine_yield( ) -- handshake not finished
581                                         end
582                                 end
583                                 err = "ssl handshake error: " .. ( err or "handshake too long" );
584                                 out_put( "server.lua: ", err );
585                                 _ = handler and handler:force_close(err)
586                                 return false, err -- handshake failed
587                         end
588                 )
589         end
590         if has_luasec then
591                 handler.starttls = function( self, _sslctx)
592                         if _sslctx then
593                                 handler:set_sslctx(_sslctx);
594                         end
595                         if bufferqueuelen > 0 then
596                                 out_put "server.lua: we need to do tls, but delaying until send buffer empty"
597                                 needtls = true
598                                 return
599                         end
600                         out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
601                         local oldsocket, err = socket
602                         socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
603                         if not socket then
604                                 out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
605                                 return nil, err -- fatal error
606                         end
607
608                         socket:settimeout( 0 )
609
610                         -- add the new socket to our system
611                         send = socket.send
612                         receive = socket.receive
613                         shutdown = id
614                         _socketlist[ socket ] = handler
615                         _readlistlen = addsocket(_readlist, socket, _readlistlen)
616
617                         -- remove traces of the old socket
618                         _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
619                         _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
620                         _socketlist[ oldsocket ] = nil
621
622                         handler.starttls = nil
623                         needtls = nil
624
625                         -- Secure now (if handshake fails connection will close)
626                         ssl = true
627
628                         handler.readbuffer = handshake
629                         handler.sendbuffer = handshake
630                         return handshake( socket ) -- do handshake
631                 end
632         end
633
634         handler.readbuffer = _readbuffer
635         handler.sendbuffer = _sendbuffer
636         send = socket.send
637         receive = socket.receive
638         shutdown = ( ssl and id ) or socket.shutdown
639
640         _socketlist[ socket ] = handler
641         _readlistlen = addsocket(_readlist, socket, _readlistlen)
642
643         if sslctx and has_luasec then
644                 out_put "server.lua: auto-starting ssl negotiation..."
645                 handler.autostart_ssl = true;
646                 local ok, err = handler:starttls(sslctx);
647                 if ok == false then
648                         return nil, nil, err
649                 end
650         end
651
652         return handler, socket
653 end
654
655 id = function( )
656 end
657
658 idfalse = function( )
659         return false
660 end
661
662 addsocket = function( list, socket, len )
663         if not list[ socket ] then
664                 len = len + 1
665                 list[ len ] = socket
666                 list[ socket ] = len
667         end
668         return len;
669 end
670
671 removesocket = function( list, socket, len )    -- this function removes sockets from a list ( copied from copas )
672         local pos = list[ socket ]
673         if pos then
674                 list[ socket ] = nil
675                 local last = list[ len ]
676                 list[ len ] = nil
677                 if last ~= socket then
678                         list[ last ] = pos
679                         list[ pos ] = last
680                 end
681                 return len - 1
682         end
683         return len
684 end
685
686 closesocket = function( socket )
687         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
688         _readlistlen = removesocket( _readlist, socket, _readlistlen )
689         _socketlist[ socket ] = nil
690         socket:close( )
691         --mem_free( )
692 end
693
694 local function link(sender, receiver, buffersize)
695         local sender_locked;
696         local _sendbuffer = receiver.sendbuffer;
697         function receiver.sendbuffer()
698                 _sendbuffer();
699                 if sender_locked and receiver.bufferlen() < buffersize then
700                         sender:lock_read(false); -- Unlock now
701                         sender_locked = nil;
702                 end
703         end
704
705         local _readbuffer = sender.readbuffer;
706         function sender.readbuffer()
707                 _readbuffer();
708                 if not sender_locked and receiver.bufferlen() >= buffersize then
709                         sender_locked = true;
710                         sender:lock_read(true);
711                 end
712         end
713         sender:set_mode("*a");
714 end
715
716 ----------------------------------// PUBLIC //--
717
718 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
719         addr = addr or "*"
720         local err
721         if type( listeners ) ~= "table" then
722                 err = "invalid listener table"
723         elseif type ( addr ) ~= "string" then
724                 err = "invalid address"
725         elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
726                 err = "invalid port"
727         elseif _server[ addr..":"..port ] then
728                 err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
729         elseif sslctx and not has_luasec then
730                 err = "luasec not found"
731         end
732         if err then
733                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
734                 return nil, err
735         end
736         local server, err = socket_bind( addr, port, _tcpbacklog )
737         if err then
738                 out_error( "server.lua, [", addr, "]:", port, ": ", err )
739                 return nil, err
740         end
741         local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
742         if not handler then
743                 server:close( )
744                 return nil, err
745         end
746         server:settimeout( 0 )
747         _readlistlen = addsocket(_readlist, server, _readlistlen)
748         _server[ addr..":"..port ] = handler
749         _socketlist[ server ] = handler
750         out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
751         return handler
752 end
753
754 getserver = function ( addr, port )
755         return _server[ addr..":"..port ];
756 end
757
758 removeserver = function( addr, port )
759         local handler = _server[ addr..":"..port ]
760         if not handler then
761                 return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
762         end
763         handler:close( )
764         _server[ addr..":"..port ] = nil
765         return true
766 end
767
768 closeall = function( )
769         for _, handler in pairs( _socketlist ) do
770                 handler:close( )
771                 _socketlist[ _ ] = nil
772         end
773         _readlistlen = 0
774         _sendlistlen = 0
775         _timerlistlen = 0
776         _server = { }
777         _readlist = { }
778         _sendlist = { }
779         _timerlist = { }
780         _socketlist = { }
781         --mem_free( )
782 end
783
784 getsettings = function( )
785         return {
786                 select_timeout = _selecttimeout;
787                 tcp_backlog = _tcpbacklog;
788                 max_send_buffer_size = _maxsendlen;
789                 max_receive_buffer_size = _maxreadlen;
790                 select_idle_check_interval = _checkinterval;
791                 send_timeout = _sendtimeout;
792                 read_timeout = _readtimeout;
793                 max_connections = _maxselectlen;
794                 max_ssl_handshake_roundtrips = _maxsslhandshake;
795                 highest_allowed_fd = _maxfd;
796         }
797 end
798
799 changesettings = function( new )
800         if type( new ) ~= "table" then
801                 return nil, "invalid settings table"
802         end
803         _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
804         _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
805         _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
806         _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
807         _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
808         _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
809         _readtimeout = tonumber( new.read_timeout ) or _readtimeout
810         _maxselectlen = new.max_connections or _maxselectlen
811         _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
812         _maxfd = new.highest_allowed_fd or _maxfd
813         return true
814 end
815
816 addtimer = function( listener )
817         if type( listener ) ~= "function" then
818                 return nil, "invalid listener function"
819         end
820         _timerlistlen = _timerlistlen + 1
821         _timerlist[ _timerlistlen ] = listener
822         return true
823 end
824
825 local add_task do
826         local data = {};
827         local new_data = {};
828
829         function add_task(delay, callback)
830                 local current_time = luasocket_gettime();
831                 delay = delay + current_time;
832                 if delay >= current_time then
833                         table_insert(new_data, {delay, callback});
834                 else
835                         local r = callback(current_time);
836                         if r and type(r) == "number" then
837                                 return add_task(r, callback);
838                         end
839                 end
840         end
841
842         addtimer(function(current_time)
843                 if #new_data > 0 then
844                         for _, d in pairs(new_data) do
845                                 table_insert(data, d);
846                         end
847                         new_data = {};
848                 end
849
850                 local next_time = math_huge;
851                 for i, d in pairs(data) do
852                         local t, callback = d[1], d[2];
853                         if t <= current_time then
854                                 data[i] = nil;
855                                 local r = callback(current_time);
856                                 if type(r) == "number" then
857                                         add_task(r, callback);
858                                         next_time = math_min(next_time, r);
859                                 end
860                         else
861                                 next_time = math_min(next_time, t - current_time);
862                         end
863                 end
864                 return next_time;
865         end);
866 end
867
868 stats = function( )
869         return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
870 end
871
872 local quitting;
873
874 local function setquitting(quit)
875         quitting = not not quit;
876 end
877
878 loop = function(once) -- this is the main loop of the program
879         if quitting then return "quitting"; end
880         if once then quitting = "once"; end
881         _currenttime = luasocket_gettime( )
882         repeat
883                 -- Fire timers
884                 local next_timer_time = math_huge;
885                 for i = 1, _timerlistlen do
886                         local t = _timerlist[ i ]( _currenttime ) -- fire timers
887                         if t then next_timer_time = math_min(next_timer_time, t); end
888                 end
889
890                 local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
891                 for i, socket in ipairs( write ) do -- send data waiting in writequeues
892                         local handler = _socketlist[ socket ]
893                         if handler then
894                                 handler.sendbuffer( )
895                         else
896                                 closesocket( socket )
897                                 out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen
898                         end
899                 end
900                 for i, socket in ipairs( read ) do -- receive data
901                         local handler = _socketlist[ socket ]
902                         if handler then
903                                 handler.readbuffer( )
904                         else
905                                 closesocket( socket )
906                                 out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
907                         end
908                 end
909                 for handler, err in pairs( _closelist ) do
910                         handler.disconnect( )( handler, err )
911                         handler:force_close()    -- forced disconnect
912                         _closelist[ handler ] = nil;
913                 end
914                 _currenttime = luasocket_gettime( )
915
916                 -- Check for socket timeouts
917                 if _currenttime - _starttime > _checkinterval then
918                         _starttime = _currenttime
919                         for handler, timestamp in pairs( _writetimes ) do
920                                 if _currenttime - timestamp > _sendtimeout then
921                                         handler.disconnect( )( handler, "send timeout" )
922                                         handler:force_close()    -- forced disconnect
923                                 end
924                         end
925                         for handler, timestamp in pairs( _readtimes ) do
926                                 if _currenttime - timestamp > _readtimeout then
927                                         if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
928                                                 handler.disconnect( )( handler, "read timeout" )
929                                                 handler:close( )        -- forced disconnect?
930                                         else
931                                                 _readtimes[ handler ] = _currenttime -- reset timer
932                                         end
933                                 end
934                         end
935                 end
936         until quitting;
937         if once and quitting == "once" then quitting = nil; return; end
938         closeall();
939         return "quitting"
940 end
941
942 local function step()
943         return loop(true);
944 end
945
946 local function get_backend()
947         return "select";
948 end
949
950 --// EXPERIMENTAL //--
951
952 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
953         local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
954         if not handler then return nil, err end
955         _socketlist[ socket ] = handler
956         if not sslctx then
957                 _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
958                 if listeners.onconnect then
959                         -- When socket is writeable, call onconnect
960                         local _sendbuffer = handler.sendbuffer;
961                         handler.sendbuffer = function ()
962                                 handler.sendbuffer = _sendbuffer;
963                                 listeners.onconnect(handler);
964                                 return _sendbuffer(); -- Send any queued outgoing data
965                         end
966                 end
967         end
968         return handler, socket
969 end
970
971 local addclient = function( address, port, listeners, pattern, sslctx, typ )
972         local err
973         if type( listeners ) ~= "table" then
974                 err = "invalid listener table"
975         elseif type ( address ) ~= "string" then
976                 err = "invalid address"
977         elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
978                 err = "invalid port"
979         elseif sslctx and not has_luasec then
980                 err = "luasec not found"
981         end
982         if getaddrinfo and not typ then
983                 local addrinfo, err = getaddrinfo(address)
984                 if not addrinfo then return nil, err end
985                 if addrinfo[1] and addrinfo[1].family == "inet6" then
986                         typ = "tcp6"
987                 end
988         end
989         local create = luasocket[typ or "tcp"]
990         if type( create ) ~= "function"  then
991                 err = "invalid socket type"
992         end
993
994         if err then
995                 out_error( "server.lua, addclient: ", err )
996                 return nil, err
997         end
998
999         local client, err = create( )
1000         if err then
1001                 return nil, err
1002         end
1003         client:settimeout( 0 )
1004         local ok, err = client:connect( address, port )
1005         if ok or err == "timeout" or err == "Operation already in progress" then
1006                 return wrapclient( client, address, port, listeners, pattern, sslctx )
1007         else
1008                 return nil, err
1009         end
1010 end
1011
1012 ----------------------------------// BEGIN //--
1013
1014 use "setmetatable" ( _socketlist, { __mode = "k" } )
1015 use "setmetatable" ( _readtimes, { __mode = "k" } )
1016 use "setmetatable" ( _writetimes, { __mode = "k" } )
1017
1018 _starttime = luasocket_gettime( )
1019
1020 local function setlogger(new_logger)
1021         local old_logger = log;
1022         if new_logger then
1023                 log = new_logger;
1024         end
1025         return old_logger;
1026 end
1027
1028 ----------------------------------// PUBLIC INTERFACE //--
1029
1030 return {
1031         _addtimer = addtimer,
1032         add_task = add_task;
1033
1034         addclient = addclient,
1035         wrapclient = wrapclient,
1036
1037         loop = loop,
1038         link = link,
1039         step = step,
1040         stats = stats,
1041         closeall = closeall,
1042         addserver = addserver,
1043         getserver = getserver,
1044         setlogger = setlogger,
1045         getsettings = getsettings,
1046         setquitting = setquitting,
1047         removeserver = removeserver,
1048         get_backend = get_backend,
1049         changesettings = changesettings,
1050 }