X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=net%2Fserver_select.lua;h=86c9daef24ae7e8fd63a130a28177b95c9ddde05;hb=aed7114ec03098b6a7d4eaafbf97a98d7e1f054b;hp=41f2b9fa50a2da2e70f8734b56c978e7b576abbb;hpb=627bc19a8ff0afb9c4e3af170ab0b0f9a66daf17;p=prosody.git diff --git a/net/server_select.lua b/net/server_select.lua index 41f2b9fa..86c9daef 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -10,16 +10,10 @@ local use = function( what ) return _G[ what ] end -local clean = function( tbl ) - for i, k in pairs( tbl ) do - tbl[ i ] = nil - end -end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end local out_error = function (...) return log("warn", table_concat{...}); end -local mem_free = collectgarbage ----------------------------------// DECLARATION //-- @@ -34,7 +28,6 @@ local pairs = use "pairs" local ipairs = use "ipairs" local tonumber = use "tonumber" local tostring = use "tostring" -local collectgarbage = use "collectgarbage" --// lua libs //-- @@ -49,8 +42,6 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat -local table_remove = table.remove -local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -67,7 +58,6 @@ local ssl_wrap = ( luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select -local ssl_newcontext = ( luasec and luasec.newcontext ) --// functions //-- @@ -75,17 +65,16 @@ local id local loop local stats local idfalse -local addtimer local closeall local addsocket local addserver +local addtimer local getserver local wrapserver local getsettings local closesocket local removesocket local removeserver -local changetimeout local wrapconnection local changesettings @@ -123,11 +112,10 @@ local _checkinterval local _sendtimeout local _readtimeout -local _cleanqueue - local _timer -local _maxclientsperserver +local _maxselectlen +local _maxfd local _maxsslhandshake @@ -159,21 +147,24 @@ _checkinterval = 1200000 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs -_cleanqueue = false -- clean bufferqueue after using - -_maxclientsperserver = 1000 +_maxfd = luasocket._SETSIZE or 1024 -- We should ignore this on Windows. Perhaps by simply setting it to math.huge or something. +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd - maxconnections = maxconnections or _maxclientsperserver + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end local connections = 0 - local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect local accept = socket.accept @@ -191,23 +182,43 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end handler.remove = function( ) connections = connections - 1 - end - handler.close = function( ) - for _, handler in pairs( _socketlist ) do - if handler.serverport == serverport then - handler.disconnect( handler, "server closed" ) - handler:close( true ) - end + if handler then + handler.resume( ) end + end + handler.close = function() socket:close( ) _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; _socketlist[ socket ] = nil handler = nil socket = nil --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end handler.ip = function( ) return ip end @@ -218,21 +229,24 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return socket end handler.readbuffer = function( ) - if connections > maxconnections then + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false end local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - client:settimeout( 0 ) local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - return dispatch( handler ) + if dispatch then + return dispatch( handler ); + end + return; elseif err then -- maybe timeout or something else out_put( "server.lua: error with new client connection: ", tostring(err) ) return false @@ -243,6 +257,12 @@ end wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + server.pause( ) + return nil, nil, "fd-too-large" + end socket:settimeout( 0 ) --// local import of socket methods //-- @@ -317,22 +337,25 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end return false, "setoption not implemented"; end - handler.close = function( self, forced ) + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) if not handler then return true; end _readlistlen = removesocket( _readlist, socket, _readlistlen ) _readtimes[ handler ] = nil if bufferqueuelen ~= 0 then - if not ( forced or fatalerror ) then - handler.sendbuffer( ) - if bufferqueuelen ~= 0 then -- try again... - if handler then - handler.write = nil -- ... but no further writing allowed - end - toclose = true - return false + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed end - else - send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send + toclose = true + return false end end if socket then @@ -347,7 +370,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if handler then _writetimes[ handler ] = nil _closelist[ handler ] = nil + local _handler = handler; handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end end if server then server.remove( ) @@ -365,7 +393,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return clientport end local write = function( self, data ) - bufferlen = bufferlen + string_len( data ) + bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle handler.write = idfalse -- dont write anymore @@ -447,10 +475,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" - local len = string_len( buffer ) + local len = #buffer if len > maxreadlen then - disconnect( handler, "receive buffer exceeded" ) - handler:close( true ) + handler:close( "receive buffer exceeded" ) return false end local count = len * STAT_UNIT @@ -462,24 +489,24 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end local _sendbuffer = function( ) -- this function sends data local succ, err, byte, buffer, count; - local count; if socket then buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) succ, err, byte = send( socket, buffer, 1, bufferlen ) count = ( succ or byte or 0 ) * STAT_UNIT sendtraffic = sendtraffic + count _sendtraffic = _sendtraffic + count - _ = _cleanqueue and clean( bufferqueue ) + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else - succ, err, count = false, "closed", 0; + succ, err, count = false, "unexpected close", 0; end if succ then -- sending succesful bufferqueuelen = 0 @@ -490,7 +517,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport drain(handler) end _ = needtls and handler:starttls(nil) - _ = toclose and handler:close( ) + _ = toclose and handler:force_close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer @@ -502,8 +529,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end @@ -525,9 +551,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) - if self.autostart_ssl and listeners.onconnect then - listeners.onconnect(self); - end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else @@ -545,9 +568,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end end out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) - disconnect( handler, "ssl handshake failed" ) - _ = handler and handler:close( true ) -- forced disconnect - return false -- handshake failed + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed end ) end @@ -591,27 +613,28 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = handshake handler.sendbuffer = handshake - handshake( socket ) -- do handshake + return handshake( socket ) -- do handshake end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer - - if sslctx then - out_put "server.lua: auto-starting ssl negotiation..." - handler.autostart_ssl = true; - handler:starttls(sslctx); - end - - else - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer send = socket.send receive = socket.receive shutdown = ( ssl and id ) or socket.shutdown _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + return handler, socket end @@ -699,7 +722,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -743,23 +766,34 @@ closeall = function( ) end getsettings = function( ) - return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } end changesettings = function( new ) if type( new ) ~= "table" then return nil, "invalid settings table" end - _selecttimeout = tonumber( new.timeout ) or _selecttimeout - _sleeptime = tonumber( new.sleeptime ) or _sleeptime - _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen - _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen - _checkinterval = tonumber( new.checkinterval ) or _checkinterval - _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout - _readtimeout = tonumber( new.readtimeout ) or _readtimeout - _cleanqueue = new.cleanqueue - _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver - _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd return true end @@ -778,7 +812,7 @@ end local quitting; -setquitting = function (quit) +local function setquitting(quit) quitting = not not quit; end @@ -808,10 +842,30 @@ loop = function(once) -- this is the main loop of the program end for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) - handler:close( true ) -- forced disconnect + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; end - clean( _closelist ) _currenttime = luasocket_gettime( ) + + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + --_writetimes[ handler ] = nil + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + --_readtimes[ handler ] = nil + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + if _currenttime - _timer >= math_min(next_timer_time, 1) then next_timer_time = math_huge; for i = 1, _timerlistlen do @@ -829,7 +883,7 @@ loop = function(once) -- this is the main loop of the program return "quitting" end -step = function () +local function step() return loop(true); end @@ -840,7 +894,8 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) @@ -886,28 +941,6 @@ use "setmetatable" ( _writetimes, { __mode = "k" } ) _timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) -addtimer( function( ) - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then - _starttime = _currenttime - for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil - handler.disconnect( )( handler, "send timeout" ) - handler:close( true ) -- forced disconnect - end - end - for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? - end - end - end - end -) - local function setlogger(new_logger) local old_logger = log; if new_logger then @@ -919,6 +952,7 @@ end ----------------------------------// PUBLIC INTERFACE //-- return { + _addtimer = addtimer, addclient = addclient, wrapclient = wrapclient, @@ -928,7 +962,6 @@ return { step = step, stats = stats, closeall = closeall, - addtimer = addtimer, addserver = addserver, getserver = getserver, setlogger = setlogger,