mod_auth_anonymous: add disallow_s2s to the host object if s2s communication is disal...
[prosody.git] / net / server_select.lua
index ca8b11131dc916f94b9cb4f6b96754ce5d7b9629..70825adaaeeec13e842359a900f2e0324bf172e2 100644 (file)
@@ -2,7 +2,7 @@
 -- server.lua by blastbeat of the luadch project
 -- Re-used here under the MIT/X Consortium License
 -- 
--- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain
+-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
 --
 
 -- // wrapping luadch stuff // --
@@ -32,6 +32,7 @@ local STAT_UNIT = 1 -- byte
 local type = use "type"
 local pairs = use "pairs"
 local ipairs = use "ipairs"
+local tonumber = use "tonumber"
 local tostring = use "tostring"
 local collectgarbage = use "collectgarbage"
 
@@ -44,8 +45,9 @@ local coroutine = use "coroutine"
 
 --// lua lib methods //--
 
-local os_time = os.time
 local os_difftime = os.difftime
+local math_min = math.min
+local math_huge = math.huge
 local table_concat = table.concat
 local table_remove = table.remove
 local string_len = string.len
@@ -57,6 +59,7 @@ local coroutine_yield = coroutine.yield
 
 local luasec = use "ssl"
 local luasocket = use "socket" or require "socket"
+local luasocket_gettime = luasocket.gettime
 
 --// extern lib methods //--
 
@@ -72,8 +75,8 @@ local id
 local loop
 local stats
 local idfalse
-local addtimer
 local closeall
+local addsocket
 local addserver
 local getserver
 local wrapserver
@@ -125,6 +128,8 @@ local _timer
 
 local _maxclientsperserver
 
+local _maxsslhandshake
+
 ----------------------------------// DEFINITION //--
 
 _server = { } -- key = port, value = table; list of listening servers
@@ -167,7 +172,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
 
        local connections = 0
 
-       local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect
+       local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
 
        local accept = socket.accept
 
@@ -196,6 +201,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                socket:close( )
                _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
                _readlistlen = removesocket( _readlist, socket, _readlistlen )
+               _server[ip..":"..serverport] = nil;
                _socketlist[ socket ] = nil
                handler = nil
                socket = nil
@@ -226,7 +232,10 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                        end
                        connections = connections + 1
                        out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
-                       return dispatch( handler )
+                       if dispatch then
+                               return dispatch( handler );
+                       end
+                       return;
                elseif err then -- maybe timeout or something else
                        out_put( "server.lua: error with new client connection: ", tostring(err) )
                        return false
@@ -252,6 +261,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local dispatch = listeners.onincoming
        local status = listeners.onstatus
        local disconnect = listeners.ondisconnect
+       local drain = listeners.ondrain
 
        local bufferqueue = { } -- buffer array
        local bufferqueuelen = 0        -- end of buffer array
@@ -284,6 +294,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                dispatch = listeners.onincoming
                disconnect = listeners.ondisconnect
                status = listeners.onstatus
+               drain = listeners.ondrain
        end
        handler.getstats = function( )
                return readtraffic, sendtraffic
@@ -341,9 +352,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        _closelist[ handler ] = nil
                        handler = nil
                end
-       if server then
-               server.remove( )
-       end
+               if server then
+                       server.remove( )
+               end
                out_put "server.lua: closed client handler and removed socket from list"
                return true
        end
@@ -379,7 +390,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.socket = function( self )
                return socket
        end
-       handler.pattern = function( self, new )
+       handler.set_mode = function( self, new )
                pattern = new or pattern
                return pattern
        end
@@ -392,6 +403,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                maxreadlen = readlen or maxreadlen
                return bufferlen, maxreadlen, maxsendlen
        end
+       --TODO: Deprecate
        handler.lock_read = function (self, switch)
                if switch == true then
                        local tmp = _readlistlen
@@ -409,6 +421,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                end
                return noread
        end
