mod_auth_anonymous: add disallow_s2s to the host object if s2s communication is disal...
[prosody.git] / net / server_select.lua
index 0310a9915ead1febc7753d608b1a933daf22b39e..70825adaaeeec13e842359a900f2e0324bf172e2 100644 (file)
@@ -75,7 +75,6 @@ local id
 local loop
 local stats
 local idfalse
-local addtimer
 local closeall
 local addsocket
 local addserver
@@ -149,7 +148,7 @@ _timerlistlen = 0 -- lenght of timerlist
 _sendtraffic = 0 -- some stats
 _readtraffic = 0
 
-_selecttimeout = 3600 -- timeout of socket.select
+_selecttimeout = 1 -- timeout of socket.select
 _sleeptime = 0 -- time to wait at the end of every loop
 
 _maxsendlen = 51000 * 1024 -- max len of send buffer
@@ -173,7 +172,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
 
        local connections = 0
 
-       local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect
+       local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
 
        local accept = socket.accept
 
@@ -202,6 +201,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                socket:close( )
                _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
                _readlistlen = removesocket( _readlist, socket, _readlistlen )
+               _server[ip..":"..serverport] = nil;
                _socketlist[ socket ] = nil
                handler = nil
                socket = nil
@@ -232,7 +232,10 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                        end
                        connections = connections + 1
                        out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
-                       return dispatch( handler )
+                       if dispatch then
+                               return dispatch( handler );
+                       end
+                       return;
                elseif err then -- maybe timeout or something else
                        out_put( "server.lua: error with new client connection: ", tostring(err) )
                        return false
@@ -463,7 +466,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
                        fatalerror = true
                        disconnect( handler, err )
-               _ = handler and handler:close( )
+                       _ = handler and handler:close( )
                        return false
                end
        end
@@ -511,10 +514,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        -- Set the sslctx
        local handshake;
        function handler.set_sslctx(self, new_sslctx)
-               ssl = true
                sslctx = new_sslctx;
-               local wrote
-               local read
+               local read, wrote
                handshake = coroutine_wrap( function( client ) -- create handshake coroutine
                                local err
                                for i = 1, _maxsslhandshake do
@@ -527,23 +528,26 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
                                                handler.sendbuffer = _sendbuffer
                                                _ = status and status( handler, "ssl-handshake-complete" )
+                                               if self.autostart_ssl and listeners.onconnect then
+                                                       listeners.onconnect(self);
+                                               end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
                                        else
-                                               out_put( "server.lua: error during ssl handshake: ", tostring(err) )
-                                               if err == "wantwrite" and not wrote then
+                                               if err == "wantwrite" then
                                                        _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
                                                        wrote = true
-                                               elseif err == "wantread" and not read then
+                                               elseif err == "wantread" then
                                                        _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                        read = true
                                                else
                                                        break;
                                                end
-                                               --coroutine_yield( handler, nil, err )   -- handshake not finished
-                                               coroutine_yield( )
+                                               err = nil;
+                                               coroutine_yield( ) -- handshake not finished
                                        end
                                end
+                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
                                disconnect( handler, "ssl handshake failed" )
                                _ = handler and handler:close( true )    -- forced disconnect
                                return false    -- handshake failed
@@ -551,84 +555,64 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                )
        end
        if luasec then
-               if sslctx then -- ssl?
-                       handler:set_sslctx(sslctx);
-                       out_put("server.lua: ", "starting ssl handshake")
-                       local err
-                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                       if err then
-                               out_put( "server.lua: ssl error: ", tostring(err) )
-                               --mem_free( )
-                               return nil, nil, err    -- fatal error
+               handler.starttls = function( self, _sslctx)
+                       if _sslctx then
+                               handler:set_sslctx(_sslctx);
                        end
-                       socket:settimeout( 0 )
-                       handler.readbuffer = handshake
-                       handler.sendbuffer = handshake
-                       handshake( socket ) -- do handshake
+                       if bufferqueuelen > 0 then
+                               out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+                               needtls = true
+                               return
+                       end
+                       out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+                       local oldsocket, err = socket
+                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
                        if not socket then
-                               return nil, nil, "ssl handshake failed";
+                               out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+                               return nil, err -- fatal error
                        end
-               else
-                       local sslctx;
-                       handler.starttls = function( self, _sslctx)
-                               if _sslctx then
-                                       sslctx = _sslctx;
-                                       handler:set_sslctx(sslctx);
-                               end
-                               if bufferqueuelen > 0 then
-                                       out_put "server.lua: we need to do tls, but delaying until send buffer empty"
-                                       needtls = true
-                                       return
-                               end
-                               out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-                               local oldsocket, err = socket
-                               socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                               --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-                               if err then
-                                       out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-                                       return nil, err -- fatal error
-                               end
-
-                               socket:settimeout( 0 )
-       
-                               -- add the new socket to our system
-       
-                               send = socket.send
-                               receive = socket.receive
-                               shutdown = id
-
-                               _socketlist[ socket ] = handler
-                               _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-                               -- remove traces of the old socket
+                       socket:settimeout( 0 )
 
