mod_presence: Don't depend on sessions array existing for a user when handling outgoi...
[prosody.git] / net / server_event.lua
index c35c6f53d09a3aa3ceba7b54413d1a99eacad883..450bd341517487b489cf946ba2e828158e82f13a 100644 (file)
@@ -138,7 +138,7 @@ do
                        local callback = function( event )\r
                                if EV_TIMEOUT == event then  -- timout during connection\r
                                        self.fatalerror = "connection timeout"\r
-                                       self.listener.ontimeout( self )  -- call timeout listener\r
+                                       self:ontimeout()  -- call timeout listener\r
                                        self:_close()\r
                                        debug( "new connection failed. id:", self.id, "error:", self.fatalerror )\r
                                else\r
@@ -212,9 +212,9 @@ do
                                                                self.receive = self.conn.receive\r
                                                                local onsomething\r
                                                                if "onconnect" == arg then  -- trigger listener\r
-                                                                       onsomething = self.listener.onconnect\r
+                                                                       onsomething = self.onconnect\r
                                                                else\r
-                                                                       onsomething = self.listener.onsslconnection\r
+                                                                       onsomething = self.onsslconnection\r
                                                                end\r
                                                                self:_start_session( onsomething )\r
                                                                debug( "ssl handshake done" )\r
@@ -263,7 +263,7 @@ do
                                _ = self.eventreadtimeout and self.eventreadtimeout:close( )\r
                                _ = self.ondisconnect and self:ondisconnect( self.fatalerror )  -- call ondisconnect listener (wont be the case if handshake failed on connect)\r
                                _ = self.conn and self.conn:close( ) -- close connection, must also be called outside of any socket registered events!\r
-                               self._server:counter(-1);\r
+                               _ = self._server and self._server:counter(-1);\r
                                self.eventread, self.eventwrite = nil, nil\r
                                self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil\r
                                self.readcallback, self.writecallback = nil, nil\r
@@ -283,14 +283,15 @@ do
 \r
        function interface_mt:counter(c)\r
                if c then\r
-                       self._connections = self._connections - c\r
+                       self._connections = self._connections + c\r
                end\r
                return self._connections\r
        end\r
        \r
        -- Public methods\r
        function interface_mt:write(data)\r
-               vdebug( "try to send data to client, id/data:", self.id, data )\r
+               if self.nowriting then return nil, "locked" end\r
+               --vdebug( "try to send data to client, id/data:", self.id, data )\r
                data = tostring( data )\r
                local len = string_len( data )\r
                local total = len + self.writebufferlen\r
@@ -308,6 +309,7 @@ do
                return true\r
        end\r
        function interface_mt:close(now)\r
+               if self.nointerface then return nil, "locked"; end\r
                debug( "try to close client connection with id:", self.id )\r
                if self.type == "client" then\r
                        self.fatalerror = "client to close"\r
@@ -355,7 +357,7 @@ do
        end\r
        \r
        function interface_mt:ssl()\r
-               return self.usingssl\r
+               return self._usingssl\r
        end\r
 \r
        function interface_mt:type()\r
@@ -370,22 +372,34 @@ do
                return self.addr\r
        end\r
        \r
-                       \r
+       function interface_mt:set_sslctx(sslctx)\r
+               self._sslctx = sslctx;\r
+               if sslctx then\r
+                       self.starttls = nil; -- use starttls() of interface_mt\r
+               else\r
+                       self.starttls = false; -- prevent starttls()\r
+               end\r
+       end\r
+       \r
+       function interface_mt:set_send(new_send)\r
+               -- No-op, we always use the underlying connection's send\r
+       end\r
        \r
        function interface_mt:starttls()\r
                debug( "try to start ssl at client id:", self.id )\r
                local err\r
                if not self._sslctx then  -- no ssl available\r
                        err = "no ssl context available"\r
