mod_tls: Mark session as not secure before negotiating TLS
[prosody.git] / net / server.lua
index 12502412f29a31ce7df2fe493d89de9ab2b50901..54eadbc617b9f0c06861b2d69dca4681ddbdc696 100644 (file)
@@ -1,19 +1,9 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
--- 
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
---[[\r
-\r
-        server.lua by blastbeat\r
-\r
-        - this script contains the server loop of the program\r
-        - other scripts can reg a server here\r
-\r
-]]--\r
+-- \r
+-- server.lua by blastbeat of the luadch project\r
+-- Re-used here under the MIT/X Consortium License\r
+-- \r
+-- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain\r
+--\r
 \r
 -- // wrapping luadch stuff // --\r
 \r
@@ -167,6 +157,7 @@ _cleanqueue = false    -- clean bufferqueue after using
 \r
 _maxclientsperserver = 1000\r
 \r
+_maxsslhandshake = 30 -- max handshake round-trips\r
 ----------------------------------// PRIVATE //--\r
 \r
 wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl )    -- this function wraps a server\r
@@ -182,22 +173,51 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
     local ssl = false\r
 \r
     if sslctx then\r
+        ssl = true\r
         if not ssl_newcontext then\r
-            return nil, "luasec not found"\r
+            out_error "luasec not found"\r
+            ssl = false\r
         end\r
         if type( sslctx ) ~= "table" then\r
             out_error "server.lua: wrong server sslctx"\r
-            return nil, "wrong server sslctx"\r
+            ssl = false\r
         end\r
-        sslctx, err = ssl_newcontext( sslctx )\r
-        if not sslctx then\r
+        local ctx;\r
+        ctx, err = ssl_newcontext( sslctx )\r
+        if not ctx then\r
             err = err or "wrong sslctx parameters"\r
-            out_error( "server.lua: ", err )\r
-            return nil, err\r
+            local file;\r
+            file = err:match("^error loading (.-) %(");\r
+            if file then\r
+               if file == "private key" then\r
+                       file = sslctx.key or "your private key";\r
+               elseif file == "certificate" then\r
+                       file = sslctx.certificate or "your certificate file";\r
+               end\r
+               local reason = err:match("%((.+)%)$") or "some reason";\r
+               if reason == "Permission denied" then\r
+                       reason = "Check that the permissions allow Prosody to read this file.";\r
+               elseif reason == "No such file or directory" then\r
+                       reason = "Check that the path is correct, and the file exists.";\r
+               elseif reason == "system lib" then\r
+                       reason = "Previous error (see logs), or other system error.";\r
+               else\r
+                       reason = "Reason: "..tostring(reason or "unknown"):lower();\r
+               end\r
+               log("error", "SSL/TLS: Failed to load %s: %s", file, reason);\r
+           else\r
+                log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );\r
+            end\r
+            ssl = false\r
         end\r
-        ssl = true\r
-    else\r
-       out_put("server.lua: ", "ssl not enabled on ", serverport);\r
+        sslctx = ctx;\r
+    end\r
+    if not ssl then\r
+      sslctx = false;\r
+      if startssl then\r
+         log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )\r
+         return nil, "Cannot start ssl,  see log for details"\r
+       end\r
     end\r
 \r
     local accept = socket.accept\r
@@ -211,6 +231,9 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
     handler.ssl = function( )\r
         return ssl\r
     end\r
+    handler.sslctx = function( )\r
+        return sslctx\r
+    end\r
     handler.remove = function( )\r
         connections = connections - 1\r
     end\r
@@ -227,7 +250,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
         _socketlist[ socket ] = nil\r
         handler = nil\r
         socket = nil\r
-        mem_free( )\r
+        --mem_free( )\r
         out_put "server.lua: closed server handler and removed sockets from list"\r
     end\r
     handler.ip = function( )\r
@@ -278,6 +301,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
     local ssl\r
 \r
     local dispatch = listeners.incoming or listeners.listener\r
+    local status = listeners.status\r
     local disconnect = listeners.disconnect\r
 \r
     local bufferqueue = { }    -- buffer array\r
@@ -317,6 +341,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
     handler.ssl = function( )\r
         return ssl\r
     end\r
+    handler.sslctx = function ( )\r
+        return sslctx\r
+    end\r
     handler.send = function( _, data, i, j )\r
         return send( socket, data, i, j )\r
     end\r
@@ -344,6 +371,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                 send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen )    -- forced send\r
             end\r
         end\r
+        if not handler then return true; end\r
         _ = shutdown and shutdown( socket )\r
         socket:close( )\r
         _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )\r
@@ -354,7 +382,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             handler = nil\r
         end\r
         socket = nil\r
-        mem_free( )\r
+        --mem_free( )\r
        if server then\r
                server.remove( )\r
        end\r
@@ -453,10 +481,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             readtraffic = readtraffic + count\r
             _readtraffic = _readtraffic + count\r
             _readtimes[ handler ] = _currenttime\r
