Merge 0.9->0.10
[prosody.git] / net / server_event.lua
index d505825dda68636187239957b33d02f32c94e652..70a6dc37881af22fd246201c959ae486b2c84642 100644 (file)
@@ -11,6 +11,7 @@
                        -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing
 
 --]]
+-- luacheck: ignore 212/self 431/err 211/ret
 
 local SCRIPT_NAME           = "server_event.lua"
 local SCRIPT_VERSION        = "0.05"
@@ -32,27 +33,32 @@ local cfg = {
        DEBUG                 = true,  -- show debug messages
 }
 
-local function use(x) return rawget(_G, x); end
-local ipairs = use "ipairs"
-local string = use "string"
-local select = use "select"
-local require = use "require"
-local tostring = use "tostring"
-local coroutine = use "coroutine"
-local setmetatable = use "setmetatable"
+local pairs = pairs
+local select = select
+local require = require
+local tostring = tostring
+local setmetatable = setmetatable
 
 local t_insert = table.insert
 local t_concat = table.concat
+local s_sub = string.sub
 
-local ssl = use "ssl"
-local socket = use "socket" or require "socket"
+local coroutine_wrap = coroutine.wrap
+local coroutine_yield = coroutine.yield
+
+local has_luasec, ssl = pcall ( require , "ssl" )
+local socket = require "socket"
+local levent = require "luaevent.core"
+
+local socket_gettime = socket.gettime
+local getaddrinfo = socket.dns.getaddrinfo
 
 local log = require ("util.logger").init("socket")
 
 local function debug(...)
        return log("debug", ("%s "):rep(select('#', ...)), ...)
 end
