Merge Tobias->trunk
[prosody.git] / net / server_select.lua
index 49cbe25ddc12a86b519f7ddb4bb41ba5fd7b8e18..0310a9915ead1febc7753d608b1a933daf22b39e 100644 (file)
@@ -2,7 +2,7 @@
 -- server.lua by blastbeat of the luadch project
 -- Re-used here under the MIT/X Consortium License
 -- 
--- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain
+-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
 --
 
 -- // wrapping luadch stuff // --
@@ -32,6 +32,7 @@ local STAT_UNIT = 1 -- byte
 local type = use "type"
 local pairs = use "pairs"
 local ipairs = use "ipairs"
+local tonumber = use "tonumber"
 local tostring = use "tostring"
 local collectgarbage = use "collectgarbage"
 
@@ -44,8 +45,9 @@ local coroutine = use "coroutine"
 
 --// lua lib methods //--
 
-local os_time = os.time
 local os_difftime = os.difftime
+local math_min = math.min
+local math_huge = math.huge
 local table_concat = table.concat
 local table_remove = table.remove
 local string_len = string.len
@@ -57,6 +59,7 @@ local coroutine_yield = coroutine.yield
 
 local luasec = use "ssl"
 local luasocket = use "socket" or require "socket"
+local luasocket_gettime = luasocket.gettime
 
 --// extern lib methods //--
 
@@ -74,6 +77,7 @@ local stats
 local idfalse
 local addtimer
 local closeall
+local addsocket
 local addserver
 local getserver
 local wrapserver
@@ -125,6 +129,8 @@ local _timer
 
 local _maxclientsperserver
 
+local _maxsslhandshake
+
 ----------------------------------// DEFINITION //--
 
 _server = { } -- key = port, value = table; list of listening servers
@@ -143,7 +149,7 @@ _timerlistlen = 0 -- lenght of timerlist
 _sendtraffic = 0 -- some stats
 _readtraffic = 0
 
-_selecttimeout = 1 -- timeout of socket.select
+_selecttimeout = 3600 -- timeout of socket.select
 _sleeptime = 0 -- time to wait at the end of every loop
 
 _maxsendlen = 51000 * 1024 -- max len of send buffer
@@ -167,7 +173,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
 
        local connections = 0
 
-       local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect
+       local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect
 
        local accept = socket.accept
 
@@ -252,6 +258,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local dispatch = listeners.onincoming
        local status = listeners.onstatus
        local disconnect = listeners.ondisconnect
+       local drain = listeners.ondrain
 
        local bufferqueue = { } -- buffer array
        local bufferqueuelen = 0        -- end of buffer array
@@ -284,6 +291,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                dispatch = listeners.onincoming
                disconnect = listeners.ondisconnect
                status = listeners.onstatus
+               drain = listeners.ondrain
        end
        handler.getstats = function( )
                return readtraffic, sendtraffic
@@ -341,9 +349,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        _closelist[ handler ] = nil
                        handler = nil
                end
-       if server then
-               server.remove( )
-       end
+               if server then
+                       server.remove( )
+               end
                out_put "server.lua: closed client handler and removed socket from list"
                return true
        end
@@ -379,7 +387,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.socket = function( self )
                return socket
        end
-       handler.pattern = function( self, new )
+       handler.set_mode = function( self, new )
                pattern = new or pattern
                return pattern
        end
@@ -392,6 +400,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                maxreadlen = readlen or maxreadlen
                return bufferlen, maxreadlen, maxsendlen
        end
+       --TODO: Deprecate
        handler.lock_read = function (self, switch)
                if switch == true then
                        local tmp = _readlistlen
@@ -409,6 +418,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                end
                return noread
        end
+       handler.pause = function (self)
+               return self:lock_read(true);
+       end
+       handler.resume = function (self)
+               return self:lock_read(false);
+       end
        handler.lock = function( self, switch )
                handler.lock_read (switch)
                if switch == true then
@@ -430,7 +445,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        end
        local _readbuffer = function( ) -- this function reads data
                local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
-               if not err or (err == "wantread" or err == "timeout") or (part and string_len(part) > 0) then -- received something
+               if not err or (err == "wantread" or err == "timeout") then -- received something
                        local buffer = buffer or part or ""
                        local len = string_len( buffer )
                        if len > maxreadlen then
@@ -470,8 +485,11 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        bufferqueuelen = 0
                        bufferlen = 0
                        _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
-                       _ = needtls and handler:starttls(nil, true)
                        _writetimes[ handler ] = nil
+                       if drain then
+                               drain(handler)
+                       end
+                       _ = needtls and handler:starttls(nil)
                        _ = toclose and handler:close( )
                        return true
                elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
