net.server_event: Check the buffer *length*, not the buffer itself (Fixes 100% cpu...
[prosody.git] / net / server_event.lua
index 0331e793a7fa447ef339e073ffc2e9645cf79d94..882d10ed95d0e45d3436ed663c73c84258bd4ddb 100644 (file)
@@ -6,7 +6,6 @@
                        notes:
                        -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE
                        -- you cant even register a new EV_READ/EV_WRITE callback inside another one
-                       -- never call eventcallback:close( ) from inside eventcallback
                        -- to do some of the above, use timeout events or something what will called from outside
                        -- dont let garbagecollect eventcallbacks, as long they are running
                        -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing
@@ -24,6 +23,7 @@ local cfg = {
        HANDSHAKE_TIMEOUT     = 60,  -- timeout in seconds per handshake attempt
        MAX_READ_LENGTH       = 1024 * 1024 * 1024 * 1024,  -- max bytes allowed to read from sockets
        MAX_SEND_LENGTH       = 1024 * 1024 * 1024 * 1024,  -- max bytes size of write buffer (for writing on sockets)
+       ACCEPT_QUEUE          = 128,  -- might influence the length of the pending sockets queue
        ACCEPT_DELAY          = 10,  -- seconds to wait until the next attempt of a full server to accept
        READ_TIMEOUT          = 60 * 60 * 6,  -- timeout in seconds for read data from socket
        WRITE_TIMEOUT         = 180,  -- timeout in seconds for write data on socket
@@ -33,8 +33,6 @@ local cfg = {
 }
 
 local function use(x) return rawget(_G, x); end
-local print = use "print"
-local pcall = use "pcall"
 local ipairs = use "ipairs"
 local string = use "string"
 local select = use "select"
@@ -43,6 +41,9 @@ local tostring = use "tostring"
 local coroutine = use "coroutine"
 local setmetatable = use "setmetatable"
 
+local t_insert = table.insert
+local t_concat = table.concat
+
 local ssl = use "ssl"
 local socket = use "socket" or require "socket"
 
@@ -117,21 +118,14 @@ do
        
        local addevent = base.addevent
        local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield
-       local string_len = string.len
        
        -- Private methods
        function interface_mt:_position(new_position)
                        self.position = new_position or self.position
                        return self.position;
        end
-       function interface_mt:_close() -- regs event to start self:_destroy()
-                       local callback = function( )
-                               self:_destroy();
-                               self.eventclose = nil
-                               return -1
-                       end
-                       self.eventclose = addevent( base, nil, EV_TIMEOUT, callback, 0 )
-                       return true
+       function interface_mt:_close()
+               return self:_destroy();
        end
        
        function interface_mt:_start_connection(plainssl) -- should be called from addclient
@@ -143,9 +137,9 @@ do
                                        debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
                                else
                                        if plainssl and ssl then  -- start ssl session
-                                               self:starttls()
+                                               self:starttls(self._sslctx, true)
                                        else  -- normal connection
-                                               self:_start_session( self.listener.onconnect )
+                                               self:_start_session(true)
                                        end
                                        debug( "new connection established. id:", self.id )
                                end
@@ -155,13 +149,15 @@ do
                        self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT )
                        return true
        end
-       function interface_mt:_start_session(onconnect) -- new session, for example after startssl
+       function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl
                if self.type == "client" then
                        local callback = function( )
                                self:_lock( false,  false, false )
                                --vdebug( "start listening on client socket with id:", self.id )
                                self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT );  -- register callback
-                               self:onconnect()
+                               if call_onconnect then
+                                       self:onconnect()
+                               end
                                self.eventsession = nil
                                return -1
                        end
@@ -173,7 +169,7 @@ do
                end
                return true
        end
-       function interface_mt:_start_ssl(arg) -- old socket will be destroyed, therefore we have to close read/write events first
+       function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first
                        --vdebug( "starting ssl session with client id:", self.id )
                        local _
                        _ = self.eventread and self.eventread:close( )  -- close events; this must be called outside of the event callbacks!
@@ -184,7 +180,7 @@ do
                        if err then
                                self.fatalerror = err
                                self.conn = nil  -- cannot be used anymore
-                               if "onconnect" == arg then
+                               if call_onconnect then
                                        self.ondisconnect = nil  -- dont call this when client isnt really connected
                                end
                                self:_close()
@@ -210,29 +206,25 @@ do
                                                                self:_lock( false, false, false )  -- unlock the interface; sending, closing etc allowed
                                                                self.send = self.conn.send  -- caching table lookups with new client object
                                                                self.receive = self.conn.receive