+       handler.pause = function (self)
+               return self:lock_read(true);
+       end
+       handler.resume = function (self)
+               return self:lock_read(false);
+       end
        handler.lock = function( self, switch )
                handler.lock_read (switch)
                if switch == true then
@@ -430,7 +448,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        end
        local _readbuffer = function( ) -- this function reads data
                local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
-               if not err or (err == "wantread" or err == "timeout") or string_len(part) > 0 then -- received something
+               if not err or (err == "wantread" or err == "timeout") then -- received something
                        local buffer = buffer or part or ""
                        local len = string_len( buffer )
                        if len > maxreadlen then
@@ -448,7 +466,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
                        fatalerror = true
                        disconnect( handler, err )
-               _ = handler and handler:close( )
+                       _ = handler and handler:close( )
                        return false
                end
        end
@@ -470,9 +488,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        bufferqueuelen = 0
                        bufferlen = 0
                        _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
-                       _ = needtls and handler:starttls(nil, true)
                        _writetimes[ handler ] = nil
-                       _ = toclose and handlerclose( )
+                       if drain then
+                               drain(handler)
+                       end
+                       _ = needtls and handler:starttls(nil)
+                       _ = toclose and handler:close( )
                        return true
                elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
                        buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
@@ -493,10 +514,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        -- Set the sslctx
        local handshake;
        function handler.set_sslctx(self, new_sslctx)
-               ssl = true
                sslctx = new_sslctx;
-               local wrote
-               local read
+               local read, wrote
                handshake = coroutine_wrap( function( client ) -- create handshake coroutine
                                local err
                                for i = 1, _maxsslhandshake do
@@ -509,23 +528,26 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
                                                handler.sendbuffer = _sendbuffer
                                                _ = status and status( handler, "ssl-handshake-complete" )
+                                               if self.autostart_ssl and listeners.onconnect then
+                                                       listeners.onconnect(self);
+                                               end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
                                        else
-                                               out_put( "server.lua: error during ssl handshake: ", tostring(err) )
-                                               if err == "wantwrite" and not wrote then
+                                               if err == "wantwrite" then
                                                        _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
                                                        wrote = true
-                                               elseif err == "wantread" and not read then
+                                               elseif err == "wantread" then
                                                        _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                        read = true
                                                else
                                                        break;
                                                end
-                                               --coroutine_yield( handler, nil, err )   -- handshake not finished
-                                               coroutine_yield( )
+                                               err = nil;
+                                               coroutine_yield( ) -- handshake not finished
                                        end
                                end
+                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
                                disconnect( handler, "ssl handshake failed" )
                                _ = handler and handler:close( true )    -- forced disconnect
                                return false    -- handshake failed
@@ -533,78 +555,51 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                )
        end
        if luasec then
-               if sslctx then -- ssl?
-                       handler:set_sslctx(sslctx);
-                       out_put("server.lua: ", "starting ssl handshake")
-                       local err
-                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                       if err then
-                               out_put( "server.lua: ssl error: ", tostring(err) )
-                               --mem_free( )
-                               return nil, nil, err    -- fatal error
+               handler.starttls = function( self, _sslctx)
+                       if _sslctx then
+                               handler:set_sslctx(_sslctx);
                        end
-                       socket:settimeout( 0 )
-                       handler.readbuffer = handshake
-                       handler.sendbuffer = handshake
-                       handshake( socket ) -- do handshake
+                       if bufferqueuelen > 0 then
+                               out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+                               needtls = true
+                               return
+                       end
+                       out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+                       local oldsocket, err = socket
+                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
                        if not socket then
-                               return nil, nil, "ssl handshake failed";
+                               out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+                               return nil, err -- fatal error
                        end
