net.server_select: remove unused one-letter loop variables [luacheck]
[prosody.git] / net / server_select.lua
index 891151ac35fdd9b630cbb5e5150f653e16f17d22..85730e73c4d2d6c216d9ec062ab97926b6cf56a2 100644 (file)
@@ -1,7 +1,7 @@
--- 
+--
 -- server.lua by blastbeat of the luadch project
 -- Re-used here under the MIT/X Consortium License
--- 
+--
 -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
 --
 
@@ -38,7 +38,6 @@ local coroutine = use "coroutine"
 
 --// lua lib methods //--
 
-local os_difftime = os.difftime
 local math_min = math.min
 local math_huge = math.huge
 local table_concat = table.concat
@@ -48,13 +47,14 @@ local coroutine_yield = coroutine.yield
 
 --// extern libs //--
 
-local luasec = use "ssl"
+local has_luasec, luasec = pcall ( require , "ssl" )
 local luasocket = use "socket" or require "socket"
 local luasocket_gettime = luasocket.gettime
+local getaddrinfo = luasocket.dns.getaddrinfo
 
 --// extern lib methods //--
 
-local ssl_wrap = ( luasec and luasec.wrap )
+local ssl_wrap = ( has_luasec and luasec.wrap )
 local socket_bind = luasocket.bind
 local socket_sleep = luasocket.sleep
 local socket_select = luasocket.select
@@ -149,7 +149,7 @@ _accepretry = 10 -- seconds to wait until the next attempt of a full server to a
 _maxsendlen = 51000 * 1024 -- max len of send buffer
 _maxreadlen = 25000 * 1024 -- max len of read buffer
 
-_checkinterval = 1200000 -- interval in secs to check idle clients
+_checkinterval = 30 -- interval in secs to check idle clients
 _sendtimeout = 60000 -- allowed send idle time in secs
 _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
 
@@ -213,6 +213,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
                                socket = nil;
                        end
                        handler.paused = true;
+                       out_put("server.lua: server [", ip, "]:", serverport, " paused")
                end
        end
        handler.resume = function( )
@@ -225,6 +226,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
                        _socketlist[ socket ] = handler
                        _fullservers[ handler ] = nil
                        handler.paused = false;
+                       out_put("server.lua: server [", ip, "]:", serverport, " resumed")
                end
        end
        handler.ip = function( )
@@ -293,6 +295,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local status = listeners.onstatus
        local disconnect = listeners.ondisconnect
        local drain = listeners.ondrain
+       local onreadtimeout = listeners.onreadtimeout;
        local detach = listeners.ondetach
 
        local bufferqueue = { } -- buffer array
@@ -322,6 +325,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.disconnect = function( )
                return disconnect
        end
+       handler.onreadtimeout = onreadtimeout;
+
        handler.setlistener = function( self, listeners )
                if detach then
                        detach(self) -- Notify listener that it is no longer responsible for this connection
@@ -330,6 +335,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                disconnect = listeners.ondisconnect
                status = listeners.onstatus
                drain = listeners.ondrain
+               handler.onreadtimeout = listeners.onreadtimeout
                detach = listeners.ondetach
        end
        handler.getstats = function( )
@@ -402,6 +408,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                out_put "server.lua: closed client handler and removed socket from list"
                return true
        end
+       handler.server = function ( )
+               return server
+       end
        handler.ip = function( )
                return ip
        end
@@ -573,6 +582,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                _ = status and status( handler, "ssl-handshake-complete" )
                                                if self.autostart_ssl and listeners.onconnect then
                                                        listeners.onconnect(self);
+                                                       if bufferqueuelen ~= 0 then
+                                                               _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
+                                                       end
                                                end
                                                _readlistlen = addsocket(_readlist, client, _readlistlen)
                                                return true
@@ -590,13 +602,14 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                                coroutine_yield( ) -- handshake not finished
                                        end
                                end
-                               out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
-                               _ = handler and handler:force_close("ssl handshake failed")
+                               err = "ssl handshake error: " .. ( err or "handshake too long" );
+                               out_put( "server.lua: ", err );
+                               _ = handler and handler:force_close(err)
                                return false, err -- handshake failed
                        end
                )
        end
-       if luasec then
+       if has_luasec then
                handler.starttls = function( self, _sslctx)
                        if _sslctx then
                                handler:set_sslctx(_sslctx);
@@ -622,7 +635,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        shutdown = id
                        _socketlist[ socket ] = handler
                        _readlistlen = addsocket(_readlist, socket, _readlistlen)
-                       
+
                        -- remove traces of the old socket
                        _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
                        _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
@@ -649,7 +662,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
 
-       if sslctx and luasec then
+       if sslctx and has_luasec then
                out_put "server.lua: auto-starting ssl negotiation..."
                handler.autostart_ssl = true;
                local ok, err = handler:starttls(sslctx);
@@ -710,7 +723,7 @@ local function link(sender, receiver, buffersize)
                        sender_locked = nil;
                end
        end