@@ -552,13 +570,13 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        end
                else
                        local sslctx;
-                       handler.starttls = function( self, _sslctx, now )
+                       handler.starttls = function( self, _sslctx)
                                if _sslctx then
                                        sslctx = _sslctx;
                                        handler:set_sslctx(sslctx);
                                end
-                               if not now then
-                                       out_put "server.lua: we need to do tls, but delaying until later"
+                               if bufferqueuelen > 0 then
+                                       out_put "server.lua: we need to do tls, but delaying until send buffer empty"
                                        needtls = true
                                        return
                                end
@@ -611,7 +629,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
 
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
-
        return handler, socket
 end
 
@@ -654,6 +671,27 @@ closesocket = function( socket )
        --mem_free( )
 end
 
+local function link(sender, receiver, buffersize)
+       local sender_locked;
+       local _sendbuffer = receiver.sendbuffer;
+       function receiver.sendbuffer()
+               _sendbuffer();
+               if sender_locked and receiver.bufferlen() < buffersize then
+                       sender:lock_read(false); -- Unlock now
+                       sender_locked = nil;
+               end
+       end
+       
+       local _readbuffer = sender.readbuffer;
+       function sender.readbuffer()
+               _readbuffer();
+               if not sender_locked and receiver.bufferlen() >= buffersize then
+                       sender_locked = true;
+                       sender:lock_read(true);
+               end
+       end
+end
+
 ----------------------------------// PUBLIC //--
 
 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
@@ -661,7 +699,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        if type( listeners ) ~= "table" then
                err = "invalid listener table"
        end
-       if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
+       if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
        elseif _server[ port ] then
                err = "listeners on port '" .. port .. "' already exist"
@@ -755,16 +793,18 @@ stats = function( )
        return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
 end
 
-local dontstop = true; -- thinking about tomorrow, ...
+local quitting;
 
 setquitting = function (quit)
-       dontstop = not quit;
-       return;
+       quitting = not not quit;
 end
 
-loop = function( ) -- this is the main loop of the program
-       while dontstop do
-               local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
+loop = function(once) -- this is the main loop of the program
+       if quitting then return "quitting"; end
+       if once then quitting = "once"; end
+       local next_timer_time = math_huge;
+       repeat
+               local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
                for i, socket in ipairs( write ) do -- send data waiting in writequeues
                        local handler = _socketlist[ socket ]
                        if handler then
@@ -788,19 +828,28 @@ loop = function( ) -- this is the main loop of the program
                        handler:close( true )    -- forced disconnect
                end
                clean( _closelist )
-               _currenttime = os_time( )
-               if os_difftime( _currenttime - _timer ) >= 1 then
+               _currenttime = luasocket_gettime( )
+               if _currenttime - _timer >= math_min(next_timer_time, 1) then
+                       next_timer_time = math_huge;
                        for i = 1, _timerlistlen do
-                               _timerlist[ i ]( _currenttime ) -- fire timers
+                               local t = _timerlist[ i ]( _currenttime ) -- fire timers
+                               if t then next_timer_time = math_min(next_timer_time, t); end
                        end
                        _timer = _currenttime
+               else
+                       next_timer_time = next_timer_time - (_currenttime - _timer);
                end
                socket_sleep( _sleeptime ) -- wait some time
                --collectgarbage( )
-       end
+       until quitting;
+       if once and quitting == "once" then quitting = nil; return; end
        return "quitting"
 end
 
+step = function ()
+       return loop(true);
+end
+
 local function get_backend()
        return "select";
 end
@@ -811,6 +860,18 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
        local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
        _socketlist[ socket ] = handler
        _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+       if listeners.onconnect then
+               -- When socket is writeable, call onconnect
+               local _sendbuffer = handler.sendbuffer;
+               handler.sendbuffer = function ()
+                       handler.sendbuffer = _sendbuffer;
+                       listeners.onconnect(handler);
+                       -- If there was data with the incoming packet, handle it now.
+                       if #handler:bufferqueue() > 0 then
+                               return _sendbuffer();
+                       end
+               end
+       end
        return handler, socket
 end
 
@@ -836,8 +897,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } )
 use "setmetatable" ( _readtimes, { __mode = "k" } )
 use "setmetatable" ( _writetimes, { __mode = "k" } )
 
-_timer = os_time( )
-_starttime = os_time( )
+_timer = luasocket_gettime( )
+_starttime = luasocket_gettime( )
 
 addtimer( function( )
                local difftime = os_difftime( _currenttime - _starttime )
@@ -877,6 +938,7 @@ return {
        wrapclient = wrapclient,
        
        loop = loop,
+       link = link,
        stats = stats,
        closeall = closeall,
        addtimer = addtimer,