-               else
-                       local sslctx;
-                       handler.starttls = function( self, _sslctx, now )
-                               if _sslctx then
-                                       sslctx = _sslctx;
-                                       handler:set_sslctx(sslctx);
-                               end
-                               if not now then
-                                       out_put "server.lua: we need to do tls, but delaying until later"
-                                       needtls = true
-                                       return
-                               end
-                               out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-                               local oldsocket, err = socket
-                               socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                               --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-                               if err then
-                                       out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-                                       return nil, err -- fatal error
-                               end
-
-                               socket:settimeout( 0 )
-       
-                               -- add the new socket to our system
-       
-                               send = socket.send
-                               receive = socket.receive
-                               shutdown = id
-
-                               _socketlist[ socket ] = handler
-                               _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-                               -- remove traces of the old socket
+                       socket:settimeout( 0 )
 
-                               _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-                               _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-                               _socketlist[ oldsocket ] = nil
+                       -- add the new socket to our system
+                       send = socket.send
+                       receive = socket.receive
+                       shutdown = id
+                       _socketlist[ socket ] = handler
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       
+                       -- remove traces of the old socket
+                       _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+                       _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+                       _socketlist[ oldsocket ] = nil
 
-                               handler.starttls = nil
-                               needtls = nil
+                       handler.starttls = nil
+                       needtls = nil
 
-                               -- Secure now
-                               ssl = true
+                       -- Secure now (if handshake fails connection will close)
+                       ssl = true
 
-                               handler.readbuffer = handshake
-                               handler.sendbuffer = handshake
-                               handshake( socket ) -- do handshake
-                       end
-                       handler.readbuffer = _readbuffer
-                       handler.sendbuffer = _sendbuffer
+                       handler.readbuffer = handshake
+                       handler.sendbuffer = handshake
+                       handshake( socket ) -- do handshake
                end
-       else
-               handler.readbuffer = _readbuffer
-               handler.sendbuffer = _sendbuffer
        end
+
+       handler.readbuffer = _readbuffer
+       handler.sendbuffer = _sendbuffer
        send = socket.send
        receive = socket.receive
        shutdown = ( ssl and id ) or socket.shutdown
@@ -612,6 +607,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
+       if sslctx and luasec then
+               out_put "server.lua: auto-starting ssl negotiation..."
+               handler.autostart_ssl = true;
+               handler:starttls(sslctx);
+       end
+
        return handler, socket
 end
 
@@ -654,6 +655,27 @@ closesocket = function( socket )
        --mem_free( )
 end
 
+local function link(sender, receiver, buffersize)
+       local sender_locked;
+       local _sendbuffer = receiver.sendbuffer;
+       function receiver.sendbuffer()
+               _sendbuffer();
+               if sender_locked and receiver.bufferlen() < buffersize then
+                       sender:lock_read(false); -- Unlock now
+                       sender_locked = nil;
+               end
+       end
+       
+       local _readbuffer = sender.readbuffer;
+       function sender.readbuffer()
+               _readbuffer();
+               if not sender_locked and receiver.bufferlen() >= buffersize then
+                       sender_locked = true;
+                       sender:lock_read(true);
+               end
+       end
+end
+
 ----------------------------------// PUBLIC //--
 
 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
@@ -661,21 +683,21 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        if type( listeners ) ~= "table" then
                err = "invalid listener table"
        end
-       if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
+       if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
-       elseif _server[ port ] then
-               err = "listeners on port '" .. port .. "' already exist"
+       elseif _server[ addr..":"..port ] then
+               err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
        elseif sslctx and not luasec then
                err = "luasec not found"
        end
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        addr = addr or "*"
        local server, err = socket_bind( addr, port )
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket
@@ -685,23 +707,23 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        end
        server:settimeout( 0 )
        _readlistlen = addsocket(_readlist, server, _readlistlen)
-       _server[ port ] = handler
+       _server[ addr..":"..port ] = handler
        _socketlist[ server ] = handler
-       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" )
+       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
        return handler
 end
 