-               elseif self.usingssl then  -- startssl was already called\r
+               elseif self._usingssl then  -- startssl was already called\r
                        err = "ssl already active"\r
                end\r
                if err then\r
                        debug( "error:", err )\r
                        return nil, err      \r
                end\r
-               self.usingssl = true\r
+               self._usingssl = true\r
                self.startsslcallback = function( )  -- we have to start the handshake outside of a read/write event\r
+                       self.startsslcallback = nil\r
                        self:_start_ssl();\r
                        self.eventstarthandshake = nil\r
                        return -1\r
@@ -400,7 +414,14 @@ do
                return true\r
        end\r
        \r
-       function interface_mt.onconnect()\r
+       -- Stub handlers\r
+       function interface_mt:onconnect()\r
+       end\r
+       function interface_mt:onincoming()\r
+       end\r
+       function interface_mt:ondisconnect()\r
+       end\r
+       function interface_mt:ontimeout()\r
        end\r
 end                    \r
 \r
@@ -427,6 +448,7 @@ do
                        onconnect = listener.onconnect;  -- will be called when client disconnects\r
                        ondisconnect = listener.ondisconnect;  -- will be called when client disconnects\r
                        onincoming = listener.onincoming;  -- will be called when client sends data\r
+                       ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs\r
                        eventread = false, eventwrite = false, eventclose = false,\r
                        eventhandshake = false, eventstarthandshake = false;  -- event handler\r
                        eventconnect = false, eventsession = false;  -- more event handler...\r
@@ -445,6 +467,9 @@ do
                        _sslctx = sslctx; -- parameters\r
                        _usingssl = false;  -- client is using ssl;\r
                }\r
+               if not sslctx then\r
+                       interface.starttls = false -- don't allow TLS\r
+               end\r
                interface.id = tostring(interface):match("%x+$");\r
                interface.writecallback = function( event )  -- called on write events\r
                        --vdebug( "new client write event, id/ip/port:", interface, ip, port )\r
@@ -460,7 +485,7 @@ do
                                interface.eventwrite = false\r
                                return -1\r
                        else  -- can write :)\r
-                               if interface.usingssl then  -- handle luasec\r
+                               if interface._usingssl then  -- handle luasec\r
                                        if interface.eventreadtimeout then  -- we have to read first\r
                                                local ret = interface.readcallback( )  -- call readcallback\r
                                                --vdebug( "tried to read in writecallback, result:", ret )\r
@@ -470,7 +495,7 @@ do
                                                interface.eventwritetimeout = false\r
                                        end\r
                                end\r
-                               local succ, err, byte = interface.send( interface.conn, interface.writebuffer, 1, interface.writebufferlen )\r
+                               local succ, err, byte = interface.conn:send( interface.writebuffer, 1, interface.writebufferlen )\r
                                --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte )\r
                                if succ then  -- writing succesful\r
                                        interface.writebuffer = ""\r
@@ -511,11 +536,11 @@ do
                                end\r
                        end\r
                end\r
-               local usingssl, receive = interface._usingssl, interface.receive;\r
+               \r
                interface.readcallback = function( event )  -- called on read events\r
-                       --vdebug( "new client read event, id/ip/port:", interface, ip, port )\r
+                       --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) )\r
                        if interface.noreading or interface.fatalerror then  -- leave this event\r
-                               --vdebug( "leaving this event because:", interface.noreading or interface.fatalerror )\r
+                               --vdebug( "leaving this event because:", tostring(interface.noreading or interface.fatalerror) )\r
                                interface.eventread = nil\r
                                return -1\r
                        end\r
@@ -526,18 +551,18 @@ do
                                interface.eventread = nil\r
                                return -1\r
                        else -- can read\r
-                               if usingssl then  -- handle luasec\r
+                               if interface._usingssl then  -- handle luasec\r
                                        if interface.eventwritetimeout then  -- ok, in the past writecallback was regged\r
                                                local ret = interface.writecallback( )  -- call it\r
