X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=net%2Fserver_select.lua;h=cfd7f3cd98c0c18489fdcdd8ed9732b1979a3d11;hb=aa53318e7ca435fd45038557eb497b6514a6b995;hp=685cd13e978503f6bdc310399b5acca8e13c771d;hpb=b6dfdd3762b43b153aeb2b9274ee88d30b4fb7e6;p=prosody.git diff --git a/net/server_select.lua b/net/server_select.lua index 685cd13e..cfd7f3cd 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -2,7 +2,7 @@ -- server.lua by blastbeat of the luadch project -- Re-used here under the MIT/X Consortium License -- --- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain +-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain -- -- // wrapping luadch stuff // -- @@ -32,6 +32,7 @@ local STAT_UNIT = 1 -- byte local type = use "type" local pairs = use "pairs" local ipairs = use "ipairs" +local tonumber = use "tonumber" local tostring = use "tostring" local collectgarbage = use "collectgarbage" @@ -44,8 +45,9 @@ local coroutine = use "coroutine" --// lua lib methods //-- -local os_time = os.time local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge local table_concat = table.concat local table_remove = table.remove local string_len = string.len @@ -57,6 +59,7 @@ local coroutine_yield = coroutine.yield local luasec = use "ssl" local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime --// extern lib methods //-- @@ -74,6 +77,7 @@ local stats local idfalse local addtimer local closeall +local addsocket local addserver local getserver local wrapserver @@ -125,6 +129,8 @@ local _timer local _maxclientsperserver +local _maxsslhandshake + ----------------------------------// DEFINITION //-- _server = { } -- key = port, value = table; list of listening servers @@ -167,7 +173,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco local connections = 0 - local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect local accept = socket.accept @@ -252,6 +258,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local dispatch = listeners.onincoming local status = listeners.onstatus local disconnect = listeners.ondisconnect + local drain = listeners.ondrain local bufferqueue = { } -- buffer array local bufferqueuelen = 0 -- end of buffer array @@ -284,6 +291,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport dispatch = listeners.onincoming disconnect = listeners.ondisconnect status = listeners.onstatus + drain = listeners.ondrain end handler.getstats = function( ) return readtraffic, sendtraffic @@ -341,9 +349,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _closelist[ handler ] = nil handler = nil end - if server then - server.remove( ) - end + if server then + server.remove( ) + end out_put "server.lua: closed client handler and removed socket from list" return true end @@ -379,7 +387,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.socket = function( self ) return socket end - handler.pattern = function( self, new ) + handler.set_mode = function( self, new ) pattern = new or pattern return pattern end @@ -392,6 +400,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport maxreadlen = readlen or maxreadlen return bufferlen, maxreadlen, maxsendlen end + --TODO: Deprecate handler.lock_read = function (self, switch) if switch == true then local tmp = _readlistlen @@ -409,6 +418,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end return noread end + handler.pause = function (self) + return self:lock_read(true); + end + handler.resume = function (self) + return self:lock_read(false); + end handler.lock = function( self, switch ) handler.lock_read (switch) if switch == true then @@ -430,12 +445,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end local _readbuffer = function( ) -- this function reads data local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" - if not err or (err == "wantread" or err == "timeout") or string_len(part) > 0 then -- received something + if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" local len = string_len( buffer ) if len > maxreadlen then disconnect( handler, "receive buffer exceeded" ) - handler.close( true ) + handler:close( true ) return false end local count = len * STAT_UNIT @@ -448,7 +463,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) fatalerror = true disconnect( handler, err ) - _ = handler and handler.close( ) + _ = handler and handler:close( ) return false end end @@ -470,9 +485,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferqueuelen = 0 bufferlen = 0 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist - _ = needtls and handler:starttls(nil, true) _writetimes[ handler ] = nil - _ = toclose and handler.close( ) + if drain then + drain(handler) + end + _ = needtls and handler:starttls(nil) + _ = toclose and handler:close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer @@ -485,7 +503,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) fatalerror = true disconnect( handler, err ) - _ = handler and handler.close( ) + _ = handler and handler:close( ) return false end end @@ -552,13 +570,13 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end else local sslctx; - handler.starttls = function( self, _sslctx, now ) + handler.starttls = function( self, _sslctx) if _sslctx then sslctx = _sslctx; handler:set_sslctx(sslctx); end - if not now then - out_put "server.lua: we need to do tls, but delaying until later" + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" needtls = true return end @@ -611,7 +629,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - return handler, socket end @@ -654,6 +671,27 @@ closesocket = function( socket ) --mem_free( ) end +local function link(sender, receiver, buffersize) + local sender_locked; + local _sendbuffer = receiver.sendbuffer; + function receiver.sendbuffer() + _sendbuffer(); + if sender_locked and receiver.bufferlen() < buffersize then + sender:lock_read(false); -- Unlock now + sender_locked = nil; + end + end + + local _readbuffer = sender.readbuffer; + function sender.readbuffer() + _readbuffer(); + if not sender_locked and receiver.bufferlen() >= buffersize then + sender_locked = true; + sender:lock_read(true); + end + end +end + ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server @@ -661,7 +699,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function if type( listeners ) ~= "table" then err = "invalid listener table" end - if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then + if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ port ] then err = "listeners on port '" .. port .. "' already exist" @@ -755,16 +793,18 @@ stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end -local dontstop = true; -- thinking about tomorrow, ... +local quitting; setquitting = function (quit) - dontstop = not quit; - return; + quitting = not not quit; end -loop = function( ) -- this is the main loop of the program - while dontstop do - local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout ) +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then @@ -788,19 +828,28 @@ loop = function( ) -- this is the main loop of the program handler:close( true ) -- forced disconnect end clean( _closelist ) - _currenttime = os_time( ) - if os_difftime( _currenttime - _timer ) >= 1 then + _currenttime = luasocket_gettime( ) + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; for i = 1, _timerlistlen do - _timerlist[ i ]( _currenttime ) -- fire timers + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end end _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); end socket_sleep( _sleeptime ) -- wait some time --collectgarbage( ) - end + until quitting; + if once and quitting == "once" then quitting = nil; return; end return "quitting" end +step = function () + return loop(true); +end + local function get_backend() return "select"; end @@ -811,6 +860,18 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + end + end return handler, socket end @@ -836,8 +897,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = os_time( ) -_starttime = os_time( ) +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) addtimer( function( ) local difftime = os_difftime( _currenttime - _starttime ) @@ -877,6 +938,7 @@ return { wrapclient = wrapclient, loop = loop, + link = link, stats = stats, closeall = closeall, addtimer = addtimer,