net.server_select: Fix global access
[prosody.git] / net / server_select.lua
index c0f8742e1adf2eb425f15bd7ab4dc35107b27fcf..122d774e19a9e756f1010cf5a0164785e2e6e61d 100644 (file)
@@ -19,7 +19,6 @@ end
 local log, table_concat = require ("util.logger").init("socket"), table.concat;
 local out_put = function (...) return log("debug", table_concat{...}); end
 local out_error = function (...) return log("warn", table_concat{...}); end
-local mem_free = collectgarbage
 
 ----------------------------------// DECLARATION //--
 
@@ -34,7 +33,6 @@ local pairs = use "pairs"
 local ipairs = use "ipairs"
 local tonumber = use "tonumber"
 local tostring = use "tostring"
-local collectgarbage = use "collectgarbage"
 
 --// lua libs //--
 
@@ -49,7 +47,6 @@ local os_difftime = os.difftime
 local math_min = math.min
 local math_huge = math.huge
 local table_concat = table.concat
-local table_remove = table.remove
 local string_len = string.len
 local string_sub = string.sub
 local coroutine_wrap = coroutine.wrap
@@ -67,7 +64,6 @@ local ssl_wrap = ( luasec and luasec.wrap )
 local socket_bind = luasocket.bind
 local socket_sleep = luasocket.sleep
 local socket_select = luasocket.select
-local ssl_newcontext = ( luasec and luasec.newcontext )
 
 --// functions //--
 
@@ -78,13 +74,13 @@ local idfalse
 local closeall
 local addsocket
 local addserver
+local addtimer
 local getserver
 local wrapserver
 local getsettings
 local closesocket
 local removesocket
 local removeserver
-local changetimeout
 local wrapconnection
 local changesettings
 
@@ -190,6 +186,9 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
        end
        handler.remove = function( )
                connections = connections - 1
+               if handler then
+                       handler.resume( )
+               end
        end
        handler.close = function()
                socket:close( )
@@ -202,6 +201,25 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                --mem_free( )
                out_put "server.lua: closed server handler and removed sockets from list"
        end
+       handler.pause = function()
+               if not handler.paused then
+                       socket:close( )
+                       _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+                       _readlistlen = removesocket( _readlist, socket, _readlistlen )
+                       _socketlist[ socket ] = nil
+                       socket = nil;
+                       handler.paused = true;
+               end
+       end
+       handler.resume = function()
+               if handler.paused then
+                       socket = socket_bind( ip, serverport );
+                       socket:settimeout( 0 )
+                       _readlistlen = addsocket(_readlist, socket, _readlistlen)
+                       _socketlist[ socket ] = handler
+                       handler.paused = false;
+               end
+       end
        handler.ip = function( )
                return ip
        end
@@ -213,6 +231,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
        end
        handler.readbuffer = function( )
                if connections > maxconnections then
+                       handler.pause( )
                        out_put( "server.lua: refused new client connection: server full" )
                        return false
                end
@@ -314,17 +333,17 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                end
                return false, "setoption not implemented";
        end
-       handler.force_close = function ( self )
+       handler.force_close = function ( self, err )
                if bufferqueuelen ~= 0 then
-                       out_put("discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
+                       out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
                        for i = bufferqueuelen, 1, -1 do
                                bufferqueue[i] = nil;
                        end
                        bufferqueuelen = 0;
                end
-               return self:close();
+               return self:close(err);
        end
-       handler.close = function( self )
+       handler.close = function( self, err )
                if not handler then return true; end
                _readlistlen = removesocket( _readlist, socket, _readlistlen )
                _readtimes[ handler ] = nil
@@ -353,7 +372,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        local _handler = handler;
                        handler = nil
                        if disconnect then
-                               disconnect(_handler, false);
+                               disconnect(_handler, err or false);
+                               disconnect = nil
                        end
                end
                if server then
@@ -456,8 +476,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        local buffer = buffer or part or ""
                        local len = string_len( buffer )
                        if len > maxreadlen then
-                               disconnect( handler, "receive buffer exceeded" )
-                               handler:close( true )
+                               handler:close( "receive buffer exceeded" )
                                return false
                        end
                        local count = len * STAT_UNIT
@@ -469,14 +488,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                else    -- connections was closed or fatal error
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
                        fatalerror = true
-                       disconnect( handler, err )
-                       _ = handler and handler:close( )
+                       _ = handler and handler:force_close( err )
                        return false
                end
        end
        local _sendbuffer = function( ) -- this function sends data
                local succ, err, byte, buffer, count;
-               local count;
                if socket then
                        buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
                        succ, err, byte = send( socket, buffer, 1, bufferlen )
@@ -509,8 +526,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                else    -- connection was closed during sending or fatal error
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
                        fatalerror = true
-                       disconnect( handler, err )
-                       _ = handler and handler:force_close( )
+                       _ = handler and handler:force_close( err )
                        return false
                end
        end
@@ -552,9 +568,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                                        end
                                end
                                out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
-                               disconnect( handler, "ssl handshake failed" )
-                               _ = handler and handler:force_close()
-                               return false, err -- handshake failed
+                               _ = handler and handler:force_close("ssl handshake failed")
+               return false, err -- handshake failed
                        end
                )
        end
@@ -758,16 +773,16 @@ changesettings = function( new )
        if type( new ) ~= "table" then
                return nil, "invalid settings table"
        end
-       _selecttimeout = tonumber( new.timeout ) or _selecttimeout
-       _sleeptime = tonumber( new.sleeptime ) or _sleeptime
-       _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen
-       _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen
-       _checkinterval = tonumber( new.checkinterval ) or _checkinterval
-       _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout
-       _readtimeout = tonumber( new.readtimeout ) or _readtimeout
-       _cleanqueue = new.cleanqueue
-       _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
-       _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake
+       _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
+       _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
+       _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
+       _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
+       _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
+       _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
+       _readtimeout = tonumber( new.read_timeout ) or _readtimeout
+       _cleanqueue = new.select_clean_queue
+       _maxclientsperserver = new.max_connections or _maxclientsperserver
+       _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
        return true
 end