MUC: Access prosody.hosts instead of the old global hosts
[prosody.git] / net / server_select.lua
index 025f145d6567e70318e20a00773272e482a4f96c..63a94b7eb8ea8230c29943f8bba891ab1c8c1793 100644 (file)
 local use = function( what )
        return _G[ what ]
 end
-local clean = function( tbl )
-       for i, k in pairs( tbl ) do
-               tbl[ i ] = nil
-       end
-end
 
 local log, table_concat = require ("util.logger").init("socket"), table.concat;
 local out_put = function (...) return log("debug", table_concat{...}); end
 local out_error = function (...) return log("warn", table_concat{...}); end
-local mem_free = collectgarbage
 
 ----------------------------------// DECLARATION //--
 
@@ -34,7 +28,6 @@ local pairs = use "pairs"
 local ipairs = use "ipairs"
 local tonumber = use "tonumber"
 local tostring = use "tostring"
-local collectgarbage = use "collectgarbage"
 
 --// lua libs //--
 
@@ -49,8 +42,6 @@ local os_difftime = os.difftime
 local math_min = math.min
 local math_huge = math.huge
 local table_concat = table.concat
-local table_remove = table.remove
-local string_len = string.len
 local string_sub = string.sub
 local coroutine_wrap = coroutine.wrap
 local coroutine_yield = coroutine.yield
@@ -67,7 +58,6 @@ local ssl_wrap = ( luasec and luasec.wrap )
 local socket_bind = luasocket.bind
 local socket_sleep = luasocket.sleep
 local socket_select = luasocket.select
-local ssl_newcontext = ( luasec and luasec.newcontext )
 
 --// functions //--
 
@@ -78,13 +68,13 @@ local idfalse
 local closeall
 local addsocket
 local addserver
+local addtimer
 local getserver
 local wrapserver
 local getsettings
 local closesocket
 local removesocket
 local removeserver
-local changetimeout
 local wrapconnection
 local changesettings
 
@@ -122,11 +112,10 @@ local _checkinterval
 local _sendtimeout
 local _readtimeout
 
-local _cleanqueue
-
 local _timer
 
-local _maxclientsperserver
+local _maxselectlen
+local _maxfd
 
 local _maxsslhandshake
 
@@ -158,17 +147,20 @@ _checkinterval = 1200000 -- interval in secs to check idle clients
 _sendtimeout = 60000 -- allowed send idle time in secs
 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
 
-_cleanqueue = false -- clean bufferqueue after using
-
-_maxclientsperserver = 1000
+_maxfd = luasocket._SETSIZE or 1024 -- We should ignore this on Windows.  Perhaps by simply setting it to math.huge or something.
+_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
 
 _maxsslhandshake = 30 -- max handshake round-trips
 
 ----------------------------------// PRIVATE //--
 
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
 
-       maxconnections = maxconnections or _maxclientsperserver
+       if socket:getfd() >= _maxfd then
+               out_error("server.lua: Disallowed FD number: "..socket:getfd())
+               socket:close()
+               return nil, "fd-too-large"
+       end
 
        local connections = 0
 
@@ -190,6 +182,9 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
        end
        handler.remove = function( )
                connections = connections - 1
+               if handler then
+                       handler.resume( )
+               end
        end
        handler.close = function()
                socket:close( )
@@ -202,6 +197,28 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                --mem_free( )
                out_put "server.lua: closed server handler and removed sockets from list"
        end
+       handler.pause = function( hard )
+               if not handler.paused then
+                       _readlistlen = removesocket( _readlist, socket, _readlistlen )
+                       if hard then
+                               _socketlist[ socket ] = nil
+                               socket:close( )
+                               socket = nil;
+                       end
+                       handler.paused = true;
+               end
+       end
+       handler.resume = function( )
+               if handler.paused then
+                       if not socket then
+                               socket = socket_bind( ip, serverport );
+                               socket:settimeout( 0 )
+                       end
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       _socketlist[ socket ] = handler
+                       handler.paused = false;
+               end
+       end
        handler.ip = function( )
                return ip
        end