-                                                               local onsomething
-                                                               if "onconnect" == arg then  -- trigger listener
-                                                                       onsomething = self.onconnect
-                                                               else
-                                                                       onsomething = self.onsslconnection
+                                                               if not call_onconnect then  -- trigger listener
+                                                                       self:onstatus("ssl-handshake-complete");
                                                                end
-                                                               self:_start_session( onsomething )
+                                                               self:_start_session( call_onconnect )
                                                                debug( "ssl handshake done" )
-                                                               self:onstatus("ssl-handshake-complete");
                                                                self.eventhandshake = nil
                                                                return -1
                                                        end
-                                                       debug( "error during ssl handshake:", err )
                                                        if err == "wantwrite" then
                                                                event = EV_WRITE
                                                        elseif err == "wantread" then
                                                                event = EV_READ
                                                        else
+                                                               debug( "ssl handshake error:", err )
                                                                self.fatalerror = err
                                                        end
                                                end
                                                if self.fatalerror then
-                                                       if "onconnect" == arg then
+                                                       if call_onconnect then
                                                                self.ondisconnect = nil  -- dont call this when client isnt really connected
                                                        end
                                                        self:_close()
@@ -250,10 +242,10 @@ do
                        return true
        end
        function interface_mt:_destroy()  -- close this interface + events and call last listener
-                       debug( "closing client with id:", self.id )
+                       debug( "closing client with id:", self.id, self.fatalerror )
                        self:_lock( true, true, true )  -- first of all, lock the interface to avoid further actions
                        local _
-                       _ = self.eventread and self.eventread:close( )  -- close events; this must be called outside of the event callbacks!
+                       _ = self.eventread and self.eventread:close( )
                        if self.type == "client" then
                                _ = self.eventwrite and self.eventwrite:close( )
                                _ = self.eventhandshake and self.eventhandshake:close( )
@@ -263,7 +255,7 @@ do
                                _ = self.eventwritetimeout and self.eventwritetimeout:close( )
                                _ = self.eventreadtimeout and self.eventreadtimeout:close( )
                                _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror)  -- call ondisconnect listener (wont be the case if handshake failed on connect)
-                               _ = self.conn and self.conn:close( ) -- close connection, must also be called outside of any socket registered events!
+                               _ = self.conn and self.conn:close( ) -- close connection
                                _ = self._server and self._server:counter(-1);
                                self.eventread, self.eventwrite = nil, nil
                                self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil
@@ -296,7 +288,11 @@ do
        end
 
        function interface_mt:resume()
-               return self:_lock(self.nointerface, false, self.nowriting);
+               self:_lock(self.nointerface, false, self.nowriting);
+               if self.readcallback and not self.eventread then
+                       self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT );  -- register callback
+                       return true;
+               end
        end
 
        function interface_mt:counter(c)
@@ -311,14 +307,14 @@ do
                if self.nowriting then return nil, "locked" end
                --vdebug( "try to send data to client, id/data:", self.id, data )
                data = tostring( data )
-               local len = string_len( data )
+               local len = #data
                local total = len + self.writebufferlen
                if total > cfg.MAX_SEND_LENGTH then  -- check buffer length
                        local err = "send buffer exceeded"
                        debug( "error:", err )  -- to much, check your app
                        return nil, err
                end
-               self.writebuffer = self.writebuffer .. data -- new buffer
+               t_insert(self.writebuffer, data) -- new buffer
                self.writebufferlen = total
                if not self.eventwrite then  -- register new write event
                        --vdebug( "register new write event" )
@@ -326,42 +322,33 @@ do
                end
                return true
        end
-       function interface_mt:close(now)
+       function interface_mt:close()
                if self.nointerface then return nil, "locked"; end
                debug( "try to close client connection with id:", self.id )
                if self.type == "client" then
                        self.fatalerror = "client to close"
-                       if ( not self.eventwrite ) or now then  -- try to close immediately
-                               self:_lock( true, true, true )
-                               self:_close()
-                               return true
-                       else  -- wait for incomplete write request
+                       if self.eventwrite then -- wait for incomplete write request
                                self:_lock( true, true, false )
                                debug "closing delayed until writebuffer is empty"
                                return nil, "writebuffer not empty, waiting"
+                       else -- close now
+                               self:_lock( true, true, true )
+                               self:_close()
+                               return true
                        end
                else
-                       debug( "try to close server with id:", self.id, "args:", now )
+                       debug( "try to close server with id:", tostring(self.id))
                        self.fatalerror = "server to close"
                        self:_lock( true )
