mod_s2s: Remove compat with namespace issue from Prosody pre-0.6.2
[prosody.git] / plugins / mod_s2s / mod_s2s.lua
index 5531ca3e1b4b90f4c71e70ffe16588b3314048e4..f05e2a95dd09af9906e5913e2959a7bdadc79c82 100644 (file)
@@ -15,7 +15,6 @@ local core_process_stanza = prosody.core_process_stanza;
 local tostring, type = tostring, type;
 local t_insert = table.insert;
 local xpcall, traceback = xpcall, debug.traceback;
-local NULL = {};
 
 local add_task = require "util.timer".add_task;
 local st = require "util.stanza";
@@ -26,7 +25,6 @@ local s2s_new_incoming = require "core.s2smanager".new_incoming;
 local s2s_new_outgoing = require "core.s2smanager".new_outgoing;
 local s2s_destroy_session = require "core.s2smanager".destroy_session;
 local uuid_gen = require "util.uuid".generate;
-local cert_verify_identity = require "util.x509".verify_identity;
 local fire_global_event = prosody.events.fire_event;
 
 local s2sout = module:require("s2sout");
@@ -39,6 +37,8 @@ local secure_domains, insecure_domains =
        module:get_option_set("s2s_secure_domains", {})._items, module:get_option_set("s2s_insecure_domains", {})._items;
 local require_encryption = module:get_option_boolean("s2s_require_encryption", false);
 
+local measure_connections = module:measure("connections", "counter");
+
 local sessions = module:shared("sessions");
 
 local log = module._log;
@@ -49,7 +49,7 @@ local bouncy_stanzas = { message = true, presence = true, iq = true };
 local function bounce_sendq(session, reason)
        local sendq = session.sendq;
        if not sendq then return; end
-       session.log("info", "sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host));
+       session.log("info", "Sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host));
        local dummy = {
                type = "s2sin";
                send = function(s)
@@ -150,11 +150,23 @@ function module.add_host(module)
        module:hook("route/remote", route_to_new_session, -10);
        module:hook("s2s-authenticated", make_authenticated, -1);
        module:hook("s2s-read-timeout", keepalive, -1);
+       module:hook_stanza("http://etherx.jabber.org/streams", "features", function (session, stanza)
+               if session.type == "s2sout" then
+                       -- Stream is authenticated and we are seem to be done with feature negotiation,
+                       -- so the stream is ready for stanzas.  RFC 6120 Section 4.3
+                       mark_connected(session);
+                       return true;
+               elseif not session.dialback_verifying then
+                       session.log("warn", "No SASL EXTERNAL offer and Dialback doesn't seem to be enabled, giving up");
+                       session:close();
+                       return false;
+               end
+       end, -1);
 end
 
 -- Stream is authorised, and ready for normal stanzas
 function mark_connected(session)
-       local sendq, send = session.sendq, session.sends2s;
+       local sendq = session.sendq;
 
        local from, to = session.from_host, session.to_host;
 
@@ -177,6 +189,7 @@ function mark_connected(session)
        if session.direction == "outgoing" then
                if sendq then
                        session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host);
+                       local send = session.sends2s;
                        for i, data in ipairs(sendq) do
                                send(data[1]);
                                sendq[i] = nil;
@@ -219,13 +232,16 @@ function make_authenticated(event)
        end
        session.log("debug", "connection %s->%s is now authenticated for %s", session.from_host, session.to_host, host);
 
-       mark_connected(session);
+       if (session.type == "s2sout" and session.external_auth ~= "succeeded") or session.type == "s2sin" then
+               -- Stream either used dialback for authentication or is an incoming stream.
+               mark_connected(session);
+       end
 
        return true;
 end
 
 --- Helper to check that a session peer's certificate is valid
-local function check_cert_status(session)
+function check_cert_status(session)
        local host = session.direction == "outgoing" and session.to_host or session.from_host
        local conn = session.conn:socket()
        local cert
@@ -233,39 +249,6 @@ local function check_cert_status(session)
                cert = conn:getpeercertificate()
        end
 
-       if cert then
-               local chain_valid, errors;
-               if conn.getpeerverification then
-                       chain_valid, errors = conn:getpeerverification();
-               elseif conn.getpeerchainvalid then -- COMPAT mw/luasec-hg
-                       chain_valid, errors = conn:getpeerchainvalid();
-                       errors = (not chain_valid) and { { errors } } or nil;
-               else
-                       chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } };
-               end
-               -- Is there any interest in printing out all/the number of errors here?
-               if not chain_valid then
-                       (session.log or log)("debug", "certificate chain validation result: invalid");
-                       for depth, t in pairs(errors or NULL) do
-                               (session.log or log)("debug", "certificate error(s) at depth %d: %s", depth-1, table.concat(t, ", "))
-                       end
-                       session.cert_chain_status = "invalid";
-               else
-                       (session.log or log)("debug", "certificate chain validation result: valid");
-                       session.cert_chain_status = "valid";
-
-                       -- We'll go ahead and verify the asserted identity if the
-                       -- connecting server specified one.
-                       if host then
-                               if cert_verify_identity(host, "xmpp-server", cert) then
-                                       session.cert_identity_status = "valid"
-                               else
-                                       session.cert_identity_status = "invalid"
-                               end
-                               (session.log or log)("debug", "certificate identity validation result: %s", session.cert_identity_status);
-                       end
-               end
-       end
        return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert });
 end
 
