net/server_event: pcall require ssl rather than relying on globals
[prosody.git] / net / server_event.lua
index e10606dded0dd079f17ecd176cc40ef585274b4a..7575044a8563778f4b388036d5cd2380596deca4 100644 (file)
@@ -44,7 +44,7 @@ local setmetatable = use "setmetatable"
 local t_insert = table.insert
 local t_concat = table.concat
 
-local ssl = use "ssl"
+local has_luasec, ssl = pcall ( require , "ssl" )
 local socket = use "socket" or require "socket"
 
 local log = require ("util.logger").init("socket")
@@ -136,7 +136,7 @@ do
                                        self:_close()
                                        debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
                                else
-                                       if plainssl and ssl then  -- start ssl session
+                                       if plainssl and has_luasec then  -- start ssl session
                                                self:starttls(self._sslctx, true)
                                        else  -- normal connection
                                                self:_start_session(true)
@@ -438,9 +438,11 @@ do
        end
 
        function interface_mt:setlistener(listener)
-               self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onreadtimeout, self.onstatus
-                       = listener.onconnect, listener.ondisconnect, listener.onincoming,
-                         listener.ontimeout, listener.onreadtimeout, listener.onstatus;
+               self:ondetach(); -- Notify listener that it is no longer responsible for this connection
+               self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout,
+               self.onreadtimeout, self.onstatus, self.ondetach
+                       = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout,
+                         listener.onreadtimeout, listener.onstatus, listener.ondetach;
        end
 
        -- Stub handlers
@@ -460,6 +462,8 @@ do
        end
        function interface_mt:ondrain()
        end
+       function interface_mt:ondetach()
+       end
        function interface_mt:onstatus()
        end
 end
@@ -487,6 +491,7 @@ do
                        ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs
                        onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs
                        ondrain = listener.ondrain; -- called when writebuffer is empty
+                       ondetach = listener.ondetach; -- called when disassociating this listener from this connection
                        onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS)
                        eventread = false, eventwrite = false, eventclose = false,
                        eventhandshake = false, eventstarthandshake = false;  -- event handler
@@ -507,7 +512,7 @@ do
                        _sslctx = sslctx; -- parameters
                        _usingssl = false;  -- client is using ssl;
                }
-               if not ssl then interface.starttls = false; end
+               if not has_luasec then interface.starttls = false; end
                interface.id = tostring(interface):match("%x+$");
                interface.writecallback = function( event )  -- called on write events
                        --vdebug( "new client write event, id/ip/port:", interface, ip, port )
@@ -690,7 +695,7 @@ do
                                interface._connections = interface._connections + 1  -- increase connection count
                                local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
                                --vdebug( "client id:", clientinterface, "startssl:", startssl )
-                               if ssl and sslctx then
+                               if has_luasec and sslctx then
                                        clientinterface:starttls(sslctx, true)
                                else
                                        clientinterface:_start_session( true )
@@ -711,25 +716,17 @@ do
 end
 
 local addserver = ( function( )
-       return function( addr, port, listener, pattern, sslcfg, startssl )  -- TODO: check arguments
-               --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
+       return function( addr, port, listener, pattern, sslctx, startssl )  -- TODO: check arguments
+               --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+               if sslctx and not has_luasec then
+                       debug "fatal error: luasec not found"
+                       return nil, "luasec not found"
+               end
                local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE )  -- create server socket
                if not server then
                        debug( "creating server socket on "..addr.." port "..port.." failed:", err )
                        return nil, err
                end
-               local sslctx
-               if sslcfg then
-                       if not ssl then
-                               debug "fatal error: luasec not found"
-                               return nil, "luasec not found"
-                       end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "error while creating new ssl context for server socket:", err )
-                               return nil, err
-                       end
-               end
                local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl )  -- new server handler
                debug( "new server created with id:", tostring(interface))
                return interface
@@ -745,36 +742,21 @@ do
                --function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
        end
 
-       function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )
+       function addclient( addr, serverport, listener, pattern, sslctx )
+               if sslctx and not has_luasec then
+                       debug "need luasec, but not available"
+                       return nil, "luasec not found"
+               end
                local client, err = socket.tcp()  -- creating new socket
                if not client then
                        debug( "cannot create socket:", err )
                        return nil, err
                end
                client:settimeout( 0 )  -- set nonblocking
-               if localaddr then
-                       local res, err = client:bind( localaddr, localport, -1 )
-                       if not res then
-                               debug( "cannot bind client:", err )
-                               return nil, err
-                       end
-               end
-               local sslctx
-               if sslcfg then  -- handle ssl/new context
-                       if not ssl then
-                               debug "need luasec, but not available"
-                               return nil, "luasec not found"
-                       end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "cannot create new ssl context:", err )
-                               return nil, err
-                       end
-               end
                local res, err = client:connect( addr, serverport )  -- connect
                if res or ( err == "timeout" ) then
                        local ip, port = client:getsockname( )
-                       local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl )
+                       local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
                        interface:_start_connection( startssl )
                        debug( "new connection id:", interface.id )
                        return interface, err