net.httpclient_listener: Define t_insert
[prosody.git] / net / server_select.lua
index 13a910f8b4ce52a963b4d54f3067d621415c3c56..41f2b9fa50a2da2e70f8734b56c978e7b576abbb 100644 (file)
@@ -511,10 +511,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        -- Set the sslctx
        local handshake;
        function handler.set_sslctx(self, new_sslctx)
-               ssl = true
                sslctx = new_sslctx;
-               local wrote
-               local read
+               local read, wrote
                handshake = coroutine_wrap( function( client ) -- create handshake coroutine
                                local err
                                for i = 1, _maxsslhandshake do
@@ -527,23 +525,26 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                handler.readbuffer = _readbuffer        -- when handshake is done, replace the handshake function with regular functions
                                                handler.sendbuffer = _sendbuffer
                                                _ = status and status( handler, "ssl-handshake-complete" )
+                                               if self.autostart_ssl and listeners.onconnect then
+                                                       listeners.onconnect(self);
+                                               end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
                                        else
-                                               if err == "wantwrite" and not wrote then
+                                               if err == "wantwrite" then
                                                        _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
                                                        wrote = true
-                                               elseif err == "wantread" and not read then
+                                               elseif err == "wantread" then
                                                        _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                        read = true
                                                else
-                                                       out_put( "server.lua: ssl handshake error: ", tostring(err) )
                                                        break;
                                                end
-                                               --coroutine_yield( handler, nil, err )   -- handshake not finished
-                                               coroutine_yield( )
+                                               err = nil;
+                                               coroutine_yield( ) -- handshake not finished
                                        end
                                end
+                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
                                disconnect( handler, "ssl handshake failed" )
                                _ = handler and handler:close( true )    -- forced disconnect
                                return false    -- handshake failed
@@ -551,74 +552,56 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                )
        end
        if luasec then
-               if sslctx then -- ssl?
-                       handler:set_sslctx(sslctx);
-                       out_put("server.lua: ", "starting ssl handshake")
-                       local err
-                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                       if err then
-                               out_put( "server.lua: ssl error: ", tostring(err) )
-                               --mem_free( )
-                               return nil, nil, err    -- fatal error
+               handler.starttls = function( self, _sslctx)
+                       if _sslctx then
+                               handler:set_sslctx(_sslctx);
                        end
-                       socket:settimeout( 0 )
-                       handler.readbuffer = handshake
-                       handler.sendbuffer = handshake
-                       handshake( socket ) -- do handshake
+                       if bufferqueuelen > 0 then
+                               out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+                               needtls = true
+                               return
+                       end
+                       out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+                       local oldsocket, err = socket
+                       socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
                        if not socket then
-                               return nil, nil, "ssl handshake failed";
+                               out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+                               return nil, err -- fatal error
                        end
-               else
-                       local sslctx;
-                       handler.starttls = function( self, _sslctx)
-                               if _sslctx then
-                                       sslctx = _sslctx;
-                                       handler:set_sslctx(sslctx);
-                               end
-                               if bufferqueuelen > 0 then
-                                       out_put "server.lua: we need to do tls, but delaying until send buffer empty"
-                                       needtls = true
-                                       return
-                               end
-                               out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
-                               local oldsocket, err = socket
-                               socket, err = ssl_wrap( socket, sslctx )        -- wrap socket
-                               --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
-                               if err then
-                                       out_put( "server.lua: error while starting tls on client: ", tostring(err) )
-                                       return nil, err -- fatal error
-                               end
 
-                               socket:settimeout( 0 )
-       
-                               -- add the new socket to our system
-       
-                               send = socket.send
-                               receive = socket.receive
-                               shutdown = id
-
-                               _socketlist[ socket ] = handler
-                               _readlistlen = addsocket(_readlist, socket, _readlistlen)
-
-                               -- remove traces of the old socket
+                       socket:settimeout( 0 )
 
-                               _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
-                               _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
-                               _socketlist[ oldsocket ] = nil
+                       -- add the new socket to our system
+                       send = socket.send
+                       receive = socket.receive
+                       shutdown = id
+                       _socketlist[ socket ] = handler
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       
+                       -- remove traces of the old socket
+                       _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+                       _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+                       _socketlist[ oldsocket ] = nil
 
-                               handler.starttls = nil
-                               needtls = nil
+                       handler.starttls = nil
+                       needtls = nil
 
-                               -- Secure now
-                               ssl = true
+                       -- Secure now (if handshake fails connection will close)
+                       ssl = true
 
-                               handler.readbuffer = handshake
-                               handler.sendbuffer = handshake
-                               handshake( socket ) -- do handshake
-                       end
-                       handler.readbuffer = _readbuffer
-                       handler.sendbuffer = _sendbuffer
+                       handler.readbuffer = handshake
+                       handler.sendbuffer = handshake
+                       handshake( socket ) -- do handshake
+               end
+               handler.readbuffer = _readbuffer
+               handler.sendbuffer = _sendbuffer
+               
+               if sslctx then
+                       out_put "server.lua: auto-starting ssl negotiation..."
+                       handler.autostart_ssl = true;
+                       handler:starttls(sslctx);
                end
+
        else
                handler.readbuffer = _readbuffer
                handler.sendbuffer = _sendbuffer
@@ -701,19 +684,19 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        end
        if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
-       elseif _server[ port ] then
-               err = "listeners on port '" .. port .. "' already exist"
+       elseif _server[ addr..":"..port ] then
+               err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
        elseif sslctx and not luasec then
                err = "luasec not found"
        end
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        addr = addr or "*"
        local server, err = socket_bind( addr, port )
        if err then
-               out_error( "server.lua, port ", port, ": ", err )
+               out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
        local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket
@@ -723,23 +706,23 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        end
        server:settimeout( 0 )
        _readlistlen = addsocket(_readlist, server, _readlistlen)
-       _server[ port ] = handler
+       _server[ addr..":"..port ] = handler
        _socketlist[ server ] = handler
-       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" )
+       out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
        return handler
 end
 
-getserver = function ( port )
-       return _server[ port ];
+getserver = function ( addr, port )
+       return _server[ addr..":"..port ];
 end
 
-removeserver = function( port )
-       local handler = _server[ port ]
+removeserver = function( addr, port )
+       local handler = _server[ addr..":"..port ]
        if not handler then
-               return nil, "no server found on port '" .. tostring( port ) .. "'"
+               return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
        end
        handler:close( )
-       _server[ port ] = nil
+       _server[ addr..":"..port ] = nil
        return true
 end
 
@@ -859,16 +842,19 @@ end
 local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
        local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
        _socketlist[ socket ] = handler
-       _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
-       if listeners.onconnect then
-               -- When socket is writeable, call onconnect
-               local _sendbuffer = handler.sendbuffer;
-               handler.sendbuffer = function ()
-                       handler.sendbuffer = _sendbuffer;
-                       listeners.onconnect(handler);
-                       -- If there was data with the incoming packet, handle it now.
-                       if #handler:bufferqueue() > 0 then
-                               return _sendbuffer();
+       if not sslctx then
+               _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+               if listeners.onconnect then
+                       -- When socket is writeable, call onconnect
+                       local _sendbuffer = handler.sendbuffer;
+                       handler.sendbuffer = function ()
+                               _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
+                               handler.sendbuffer = _sendbuffer;
+                               listeners.onconnect(handler);
+                               -- If there was data with the incoming packet, handle it now.
+                               if #handler:bufferqueue() > 0 then
+                                       return _sendbuffer();
+                               end
                        end
                end
        end