Merge 0.10->trunk
[prosody.git] / net / server_select.lua
index 486e953b6cb0269f3c1eb94b762fdc5be118e912..37d57d29ae4afd5ffeab9960c5686d60f3c1700d 100644 (file)
@@ -31,32 +31,31 @@ local tostring = use "tostring"
 
 --// lua libs //--
 
-local os = use "os"
 local table = use "table"
 local string = use "string"
 local coroutine = use "coroutine"
 
 --// lua lib methods //--
 
-local os_difftime = os.difftime
 local math_min = math.min
 local math_huge = math.huge
 local table_concat = table.concat
+local table_insert = table.insert
 local string_sub = string.sub
 local coroutine_wrap = coroutine.wrap
 local coroutine_yield = coroutine.yield
 
 --// extern libs //--
 
-local luasec = use "ssl"
+local has_luasec, luasec = pcall ( require , "ssl" )
 local luasocket = use "socket" or require "socket"
 local luasocket_gettime = luasocket.gettime
+local getaddrinfo = luasocket.dns.getaddrinfo
 
 --// extern lib methods //--
 
-local ssl_wrap = ( luasec and luasec.wrap )
+local ssl_wrap = ( has_luasec and luasec.wrap )
 local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
 local socket_select = luasocket.select
 
 --// functions //--
@@ -88,6 +87,7 @@ local _socketlist
 local _closelist
 local _readtimes
 local _writetimes
+local _fullservers
 
 --// simple data types //--
 
@@ -100,8 +100,8 @@ local _sendtraffic
 local _readtraffic
 
 local _selecttimeout
-local _sleeptime
 local _tcpbacklog
+local _accepretry
 
 local _starttime
 local _currenttime
@@ -113,8 +113,6 @@ local _checkinterval
 local _sendtimeout
 local _readtimeout
 
-local _timer
-
 local _maxselectlen
 local _maxfd
 
@@ -130,6 +128,7 @@ _socketlist = { } -- key = socket, value = wrapped socket (handlers)
 _readtimes = { } -- key = handler, value = timestamp of last data reading
 _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
 _closelist = { } -- handlers to close
+_fullservers = { } -- servers in a paused state while there are too many clients
 
 _readlistlen = 0 -- length of readlist
 _sendlistlen = 0 -- length of sendlist
@@ -139,8 +138,8 @@ _sendtraffic = 0 -- some stats
 _readtraffic = 0
 
 _selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
 _tcpbacklog = 128 -- some kind of hint to the OS
+_accepretry = 10 -- seconds to wait until the next attempt of a full server to accept
 
 _maxsendlen = 51000 * 1024 -- max len of send buffer
 _maxreadlen = 25000 * 1024 -- max len of read buffer
@@ -209,6 +208,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
                                socket = nil;
                        end
                        handler.paused = true;
+                       out_put("server.lua: server [", ip, "]:", serverport, " paused")
                end
        end
        handler.resume = function( )
@@ -219,7 +219,9 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
                        end
                        _readlistlen = addsocket(_readlist, socket, _readlistlen)
                        _socketlist[ socket ] = handler
+                       _fullservers[ handler ] = nil
                        handler.paused = false;
+                       out_put("server.lua: server [", ip, "]:", serverport, " resumed")
                end
        end
        handler.ip = function( )
@@ -234,6 +236,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
        handler.readbuffer = function( )
                if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
                        handler.pause( )
+                       _fullservers[ handler ] = _currenttime
                        out_put( "server.lua: refused new client connection: server full" )
                        return false
                end
@@ -252,6 +255,8 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
                        return;
                elseif err then -- maybe timeout or something else
                        out_put( "server.lua: error with new client connection: ", tostring(err) )
+                       handler.pause( )
+                       _fullservers[ handler ] = _currenttime
                        return false
                end
        end
@@ -264,6 +269,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
                socket:close( ) -- Should we send some kind of error here?
                if server then
+                       _fullservers[ server ] = _currenttime
                        server.pause( )
                end
                return nil, nil, "fd-too-large"
@@ -291,7 +297,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local bufferqueuelen = 0        -- end of buffer array
 
        local toclose
-       local fatalerror
        local needtls
 
        local bufferlen = 0
@@ -397,6 +402,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                out_put "server.lua: closed client handler and removed socket from list"
                return true
        end
+       handler.server = function ( )
+               return server
+       end
        handler.ip = function( )
                return ip
        end
