X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;ds=sidebyside;f=net%2Fserver_select.lua;h=122d774e19a9e756f1010cf5a0164785e2e6e61d;hb=d59d4cf7949da5214153f7cec67fdb00cd77040f;hp=51ae4e66a1e1ccbbf990ef6356c2cb9b58303dda;hpb=2394ef36042c71f29272ac84d5bb83467839636e;p=prosody.git diff --git a/net/server_select.lua b/net/server_select.lua index 51ae4e66..122d774e 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -19,7 +19,6 @@ end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end local out_error = function (...) return log("warn", table_concat{...}); end -local mem_free = collectgarbage ----------------------------------// DECLARATION //-- @@ -32,8 +31,8 @@ local STAT_UNIT = 1 -- byte local type = use "type" local pairs = use "pairs" local ipairs = use "ipairs" +local tonumber = use "tonumber" local tostring = use "tostring" -local collectgarbage = use "collectgarbage" --// lua libs //-- @@ -44,10 +43,10 @@ local coroutine = use "coroutine" --// lua lib methods //-- -local os_time = os.time local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge local table_concat = table.concat -local table_remove = table.remove local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap @@ -57,6 +56,7 @@ local coroutine_yield = coroutine.yield local luasec = use "ssl" local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime --// extern lib methods //-- @@ -64,7 +64,6 @@ local ssl_wrap = ( luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select -local ssl_newcontext = ( luasec and luasec.newcontext ) --// functions //-- @@ -72,16 +71,16 @@ local id local loop local stats local idfalse -local addtimer local closeall +local addsocket local addserver +local addtimer local getserver local wrapserver local getsettings local closesocket local removesocket local removeserver -local changetimeout local wrapconnection local changesettings @@ -125,6 +124,8 @@ local _timer local _maxclientsperserver +local _maxsslhandshake + ----------------------------------// DEFINITION //-- _server = { } -- key = port, value = table; list of listening servers @@ -167,7 +168,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco local connections = 0 - local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect local accept = socket.accept @@ -185,23 +186,40 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end handler.remove = function( ) connections = connections - 1 - end - handler.close = function( ) - for _, handler in pairs( _socketlist ) do - if handler.serverport == serverport then - handler.disconnect( handler, "server closed" ) - handler:close( true ) - end + if handler then + handler.resume( ) end + end + handler.close = function() socket:close( ) _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; _socketlist[ socket ] = nil handler = nil socket = nil --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end + handler.pause = function() + if not handler.paused then + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil + socket = nil; + handler.paused = true; + end + end + handler.resume = function() + if handler.paused then + socket = socket_bind( ip, serverport ); + socket:settimeout( 0 ) + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end handler.ip = function( ) return ip end @@ -213,6 +231,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end handler.readbuffer = function( ) if connections > maxconnections then + handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false end @@ -226,7 +245,10 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - return dispatch( handler ) + if dispatch then + return dispatch( handler ); + end + return; elseif err then -- maybe timeout or something else out_put( "server.lua: error with new client connection: ", tostring(err) ) return false @@ -311,22 +333,28 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end return false, "setoption not implemented"; end - handler.close = function( self, forced ) + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + for i = bufferqueuelen, 1, -1 do + bufferqueue[i] = nil; + end + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) if not handler then return true; end _readlistlen = removesocket( _readlist, socket, _readlistlen ) _readtimes[ handler ] = nil if bufferqueuelen ~= 0 then - if not ( forced or fatalerror ) then - handler.sendbuffer( ) - if bufferqueuelen ~= 0 then -- try again... - if handler then - handler.write = nil -- ... but no further writing allowed - end - toclose = true - return false + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed end - else - send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send + toclose = true + return false end end if socket then @@ -341,7 +369,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if handler then _writetimes[ handler ] = nil _closelist[ handler ] = nil + local _handler = handler; handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end end if server then server.remove( ) @@ -443,8 +476,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer = buffer or part or "" local len = string_len( buffer ) if len > maxreadlen then - disconnect( handler, "receive buffer exceeded" ) - handler:close( true ) + handler:close( "receive buffer exceeded" ) return false end local count = len * STAT_UNIT @@ -456,14 +488,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end local _sendbuffer = function( ) -- this function sends data local succ, err, byte, buffer, count; - local count; if socket then buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) succ, err, byte = send( socket, buffer, 1, bufferlen ) @@ -473,7 +503,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _ = _cleanqueue and clean( bufferqueue ) --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else - succ, err, count = false, "closed", 0; + succ, err, count = false, "unexpected close", 0; end if succ then -- sending succesful bufferqueuelen = 0 @@ -484,7 +514,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport drain(handler) end _ = needtls and handler:starttls(nil) - _ = toclose and handler:close( ) + _ = toclose and handler:force_close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer @@ -496,8 +526,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end @@ -505,10 +534,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport -- Set the sslctx local handshake; function handler.set_sslctx(self, new_sslctx) - ssl = true sslctx = new_sslctx; - local wrote - local read + local read, wrote handshake = coroutine_wrap( function( client ) -- create handshake coroutine local err for i = 1, _maxsslhandshake do @@ -521,108 +548,93 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else - out_put( "server.lua: error during ssl handshake: ", tostring(err) ) - if err == "wantwrite" and not wrote then + if err == "wantwrite" then _sendlistlen = addsocket(_sendlist, client, _sendlistlen) wrote = true - elseif err == "wantread" and not read then + elseif err == "wantread" then _readlistlen = addsocket(_readlist, client, _readlistlen) read = true else break; end - --coroutine_yield( handler, nil, err ) -- handshake not finished - coroutine_yield( ) + err = nil; + coroutine_yield( ) -- handshake not finished end end - disconnect( handler, "ssl handshake failed" ) - _ = handler and handler:close( true ) -- forced disconnect - return false -- handshake failed + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed end ) end if luasec then - if sslctx then -- ssl? - handler:set_sslctx(sslctx); - out_put("server.lua: ", "starting ssl handshake") - local err - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); end - socket:settimeout( 0 ) - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket if not socket then - return nil, nil, "ssl handshake failed"; + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error end - else - local sslctx; - handler.starttls = function( self, _sslctx) - if _sslctx then - sslctx = _sslctx; - handler:set_sslctx(sslctx); - end - if bufferqueuelen > 0 then - out_put "server.lua: we need to do tls, but delaying until send buffer empty" - needtls = true - return - end - out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end - - socket:settimeout( 0 ) - - -- add the new socket to our system - - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) - -- remove traces of the old socket + socket:settimeout( 0 ) - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil - handler.starttls = nil - needtls = nil + handler.starttls = nil + needtls = nil - -- Secure now - ssl = true + -- Secure now (if handshake fails connection will close) + ssl = true - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake end - else - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer send = socket.send receive = socket.receive shutdown = ( ssl and id ) or socket.shutdown _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + return handler, socket end @@ -666,7 +678,6 @@ closesocket = function( socket ) end local function link(sender, receiver, buffersize) - sender:set_mode(buffersize); local sender_locked; local _sendbuffer = receiver.sendbuffer; function receiver.sendbuffer() @@ -696,19 +707,19 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function end if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" - elseif _server[ port ] then - err = "listeners on port '" .. port .. "' already exist" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" elseif sslctx and not luasec then err = "luasec not found" end if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end addr = addr or "*" local server, err = socket_bind( addr, port ) if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket @@ -718,23 +729,23 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function end server:settimeout( 0 ) _readlistlen = addsocket(_readlist, server, _readlistlen) - _server[ port ] = handler + _server[ addr..":"..port ] = handler _socketlist[ server ] = handler - out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" ) + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) return handler end -getserver = function ( port ) - return _server[ port ]; +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; end -removeserver = function( port ) - local handler = _server[ port ] +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] if not handler then - return nil, "no server found on port '" .. tostring( port ) .. "'" + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" end handler:close( ) - _server[ port ] = nil + _server[ addr..":"..port ] = nil return true end @@ -762,16 +773,16 @@ changesettings = function( new ) if type( new ) ~= "table" then return nil, "invalid settings table" end - _selecttimeout = tonumber( new.timeout ) or _selecttimeout - _sleeptime = tonumber( new.sleeptime ) or _sleeptime - _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen - _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen - _checkinterval = tonumber( new.checkinterval ) or _checkinterval - _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout - _readtimeout = tonumber( new.readtimeout ) or _readtimeout - _cleanqueue = new.cleanqueue - _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver - _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _cleanqueue = new.select_clean_queue + _maxclientsperserver = new.max_connections or _maxclientsperserver + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake return true end @@ -788,16 +799,18 @@ stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end -local dontstop = true; -- thinking about tomorrow, ... +local quitting; -setquitting = function (quit) - dontstop = not quit; - return; +local function setquitting(quit) + quitting = not not quit; end -loop = function( ) -- this is the main loop of the program - while dontstop do - local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout ) +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then @@ -818,22 +831,31 @@ loop = function( ) -- this is the main loop of the program end for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) - handler:close( true ) -- forced disconnect + handler:force_close() -- forced disconnect end clean( _closelist ) - _currenttime = os_time( ) - if os_difftime( _currenttime - _timer ) >= 1 then + _currenttime = luasocket_gettime( ) + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; for i = 1, _timerlistlen do - _timerlist[ i ]( _currenttime ) -- fire timers + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end end _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); end socket_sleep( _sleeptime ) -- wait some time --collectgarbage( ) - end + until quitting; + if once and quitting == "once" then quitting = nil; return; end return "quitting" end +local function step() + return loop(true); +end + local function get_backend() return "select"; end @@ -843,16 +865,19 @@ end local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - listeners.onconnect(handler); - handler.sendbuffer = _sendbuffer; - -- If there was data with the incoming packet, handle it now. - if #handler:bufferqueue() > 0 then - return _sendbuffer(); + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ); + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end end end end @@ -881,8 +906,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = os_time( ) -_starttime = os_time( ) +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) addtimer( function( ) local difftime = os_difftime( _currenttime - _starttime ) @@ -892,7 +917,7 @@ addtimer( function( ) if os_difftime( _currenttime - timestamp ) > _sendtimeout then --_writetimes[ handler ] = nil handler.disconnect( )( handler, "send timeout" ) - handler:close( true ) -- forced disconnect + handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do @@ -917,15 +942,16 @@ end ----------------------------------// PUBLIC INTERFACE //-- return { + _addtimer = addtimer, addclient = addclient, wrapclient = wrapclient, loop = loop, link = link, + step = step, stats = stats, closeall = closeall, - addtimer = addtimer, addserver = addserver, getserver = getserver, setlogger = setlogger,