-                               _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-                               _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-                               _socketlist[ oldsocket ] = nil
+                       -- add the new socket to our system
+                       send = socket.send
+                       receive = socket.receive
+                       shutdown = id
+                       _socketlist[ socket ] = handler
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       
+                       -- remove traces of the old socket
+                       _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+                       _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+                       _socketlist[ oldsocket ] = nil
 
-                               handler.starttls = nil
-                               needtls = nil
+                       handler.starttls = nil
+                       needtls = nil
 
-                               -- Secure now
-                               ssl = true
+                       -- Secure now (if handshake fails connection will close)
+                       ssl = true
 
-                               handler.readbuffer = handshake
-                               handler.sendbuffer = handshake
-                               handshake( socket ) -- do handshake
-                       end
-                       handler.readbuffer = _readbuffer
-                       handler.sendbuffer = _sendbuffer
+                       handler.readbuffer = handshake
+                       handler.sendbuffer = handshake
+                       handshake( socket ) -- do handshake
                end
-       else
-               handler.readbuffer = _readbuffer
-               handler.sendbuffer = _sendbuffer
        end
+
+       handler.readbuffer = _readbuffer
+       handler.sendbuffer = _sendbuffer
        send = socket.send
        receive = socket.receive
        shutdown = ( ssl and id ) or socket.shutdown
 
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
+
+       if sslctx and luasec then
+               out_put "server.lua: auto-starting ssl negotiation..."
+               handler.autostart_ssl = true;
+               handler:starttls(sslctx);
+       end
+
        return handler, socket
 end
 
@@ -701,19 +685,19 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        end
        if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
-       elseif _server[ port ] then
-               err = "listeners on port '" .. port .. "' already exist"
+       elseif _server[ addr..":"..port ] then
+               err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
        elseif sslctx and not luasec then
                err = "luasec not found"
        end
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        addr = addr or "*"
        local server, err = socket_bind( addr, port )
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket
@@ -723,23 +707,23 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        end
        server:settimeout( 0 )
        _readlistlen = addsocket(_readlist, server, _readlistlen)
-       _server[ port ] = handler
+       _server[ addr..":"..port ] = handler
        _socketlist[ server ] = handler
-       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" )
+       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
        return handler
 end
 
-getserver = function ( port )
-       return _server[ port ];
+getserver = function ( addr, port )
+       return _server[ addr..":"..port ];
 end
 
-removeserver = function( port )
-       local handler = _server[ port ]
+removeserver = function( addr, port )
+       local handler = _server[ addr..":"..port ]
        if not handler then
-               return nil, "no server found on port '" .. tostring( port ) .. "'"
+               return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
        end
        handler:close( )
-       _server[ port ] = nil
+       _server[ addr..":"..port ] = nil
        return true
 end
 
@@ -795,7 +779,7 @@ end
 
 local quitting;
 
-setquitting = function (quit)
+local function setquitting(quit)
        quitting = not not quit;
 end
 
@@ -846,7 +830,7 @@ loop = function(once) -- this is the main loop of the program
        return "quitting"
 end
 
-step = function ()
+local function step()
        return loop(true);
 end
 
@@ -859,16 +843,19 @@ end
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
        local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
        _socketlist[ socket ] = handler
-       _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
-       if listeners.onconnect then
-               -- When socket is writeable, call onconnect
-               local _sendbuffer = handler.sendbuffer;
-               handler.sendbuffer = function ()
-                       handler.sendbuffer = _sendbuffer;
-                       listeners.onconnect(handler);
-                       -- If there was data with the incoming packet, handle it now.
-                       if #handler:bufferqueue() > 0 then
-                               return _sendbuffer();
+       if not sslctx then
+               _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+               if listeners.onconnect then
+                       -- When socket is writeable, call onconnect
+                       local _sendbuffer = handler.sendbuffer;
+                       handler.sendbuffer = function ()
+                               _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
+                               handler.sendbuffer = _sendbuffer;
+                               listeners.onconnect(handler);
+                               -- If there was data with the incoming packet, handle it now.
+                               if #handler:bufferqueue() > 0 then
+                                       return _sendbuffer();
+                               end
                        end
                end
        end
@@ -933,15 +920,16 @@ end
 ----------------------------------// PUBLIC INTERFACE //--
 
 return {
+       _addtimer = addtimer,
 
        addclient = addclient,
        wrapclient = wrapclient,
        
        loop = loop,
        link = link,
+       step = step,
        stats = stats,
        closeall = closeall,
-       addtimer = addtimer,
        addserver = addserver,
        getserver = getserver,
        setlogger = setlogger,