@@ -503,7 +511,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        return dispatch( handler, buffer, err )
                else    -- connections was closed or fatal error
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
-                       fatalerror = true
                        _ = handler and handler:force_close( err )
                        return false
                end
@@ -543,7 +550,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        return true
                else    -- connection was closed during sending or fatal error
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
-                       fatalerror = true
                        _ = handler and handler:force_close( err )
                        return false
                end
@@ -588,13 +594,14 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                coroutine_yield( ) -- handshake not finished
                                        end
                                end
-                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
-                               _ = handler and handler:force_close("ssl handshake failed")
+                               err = "ssl handshake error: " .. ( err or "handshake too long" );
+                               out_put( "server.lua: ", err );
+                               _ = handler and handler:force_close(err)
                                return false, err -- handshake failed
                        end
                )
        end
-       if luasec then
+       if has_luasec then
                handler.starttls = function( self, _sslctx)
                        if _sslctx then
                                handler:set_sslctx(_sslctx);
@@ -647,7 +654,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-       if sslctx and luasec then
+       if sslctx and has_luasec then
                out_put "server.lua: auto-starting ssl negotiation..."
                handler.autostart_ssl = true;
                local ok, err = handler:starttls(sslctx);
@@ -723,22 +730,23 @@ end
 ----------------------------------// PUBLIC //--
 
 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+       addr = addr or "*"
        local err
        if type( listeners ) ~= "table" then
                err = "invalid listener table"
-       end
-       if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+       elseif type ( addr ) ~= "string" then
+               err = "invalid address"
+       elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
        elseif _server[ addr..":"..port ] then
                err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
-       elseif sslctx and not luasec then
+       elseif sslctx and not has_luasec then
                err = "luasec not found"
        end
        if err then
                out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
-       addr = addr or "*"
        local server, err = socket_bind( addr, port, _tcpbacklog )
        if err then
                out_error( "server.lua, [", addr, "]:", port, ": ", err )
@@ -790,7 +798,6 @@ end
 getsettings = function( )
        return {
                select_timeout = _selecttimeout;
-               select_sleep_time = _sleeptime;
                tcp_backlog = _tcpbacklog;
                max_send_buffer_size = _maxsendlen;
                max_receive_buffer_size = _maxreadlen;
@@ -800,6 +807,7 @@ getsettings = function( )
                max_connections = _maxselectlen;
                max_ssl_handshake_roundtrips = _maxsslhandshake;
                highest_allowed_fd = _maxfd;
+               accept_retry_interval = _accepretry;
        }
 end
 
@@ -808,13 +816,13 @@ changesettings = function( new )
                return nil, "invalid settings table"
        end
        _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
-       _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
        _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
        _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
        _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
        _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
        _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
        _readtimeout = tonumber( new.read_timeout ) or _readtimeout
+       _accepretry = tonumber( new.accept_retry_interval ) or _accepretry
        _maxselectlen = new.max_connections or _maxselectlen
        _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
        _maxfd = new.highest_allowed_fd or _maxfd
@@ -830,6 +838,49 @@ addtimer = function( listener )
        return true
 end
 
+local add_task do
+       local data = {};
+       local new_data = {};
+
+       function add_task(delay, callback)
+               local current_time = luasocket_gettime();
+               delay = delay + current_time;
+               if delay >= current_time then
+                       table_insert(new_data, {delay, callback});
+               else
+                       local r = callback(current_time);
+                       if r and type(r) == "number" then
+                               return add_task(r, callback);
+                       end
+               end
+       end
+
+       addtimer(function(current_time)
+               if #new_data > 0 then
+                       for _, d in pairs(new_data) do
+                               table_insert(data, d);
+                       end
+                       new_data = {};
+               end
+
+               local next_time = math_huge;
+               for i, d in pairs(data) do
+                       local t, callback = d[1], d[2];
+                       if t <= current_time then
+                               data[i] = nil;
+                               local r = callback(current_time);
+                               if type(r) == "number" then
+                                       add_task(r, callback);
+                                       next_time = math_min(next_time, r);
+                               end
+                       else
+                               next_time = math_min(next_time, t - current_time);
+                       end
+               end
+               return next_time;
+       end);
+end
+
 stats = function( )
        return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
 end
@@ -843,8 +894,15 @@ end
 loop = function(once) -- this is the main loop of the program
        if quitting then return "quitting"; end
        if once then quitting = "once"; end
-       local next_timer_time = math_huge;
+       _currenttime = luasocket_gettime( )
        repeat
