util.sasl_cyrus: Protect the call to cyrussasl.server_new properly.
[prosody.git] / core / s2smanager.lua
index af5e91e3585705eb946b68a2ca3759a53848232c..f0b802d87e50c164b0e84e9fd7a24a5d6e302178 100644 (file)
@@ -16,8 +16,10 @@ local socket = require "socket";
 local format = string.format;
 local t_insert, t_sort = table.insert, table.sort;
 local get_traceback = debug.traceback;
-local tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber
-    = tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber;
+local tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber,
+      setmetatable
+    = tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber,
+      setmetatable;
 
 local idna_to_ascii = require "util.encodings".idna.to_ascii;
 local connlisteners_get = require "net.connlisteners".get;
@@ -48,7 +50,9 @@ local incoming_s2s = incoming_s2s;
 
 module "s2smanager"
 
-local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end
+function compare_srv_priorities(a,b)
+       return a.priority < b.priority or (a.priority == b.priority and a.weight > b.weight);
+end
 
 local function bounce_sendq(session, reason)
        local sendq = session.sendq;
@@ -181,7 +185,6 @@ function new_outgoing(from_host, to_host, connect)
                                buffer[#buffer+1] = data;
                                log("debug", "Buffered item %d: %s", #buffer, tostring(data));
                        end
-                       
                end
 
                return host_session;
@@ -302,12 +305,17 @@ function try_connect(host_session, connect_host, connect_port)
 end
 
 function make_connect(host_session, connect_host, connect_port)
-       host_session.log("info", "Beginning new connection attempt to %s (%s:%d)", host_session.to_host, connect_host, connect_port);
+       (host_session.log or log)("info", "Beginning new connection attempt to %s (%s:%d)", host_session.to_host, connect_host, connect_port);
        -- Ok, we're going to try to connect
        
        local from_host, to_host = host_session.from_host, host_session.to_host;
        
        local conn, handler = socket.tcp()
+       
+       if not conn then
+               log("warn", "Failed to create outgoing connection, system error: %s", handler);
+               return false, handler;
+       end
 
        conn:settimeout(0);
        local success, err = conn:connect(connect_host, connect_port);
@@ -317,7 +325,7 @@ function make_connect(host_session, connect_host, connect_port)
        end
        
        local cl = connlisteners_get("xmppserver");
-       conn = wrapclient(conn, connect_host, connect_port, cl, cl.default_mode or 1, hosts[from_host].ssl_ctx, false );
+       conn = wrapclient(conn, connect_host, connect_port, cl, cl.default_mode or 1 );
        host_session.conn = conn;
        
        -- Register this outgoing connection so that xmppserver_listener knows about it
@@ -327,7 +335,7 @@ function make_connect(host_session, connect_host, connect_port)
        local w, log = conn.write, host_session.log;
        host_session.sends2s = function (t) log("debug", "sending: %s", (t.top_tag and t:top_tag()) or t:match("^[^>]*>?")); w(conn, tostring(t)); end
        
-       host_session:open_stream();
+       host_session:open_stream(from_host, to_host);
        
        log("debug", "Connection attempt in progress...");
        add_task(connect_timeout, function ()
@@ -361,11 +369,6 @@ function streamopened(session, attr)
                session.secure = true;
        end
        
-       if session.version >= 1.0 and not (attr.to and attr.from) then
-               (session.log or log)("warn", "Remote of stream "..(session.from_host or "(unknown)").."->"..(session.to_host or "(unknown)")
-                       .." failed to specify to (%s) and/or from (%s) hostname as per RFC", tostring(attr.to), tostring(attr.from));
-       end
-       
        if session.direction == "incoming" then
                -- Send a reply stream header
                session.to_host = attr.to and nameprep(attr.to);
@@ -385,7 +388,7 @@ function streamopened(session, attr)
                        local features = st.stanza("stream:features");
                        
                        if session.to_host then
-                               hosts[session.to_host].events.fire_event("s2s-stream-features", { session = session, features = features });
+                               hosts[session.to_host].events.fire_event("s2s-stream-features", { origin = session, features = features });
                        else
                                (session.log or log)("warn", "No 'to' on stream header from %s means we can't offer any features", session.from_host or "unknown host");
                        end
@@ -426,11 +429,8 @@ function streamopened(session, attr)
 end
 
 function streamclosed(session)
-       (session.log or log)("debug", "</stream:stream>");
-       if session.sends2s then
-               session.sends2s("</stream:stream>");
-       end
-       session.notopen = true;
+       (session.log or log)("debug", "Received </stream:stream>");
+       session:close();
 end
 
 function initiate_dialback(session)
@@ -449,6 +449,16 @@ function verify_dialback(id, to, from, key)
 end
 
 function make_authenticated(session, host)
+       if not session.secure then
+               local local_host = session.direction == "incoming" and session.to_host or session.from_host;
+               if config.get(local_host, "core", "s2s_require_encryption") then
+                       session:close({
+                               condition = "policy-violation",
+                               text = "Encrypted server-to-server communication is required but was not "
+                                      ..((session.direction == "outgoing" and "offered") or "used")
+                       });
+               end
+       end
        if session.type == "s2sout_unauthed" then
                session.type = "s2sout";
        elseif session.type == "s2sin_unauthed" then
@@ -494,7 +504,31 @@ function mark_connected(session)
        end
 end
 
+local resting_session = { -- Resting, not dead
+               destroyed = true;
+               open_stream = function (session)
+                       session.log("debug", "Attempt to open stream on resting session");
+               end;
+               close = function (session)
+                       session.log("debug", "Attempt to close already-closed session");
+               end;
+       }; resting_session.__index = resting_session;
+
+function retire_session(session)
+       local log = session.log or log;
+       for k in pairs(session) do
+               if k ~= "trace" and k ~= "log" and k ~= "id" then
+                       session[k] = nil;
+               end
+       end
+
+       function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); end
+       function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end
+       return setmetatable(session, resting_session);
+end
+
 function destroy_session(session, reason)
+       if session.destroyed then return; end
        (session.log or log)("info", "Destroying "..tostring(session.direction).." session "..tostring(session.from_host).."->"..tostring(session.to_host));
        
        if session.direction == "outgoing" then
@@ -504,11 +538,7 @@ function destroy_session(session, reason)
                incoming_s2s[session] = nil;
        end
        
-       for k in pairs(session) do
-               if k ~= "trace" then
-                       session[k] = nil;
-               end
-       end
+       retire_session(session); -- Clean session until it is GC'd
 end
 
 return _M;