util.roster: Initial skeleton commit
[prosody.git] / net / server_select.lua
index e3619d3028e37f77f3b3daab7fcfe2a370140530..298e560aa00cc6a5ab2eebde12b7b53634641a08 100644 (file)
@@ -2,7 +2,7 @@
 -- server.lua by blastbeat of the luadch project
 -- Re-used here under the MIT/X Consortium License
 -- 
--- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain
+-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
 --
 
 -- // wrapping luadch stuff // --
@@ -252,6 +252,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        local dispatch = listeners.onincoming
        local status = listeners.onstatus
        local disconnect = listeners.ondisconnect
+       local drain = listeners.ondrain
 
        local bufferqueue = { } -- buffer array
        local bufferqueuelen = 0        -- end of buffer array
@@ -284,6 +285,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                dispatch = listeners.onincoming
                disconnect = listeners.ondisconnect
                status = listeners.onstatus
+               drain = listeners.ondrain
        end
        handler.getstats = function( )
                return readtraffic, sendtraffic
@@ -341,9 +343,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        _closelist[ handler ] = nil
                        handler = nil
                end
-       if server then
-               server.remove( )
-       end
+               if server then
+                       server.remove( )
+               end
                out_put "server.lua: closed client handler and removed socket from list"
                return true
        end
@@ -379,7 +381,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        handler.socket = function( self )
                return socket
        end
-       handler.pattern = function( self, new )
+       handler.set_mode = function( self, new )
                pattern = new or pattern
                return pattern
        end
@@ -392,6 +394,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                maxreadlen = readlen or maxreadlen
                return bufferlen, maxreadlen, maxsendlen
        end
+       --TODO: Deprecate
        handler.lock_read = function (self, switch)
                if switch == true then
                        local tmp = _readlistlen
@@ -409,6 +412,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                end
                return noread
        end
+       handler.pause = function (self)
+               return self:lock_read(true);
+       end
+       handler.resume = function (self)
+               return self:lock_read(false);
+       end
        handler.lock = function( self, switch )
                handler.lock_read (switch)
                if switch == true then
@@ -430,12 +439,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
        end
        local _readbuffer = function( ) -- this function reads data
                local buffer, err, part = receive( socket, pattern )    -- receive buffer with "pattern"
-               if not err or (err == "wantread" or err == "timeout") or string_len(part) > 0 then -- received something
+               if not err or (err == "wantread" or err == "timeout") then -- received something
                        local buffer = buffer or part or ""
                        local len = string_len( buffer )
                        if len > maxreadlen then
                                disconnect( handler, "receive buffer exceeded" )
-                               handler.close( true )
+                               handler:close( true )
                                return false
                        end
                        local count = len * STAT_UNIT
@@ -448,7 +457,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
                        fatalerror = true
                        disconnect( handler, err )
-               _ = handler and handler.close( )
+               _ = handler and handler:close( )
                        return false
                end
        end
@@ -470,9 +479,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        bufferqueuelen = 0
                        bufferlen = 0
                        _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
-                       _ = needtls and handler:starttls(nil, true)
                        _writetimes[ handler ] = nil
-               _ = toclose and handler.close( )
+                       if drain then
+                               drain(handler)
+                       end
+                       _ = needtls and handler:starttls(nil, true)
+                       _ = toclose and handler:close( )
                        return true
                elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
                        buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
@@ -485,7 +497,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                        out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
                        fatalerror = true
                        disconnect( handler, err )
-                       _ = handler and handler.close( )
+                       _ = handler and handler:close( )
                        return false
                end
        end
@@ -611,7 +623,16 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
 
        _socketlist[ socket ] = handler
        _readlistlen = addsocket(_readlist, socket, _readlistlen)
-
+       if listeners.onconnect then
+               _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+               handler.sendbuffer = function ()
+                       listeners.onconnect(handler);
+                       handler.sendbuffer = _sendbuffer;
+                       if bufferqueuelen > 0 then
+                               return _sendbuffer();
+                       end
+               end
+       end
        return handler, socket
 end
 
@@ -654,6 +675,28 @@ closesocket = function( socket )
        --mem_free( )
 end
 
+local function link(sender, receiver, buffersize)
+       sender:set_mode(buffersize);
+       local sender_locked;
+       local _sendbuffer = receiver.sendbuffer;
+       function receiver.sendbuffer()
+               _sendbuffer();
+               if sender_locked and receiver.bufferlen() < buffersize then
+                       sender:lock_read(false); -- Unlock now
+                       sender_locked = nil;
+               end
+       end
+       
+       local _readbuffer = sender.readbuffer;
+       function sender.readbuffer()
+               _readbuffer();
+               if not sender_locked and receiver.bufferlen() >= buffersize then
+                       sender_locked = true;
+                       sender:lock_read(true);
+               end
+       end
+end
+
 ----------------------------------// PUBLIC //--
 
 addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
@@ -661,7 +704,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
        if type( listeners ) ~= "table" then
                err = "invalid listener table"
        end
-       if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
+       if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
                err = "invalid port"
        elseif _server[ port ] then
                err = "listeners on port '" .. port .. "' already exist"
@@ -877,6 +920,7 @@ return {
        wrapclient = wrapclient,
        
        loop = loop,
+       link = link,
        stats = stats,
        closeall = closeall,
        addtimer = addtimer,