+               -- Fire timers
+       local next_timer_time = math_huge;
+               for i = 1, _timerlistlen do
+                       local t = _timerlist[ i ]( _currenttime ) -- fire timers
+                       if t then next_timer_time = math_min(next_timer_time, t); end
+               end
+
                local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
                for i, socket in ipairs( write ) do -- send data waiting in writequeues
                        local handler = _socketlist[ socket ]
@@ -872,17 +930,16 @@ loop = function(once) -- this is the main loop of the program
                _currenttime = luasocket_gettime( )
 
                -- Check for socket timeouts
-               local difftime = os_difftime( _currenttime - _starttime )
-               if difftime > _checkinterval then
+               if _currenttime - _starttime > _checkinterval then
                        _starttime = _currenttime
                        for handler, timestamp in pairs( _writetimes ) do
-                               if os_difftime( _currenttime - timestamp ) > _sendtimeout then
+                               if _currenttime - timestamp > _sendtimeout then
                                        handler.disconnect( )( handler, "send timeout" )
                                        handler:force_close()    -- forced disconnect
                                end
                        end
                        for handler, timestamp in pairs( _readtimes ) do
-                               if os_difftime( _currenttime - timestamp ) > _readtimeout then
+                               if _currenttime - timestamp > _readtimeout then
                                        if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
                                                handler.disconnect( )( handler, "read timeout" )
                                                handler:close( )        -- forced disconnect?
@@ -893,22 +950,15 @@ loop = function(once) -- this is the main loop of the program
                        end
                end
 
-               -- Fire timers
-               if _currenttime - _timer >= math_min(next_timer_time, 1) then
-                       next_timer_time = math_huge;
-                       for i = 1, _timerlistlen do
-                               local t = _timerlist[ i ]( _currenttime ) -- fire timers
-                               if t then next_timer_time = math_min(next_timer_time, t); end
+               for server, paused_time in pairs( _fullservers ) do
+                       if _currenttime - paused_time > _accepretry then
+                               _fullservers[ server ] = nil;
+                               server.resume();
                        end
-                       _timer = _currenttime
-               else
-                       next_timer_time = next_timer_time - (_currenttime - _timer);
                end
-
-               -- wait some time (0 by default)
-               socket_sleep( _sleeptime )
        until quitting;
        if once and quitting == "once" then quitting = nil; return; end
+       closeall();
        return "quitting"
 end
 
@@ -941,29 +991,53 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
        return handler, socket
 end
 
-local addclient = function( address, port, listeners, pattern, sslctx )
-       local client, err = luasocket.tcp( )
+local addclient = function( address, port, listeners, pattern, sslctx, typ )
+       local err
+       if type( listeners ) ~= "table" then
+               err = "invalid listener table"
+       elseif type ( address ) ~= "string" then
+               err = "invalid address"
+       elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+               err = "invalid port"
+       elseif sslctx and not has_luasec then
+               err = "luasec not found"
+       end
+       if getaddrinfo and not typ then
+               local addrinfo, err = getaddrinfo(address)
+               if not addrinfo then return nil, err end
+               if addrinfo[1] and addrinfo[1].family == "inet6" then
+                       typ = "tcp6"
+               end
+       end
+       local create = luasocket[typ or "tcp"]
+       if type( create ) ~= "function"  then
+               err = "invalid socket type"
+       end
+
+       if err then
+               out_error( "server.lua, addclient: ", err )
+               return nil, err
+       end
+
+       local client, err = create( )
        if err then
                return nil, err
        end
        client:settimeout( 0 )
-       _, err = client:connect( address, port )
-       if err then -- try again
+       local ok, err = client:connect( address, port )
+       if ok or err == "timeout" or err == "Operation already in progress" then
                return wrapclient( client, address, port, listeners, pattern, sslctx )
        else
-               return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
+               return nil, err
        end
 end
 
---// EXPERIMENTAL //--
-
 ----------------------------------// BEGIN //--
 
 use "setmetatable" ( _socketlist, { __mode = "k" } )
 use "setmetatable" ( _readtimes, { __mode = "k" } )
 use "setmetatable" ( _writetimes, { __mode = "k" } )
 
-_timer = luasocket_gettime( )
 _starttime = luasocket_gettime( )
 
 local function setlogger(new_logger)
@@ -978,6 +1052,7 @@ end
 
 return {
        _addtimer = addtimer,
+       add_task = add_task;
 
        addclient = addclient,
        wrapclient = wrapclient,