@@ -276,8 +259,6 @@ local stream_callbacks = { default_ns = "jabber:server", handlestanza =  core_pr
 local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams";
 
 function stream_callbacks.streamopened(session, attr)
-       local send = session.sends2s;
-
        session.version = tonumber(attr.version) or 0;
 
        -- TODO: Rename session.secure to session.encrypted
@@ -360,6 +341,7 @@ function stream_callbacks.streamopened(session, attr)
                end
 
                session:open_stream(session.to_host, session.from_host)
+               session.notopen = nil;
                if session.version >= 1.0 then
                        local features = st.stanza("stream:features");
 
@@ -367,14 +349,24 @@ function stream_callbacks.streamopened(session, attr)
                                hosts[to].events.fire_event("s2s-stream-features", { origin = session, features = features });
                        else
                                (session.log or log)("warn", "No 'to' on stream header from %s means we can't offer any features", from or session.ip or "unknown host");
+                               fire_global_event("s2s-stream-features-legacy", { origin = session, features = features });
                        end
 
-                       log("debug", "Sending stream features: %s", tostring(features));
-                       send(features);
+                       if ( session.type == "s2sin" or session.type == "s2sout" ) or features.tags[1] then
+                               log("debug", "Sending stream features: %s", tostring(features));
+                               session.sends2s(features);
+                       else
+                               (session.log or log)("warn", "No features to offer, giving up");
+                               session:close({ condition = "undefined-condition", text = "No features to offer" });
+                       end
                end
        elseif session.direction == "outgoing" then
-               -- If we are just using the connection for verifying dialback keys, we won't try and auth it
-               if not attr.id then error("stream response did not give us a streamid!!!"); end
+               session.notopen = nil;
+               if not attr.id then
+                       log("error", "Stream response from %s did not give us a stream id!", session.to_host);
+                       session:close({ condition = "undefined-condition", text = "Missing stream ID" });
+                       return;
+               end
                session.streamid = attr.id;
 
                if session.secure and not session.cert_chain_status then
@@ -406,7 +398,6 @@ function stream_callbacks.streamopened(session, attr)
                        end
                end
        end
-       session.notopen = nil;
 end
 
 function stream_callbacks.streamclosed(session)
@@ -416,6 +407,7 @@ end
 
 function stream_callbacks.error(session, error, data)
        if error == "no-stream" then
+               session.log("debug", "Invalid opening stream header (%s)", (data:gsub("^([^\1]+)\1", "{%1}")));
                session:close("invalid-namespace");
        elseif error == "parse-error" then
                session.log("debug", "Server-to-server XML parse error: %s", tostring(error));
@@ -442,9 +434,6 @@ end
 
 local function handleerr(err) log("error", "Traceback[s2s]: %s", traceback(tostring(err), 2)); end
 function stream_callbacks.handlestanza(session, stanza)
-       if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client
-               stanza.attr.xmlns = nil;
-       end
        stanza = session.filter("stanzas/in", stanza);
        if stanza then
                return xpcall(function () return core_process_stanza(session, stanza) end, handleerr);
@@ -510,46 +499,58 @@ local function session_close(session, reason, remote_reason)
        end
 end
 
-function session_open_stream(session, from, to)
-       local attr = {
-               ["xmlns:stream"] = 'http://etherx.jabber.org/streams',
-               xmlns = 'jabber:server',
-               version = session.version and (session.version > 0 and "1.0" or nil),
-               ["xml:lang"] = 'en',
-               id = session.streamid,
-               from = from, to = to,
-       }
+function session_stream_attrs(session, from, to, attr)
        if not from or (hosts[from] and hosts[from].modules.dialback) then
                attr["xmlns:db"] = 'jabber:server:dialback';
        end
-
-       session.sends2s("<?xml version='1.0'?>");
-       session.sends2s(st.stanza("stream:stream", attr):top_tag());
-       return true;
+       if not from then
+               attr.from = '';
+       end
+       if not to then
+               attr.to = '';
+       end
 end
 
 -- Session initialization logic shared by incoming and outgoing
 local function initialize_session(session)
        local stream = new_xmpp_stream(session, stream_callbacks);
+       local log = session.log or log;
        session.stream = stream;
 
        session.notopen = true;
 
        function session.reset_stream()
                session.notopen = true;
+               session.streamid = nil;
                session.stream:reset();
        end
 
-       session.open_stream = session_open_stream;
+       session.stream_attrs = session_stream_attrs;
+
+       local filter = initialize_filters(session);
+       local conn = session.conn;
+       local w = conn.write;
+
+       function session.sends2s(t)
+               log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^[^>]*>?"));
+               if t.name then
+                       t = filter("stanzas/out", t);
+               end
+               if t then
+                       t = filter("bytes/out", tostring(t));
+                       if t then
+                               return w(conn, t);
+                       end
+               end
+       end
 
-       local filter = session.filter;
        function session.data(data)
                data = filter("bytes/in", data);
                if data then
                        local ok, err = stream:feed(data);
                        if ok then return; end
-                       (session.log or log)("warn", "Received invalid XML: %s", data);
-                       (session.log or log)("warn", "Problem was: %s", err);
+                       log("warn", "Received invalid XML: %s", data);
+                       log("warn", "Problem was: %s", err);
                        session:close("not-well-formed");
                end
        end
@@ -561,6 +562,8 @@ local function initialize_session(session)
                return handlestanza(session, stanza);
        end
 
+       module:fire_event("s2s-created", { session = session });
+
        add_task(connect_timeout, function ()
                if session.type == "s2sin" or session.type == "s2sout" then
                        return; -- Ok, we're connected
@@ -575,28 +578,13 @@ local function initialize_session(session)
 end
 
 function listener.onconnect(conn)
+       measure_connections(1);
        conn:setoption("keepalive", opt_keepalives);
        local session = sessions[conn];
        if not session then -- New incoming connection
                session = s2s_new_incoming(conn);
                sessions[conn] = session;
                session.log("debug", "Incoming s2s connection");
-
-               local filter = initialize_filters(session);
-               local w = conn.write;
-               session.sends2s = function (t)
-                       log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^([^>]*>?)"));
-                       if t.name then
-                               t = filter("stanzas/out", t);
-                       end
-                       if t then
-                               t = filter("bytes/out", tostring(t));
-                               if t then
-                                       return w(conn, t);
-                               end
-                       end
-               end
-
                initialize_session(session);
        else -- Outgoing session connected
                session:open_stream(session.from_host, session.to_host);
@@ -621,7 +609,13 @@ function listener.onstatus(conn, status)
        end
 end
 
+function listener.ontimeout(conn)
+       -- Called instead of onconnect when the connection times out
+       measure_connections(1);
+end
+
 function listener.ondisconnect(conn, err)
+       measure_connections(-1);
        local session = sessions[conn];
        if session then
                sessions[conn] = nil;
@@ -638,17 +632,21 @@ end
 
 function listener.onreadtimeout(conn)
        local session = sessions[conn];
+       local host = session.host or session.to_host;
        if session then
-               return (hosts[session.host] or prosody).events.fire_event("s2s-read-timeout", { session = session });
+               return (hosts[host] or prosody).events.fire_event("s2s-read-timeout", { session = session });
        end
 end
 
 function listener.register_outgoing(conn, session)
-       session.direction = "outgoing";
        sessions[conn] = session;
        initialize_session(session);
 end
 
+function listener.ondetach(conn)
+       sessions[conn] = nil;
+end
+
 function check_auth_policy(event)
        local host, session = event.host, event.session;
        local must_secure = secure_auth;
@@ -679,7 +677,7 @@ module:hook("server-stopping", function(event)
        for _, session in pairs(sessions) do
                session:close{ condition = "system-shutdown", text = reason };
        end
-end,500);
+end, -200);