mod_auth_anonymous: add disallow_s2s to the host object if s2s communication is disal...
[prosody.git] / net / server_select.lua
index bb3b12c9069cd611fac0782c6e3432a63f21adec..70825adaaeeec13e842359a900f2e0324bf172e2 100644 (file)
@@ -75,7 +75,6 @@ local id
 local loop
 local stats
 local idfalse
-local addtimer
 local closeall
 local addsocket
 local addserver
@@ -173,7 +172,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
 
        local connections = 0
 
-       local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect
+       local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
 
        local accept = socket.accept
 
@@ -202,6 +201,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                socket:close( )
                _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
                _readlistlen = removesocket( _readlist, socket, _readlistlen )
+               _server[ip..":"..serverport] = nil;
                _socketlist[ socket ] = nil
                handler = nil
                socket = nil
@@ -232,7 +232,10 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                        end
                        connections = connections + 1
                        out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
-                       return dispatch( handler )
+                       if dispatch then
+                               return dispatch( handler );
+                       end
+                       return;
                elseif err then -- maybe timeout or something else
                        out_put( "server.lua: error with new client connection: ", tostring(err) )
                        return false
@@ -463,7 +466,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
                        fatalerror = true
                        disconnect( handler, err )
-               _ = handler and handler:close( )
+                       _ = handler and handler:close( )
                        return false
                end
        end
@@ -525,6 +528,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
                                                handler.sendbuffer = _sendbuffer
                                                _ = status and status( handler, "ssl-handshake-complete" )
+                                               if self.autostart_ssl and listeners.onconnect then
+                                                       listeners.onconnect(self);
+                                               end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
                                        else
@@ -535,13 +541,13 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                        _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                        read = true
                                                else
-                                                       out_put( "server.lua: ssl handshake error: ", tostring(err) )
                                                        break;
                                                end
-                                               --coroutine_yield( handler, nil, err )   -- handshake not finished
-                                               coroutine_yield( )
+                                               err = nil;
+                                               coroutine_yield( ) -- handshake not finished
                                        end
                                end
+                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
                                disconnect( handler, "ssl handshake failed" )
                                _ = handler and handler:close( true )    -- forced disconnect
                                return false    -- handshake failed
@@ -549,84 +555,64 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                )
        end
        if luasec then
-               if sslctx then -- ssl?
-                       handler:set_sslctx(sslctx);
-                       out_put("server.lua: ", "starting ssl handshake")
-                       local err
-                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                       if err then
-                               out_put( "server.lua: ssl error: ", tostring(err) )
-                               --mem_free( )
-                               return nil, nil, err    -- fatal error
+               handler.starttls = function( self, _sslctx)
+                       if _sslctx then
+                               handler:set_sslctx(_sslctx);
                        end
-                       socket:settimeout( 0 )
-                       handler.readbuffer = handshake
-                       handler.sendbuffer = handshake
-                       handshake( socket ) -- do handshake
+                       if bufferqueuelen > 0 then
+                               out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+                               needtls = true
+                               return
+                       end
+                       out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+                       local oldsocket, err = socket
+                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
                        if not socket then
-                               return nil, nil, "ssl handshake failed";
+                               out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+                               return nil, err -- fatal error
                        end
-               else
-                       local sslctx;
-                       handler.starttls = function( self, _sslctx)
-                               if _sslctx then
-                                       sslctx = _sslctx;
-                                       handler:set_sslctx(sslctx);
-                               end
-                               if bufferqueuelen > 0 then
-                                       out_put "server.lua: we need to do tls, but delaying until send buffer empty"
-                                       needtls = true
-                                       return
-                               end
-                               out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-                               local oldsocket, err = socket
-                               socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                               --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-                               if err then
-                                       out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-                                       return nil, err -- fatal error
-                               end
-
-                               socket:settimeout( 0 )
-       
-                               -- add the new socket to our system
-       
-                               send = socket.send
-                               receive = socket.receive
-                               shutdown = id
-
-                               _socketlist[ socket ] = handler
-                               _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-                               -- remove traces of the old socket
+                       socket:settimeout( 0 )
 
-                               _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-                               _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-                               _socketlist[ oldsocket ] = nil
+                       -- add the new socket to our system
+                       send = socket.send
+                       receive = socket.receive
+                       shutdown = id
+                       _socketlist[ socket ] = handler
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       
+                       -- remove traces of the old socket
+                       _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+                       _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+                       _socketlist[ oldsocket ] = nil
 
-                               handler.starttls = nil
-                               needtls = nil
+                       handler.starttls = nil
+                       needtls = nil
 
-                               -- Secure now
-                               ssl = true
+                       -- Secure now (if handshake fails connection will close)
+                       ssl = true
 
-                               handler.readbuffer = handshake
-                               handler.sendbuffer = handshake
-                               handshake( socket ) -- do handshake
-                       end
-                       handler.readbuffer = _readbuffer
-                       handler.sendbuffer = _sendbuffer
+                       handler.readbuffer = handshake
+                       handler.sendbuffer = handshake
+                       handshake( socket ) -- do handshake
                end
-       else
-               handler.readbuffer = _readbuffer
-               handler.sendbuffer = _sendbuffer
        end
+
+       handler.readbuffer = _readbuffer
+       handler.sendbuffer = _sendbuffer
        send = socket.send
        receive = socket.receive
        shutdown = ( ssl and id ) or socket.shutdown
 
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
+
+       if sslctx and luasec then
+               out_put "server.lua: auto-starting ssl negotiation..."
+               handler.autostart_ssl = true;
+               handler:starttls(sslctx);
+       end
+
        return handler, socket
 end
 
@@ -793,7 +779,7 @@ end
 
 local quitting;
 
-setquitting = function (quit)
+local function setquitting(quit)
        quitting = not not quit;
 end
 
@@ -844,7 +830,7 @@ loop = function(once) -- this is the main loop of the program
        return "quitting"
 end
 
-step = function ()
+local function step()
        return loop(true);
 end
 
@@ -857,16 +843,19 @@ end
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
        local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
        _socketlist[ socket ] = handler
-       _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
-       if listeners.onconnect then
-               -- When socket is writeable, call onconnect
-               local _sendbuffer = handler.sendbuffer;
-               handler.sendbuffer = function ()
-                       handler.sendbuffer = _sendbuffer;
-                       listeners.onconnect(handler);
-                       -- If there was data with the incoming packet, handle it now.
-                       if #handler:bufferqueue() > 0 then
-                               return _sendbuffer();
+       if not sslctx then
+               _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+               if listeners.onconnect then
+                       -- When socket is writeable, call onconnect
+                       local _sendbuffer = handler.sendbuffer;
+                       handler.sendbuffer = function ()
+                               _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
+                               handler.sendbuffer = _sendbuffer;
+                               listeners.onconnect(handler);
+                               -- If there was data with the incoming packet, handle it now.
+                               if #handler:bufferqueue() > 0 then
+                                       return _sendbuffer();
+                               end
                        end
                end
        end
@@ -931,6 +920,7 @@ end
 ----------------------------------// PUBLIC INTERFACE //--
 
 return {
+       _addtimer = addtimer,
 
        addclient = addclient,
        wrapclient = wrapclient,
@@ -940,7 +930,6 @@ return {
        step = step,
        stats = stats,
        closeall = closeall,
-       addtimer = addtimer,
        addserver = addserver,
        getserver = getserver,
        setlogger = setlogger,