-local vdebug = debug;
+-- local vdebug = debug;
 
 local bitor = ( function( ) -- thx Rici Lake
        local hasbit = function( x, p )
@@ -72,62 +78,25 @@ local bitor = ( function( ) -- thx Rici Lake
        end
 end )( )
 
-local event = require "luaevent.core"
-local base = event.new( )
-local EV_READ = event.EV_READ
-local EV_WRITE = event.EV_WRITE
-local EV_TIMEOUT = event.EV_TIMEOUT
-local EV_SIGNAL = event.EV_SIGNAL
+local base = levent.new( )
+local addevent = base.addevent
+local EV_READ = levent.EV_READ
+local EV_WRITE = levent.EV_WRITE
+local EV_TIMEOUT = levent.EV_TIMEOUT
+local EV_SIGNAL = levent.EV_SIGNAL
 
 local EV_READWRITE = bitor( EV_READ, EV_WRITE )
 
-local interfacelist = ( function( )  -- holds the interfaces for sockets
-       local array = { }
-       local len = 0
-       return function( method, arg )
-               if "add" == method then
-                       len = len + 1
-                       array[ len ] = arg
-                       arg:_position( len )
-                       return len
-               elseif "delete" == method then
-                       if len <= 0 then
-                               return nil, "array is already empty"
-                       end
-                       local position = arg:_position()  -- get position in array
-                       if position ~= len then
-                               local interface = array[ len ]  -- get last interface
-                               array[ position ] = interface  -- copy it into free position
-                               array[ len ] = nil  -- free last position
-                               interface:_position( position )  -- set new position in array
-                       else  -- free last position
-                               array[ len ] = nil
-                       end
-                       len = len - 1
-                       return len
-               else
-                       return array
-               end
-       end
-end )( )
+local interfacelist = { }
 
 -- Client interface methods
-local interface_mt
-do
-       interface_mt = {}; interface_mt.__index = interface_mt;
-       
-       local addevent = base.addevent
-       local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield
-       
+local interface_mt = {}; interface_mt.__index = interface_mt;
+
        -- Private methods
-       function interface_mt:_position(new_position)
-                       self.position = new_position or self.position
-                       return self.position;
-       end
        function interface_mt:_close()
                return self:_destroy();
        end
-       
+
        function interface_mt:_start_connection(plainssl) -- should be called from addclient
                        local callback = function( event )
                                if EV_TIMEOUT == event then  -- timeout during connection
@@ -136,7 +105,7 @@ do
                                        self:_close()
                                        debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
                                else
-                                       if plainssl and ssl then  -- start ssl session
+                       if plainssl and has_luasec then  -- start ssl session
                                                self:starttls(self._sslctx, true)
                                        else  -- normal connection
                                                self:_start_session(true)
@@ -188,8 +157,7 @@ do
                                return false
                        end
                        self.conn:settimeout( 0 )  -- set non blocking
-                       local handshakecallback = coroutine_wrap(
-                               function( event )
+       local handshakecallback = coroutine_wrap(function( event )
                                        local _, err
                                        local attempt = 0
                                        local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS
@@ -265,15 +233,15 @@ do
                                self.eventread, self.eventclose = nil, nil
                                self.interface, self.readcallback = nil, nil
                        end
-                       interfacelist( "delete", self )
+       interfacelist[ self ] = nil
                        return true
        end
-       
+
        function interface_mt:_lock(nointerface, noreading, nowriting)  -- lock or unlock this interface or events
                        self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting
                        return nointerface, noreading, nowriting
        end
-       
+
        --TODO: Deprecate
        function interface_mt:lock_read(switch)
                if switch then
@@ -301,7 +269,7 @@ do
                end
                return self._connections
        end
-       
+
        -- Public methods
        function interface_mt:write(data)
                if self.nowriting then return nil, "locked" end
@@ -344,27 +312,27 @@ do
                        return true
                end
        end
-       
+
        function interface_mt:socket()
                return self.conn
        end
-       
+
        function interface_mt:server()
                return self._server or self;
        end
-       
+
        function interface_mt:port()
                return self._port
        end
-       
+
        function interface_mt:serverport()
                return self._serverport
        end
-       
+
        function interface_mt:ip()
                return self._ip
        end
-       
+
        function interface_mt:ssl()
                return self._usingssl
        end
@@ -373,15 +341,15 @@ do
        function interface_mt:type()
                return self._type or "client"
        end
-       
+
        function interface_mt:connections()
                return self._connections
        end
-       
+
        function interface_mt:address()
                return self.addr
        end
-       
+
        function interface_mt:set_sslctx(sslctx)
                self._sslctx = sslctx;
                if sslctx then
@@ -397,11 +365,11 @@ do
                end
                return self._pattern;
        end
-       
-       function interface_mt:set_send(new_send)
+
+function interface_mt:set_send(new_send) -- luacheck: ignore 212
                -- No-op, we always use the underlying connection's send
        end
-       
+
        function interface_mt:starttls(sslctx, call_onconnect)
                debug( "try to start ssl at client id:", self.id )
                local err
@@ -430,22 +398,22 @@ do
                self.starttls = false;
                return true
        end
-       
+
        function interface_mt:setoption(option, value)
                if self.conn.setoption then
                        return self.conn:setoption(option, value);
                end
                return false, "setoption not implemented";
        end
-       
+
        function interface_mt:setlistener(listener)
                self:ondetach(); -- Notify listener that it is no longer responsible for this connection
-               self.onconnect, self.ondisconnect, self.onincoming,
-               self.ontimeout, self.onstatus, self.ondetach
-                       = listener.onconnect, listener.ondisconnect, listener.onincoming,
-                       listener.ontimeout, listener.onstatus, listener.ondetach;
+       self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout,
+       self.onreadtimeout, self.onstatus, self.ondetach
+               = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout,
+                 listener.onreadtimeout, listener.onstatus, listener.ondetach;
        end
-       
+
        -- Stub handlers
        function interface_mt:onconnect()
        end
@@ -455,22 +423,22 @@ do
        end
        function interface_mt:ontimeout()
        end
+function interface_mt:onreadtimeout()
+       self.fatalerror = "timeout during receiving"
+       debug( "connection failed:", self.fatalerror )
+       self:_close()
+       self.eventread = nil
+end
        function interface_mt:ondrain()
        end
        function interface_mt:ondetach()
        end
        function interface_mt:onstatus()
        end
-end
 
 -- End of client interface methods
 
-local handleclient;
-do
-       local string_sub = string.sub  -- caching table lookups
-       local addevent = base.addevent
-       local socket_gettime = socket.gettime
-       function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
+local function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
                --vdebug("creating client interfacce...")
                local interface = {
                        type = "client";
@@ -484,6 +452,7 @@ do
                        ondisconnect = listener.ondisconnect;  -- will be called when client disconnects
                        onincoming = listener.onincoming;  -- will be called when client sends data
                        ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs
+               onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs
                        ondrain = listener.ondrain; -- called when writebuffer is empty
                        ondetach = listener.ondetach; -- called when disassociating this listener from this connection
                        onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS)
@@ -499,14 +468,14 @@ do
                        noreading = false, nowriting = false;  -- locks of the read/writecallback
                        startsslcallback = false;  -- starting handshake callback
                        position = false;  -- position of client in interfacelist
-                       
+
                        -- Properties
                        _ip = ip, _port = port, _server = server, _pattern = pattern,
                        _serverport = (server and server:port() or nil),
                        _sslctx = sslctx; -- parameters
                        _usingssl = false;  -- client is using ssl;
                }
-               if not ssl then interface.starttls = false; end
+       if not has_luasec then interface.starttls = false; end
                interface.id = tostring(interface):match("%x+$");
                interface.writecallback = function( event )  -- called on write events
                        --vdebug( "new client write event, id/ip/port:", interface, ip, port )
@@ -552,7 +521,7 @@ do
                                        return -1
                                elseif byte and (err == "timeout" or err == "wantwrite") then  -- want write again
                                        --vdebug( "writebuffer is not empty:", err )
-                                       interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen )  -- new buffer
+                               interface.writebuffer[1] = s_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen )  -- new buffer
                                        interface.writebufferlen = interface.writebufferlen - byte
                                        if "wantread" == err then  -- happens only with luasec
                                                local callback = function( )