-            --out_put( "server.lua: read data '", buffer, "', error: ", err )\r
+            --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )\r
             return dispatch( handler, buffer, err )\r
         else    -- connections was closed or fatal error\r
-            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
+            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )\r
             fatalerror = true\r
             disconnect( handler, err )\r
            _ = handler and handler.close( )\r
@@ -464,13 +492,19 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         end\r
     end\r
     local _sendbuffer = function( )    -- this function sends data\r
-        local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
-        local succ, err, byte = send( socket, buffer, 1, bufferlen )\r
-        local count = ( succ or byte or 0 ) * STAT_UNIT\r
-        sendtraffic = sendtraffic + count\r
-        _sendtraffic = _sendtraffic + count\r
-        _ = _cleanqueue and clean( bufferqueue )\r
-        --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )\r
+       local succ, err, byte, buffer, count;\r
+       local count;\r
+       if socket then\r
+            buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )\r
+            succ, err, byte = send( socket, buffer, 1, bufferlen )\r
+            count = ( succ or byte or 0 ) * STAT_UNIT\r
+            sendtraffic = sendtraffic + count\r
+            _sendtraffic = _sendtraffic + count\r
+            _ = _cleanqueue and clean( bufferqueue )\r
+            --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )\r
+        else\r
+            succ, err, count = false, "closed", 0;\r
+        end\r
         if succ then    -- sending succesful\r
             bufferqueuelen = 0\r
             bufferlen = 0\r
@@ -487,7 +521,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             _writetimes[ handler ] = _currenttime\r
             return true\r
         else    -- connection was closed during sending or fatal error\r
-            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )\r
+            out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )\r
             fatalerror = true\r
             disconnect( handler, err )\r
             _ = handler and handler.close( )\r
@@ -501,7 +535,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
         local read\r
         local handshake = coroutine_wrap( function( client )    -- create handshake coroutine\r
                 local err\r
-                for i = 1, 10 do    -- 10 handshake attemps\r
+                for i = 1, _maxsslhandshake do\r
                     _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen\r
                     _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen\r
                     read, wrote = nil, nil\r
@@ -510,7 +544,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
                         out_put( "server.lua: ssl handshake done" )\r
                         handler.readbuffer = _readbuffer    -- when handshake is done, replace the handshake function with regular functions\r
                         handler.sendbuffer = _sendbuffer\r
-                        -- return dispatch( handler )\r
+                        _ = status and status( handler, "ssl-handshake-complete" )\r
                         return true\r
                     else\r
                         out_put( "server.lua: error during ssl handshake: ", tostring(err) )\r
@@ -540,13 +574,14 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
             socket, err = ssl_wrap( socket, sslctx )    -- wrap socket\r
             if err then\r
                 out_put( "server.lua: ssl error: ", tostring(err) )\r
-                mem_free( )\r
+                --mem_free( )\r
                 return nil, nil, err    -- fatal error\r
             end\r
             socket:settimeout( 0 )\r
             handler.readbuffer = handshake\r
             handler.sendbuffer = handshake\r
-            if not socket then   -- do handshake\r
+            handshake( socket ) -- do handshake\r
+            if not socket then\r
                 return nil, nil, "ssl handshake failed";\r
             end\r
         else\r
@@ -644,7 +679,7 @@ closesocket = function( socket )
     _readlistlen = removesocket( _readlist, socket, _readlistlen )\r
     _socketlist[ socket ] = nil\r
     socket:close( )\r
-    mem_free( )\r
+    --mem_free( )\r
 end\r
 \r
 ----------------------------------// PUBLIC //--\r
@@ -693,9 +728,10 @@ end
 removeserver = function( port )\r
     local handler = _server[ port ]\r
     if not handler then\r
-        return nil, "no server found on port '" .. tostring( port ) "'"\r
+        return nil, "no server found on port '" .. tostring( port ) .. "'"\r
     end\r
     handler.close( )\r
+    _server[ port ] = nil\r
     return true\r
 end\r
 \r
@@ -712,11 +748,11 @@ closeall = function( )
     _sendlist = { }\r
     _timerlist = { }\r
     _socketlist = { }\r
-    mem_free( )\r
+    --mem_free( )\r
 end\r
 \r
 getsettings = function( )\r
-    return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver\r
+    return  _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake\r
 end\r
 \r
 changesettings = function( new )\r
@@ -732,6 +768,7 @@ changesettings = function( new )
     _readtimeout = tonumber( new.readtimeout ) or _readtimeout\r
     _cleanqueue = new.cleanqueue\r
     _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver\r
+    _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake\r
     return true\r
 end\r
 \r
@@ -784,7 +821,7 @@ loop = function( )    -- this is the main loop of the program
         _currenttime = os_time( )\r
         if os_difftime( _currenttime - _timer ) >= 1 then\r
             for i = 1, _timerlistlen do\r
-                _timerlist[ i ]( )    -- fire timers\r
+                _timerlist[ i ]( _currenttime )    -- fire timers\r
             end\r
             _timer = _currenttime\r
         end\r