-                       local count = 0
-                       for _, item in ipairs( interfacelist( ) ) do
-                               if ( item.type ~= "server" ) and ( item._server == self ) then  -- client/server match
-                                       if item:close( now ) then  -- writebuffer was empty
-                                               count = count + 1
-                                       end
-                               end
-                       end
-                       local timeout = 0  -- dont wait for unfinished writebuffers of clients...
-                       if not now then
-                               timeout = cfg.WRITE_TIMEOUT  -- ...or wait for it
-                       end
-                       self:_close( timeout )  -- add new event to remove the server interface
-                       debug( "seconds remained until server is closed:", timeout )
-                       return count  -- returns finished clients with empty writebuffer
+                       self:_close( 0 )
+                       return true
                end
        end
        
+       function interface_mt:socket()
+               return self.conn
+       end
+       
        function interface_mt:server()
                return self._server or self;
        end
@@ -381,6 +368,7 @@ do
        function interface_mt:ssl()
                return self._usingssl
        end
+       interface_mt.clientport = interface_mt.port -- COMPAT server_select
 
        function interface_mt:type()
                return self._type or "client"
@@ -414,7 +402,7 @@ do
                -- No-op, we always use the underlying connection's send
        end
        
-       function interface_mt:starttls(sslctx)
+       function interface_mt:starttls(sslctx, call_onconnect)
                debug( "try to start ssl at client id:", self.id )
                local err
                self._sslctx = sslctx;
@@ -428,7 +416,7 @@ do
                self._usingssl = true
                self.startsslcallback = function( )  -- we have to start the handshake outside of a read/write event
                        self.startsslcallback = nil
-                       self:_start_ssl();
+                       self:_start_ssl(call_onconnect);
                        self.eventstarthandshake = nil
                        return -1
                end
@@ -451,13 +439,15 @@ do
        end
        
        function interface_mt:setlistener(listener)
-               self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus
-                       = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus;
+               self:ondetach(); -- Notify listener that it is no longer responsible for this connection
+               self.onconnect, self.ondisconnect, self.onincoming,
+               self.ontimeout, self.onstatus, self.ondetach
+                       = listener.onconnect, listener.ondisconnect, listener.onincoming,
+                       listener.ontimeout, listener.onstatus, listener.ondetach;
        end
        
        -- Stub handlers
        function interface_mt:onconnect()
-               return self:onincoming(nil);
        end
        function interface_mt:onincoming()
        end
@@ -467,8 +457,9 @@ do
        end
        function interface_mt:ondrain()
        end
+       function interface_mt:ondetach()
+       end
        function interface_mt:onstatus()
-               debug("server.lua: Dummy onstatus()")
        end
 end
 
@@ -477,18 +468,15 @@ end
 local handleclient;
 do
        local string_sub = string.sub  -- caching table lookups
-       local string_len = string.len
        local addevent = base.addevent
-       local coroutine_wrap = coroutine.wrap
        local socket_gettime = socket.gettime
-       local coroutine_yield = coroutine.yield
-       function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
+       function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
                --vdebug("creating client interfacce...")
                local interface = {
                        type = "client";
                        conn = client;
                        currenttime = socket_gettime( );  -- safe the origin
-                       writebuffer = "";  -- writebuffer
+                       writebuffer = {};  -- writebuffer
                        writebufferlen = 0;  -- length of writebuffer
                        send = client.send;  -- caching table lookups
                        receive = client.receive;
@@ -496,6 +484,8 @@ do
                        ondisconnect = listener.ondisconnect;  -- will be called when client disconnects
                        onincoming = listener.onincoming;  -- will be called when client sends data
                        ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs
+                       ondrain = listener.ondrain; -- called when writebuffer is empty
+                       ondetach = listener.ondetach; -- called when disassociating this listener from this connection
                        onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS)
                        eventread = false, eventwrite = false, eventclose = false,
                        eventhandshake = false, eventstarthandshake = false;  -- event handler
@@ -542,10 +532,11 @@ do
                                                interface.eventwritetimeout = false
                                        end
                                end
-                               local succ, err, byte = interface.conn:send( interface.writebuffer, 1, interface.writebufferlen )
+                               interface.writebuffer = { t_concat(interface.writebuffer) }
+                               local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen )
                                --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte )
                                if succ then  -- writing succesful
-                                       interface.writebuffer = ""
+                                       interface.writebuffer[1] = nil
                                        interface.writebufferlen = 0
                                        interface:ondrain();
                                        if interface.fatalerror then
@@ -554,14 +545,17 @@ do
                                        elseif interface.startsslcallback then  -- start ssl connection if needed
                                                debug "starting ssl handshake after writing"
                                                interface.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, interface.startsslcallback, 0 )
+                                       elseif interface.writebufferlen ~= 0 then
+                                               -- data possibly written from ondrain
+                                               return EV_WRITE, cfg.WRITE_TIMEOUT
                                        elseif interface.eventreadtimeout then