@@ -575,7 +544,7 @@ do
                                end
                        end
                end
-               
+
                interface.readcallback = function( event )  -- called on read events
                        --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) )
                        if interface.noreading or interface.fatalerror then  -- leave this event
@@ -583,13 +552,9 @@ do
                                interface.eventread = nil
                                return -1
                        end
-                       if EV_TIMEOUT == event then  -- took too long to get some data from client -> disconnect
-                               interface.fatalerror = "timeout during receiving"
-                               debug( "connection failed:", interface.fatalerror )
-                               interface:_close()
-                               interface.eventread = nil
-                               return -1
-                       else -- can read
+               if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then
+                       return -1 -- took too long to get some data from client -> disconnect
+               end
                                if interface._usingssl then  -- handle luasec
                                        if interface.eventwritetimeout then  -- ok, in the past writecallback was regged
                                                local ret = interface.writecallback( )  -- call it
@@ -638,22 +603,19 @@ do
                                end
                                return EV_READ, cfg.READ_TIMEOUT
                        end
-               end
 
                client:settimeout( 0 )  -- set non blocking
                setmetatable(interface, interface_mt)
-               interfacelist( "add", interface )  -- add to interfacelist
+       interfacelist[ interface ] = true  -- add to interfacelist
                return interface
        end
-end
 
-local handleserver
-do
-       function handleserver( server, addr, port, pattern, listener, sslctx )  -- creates an server interface
+local function handleserver( server, addr, port, pattern, listener, sslctx )  -- creates an server interface
                debug "creating server interface..."
                local interface = {
                        _connections = 0;
-                       
+
+               type = "server";
                        conn = server;
                        onconnect = listener.onconnect;  -- will be called when new client connected
                        eventread = false;  -- read event handler
@@ -661,7 +623,7 @@ do
                        readcallback = false; -- read event callback
                        fatalerror = false; -- error message
                        nointerface = true;  -- lock/unlock parameter
-                       
+
                        _ip = addr, _port = port, _pattern = pattern,
                        _sslctx = sslctx;
                }
@@ -694,92 +656,77 @@ do
                                interface._connections = interface._connections + 1  -- increase connection count
                                local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
                                --vdebug( "client id:", clientinterface, "startssl:", startssl )
-                               if ssl and sslctx then
+                       if has_luasec and sslctx then
                                        clientinterface:starttls(sslctx, true)
                                else
                                        clientinterface:_start_session( true )
                                end
                                debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>");
-                               
+
                                client, err = server:accept()    -- try to accept again
                        end
                        return EV_READ
                end
-               
+
                server:settimeout( 0 )
                setmetatable(interface, interface_mt)
-               interfacelist( "add", interface )
+       interfacelist[ interface ] = true
                interface:_start_session()
                return interface
        end
-end
 
-local addserver = ( function( )
-       return function( addr, port, listener, pattern, sslcfg, startssl )  -- TODO: check arguments
-               --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
+local function addserver( addr, port, listener, pattern, sslctx, startssl )  -- TODO: check arguments
+       --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+       if sslctx and not has_luasec then
+               debug "fatal error: luasec not found"
+               return nil, "luasec not found"
+end
                local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE )  -- create server socket
                if not server then
                        debug( "creating server socket on "..addr.." port "..port.." failed:", err )
                        return nil, err
                end