-getserver = function ( port )
-       return _server[ port ];
+getserver = function ( addr, port )
+       return _server[ addr..":"..port ];
 end
 
-removeserver = function( port )
-       local handler = _server[ port ]
+removeserver = function( addr, port )
+       local handler = _server[ addr..":"..port ]
        if not handler then
-               return nil, "no server found on port '" .. tostring( port ) .. "'"
+               return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
        end
        handler:close( )
-       _server[ port ] = nil
+       _server[ addr..":"..port ] = nil
        return true
 end
 
@@ -755,16 +777,18 @@ stats = function( )
        return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
 end
 
-local dontstop = true; -- thinking about tomorrow, ...
+local quitting;
 
-setquitting = function (quit)
-       dontstop = not quit;
-       return;
+local function setquitting(quit)
+       quitting = not not quit;
 end
 
-loop = function( ) -- this is the main loop of the program
-       while dontstop do
-               local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
+loop = function(once) -- this is the main loop of the program
+       if quitting then return "quitting"; end
+       if once then quitting = "once"; end
+       local next_timer_time = math_huge;
+       repeat
+               local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
                for i, socket in ipairs( write ) do -- send data waiting in writequeues
                        local handler = _socketlist[ socket ]
                        if handler then
@@ -788,19 +812,28 @@ loop = function( ) -- this is the main loop of the program
                        handler:close( true )    -- forced disconnect
                end
                clean( _closelist )
-               _currenttime = os_time( )
-               if os_difftime( _currenttime - _timer ) >= 1 then
+               _currenttime = luasocket_gettime( )
+               if _currenttime - _timer >= math_min(next_timer_time, 1) then
+                       next_timer_time = math_huge;
                        for i = 1, _timerlistlen do
-                               _timerlist[ i ]( _currenttime ) -- fire timers
+                               local t = _timerlist[ i ]( _currenttime ) -- fire timers
+                               if t then next_timer_time = math_min(next_timer_time, t); end
                        end
                        _timer = _currenttime
+               else
+                       next_timer_time = next_timer_time - (_currenttime - _timer);
                end
                socket_sleep( _sleeptime ) -- wait some time
                --collectgarbage( )
-       end
+       until quitting;
+       if once and quitting == "once" then quitting = nil; return; end
        return "quitting"
 end
 
+local function step()
+       return loop(true);
+end
+
 local function get_backend()
        return "select";
 end
@@ -810,7 +843,22 @@ end
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
        local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
        _socketlist[ socket ] = handler
-       _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+       if not sslctx then
+               _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+               if listeners.onconnect then
+                       -- When socket is writeable, call onconnect
+                       local _sendbuffer = handler.sendbuffer;
+                       handler.sendbuffer = function ()
+                               _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
+                               handler.sendbuffer = _sendbuffer;
+                               listeners.onconnect(handler);
+                               -- If there was data with the incoming packet, handle it now.
+                               if #handler:bufferqueue() > 0 then
+                                       return _sendbuffer();
+                               end
+                       end
+               end
+       end
        return handler, socket
 end
 
@@ -836,8 +884,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } )
 use "setmetatable" ( _readtimes, { __mode = "k" } )
 use "setmetatable" ( _writetimes, { __mode = "k" } )
 
-_timer = os_time( )
-_starttime = os_time( )
+_timer = luasocket_gettime( )
+_starttime = luasocket_gettime( )
 
 addtimer( function( )
                local difftime = os_difftime( _currenttime - _starttime )
@@ -872,14 +920,16 @@ end
 ----------------------------------// PUBLIC INTERFACE //--
 
 return {
+       _addtimer = addtimer,
 
        addclient = addclient,
        wrapclient = wrapclient,
        
        loop = loop,
+       link = link,
+       step = step,
        stats = stats,
        closeall = closeall,
-       addtimer = addtimer,
        addserver = addserver,
        getserver = getserver,
        setlogger = setlogger,