-                                               return EV_WRITE, EV_TIMEOUT
+                                               return EV_WRITE, cfg.WRITE_TIMEOUT
                                        end
                                        interface.eventwrite = nil
                                        return -1
                                elseif byte and (err == "timeout" or err == "wantwrite") then  -- want write again
                                        --vdebug( "writebuffer is not empty:", err )
-                                       interface.writebuffer = string_sub( interface.writebuffer, byte + 1, interface.writebufferlen )  -- new buffer
+                                       interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen )  -- new buffer
                                        interface.writebufferlen = interface.writebufferlen - byte
                                        if "wantread" == err then  -- happens only with luasec
                                                local callback = function( )
@@ -611,16 +605,14 @@ do
                                end
                                local buffer, err, part = interface.conn:receive( interface._pattern )  -- receive buffer with "pattern"
                                --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) )
-                               buffer = buffer or part or ""
-                               local len = string_len( buffer )
-                               if len > cfg.MAX_READ_LENGTH then  -- check buffer length
+                               buffer = buffer or part
+                               if buffer and #buffer > cfg.MAX_READ_LENGTH then  -- check buffer length
                                        interface.fatalerror = "receive buffer exceeded"
                                        debug( "fatal error:", interface.fatalerror )
                                        interface:_close()
                                        interface.eventread = nil
                                        return -1
                                end
-                               interface.onincoming( interface, buffer, err )  -- send new data to listener
                                if err and ( err ~= "timeout" and err ~= "wantread" ) then
                                        if "wantwrite" == err then -- need to read on write event
                                                if not interface.eventwrite then  -- register new write event if needed
@@ -640,6 +632,12 @@ do
                                                interface.eventread = nil
                                                return -1
                                        end
+                               else
+                                       interface.onincoming( interface, buffer, err )  -- send new data to listener
+                               end
+                               if interface.noreading then
+                                       interface.eventread = nil;
+                                       return -1;
                                end
                                return EV_READ, cfg.READ_TIMEOUT
                        end
@@ -697,12 +695,12 @@ do
                                end
                                local client_ip, client_port = client:getpeername( )
                                interface._connections = interface._connections + 1  -- increase connection count
-                               local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, nil, sslctx )
+                               local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
                                --vdebug( "client id:", clientinterface, "startssl:", startssl )
                                if ssl and sslctx then
-                                       clientinterface:starttls(sslctx)
+                                       clientinterface:starttls(sslctx, true)
                                else
-                                       clientinterface:_start_session( clientinterface.onconnect )
+                                       clientinterface:_start_session( true )
                                end
                                debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>");
                                
@@ -724,7 +722,7 @@ local addserver = ( function( )
                --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
                local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE )  -- create server socket
                if not server then
-                       debug( "creating server socket failed because:", err )
+                       debug( "creating server socket on "..addr.." port "..port.." failed:", err )
                        return nil, err
                end
                local sslctx
@@ -747,9 +745,9 @@ end )( )
 
 local addclient, wrapclient
 do
-       function wrapclient( client, ip, port, listeners, pattern, sslctx, startssl )
+       function wrapclient( client, ip, port, listeners, pattern, sslctx )
                local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
-               interface:_start_session()
+               interface:_start_connection(sslctx)
                return interface, client
                --function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
        end
@@ -783,9 +781,6 @@ do
                local res, err = client:connect( addr, serverport )  -- connect
                if res or ( err == "timeout" ) then
                        local ip, port = client:getsockname( )
-                       local server = function( )
-                               return nil, "this is a dummy server interface"
-                       end
                        local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl )
                        interface:_start_connection( startssl )
                        debug( "new connection id:", interface.id )
@@ -826,14 +821,14 @@ local function setquitting(yes)
        end
 end
 
-function get_backend()
+local function get_backend()
        return base:method();
 end
 
 -- We need to hold onto the events to stop them
 -- being garbage-collected
 local signal_events = {}; -- [signal_num] -> event object
-function hook_signal(signal_num, handler)
+local function hook_signal(signal_num, handler)
        local function _handler(event)
                local ret = handler();
                if ret ~= false then -- Continue handling this signal?
@@ -846,7 +841,6 @@ function hook_signal(signal_num, handler)
 end
 
 local function link(sender, receiver, buffersize)
-       sender:set_mode(buffersize);
        local sender_locked;
        
        function receiver:ondrain()
@@ -863,6 +857,7 @@ local function link(sender, receiver, buffersize)
                        sender:pause();
                end
        end
+       sender:set_mode("*a");
 end
 
 return {