core.s2smanager: Don't recurse CNAMEs infinitely :)
[prosody.git] / net / server.lua
index a5c8e24c32f40f0fe246631c32020fad3be6f4fc..971ea5530695f191cfeff18dcc1c9b3c93eae3fe 100644 (file)
@@ -1,11 +1,9 @@
---[[\r
-\r
-        server.lua by blastbeat\r
-\r
-        - this script contains the server loop of the program\r
-        - other scripts can reg a server here\r
-\r
-]]--\r
+-- \r
+-- server.lua by blastbeat of the luadch project\r
+-- Re-used here under the MIT/X Consortium License\r
+-- \r
+-- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain\r
+--\r
 \r
 -- // wrapping luadch stuff // --\r
 \r
@@ -77,6 +75,7 @@ local idfalse
 local addtimer\r
 local closeall\r
 local addserver\r
+local getserver\r
 local wrapserver\r
 local getsettings\r
 local closesocket\r
@@ -173,22 +172,51 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
     local ssl = false\r
 \r
     if sslctx then\r
+        ssl = true\r
         if not ssl_newcontext then\r
-            return nil, "luasec not found"\r
+            out_error "luasec not found"\r
+            ssl = false\r
         end\r
         if type( sslctx ) ~= "table" then\r
             out_error "server.lua: wrong server sslctx"\r
-            return nil, "wrong server sslctx"\r
+            ssl = false\r
         end\r
-        sslctx, err = ssl_newcontext( sslctx )\r
-        if not sslctx then\r
+        local ctx;\r
+        ctx, err = ssl_newcontext( sslctx )\r
+        if not ctx then\r
             err = err or "wrong sslctx parameters"\r
-            out_error( "server.lua: ", err )\r
-            return nil, err\r
+            local file;\r
+            file = err:match("^error loading (.-) %(");\r
+            if file then\r
+               if file == "private key" then\r
+                       file = sslctx.key or "your private key";\r
+               elseif file == "certificate" then\r
+                       file = sslctx.certificate or "your certificate file";\r
+               end\r
+               local reason = err:match("%((.+)%)$") or "some reason";\r
+               if reason == "Permission denied" then\r
+                       reason = "Check that the permissions allow Prosody to read this file.";\r
+               elseif reason == "No such file or directory" then\r
+                       reason = "Check that the path is correct, and the file exists.";\r
+               elseif reason == "system lib" then\r
+                       reason = "Previous error (see logs), or other system error.";\r
+               else\r
+                       reason = "Reason: "..tostring(reason or "unknown"):lower();\r
+               end\r
+               log("error", "SSL/TLS: Failed to load %s: %s", file, reason);\r
+           else\r
+                log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );\r
+            end\r
+            ssl = false\r
         end\r
-        ssl = true\r
-    else\r
-       out_put("server.lua: ", "ssl not enabled on ", serverport);\r
+        sslctx = ctx;\r
+    end\r
+    if not ssl then\r
+      sslctx = false;\r
+      if startssl then\r
+         log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )\r
+         return nil, "Cannot start ssl,  see log for details"\r
+       end\r
     end\r
 \r
     local accept = socket.accept\r
@@ -244,10 +272,10 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
                 return false\r
             end\r
             connections = connections + 1\r
-            out_put( "server.lua: accepted new client connection from ", ip, ":", clientport, " to ",  serverport)\r
+            out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))\r
             return dispatch( handler )\r
         elseif err then    -- maybe timeout or something else\r
-            out_put( "server.lua: error with new client connection: ", err )\r
+            out_put( "server.lua: error with new client connection: ", tostring(err) )\r
             return false\r
         end\r
     end\r
@@ -318,13 +346,16 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         return shutdown( socket, pattern )\r
     end\r
     handler.close = function( forced )\r
+        if not handler then return true; end\r
         _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
         _readtimes[ handler ] = nil\r
         if bufferqueuelen ~= 0 then\r
             if not ( forced or fatalerror ) then\r
                 handler.sendbuffer( )\r
                 if bufferqueuelen ~= 0 then   -- try again...\r
-                    handler.write = nil    -- ... but no further writing allowed\r
+                    if handler then\r
+                        handler.write = nil    -- ... but no further writing allowed\r
+                    end\r
                     toclose = true\r
                     return false\r
                 end\r
@@ -332,13 +363,16 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                 send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen )    -- forced send\r
             end\r
         end\r
-        shutdown( socket )\r
+        if not handler then return true; end\r
+        _ = shutdown and shutdown( socket )\r
         socket:close( )\r
         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
         _socketlist[ socket ] = nil\r
