Fix for never checking if the first module for a host is already loaded (affects...
[prosody.git] / core / s2smanager.lua
index df877767cc77c8fa95c45ab77f4b6ddb051b26b6..4f2054189fd35d0cf018955c8accad38b4a6941e 100644 (file)
@@ -1,26 +1,16 @@
--- Prosody IM v0.1
--- Copyright (C) 2008 Matthew Wild
--- Copyright (C) 2008 Waqas Hussain
+-- Prosody IM v0.3
+-- Copyright (C) 2008-2009 Matthew Wild
+-- Copyright (C) 2008-2009 Waqas Hussain
 -- 
--- This program is free software; you can redistribute it and/or
--- modify it under the terms of the GNU General Public License
--- as published by the Free Software Foundation; either version 2
--- of the License, or (at your option) any later version.
--- 
--- This program is distributed in the hope that it will be useful,
--- but WITHOUT ANY WARRANTY; without even the implied warranty of
--- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
--- GNU General Public License for more details.
--- 
--- You should have received a copy of the GNU General Public License
--- along with this program; if not, write to the Free Software
--- Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
 --
 
 
 
 local hosts = hosts;
 local sessions = sessions;
+local core_process_stanza = function(a, b) core_process_stanza(a, b); end
 local socket = require "socket";
 local format = string.format;
 local t_insert, t_sort = table.insert, table.sort;
@@ -30,7 +20,7 @@ local tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber
 
 local idna_to_ascii = require "util.encodings".idna.to_ascii;
 local connlisteners_get = require "net.connlisteners".get;
-local wraptlsclient = require "net.server".wraptlsclient;
+local wrapclient = require "net.server".wrapclient;
 local modulemanager = require "core.modulemanager";
 local st = require "stanza";
 local stanza = st.stanza;
@@ -43,29 +33,58 @@ local log = logger_init("s2smanager");
 
 local sha256_hash = require "util.hashes".sha256;
 
-local dialback_secret = "This is very secret!!! Ha!";
+local dialback_secret = sha256_hash(tostring{} .. math.random() .. socket.gettime(), true);
 
 local dns = require "net.dns";
 
+incoming_s2s = {};
+local incoming_s2s = incoming_s2s;
+
 module "s2smanager"
 
 local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end
 
+local function bounce_sendq(session)
+       local sendq = session.sendq;
+       if sendq then
+               session.log("debug", "sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host));
+               local dummy = {
+                       type = "s2sin";
+                       send = function(s)
+                               (session.log or log)("error", "Replying to to an s2s error reply, please report this! Traceback: %s", get_traceback());
+                       end;
+                       dummy = true;
+               };
+               for i, data in ipairs(sendq) do
+                       local reply = data[2];
+                       local xmlns = reply.attr.xmlns;
+                       if not xmlns or xmlns == "jabber:client" or xmlns == "jabber:server" then
+                               reply.attr.type = "error";
+                               reply:tag("error", {type = "cancel"})
+                                       :tag("remote-server-not-found", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up();
+                               core_process_stanza(dummy, reply);
+                       end
+                       sendq[i] = nil;
+               end
+               session.sendq = nil;
+       end
+end
+
 function send_to_host(from_host, to_host, data)
-       if data.name then data = tostring(data); end
        local host = hosts[from_host].s2sout[to_host];
        if host then
                -- We have a connection to this host already
-               if host.type == "s2sout_unauthed" and ((not data.xmlns) or data.xmlns == "jabber:client" or data.xmlns == "jabber:server") then
-                       (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now...");
+               if host.type == "s2sout_unauthed" and data.name ~= "db:verify" and ((not data.xmlns) or data.xmlns == "jabber:client" or data.xmlns == "jabber:server") then
+                       (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host);
                        if not host.notopen and not host.dialback_key then
                                host.log("debug", "dialback had not been initiated");
                                initiate_dialback(host);
                        end
                        
                        -- Queue stanza until we are able to send it
-                       if host.sendq then t_insert(host.sendq, data);
-                       else host.sendq = { data }; end
+                       if host.sendq then t_insert(host.sendq, {tostring(data), st.reply(data)});
+                       else host.sendq = { {tostring(data), st.reply(data)} }; end
+                       host.log("debug", "stanza [%s] queued ", data.name);
                elseif host.type == "local" or host.type == "component" then
                        log("error", "Trying to send a stanza to ourselves??")
                        log("error", "Traceback: %s", get_traceback());
@@ -84,21 +103,23 @@ function send_to_host(from_host, to_host, data)
                log("debug", "opening a new outgoing connection for this stanza");
                local host_session = new_outgoing(from_host, to_host);
                -- Store in buffer
-               host_session.sendq = { data };
+               host_session.sendq = { {tostring(data), st.reply(data)} };
+               if not host_session.conn then destroy_session(host_session); end
        end
 end
 
 local open_sessions = 0;
 
 function new_incoming(conn)
-       local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming" };
+       local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} };
        if true then
                session.trace = newproxy(true);
-               getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; print("s2s session got collected, now "..open_sessions.." s2s sessions are allocated") end;
+               getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; end;
        end
        open_sessions = open_sessions + 1;
        local w, log = conn.write, logger_init("s2sin"..tostring(conn):match("[a-f0-9]+$"));
        session.sends2s = function (t) log("debug", "sending: %s", tostring(t)); w(tostring(t)); end
+       incoming_s2s[session] = true;
        return session;
 end
 
@@ -160,10 +181,11 @@ function attempt_connection(host_session, err)
        local success, err = conn:connect(connect_host, connect_port);
        if not success and err ~= "timeout" then
                log("warn", "s2s connect() failed: %s", err);
+               return false;
        end
        
        local cl = connlisteners_get("xmppserver");
-       conn = wraptlsclient(cl, conn, connect_host, connect_port, 0, cl.default_mode or 1, hosts[from_host].ssl_ctx );
+       conn = wrapclient(conn, connect_host, connect_port, cl, cl.default_mode or 1, hosts[from_host].ssl_ctx, false );
        host_session.conn = conn;
        
        -- Register this outgoing connection so that xmppserver_listener knows about it
@@ -239,11 +261,16 @@ function verify_dialback(id, to, from, key)
        return key == generate_dialback(id, to, from);
 end
 
-function make_authenticated(session)
+function make_authenticated(session, host)
        if session.type == "s2sout_unauthed" then
                session.type = "s2sout";
        elseif session.type == "s2sin_unauthed" then
                session.type = "s2sin";
+               if host then
+                       session.hosts[host].authed = true;
+               end
+       elseif session.type == "s2sin" and host then
+               session.hosts[host].authed = true;
        else
                return false;
        end
@@ -269,7 +296,7 @@ function mark_connected(session)
                if sendq then
                        session.log("debug", "sending "..#sendq.." queued stanzas across new outgoing connection to "..session.to_host);
                        for i, data in ipairs(sendq) do
-                               send(data);
+                               send(data[1]);
                                sendq[i] = nil;
                        end
                        session.sendq = nil;
@@ -280,10 +307,12 @@ end
 function destroy_session(session)
        (session.log or log)("info", "Destroying "..tostring(session.direction).." session "..tostring(session.from_host).."->"..tostring(session.to_host));
        
-       -- FIXME: Flush sendq here/report errors to originators
        
        if session.direction == "outgoing" then
                hosts[session.from_host].s2sout[session.to_host] = nil;
+               bounce_sendq(session);
+       elseif session.direction == "incoming" then
+               incoming_s2s[session] = nil;
        end
        
        for k in pairs(session) do