net.server_select: When an SSL handshake is connected, if there is pending data to...
[prosody.git] / net / server_select.lua
index 7eb330a8a86a22e1da35cd91710971288ff7bd75..33c18a606140b316ecc55045104bac1750e683a7 100644 (file)
@@ -1,7 +1,7 @@
--- 
+--
 -- server.lua by blastbeat of the luadch project
 -- Re-used here under the MIT/X Consortium License
--- 
+--
 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
 --
 
@@ -145,12 +145,12 @@ _tcpbacklog = 128 -- some kind of hint to the OS
 _maxsendlen = 51000 * 1024 -- max len of send buffer
 _maxreadlen = 25000 * 1024 -- max len of read buffer
 
-_checkinterval = 1200000 -- interval in secs to check idle clients
+_checkinterval = 30 -- interval in secs to check idle clients
 _sendtimeout = 60000 -- allowed send idle time in secs
 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
 
 local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
-_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
+_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
 _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
 
 _maxsslhandshake = 30 -- max handshake round-trips
@@ -263,7 +263,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        if socket:getfd() >= _maxfd then
                out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
                socket:close( ) -- Should we send some kind of error here?
-               server.pause( )
+               if server then
+                       server.pause( )
+               end
                return nil, nil, "fd-too-large"
        end
        socket:settimeout( 0 )
@@ -282,6 +284,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local status = listeners.onstatus
        local disconnect = listeners.ondisconnect
        local drain = listeners.ondrain
+       local onreadtimeout = listeners.onreadtimeout;
+       local detach = listeners.ondetach
 
        local bufferqueue = { } -- buffer array
        local bufferqueuelen = 0        -- end of buffer array
@@ -310,11 +314,18 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.disconnect = function( )
                return disconnect
        end
+       handler.onreadtimeout = onreadtimeout;
+
        handler.setlistener = function( self, listeners )
+               if detach then
+                       detach(self) -- Notify listener that it is no longer responsible for this connection
+               end
                dispatch = listeners.onincoming
                disconnect = listeners.ondisconnect
                status = listeners.onstatus
                drain = listeners.ondrain
+               handler.onreadtimeout = listeners.onreadtimeout
+               detach = listeners.ondetach
        end
        handler.getstats = function( )
                return readtraffic, sendtraffic
@@ -395,6 +406,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.clientport = function( )
                return clientport
        end
+       handler.port = handler.clientport -- COMPAT server_event
        local write = function( self, data )
                bufferlen = bufferlen + #data
                if bufferlen > maxsendlen then
@@ -556,6 +568,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                _ = status and status( handler, "ssl-handshake-complete" )
                                                if self.autostart_ssl and listeners.onconnect then
                                                        listeners.onconnect(self);
+                                                       if bufferqueuelen ~= 0 then
+                                                               _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
+                                                       end
                                                end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
@@ -605,7 +620,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        shutdown = id
                        _socketlist[ socket ] = handler
                        _readlistlen = addsocket(_readlist, socket, _readlistlen)
-                       
+
                        -- remove traces of the old socket
                        _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
                        _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
@@ -693,7 +708,7 @@ local function link(sender, receiver, buffersize)
                        sender_locked = nil;
                end
        end
-       
+
        local _readbuffer = sender.readbuffer;
        function sender.readbuffer()
                _readbuffer();
@@ -702,6 +717,7 @@ local function link(sender, receiver, buffersize)
                        sender:lock_read(true);
                end
        end
+       sender:set_mode("*a");
 end
 
 ----------------------------------// PUBLIC //--
@@ -861,16 +877,16 @@ loop = function(once) -- this is the main loop of the program
                        _starttime = _currenttime
                        for handler, timestamp in pairs( _writetimes ) do
                                if os_difftime( _currenttime - timestamp ) > _sendtimeout then
-                                       --_writetimes[ handler ] = nil
                                        handler.disconnect( )( handler, "send timeout" )
                                        handler:force_close()    -- forced disconnect
                                end
                        end
                        for handler, timestamp in pairs( _readtimes ) do
                                if os_difftime( _currenttime - timestamp ) > _readtimeout then
-                                       --_readtimes[ handler ] = nil
-                                       handler.disconnect( )( handler, "read timeout" )
-                                       handler:close( )        -- forced disconnect?
+                                       if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
+                                               handler.disconnect( )( handler, "read timeout" )
+                                               handler:close( )        -- forced disconnect?
+                                       end
                                end
                        end
                end
@@ -914,13 +930,9 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
                        -- When socket is writeable, call onconnect
                        local _sendbuffer = handler.sendbuffer;
                        handler.sendbuffer = function ()
-                               _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
                                handler.sendbuffer = _sendbuffer;
                                listeners.onconnect(handler);
-                               -- If there was data with the incoming packet, handle it now.
-                               if #handler:bufferqueue() > 0 then
-                                       return _sendbuffer();
-                               end
+                               return _sendbuffer(); -- Send any queued outgoing data
                        end
                end
        end
@@ -935,9 +947,9 @@ local addclient = function( address, port, listeners, pattern, sslctx )
        client:settimeout( 0 )
        _, err = client:connect( address, port )
        if err then -- try again
-               local handler = wrapclient( client, address, port, listeners )
+               return wrapclient( client, address, port, listeners, pattern, sslctx )
        else
-               wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
+               return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
        end
 end
 
@@ -967,7 +979,7 @@ return {
 
        addclient = addclient,
        wrapclient = wrapclient,
-       
+
        loop = loop,
        link = link,
        step = step,