-        _writetimes[ handler ] = nil\r
-        _closelist[ handler ] = nil\r
-        handler = nil\r
+        if handler then\r
+            _writetimes[ handler ] = nil\r
+            _closelist[ handler ] = nil\r
+            handler = nil\r
+        end\r
         socket = nil\r
         mem_free( )\r
        if server then\r
@@ -369,7 +403,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         end\r
         bufferqueuelen = bufferqueuelen + 1\r
         bufferqueue[ bufferqueuelen ] = data\r
-        _writetimes[ handler ] = _writetimes[ handler ] or _currenttime\r
+        if handler then\r
+               _writetimes[ handler ] = _writetimes[ handler ] or _currenttime\r
+        end\r
         return true\r
     end\r
     handler.write = write\r
@@ -383,6 +419,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         pattern = new or pattern\r
         return pattern\r
     end\r
+    handler.setsend = function ( newsend )\r
+        send = newsend or send\r
+        return send\r
+    end\r
     handler.bufferlen = function( readlen, sendlen )\r
         maxsendlen = sendlen or maxsendlen\r
         maxreadlen = readlen or maxreadlen\r
@@ -436,7 +476,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             --out_put( "server.lua: read data '", buffer, "', error: ", err )\r
             return dispatch( handler, buffer, err )\r
         else    -- connections was closed or fatal error\r
-            out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
+            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
             fatalerror = true\r
             disconnect( handler, err )\r
            _ = handler and handler.close( )\r
@@ -444,13 +484,19 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         end\r
     end\r
     local _sendbuffer = function( )    -- this function sends data\r
-        local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
-        local succ, err, byte = send( socket, buffer, 1, bufferlen )\r
-        local count = ( succ or byte or 0 ) * STAT_UNIT\r
-        sendtraffic = sendtraffic + count\r
-        _sendtraffic = _sendtraffic + count\r
-        _ = _cleanqueue and clean( bufferqueue )\r
-        --out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )\r
+       local succ, err, byte, buffer, count;\r
+       local count;\r
+       if socket then\r
+            buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
+            succ, err, byte = send( socket, buffer, 1, bufferlen )\r
+            count = ( succ or byte or 0 ) * STAT_UNIT\r
+            sendtraffic = sendtraffic + count\r
+            _sendtraffic = _sendtraffic + count\r
+            _ = _cleanqueue and clean( bufferqueue )\r
+            --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )\r
+        else\r
+            succ, err, count = false, "closed", 0;\r
+        end\r
         if succ then    -- sending succesful\r
             bufferqueuelen = 0\r
             bufferlen = 0\r
@@ -467,10 +513,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             _writetimes[ handler ] = _currenttime\r
             return true\r
         else    -- connection was closed during sending or fatal error\r
-            out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )\r
+            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
             fatalerror = true\r
             disconnect( handler, err )\r
-            handler.close( )\r
+            _ = handler and handler.close( )\r
             return false\r
         end\r
     end\r
@@ -478,30 +524,39 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
     if sslctx then    -- ssl?\r
         ssl = true\r
         local wrote\r
+        local read\r
         local handshake = coroutine_wrap( function( client )    -- create handshake coroutine\r
                 local err\r
                 for i = 1, 10 do    -- 10 handshake attemps\r
+                    _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen\r
+                    _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen\r
+                    read, wrote = nil, nil\r
                     _, err = client:dohandshake( )\r
                     if not err then\r
-                        --out_put( "server.lua: ssl handshake done" )\r
-                        _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen\r
+                        out_put( "server.lua: ssl handshake done" )\r
                         handler.readbuffer = _readbuffer    -- when handshake is done, replace the handshake function with regular functions\r
                         handler.sendbuffer = _sendbuffer\r
-                        --return dispatch( handler )\r
+                        -- return dispatch( handler )\r
                         return true\r
                     else\r
-                        out_put( "server.lua: error during ssl handshake: ", err )\r
+                        out_put( "server.lua: error during ssl handshake: ", tostring(err) )\r
                         if err == "wantwrite" and not wrote then\r
                             _sendlistlen = _sendlistlen + 1\r
                             _sendlist[ _sendlistlen ] = client\r
                             wrote = true\r
+                        elseif err == "wantread" and not read then\r
+                                _readlistlen = _readlistlen + 1\r
+                                _readlist [ _readlistlen ] = client\r
+                                read = true\r
+                        else\r
+                               break;\r
                         end\r
                         --coroutine_yield( handler, nil, err )    -- handshake not finished\r
                         coroutine_yield( )\r
                     end\r
                 end\r
-                disconnect( handler, "max handshake attemps exceeded" )\r
-                handler.close( true )    -- forced disconnect\r
+                disconnect( handler, "ssl handshake failed" )\r
+                _ = handler and handler.close( true )    -- forced disconnect\r
                 return false    -- handshake failed\r
             end\r
         )\r
