Merge 0.9->0.10
[prosody.git] / plugins / mod_s2s / mod_s2s.lua
index d35fc4891712f370fd9aff2cf8bb9e8a9e6dd41f..8614b8570933085a79dc5dfaeb251e9bf042862e 100644 (file)
@@ -15,7 +15,6 @@ local core_process_stanza = prosody.core_process_stanza;
 local tostring, type = tostring, type;
 local t_insert = table.insert;
 local xpcall, traceback = xpcall, debug.traceback;
-local NULL = {};
 
 local add_task = require "util.timer".add_task;
 local st = require "util.stanza";
@@ -26,7 +25,6 @@ local s2s_new_incoming = require "core.s2smanager".new_incoming;
 local s2s_new_outgoing = require "core.s2smanager".new_outgoing;
 local s2s_destroy_session = require "core.s2smanager".destroy_session;
 local uuid_gen = require "util.uuid".generate;
-local cert_verify_identity = require "util.x509".verify_identity;
 local fire_global_event = prosody.events.fire_event;
 
 local s2sout = module:require("s2sout");
@@ -235,7 +233,7 @@ function make_authenticated(event)
 end
 
 --- Helper to check that a session peer's certificate is valid
-local function check_cert_status(session)
+function check_cert_status(session)
        local host = session.direction == "outgoing" and session.to_host or session.from_host
        local conn = session.conn:socket()
        local cert
@@ -243,39 +241,6 @@ local function check_cert_status(session)
                cert = conn:getpeercertificate()
        end
 
-       if cert then
-               local chain_valid, errors;
-               if conn.getpeerverification then
-                       chain_valid, errors = conn:getpeerverification();
-               elseif conn.getpeerchainvalid then -- COMPAT mw/luasec-hg
-                       chain_valid, errors = conn:getpeerchainvalid();
-                       errors = (not chain_valid) and { { errors } } or nil;
-               else
-                       chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } };
-               end
-               -- Is there any interest in printing out all/the number of errors here?
-               if not chain_valid then
-                       (session.log or log)("debug", "certificate chain validation result: invalid");
-                       for depth, t in pairs(errors or NULL) do
-                               (session.log or log)("debug", "certificate error(s) at depth %d: %s", depth-1, table.concat(t, ", "))
-                       end
-                       session.cert_chain_status = "invalid";
-               else
-                       (session.log or log)("debug", "certificate chain validation result: valid");
-                       session.cert_chain_status = "valid";
-
-                       -- We'll go ahead and verify the asserted identity if the
-                       -- connecting server specified one.
-                       if host then
-                               if cert_verify_identity(host, "xmpp-server", cert) then
-                                       session.cert_identity_status = "valid"
-                               else
-                                       session.cert_identity_status = "invalid"
-                               end
-                               (session.log or log)("debug", "certificate identity validation result: %s", session.cert_identity_status);
-                       end
-               end
-       end
        return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert });
 end
 
@@ -382,7 +347,9 @@ function stream_callbacks.streamopened(session, attr)
                        log("debug", "Sending stream features: %s", tostring(features));
                        send(features);
                end
+               session.notopen = nil;
        elseif session.direction == "outgoing" then
+               session.notopen = nil;
                -- If we are just using the connection for verifying dialback keys, we won't try and auth it
                if not attr.id then error("stream response did not give us a streamid!!!"); end
                session.streamid = attr.id;
@@ -416,7 +383,6 @@ function stream_callbacks.streamopened(session, attr)
                        end
                end
        end
-       session.notopen = nil;
 end
 
 function stream_callbacks.streamclosed(session)
@@ -426,6 +392,7 @@ end
 
 function stream_callbacks.error(session, error, data)
        if error == "no-stream" then
+               session.log("debug", "Invalid opening stream header (%s)", (data:gsub("^([^\1]+)\1", "{%1}")));
                session:close("invalid-namespace");
        elseif error == "parse-error" then
                session.log("debug", "Server-to-server XML parse error: %s", tostring(error));
@@ -536,12 +503,29 @@ local function initialize_session(session)
 
        function session.reset_stream()
                session.notopen = true;
+               session.streamid = nil;
                session.stream:reset();
        end
 
        session.stream_attrs = session_stream_attrs;
 
-       local filter = session.filter;
+       local filter = initialize_filters(session);
+       local conn = session.conn;
+       local w = conn.write;
+
+       function session.sends2s(t)
+               log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^[^>]*>?"));
+               if t.name then
+                       t = filter("stanzas/out", t);
+               end
+               if t then
+                       t = filter("bytes/out", tostring(t));
+                       if t then
+                               return w(conn, t);
+                       end
+               end
+       end
+
        function session.data(data)
                data = filter("bytes/in", data);
                if data then
@@ -560,6 +544,8 @@ local function initialize_session(session)
                return handlestanza(session, stanza);
        end
 
+       module:fire_event("s2s-created", { session = session });
+
        add_task(connect_timeout, function ()
                if session.type == "s2sin" or session.type == "s2sout" then
                        return; -- Ok, we're connected
@@ -580,22 +566,6 @@ function listener.onconnect(conn)
                session = s2s_new_incoming(conn);
                sessions[conn] = session;
                session.log("debug", "Incoming s2s connection");
-
-               local filter = initialize_filters(session);
-               local w = conn.write;
-               session.sends2s = function (t)
-                       log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^([^>]*>?)"));
-                       if t.name then
-                               t = filter("stanzas/out", t);
-                       end
-                       if t then
-                               t = filter("bytes/out", tostring(t));
-                               if t then
-                                       return w(conn, t);
-                               end
-                       end
-               end
-
                initialize_session(session);
        else -- Outgoing session connected
                session:open_stream(session.from_host, session.to_host);
@@ -643,7 +613,6 @@ function listener.onreadtimeout(conn)
 end
 
 function listener.register_outgoing(conn, session)
-       session.direction = "outgoing";
        sessions[conn] = session;
        initialize_session(session);
 end