-                                               --vdebug( "tried to write in readcallback, result:", ret )\r
+                                               --vdebug( "tried to write in readcallback, result:", tostring(ret) )\r
                                        end\r
                                        if interface.eventreadtimeout then\r
                                                interface.eventreadtimeout:close( )\r
                                                interface.eventreadtimeout = nil\r
                                        end\r
                                end\r
-                               local buffer, err, part = receive( client, pattern )  -- receive buffer with "pattern"\r
-                               --vdebug( "read data:", buffer, "error:", err, "part:", part )        \r
+                               local buffer, err, part = interface.conn:receive( pattern )  -- receive buffer with "pattern"\r
+                               --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) )        \r
                                buffer = buffer or part or ""\r
                                local len = string_len( buffer )\r
                                if len > cfg.MAX_READ_LENGTH then  -- check buffer length\r
@@ -547,7 +572,7 @@ do
                                        interface.eventread = nil\r
                                        return -1\r
                                end\r
-                               if err and ( "timeout" ~= err ) then\r
+                               if err and ( err ~= "timeout" and err ~= "wantread" ) then\r
                                        if "wantwrite" == err then -- need to read on write event\r
                                                if not interface.eventwrite then  -- register new write event if needed\r
                                                        interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT )\r
@@ -668,26 +693,16 @@ local addserver = ( function( )
        end\r
 end )( )\r
 \r
-local wrapclient = ( function( )\r
-       return function( client, addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )\r
-               debug( "try to connect to:", addr, serverport, "with parameters:", pattern, localaddr, localport, sslcfg, startssl )\r
-               local sslctx\r
-               if sslcfg then  -- handle ssl/new context\r
-                       if not ssl then\r
-                               debug "need luasec, but not available" \r
-                               return nil, "luasec not found"\r
-                       end\r
-                       sslctx, err = ssl.newcontext( sslcfg )\r
-                       if err then\r
-                               debug( "cannot create new ssl context:", err )\r
-                               return nil, err\r
-                       end\r
-               end\r
+local addclient, wrapclient\r
+do\r
+       function wrapclient( client, ip, port, listeners, pattern, sslctx, startssl )\r
+               local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )\r
+               interface:_start_session()\r
+               return interface\r
+               --function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface\r
        end\r
-end )( )\r
-\r
-local addclient = ( function( )\r
-       return function( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )\r
+       \r
+       function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )\r
                local client, err = socket.tcp()  -- creating new socket\r
                if not client then\r
                        debug( "cannot create socket:", err ) \r
@@ -701,23 +716,35 @@ local addclient = ( function( )
                                return nil, err\r
                        end\r
                end\r
+               local sslctx\r
+               if sslcfg then  -- handle ssl/new context\r
+                       if not ssl then\r
+                               debug "need luasec, but not available" \r
+                               return nil, "luasec not found"\r
+                       end\r
+                       sslctx, err = ssl.newcontext( sslcfg )\r
+                       if err then\r
+                               debug( "cannot create new ssl context:", err )\r
+                               return nil, err\r
+                       end\r
+               end\r
                local res, err = client:connect( addr, serverport )  -- connect\r
                if res or ( err == "timeout" ) then\r
                        local ip, port = client:getsockname( )\r
                        local server = function( )\r
                                return nil, "this is a dummy server interface"\r
                        end\r
-                       local interface = handleclient( client, ip, port, server, pattern, listener, sslctx )\r
+                       local interface = wrapclient( client, ip, serverport, listeners, pattern, sslctx, startssl )\r
                        interface:_start_connection( startssl )\r
-                       debug( "new connection id:", interface )\r
+                       debug( "new connection id:", interface.id )\r
                        return interface, err\r
                else\r
                        debug( "new connection failed:", err )\r
                        return nil, err\r
                end\r
-               return wrapclient( client, addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )    \r
        end\r
-end )( )\r
+end\r
+\r
 \r
 local loop = function( )  -- starts the event loop\r
        return base:loop( )\r