@@ -510,15 +565,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
            local err\r
             socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
             if err then\r
-                out_put( "server.lua: ssl error: ", err )\r
+                out_put( "server.lua: ssl error: ", tostring(err) )\r
                 mem_free( )\r
                 return nil, nil, err    -- fatal error\r
             end\r
             socket:settimeout( 0 )\r
             handler.readbuffer = handshake\r
             handler.sendbuffer = handshake\r
-            handshake( socket )    -- do handshake\r
+            handshake( socket ) -- do handshake\r
+            if not socket then\r
+                return nil, nil, "ssl handshake failed";\r
+            end\r
         else\r
+            -- We're not automatically doing SSL, so we're not secure (yet)\r
+            ssl = false\r
             handler.starttls = function( now )\r
                 if not now then\r
                     --out_put "server.lua: we need to do tls, but delaying until later"\r
@@ -530,7 +590,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                 socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
                 --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )\r
                 if err then\r
-                    out_put( "server.lua: error while starting tls on client: ", err )\r
+                    out_put( "server.lua: error while starting tls on client: ", tostring(err) )\r
                     return nil, err    -- fatal error\r
                 end\r
 \r
@@ -555,9 +615,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
 \r
                 handler.starttls = nil\r
                 needtls = nil\r
+                \r
+                -- Secure now\r
+                ssl = true\r
 \r
-                handler.receivedata = handler.handshake\r
-                handler.dispatchdata = handler.handshake\r
+                handler.readbuffer = handshake\r
+                handler.sendbuffer = handshake\r
                 handshake( socket )    -- do handshake\r
             end\r
             handler.readbuffer = _readbuffer\r
@@ -627,13 +690,13 @@ addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, st
         err = "luasec not found"\r
     end\r
     if err then\r
-        out_error( "server.lua: ", err )\r
+        out_error( "server.lua, port ", port, ": ", err )\r
         return nil, err\r
     end\r
     addr = addr or "*"\r
     local server, err = socket_bind( addr, port )\r
     if err then\r
-        out_error( "server.lua: ", err )\r
+        out_error( "server.lua, port ", port, ": ", err )\r
         return nil, err\r
     end\r
     local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, maxconnections, startssl )    -- wrap new server socket\r
@@ -650,12 +713,17 @@ addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, st
     return handler\r
 end\r
 \r
+getserver = function ( port )\r
+       return _server[ port ];\r
+end\r
+\r
 removeserver = function( port )\r
     local handler = _server[ port ]\r
     if not handler then\r
         return nil, "no server found on port '" .. tostring( port ) "'"\r
     end\r
     handler.close( )\r
+    _server[ port ] = nil\r
     return true\r
 end\r
 \r
@@ -708,8 +776,15 @@ stats = function( )
     return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen\r
 end\r
 \r
+local dontstop = true; -- thinking about tomorrow, ...\r
+\r
+setquitting = function (quit)\r
+       dontstop = not quit;\r
+       return;\r
+end\r
+\r
 loop = function( )    -- this is the main loop of the program\r
-    while true do\r
+    while dontstop do\r
         local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )\r
         for i, socket in ipairs( write ) do    -- send data waiting in writequeues\r
             local handler = _socketlist[ socket ]\r
@@ -744,6 +819,7 @@ loop = function( )    -- this is the main loop of the program
         socket_sleep( _sleeptime )    -- wait some time\r
         --collectgarbage( )\r
     end\r
+    return "quitting"\r
 end\r
 \r
 --// EXPERIMENTAL //--\r
@@ -758,7 +834,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx,
 end\r
 \r
 local addclient = function( address, port, listeners, pattern, sslctx, startssl )\r
-    local client, err = socket.tcp( )\r
+    local client, err = luasocket.tcp( )\r
     if err then\r
         return nil, err\r
     end\r
@@ -767,7 +843,7 @@ local addclient = function( address, port, listeners, pattern, sslctx, startssl
     if err then    -- try again\r
         local handler = wrapclient( client, address, port, listeners )\r
     else\r
-        wrapconnection( server, listeners, socket, address, port, "clientport", pattern, sslctx, startssl )\r
+        wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )\r
     end\r
 end\r
 \r
@@ -816,8 +892,9 @@ return {
     closeall = closeall,\r
     addtimer = addtimer,\r
     addserver = addserver,\r
+    getserver = getserver,\r
     getsettings = getsettings,\r
+    setquitting = setquitting,\r
     removeserver = removeserver,\r
     changesettings = changesettings,\r
-\r
 }\r