-       
+
        local _readbuffer = sender.readbuffer;
        function sender.readbuffer()
                _readbuffer();
@@ -725,22 +738,23 @@ end
 ----------------------------------// PUBLIC //--
 
 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+       addr = addr or "*"
        local err
        if type( listeners ) ~= "table" then
                err = "invalid listener table"
-       end
-       if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+       elseif type ( addr ) ~= "string" then
+               err = "invalid address"
+       elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
        elseif _server[ addr..":"..port ] then
                err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
-       elseif sslctx and not luasec then
+       elseif sslctx and not has_luasec then
                err = "luasec not found"
        end
        if err then
                out_error( "server.lua, [", addr, "]:", port, ": ", err )
                return nil, err
        end
-       addr = addr or "*"
        local server, err = socket_bind( addr, port, _tcpbacklog )
        if err then
                out_error( "server.lua, [", addr, "]:", port, ": ", err )
@@ -850,7 +864,7 @@ loop = function(once) -- this is the main loop of the program
        local next_timer_time = math_huge;
        repeat
                local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
-               for i, socket in ipairs( write ) do -- send data waiting in writequeues
+               for _, socket in ipairs( write ) do -- send data waiting in writequeues
                        local handler = _socketlist[ socket ]
                        if handler then
                                handler.sendbuffer( )
@@ -859,7 +873,7 @@ loop = function(once) -- this is the main loop of the program
                                out_put "server.lua: found no handler and closed socket (writelist)"    -- this should not happen
                        end
                end
-               for i, socket in ipairs( read ) do -- receive data
+               for _, socket in ipairs( read ) do -- receive data
                        local handler = _socketlist[ socket ]
                        if handler then
                                handler.readbuffer( )
@@ -876,21 +890,22 @@ loop = function(once) -- this is the main loop of the program
                _currenttime = luasocket_gettime( )
 
                -- Check for socket timeouts
-               local difftime = os_difftime( _currenttime - _starttime )
-               if difftime > _checkinterval then
+               if _currenttime - _starttime > _checkinterval then
                        _starttime = _currenttime
                        for handler, timestamp in pairs( _writetimes ) do
-                               if os_difftime( _currenttime - timestamp ) > _sendtimeout then
-                                       --_writetimes[ handler ] = nil
+                               if _currenttime - timestamp > _sendtimeout then
                                        handler.disconnect( )( handler, "send timeout" )
                                        handler:force_close()    -- forced disconnect
                                end
                        end
                        for handler, timestamp in pairs( _readtimes ) do
-                               if os_difftime( _currenttime - timestamp ) > _readtimeout then
-                                       --_readtimes[ handler ] = nil
-                                       handler.disconnect( )( handler, "read timeout" )
-                                       handler:close( )        -- forced disconnect?
+                               if _currenttime - timestamp > _readtimeout then
+                                       if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
+                                               handler.disconnect( )( handler, "read timeout" )
+                                               handler:close( )        -- forced disconnect?
+                                       else
+                                               _readtimes[ handler ] = _currenttime -- reset timer
+                                       end
                                end
                        end
                end
@@ -918,6 +933,7 @@ loop = function(once) -- this is the main loop of the program
                socket_sleep( _sleeptime )
        until quitting;
        if once and quitting == "once" then quitting = nil; return; end
+       closeall();
        return "quitting"
 end
 
@@ -950,17 +966,46 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
        return handler, socket
 end
 
-local addclient = function( address, port, listeners, pattern, sslctx )
-       local client, err = luasocket.tcp( )
+local addclient = function( address, port, listeners, pattern, sslctx, typ )
+       local err
+       if type( listeners ) ~= "table" then
+               err = "invalid listener table"
+       elseif type ( address ) ~= "string" then
+               err = "invalid address"
+       elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+               err = "invalid port"
+       elseif sslctx and not has_luasec then
+               err = "luasec not found"
+       end
+       if not typ then
+               local addrinfo, err = getaddrinfo(address)
+               if not addrinfo then return nil, err end
+               if addrinfo[1] and addrinfo[1].family == "inet6" then
+                       typ = "tcp6"
+               else
+                       typ = "tcp"
+               end
+       end
+       local create = luasocket[typ]
+       if type( create ) ~= "function"  then
+               err = "invalid socket type"
+       end
+
+       if err then
+               out_error( "server.lua, addclient: ", err )
+               return nil, err
+       end
+
+       local client, err = create( )
        if err then
                return nil, err
        end
        client:settimeout( 0 )
-       _, err = client:connect( address, port )
-       if err then -- try again
-               local handler = wrapclient( client, address, port, listeners )
+       local ok, err = client:connect( address, port )
+       if ok or err == "timeout" then
+               return wrapclient( client, address, port, listeners, pattern, sslctx )
        else
-               wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
+               return nil, err
        end
 end
 
@@ -990,7 +1035,7 @@ return {
 
        addclient = addclient,
        wrapclient = wrapclient,
-       
+
        loop = loop,
        link = link,
        step = step,