@@ -212,14 +229,14 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                return socket
        end
        handler.readbuffer = function( )
-               if connections > maxconnections then
+               if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
+                       handler.pause( )
                        out_put( "server.lua: refused new client connection: server full" )
                        return false
                end
                local client, err = accept( socket )    -- try to accept
                if client then
                        local ip, clientport = client:getpeername( )
-                       client:settimeout( 0 )
                        local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
                        if err then -- error while wrapping ssl socket
                                return false
@@ -240,6 +257,12 @@ end
 
 wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
 
+       if socket:getfd() >= _maxfd then
+               out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
+               socket:close( ) -- Should we send some kind of error here?
+               server.pause( )
+               return nil, nil, "fd-too-large"
+       end
        socket:settimeout( 0 )
 
        --// local import of socket methods //--
@@ -317,9 +340,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.force_close = function ( self, err )
                if bufferqueuelen ~= 0 then
                        out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
-                       for i = bufferqueuelen, 1, -1 do
-                               bufferqueue[i] = nil;
-                       end
                        bufferqueuelen = 0;
                end
                return self:close(err);
@@ -373,7 +393,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                return clientport
        end
        local write = function( self, data )
-               bufferlen = bufferlen + string_len( data )
+               bufferlen = bufferlen + #data
                if bufferlen > maxsendlen then
                        _closelist[ handler ] = "send buffer exceeded"   -- cannot close the client at the moment, have to wait to the end of the cycle
                        handler.write = idfalse -- dont write anymore
@@ -455,7 +475,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
                if not err or (err == "wantread" or err == "timeout") then -- received something
                        local buffer = buffer or part or ""
-                       local len = string_len( buffer )
+                       local len = #buffer
                        if len > maxreadlen then
                                handler:close( "receive buffer exceeded" )
                                return false
@@ -481,7 +501,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        count = ( succ or byte or 0 ) * STAT_UNIT
                        sendtraffic = sendtraffic + count
                        _sendtraffic = _sendtraffic + count
-                       _ = _cleanqueue and clean( bufferqueue )
+                       for i = bufferqueuelen,1,-1 do
+                               bufferqueue[ i ] = nil
+                       end
                        --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
                else
                        succ, err, count = false, "unexpected close", 0;
@@ -703,7 +725,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
                out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
-       local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket
+       local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
        if not handler then
                server:close( )
                return nil, err
@@ -747,23 +769,23 @@ closeall = function( )
 end
 
 getsettings = function( )
-       return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake
+       return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, nil, _maxselectlen, _maxsslhandshake, _maxfd
 end
 
 changesettings = function( new )
        if type( new ) ~= "table" then
                return nil, "invalid settings table"
        end
-       _selecttimeout = tonumber( new.timeout ) or _selecttimeout
-       _sleeptime = tonumber( new.sleeptime ) or _sleeptime
-       _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen
-       _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen
-       _checkinterval = tonumber( new.checkinterval ) or _checkinterval
-       _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout
-       _readtimeout = tonumber( new.readtimeout ) or _readtimeout
-       _cleanqueue = new.cleanqueue
-       _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
-       _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake
+       _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
+       _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
+       _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
+       _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
+       _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
+       _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
+       _readtimeout = tonumber( new.read_timeout ) or _readtimeout
+       _maxselectlen = new.max_connections or _maxselectlen
+       _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
+       _maxfd = new.highest_allowed_fd or _maxfd
        return true
 end
 
@@ -813,8 +835,8 @@ loop = function(once) -- this is the main loop of the program
                for handler, err in pairs( _closelist ) do
                        handler.disconnect( )( handler, err )
                        handler:force_close()    -- forced disconnect
+                       _closelist[ handler ] = nil;
                end
-               clean( _closelist )
                _currenttime = luasocket_gettime( )
                if _currenttime - _timer >= math_min(next_timer_time, 1) then
                        next_timer_time = math_huge;
@@ -844,7 +866,8 @@ end
 --// EXPERIMENTAL //--
 
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
-       local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+       local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+       if not handler then return nil, err end
        _socketlist[ socket ] = handler
        if not sslctx then
                _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)