-               local sslctx
-               if sslcfg then
-                       if not ssl then
-                               debug "fatal error: luasec not found"
-                               return nil, "luasec not found"
-                       end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "error while creating new ssl context for server socket:", err )
-                               return nil, err
-                       end
-               end
                local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl )  -- new server handler
                debug( "new server created with id:", tostring(interface))
                return interface
        end
-end )( )
 
-local addclient, wrapclient
-do
-       function wrapclient( client, ip, port, listeners, pattern, sslctx )
+local function wrapclient( client, ip, port, listeners, pattern, sslctx )
                local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
                interface:_start_connection(sslctx)
                return interface, client
                --function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
        end
-       
-       function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )
-               local client, err = socket.tcp()  -- creating new socket
-               if not client then
-                       debug( "cannot create socket:", err )
-                       return nil, err
+
+local function addclient( addr, serverport, listener, pattern, sslctx, typ )
+       if sslctx and not has_luasec then
+               debug "need luasec, but not available"
+               return nil, "luasec not found"
                end
-               client:settimeout( 0 )  -- set nonblocking
-               if localaddr then
-                       local res, err = client:bind( localaddr, localport, -1 )
-                       if not res then
-                               debug( "cannot bind client:", err )
-                               return nil, err
+       if not typ then
+               local addrinfo, err = getaddrinfo(addr)
+               if not addrinfo then return nil, err end
+               if addrinfo[1] and addrinfo[1].family == "inet6" then
+                       typ = "tcp6"
+               else
+                       typ = "tcp"
                        end
                end
-               local sslctx
-               if sslcfg then  -- handle ssl/new context
-                       if not ssl then
-                               debug "need luasec, but not available"
-                               return nil, "luasec not found"
+       local create = socket[typ]
+       if type( create ) ~= "function"  then
+               return nil, "invalid socket type"
                        end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "cannot create new ssl context:", err )
+       local client, err = create()  -- creating new socket
+       if not client then
+               debug( "cannot create socket:", err )
                                return nil, err
                        end
-               end
+       client:settimeout( 0 )  -- set nonblocking
                local res, err = client:connect( addr, serverport )  -- connect
                if res or ( err == "timeout" ) then
                        local ip, port = client:getsockname( )
-                       local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl )
-                       interface:_start_connection( startssl )
+               local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
+               interface:_start_connection( sslctx )
                        debug( "new connection id:", interface.id )
                        return interface, err
                else
@@ -787,23 +734,18 @@ do
                        return nil, err
                end
        end
-end
 
-
-local loop = function( )  -- starts the event loop
+local function loop( )  -- starts the event loop
        base:loop( )
        return "quitting";
 end
 
-local newevent = ( function( )
-       local add = base.addevent
-       return function( ... )
-               return add( base, ... )
+local function newevent( ... )
+       return addevent( base, ... )
        end
-end )( )
 
-local closeallservers = function( arg )
-       for _, item in ipairs( interfacelist( ) ) do
+local function closeallservers ( arg )
+       for item in pairs( interfacelist ) do
                if item.type == "server" then
                        item:close( arg )
                end
@@ -826,7 +768,7 @@ end
 -- being garbage-collected
 local signal_events = {}; -- [signal_num] -> event object
 local function hook_signal(signal_num, handler)
-       local function _handler(event)
+       local function _handler()
                local ret = handler();
                if ret ~= false then -- Continue handling this signal?
                        return EV_SIGNAL; -- Yes
@@ -839,14 +781,14 @@ end
 
 local function link(sender, receiver, buffersize)
        local sender_locked;
-       
+
        function receiver:ondrain()
                if sender_locked then
                        sender:resume();
                        sender_locked = nil;
                end
        end
-       
+
        function sender:onincoming(data)
                receiver:write(data);
                if receiver.writebufferlen >= buffersize then
@@ -858,12 +800,11 @@ local function link(sender, receiver, buffersize)
 end
 
 return {
-
        cfg = cfg,
        base = base,
        loop = loop,
        link = link,
-       event = event,
+       event = levent,
        event_base = base,
        addevent = newevent,
        addserver = addserver,