storagemanager: Fix saving data in map shim when no prior data exists
[prosody.git] / net / server_event.lua
index 5333099776674de86a591e58b582052d77dc6b03..70a6dc37881af22fd246201c959ae486b2c84642 100644 (file)
@@ -11,6 +11,7 @@
                        -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing
 
 --]]
+-- luacheck: ignore 212/self 431/err 211/ret
 
 local SCRIPT_NAME           = "server_event.lua"
 local SCRIPT_VERSION        = "0.05"
@@ -32,27 +33,32 @@ local cfg = {
        DEBUG                 = true,  -- show debug messages
 }
 
-local function use(x) return rawget(_G, x); end
-local ipairs = use "ipairs"
-local string = use "string"
-local select = use "select"
-local require = use "require"
-local tostring = use "tostring"
-local coroutine = use "coroutine"
-local setmetatable = use "setmetatable"
+local pairs = pairs
+local select = select
+local require = require
+local tostring = tostring
+local setmetatable = setmetatable
 
 local t_insert = table.insert
 local t_concat = table.concat
+local s_sub = string.sub
 
-local ssl = use "ssl"
-local socket = use "socket" or require "socket"
+local coroutine_wrap = coroutine.wrap
+local coroutine_yield = coroutine.yield
+
+local has_luasec, ssl = pcall ( require , "ssl" )
+local socket = require "socket"
+local levent = require "luaevent.core"
+
+local socket_gettime = socket.gettime
+local getaddrinfo = socket.dns.getaddrinfo
 
 local log = require ("util.logger").init("socket")
 
 local function debug(...)
        return log("debug", ("%s "):rep(select('#', ...)), ...)
 end
-local vdebug = debug;
+-- local vdebug = debug;
 
 local bitor = ( function( ) -- thx Rici Lake
        local hasbit = function( x, p )
@@ -72,58 +78,21 @@ local bitor = ( function( ) -- thx Rici Lake
        end
 end )( )
 
-local event = require "luaevent.core"
-local base = event.new( )
-local EV_READ = event.EV_READ
-local EV_WRITE = event.EV_WRITE
-local EV_TIMEOUT = event.EV_TIMEOUT
-local EV_SIGNAL = event.EV_SIGNAL
+local base = levent.new( )
+local addevent = base.addevent
+local EV_READ = levent.EV_READ
+local EV_WRITE = levent.EV_WRITE
+local EV_TIMEOUT = levent.EV_TIMEOUT
+local EV_SIGNAL = levent.EV_SIGNAL
 
 local EV_READWRITE = bitor( EV_READ, EV_WRITE )
 
-local interfacelist = ( function( )  -- holds the interfaces for sockets
-       local array = { }
-       local len = 0
-       return function( method, arg )
-               if "add" == method then
-                       len = len + 1
-                       array[ len ] = arg
-                       arg:_position( len )
-                       return len
-               elseif "delete" == method then
-                       if len <= 0 then
-                               return nil, "array is already empty"
-                       end
-                       local position = arg:_position()  -- get position in array
-                       if position ~= len then
-                               local interface = array[ len ]  -- get last interface
-                               array[ position ] = interface  -- copy it into free position
-                               array[ len ] = nil  -- free last position
-                               interface:_position( position )  -- set new position in array
-                       else  -- free last position
-                               array[ len ] = nil
-                       end
-                       len = len - 1
-                       return len
-               else
-                       return array
-               end
-       end
-end )( )
+local interfacelist = { }
 
 -- Client interface methods
-local interface_mt
-do
-       interface_mt = {}; interface_mt.__index = interface_mt;
-
-       local addevent = base.addevent
-       local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield
+local interface_mt = {}; interface_mt.__index = interface_mt;
 
        -- Private methods
-       function interface_mt:_position(new_position)
-                       self.position = new_position or self.position
-                       return self.position;
-       end
        function interface_mt:_close()
                return self:_destroy();
        end
@@ -136,7 +105,7 @@ do
                                        self:_close()
                                        debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
                                else
-                                       if plainssl and ssl then  -- start ssl session
+                       if plainssl and has_luasec then  -- start ssl session
                                                self:starttls(self._sslctx, true)
                                        else  -- normal connection
                                                self:_start_session(true)
@@ -188,8 +157,7 @@ do
                                return false
                        end
                        self.conn:settimeout( 0 )  -- set non blocking
-                       local handshakecallback = coroutine_wrap(
-                               function( event )
+       local handshakecallback = coroutine_wrap(function( event )
                                        local _, err
                                        local attempt = 0
                                        local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS
@@ -265,7 +233,7 @@ do
                                self.eventread, self.eventclose = nil, nil
                                self.interface, self.readcallback = nil, nil
                        end
-                       interfacelist( "delete", self )
+       interfacelist[ self ] = nil
                        return true
        end
 
@@ -289,8 +257,9 @@ do
 
        function interface_mt:resume()
                self:_lock(self.nointerface, false, self.nowriting);
-               if not self.eventread then
+               if self.readcallback and not self.eventread then
                        self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT );  -- register callback
+                       return true;
                end
        end
 
@@ -397,7 +366,7 @@ do
                return self._pattern;
        end
 
-       function interface_mt:set_send(new_send)
+function interface_mt:set_send(new_send) -- luacheck: ignore 212
                -- No-op, we always use the underlying connection's send
        end
 
@@ -438,9 +407,11 @@ do
        end
 
        function interface_mt:setlistener(listener)
-               self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onreadtimeout, self.onstatus
-                       = listener.onconnect, listener.ondisconnect, listener.onincoming,
-                         listener.ontimeout, listener.onreadtimeout, listener.onstatus;
+               self:ondetach(); -- Notify listener that it is no longer responsible for this connection
+       self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout,
+       self.onreadtimeout, self.onstatus, self.ondetach
+               = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout,
+                 listener.onreadtimeout, listener.onstatus, listener.ondetach;
        end
 
        -- Stub handlers
@@ -452,26 +423,22 @@ do
        end
        function interface_mt:ontimeout()
        end
-       function interface_mt:onreadtimeout()
-               self.fatalerror = "timeout during receiving"
-               debug( "connection failed:", self.fatalerror )
-               self:_close()
-               self.eventread = nil
-       end
+function interface_mt:onreadtimeout()
+       self.fatalerror = "timeout during receiving"
+       debug( "connection failed:", self.fatalerror )
+       self:_close()
+       self.eventread = nil
+end
        function interface_mt:ondrain()
        end
+       function interface_mt:ondetach()
+       end
        function interface_mt:onstatus()
        end
-end
 
 -- End of client interface methods
 
-local handleclient;
-do
-       local string_sub = string.sub  -- caching table lookups
-       local addevent = base.addevent
-       local socket_gettime = socket.gettime
-       function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
+local function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
                --vdebug("creating client interfacce...")
                local interface = {
                        type = "client";
@@ -485,8 +452,9 @@ do
                        ondisconnect = listener.ondisconnect;  -- will be called when client disconnects
                        onincoming = listener.onincoming;  -- will be called when client sends data
                        ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs
-                       onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs
+               onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs
                        ondrain = listener.ondrain; -- called when writebuffer is empty
+                       ondetach = listener.ondetach; -- called when disassociating this listener from this connection
                        onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS)
                        eventread = false, eventwrite = false, eventclose = false,
                        eventhandshake = false, eventstarthandshake = false;  -- event handler
@@ -507,7 +475,7 @@ do
                        _sslctx = sslctx; -- parameters
                        _usingssl = false;  -- client is using ssl;
                }
-               if not ssl then interface.starttls = false; end
+       if not has_luasec then interface.starttls = false; end
                interface.id = tostring(interface):match("%x+$");
                interface.writecallback = function( event )  -- called on write events
                        --vdebug( "new client write event, id/ip/port:", interface, ip, port )
@@ -553,7 +521,7 @@ do
                                        return -1
                                elseif byte and (err == "timeout" or err == "wantwrite") then  -- want write again
                                        --vdebug( "writebuffer is not empty:", err )
-                                       interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen )  -- new buffer
+                               interface.writebuffer[1] = s_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen )  -- new buffer
                                        interface.writebufferlen = interface.writebufferlen - byte
                                        if "wantread" == err then  -- happens only with luasec
                                                local callback = function( )
@@ -584,72 +552,70 @@ do
                                interface.eventread = nil
                                return -1
                        end
-                       if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then
-                               return -1 -- took too long to get some data from client -> disconnect
-                       end
-                       if interface._usingssl then  -- handle luasec
-                               if interface.eventwritetimeout then  -- ok, in the past writecallback was regged
-                                       local ret = interface.writecallback( )  -- call it
-                                       --vdebug( "tried to write in readcallback, result:", tostring(ret) )
-                               end
-                               if interface.eventreadtimeout then
-                                       interface.eventreadtimeout:close( )
-                                       interface.eventreadtimeout = nil
-                               end
-                       end
-                       local buffer, err, part = interface.conn:receive( interface._pattern )  -- receive buffer with "pattern"
-                       --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) )
-                       buffer = buffer or part
-                       if buffer and #buffer > cfg.MAX_READ_LENGTH then  -- check buffer length
-                               interface.fatalerror = "receive buffer exceeded"
-                               debug( "fatal error:", interface.fatalerror )
-                               interface:_close()
-                               interface.eventread = nil
-                               return -1
-                       end
-                       if err and ( err ~= "timeout" and err ~= "wantread" ) then
-                               if "wantwrite" == err then -- need to read on write event
-                                       if not interface.eventwrite then  -- register new write event if needed
-                                               interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT )
+               if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then
+                       return -1 -- took too long to get some data from client -> disconnect
+               end
+                               if interface._usingssl then  -- handle luasec
+                                       if interface.eventwritetimeout then  -- ok, in the past writecallback was regged
+                                               local ret = interface.writecallback( )  -- call it
+                                               --vdebug( "tried to write in readcallback, result:", tostring(ret) )
                                        end
-                                       interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT,
-                                               function( )
-                                                       interface:_close()
-                                               end, cfg.READ_TIMEOUT
-                                       )
-                                       debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." )
-                                       -- to be honest i dont know what happens next, if it is allowed to first read, the write etc...
-                               else  -- connection was closed or fatal error
-                                       interface.fatalerror = err
-                                       debug( "connection failed in read event:", interface.fatalerror )
+                                       if interface.eventreadtimeout then
+                                               interface.eventreadtimeout:close( )
+                                               interface.eventreadtimeout = nil
+                                       end
+                               end
+                               local buffer, err, part = interface.conn:receive( interface._pattern )  -- receive buffer with "pattern"
+                               --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) )
+                               buffer = buffer or part
+                               if buffer and #buffer > cfg.MAX_READ_LENGTH then  -- check buffer length
+                                       interface.fatalerror = "receive buffer exceeded"
+                                       debug( "fatal error:", interface.fatalerror )
                                        interface:_close()
                                        interface.eventread = nil
                                        return -1
                                end
-                       else
-                               interface.onincoming( interface, buffer, err )  -- send new data to listener
-                       end
-                       if interface.noreading then
-                               interface.eventread = nil;
-                               return -1;
+                               if err and ( err ~= "timeout" and err ~= "wantread" ) then
+                                       if "wantwrite" == err then -- need to read on write event
+                                               if not interface.eventwrite then  -- register new write event if needed
+                                                       interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT )
+                                               end
+                                               interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT,
+                                                       function( )
+                                                               interface:_close()
+                                                       end, cfg.READ_TIMEOUT
+                                               )
+                                               debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." )
+                                               -- to be honest i dont know what happens next, if it is allowed to first read, the write etc...
+                                       else  -- connection was closed or fatal error
+                                               interface.fatalerror = err
+                                               debug( "connection failed in read event:", interface.fatalerror )
+                                               interface:_close()
+                                               interface.eventread = nil
+                                               return -1
+                                       end
+                               else
+                                       interface.onincoming( interface, buffer, err )  -- send new data to listener
+                               end
+                               if interface.noreading then
+                                       interface.eventread = nil;
+                                       return -1;
+                               end
+                               return EV_READ, cfg.READ_TIMEOUT
                        end
-                       return EV_READ, cfg.READ_TIMEOUT
-               end
 
                client:settimeout( 0 )  -- set non blocking
                setmetatable(interface, interface_mt)
-               interfacelist( "add", interface )  -- add to interfacelist
+       interfacelist[ interface ] = true  -- add to interfacelist
                return interface
        end
-end
 
-local handleserver
-do
-       function handleserver( server, addr, port, pattern, listener, sslctx )  -- creates an server interface
+local function handleserver( server, addr, port, pattern, listener, sslctx )  -- creates an server interface
                debug "creating server interface..."
                local interface = {
                        _connections = 0;
 
+               type = "server";
                        conn = server;
                        onconnect = listener.onconnect;  -- will be called when new client connected
                        eventread = false;  -- read event handler
@@ -690,7 +656,7 @@ do
                                interface._connections = interface._connections + 1  -- increase connection count
                                local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
                                --vdebug( "client id:", clientinterface, "startssl:", startssl )
-                               if ssl and sslctx then
+                       if has_luasec and sslctx then
                                        clientinterface:starttls(sslctx, true)
                                else
                                        clientinterface:_start_session( true )
@@ -704,78 +670,63 @@ do
 
                server:settimeout( 0 )
                setmetatable(interface, interface_mt)
-               interfacelist( "add", interface )
+       interfacelist[ interface ] = true
                interface:_start_session()
                return interface
        end
-end
 
-local addserver = ( function( )
-       return function( addr, port, listener, pattern, sslcfg, startssl )  -- TODO: check arguments
-               --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
+local function addserver( addr, port, listener, pattern, sslctx, startssl )  -- TODO: check arguments
+       --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+       if sslctx and not has_luasec then
+               debug "fatal error: luasec not found"
+               return nil, "luasec not found"
+end
                local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE )  -- create server socket
                if not server then
                        debug( "creating server socket on "..addr.." port "..port.." failed:", err )
                        return nil, err
                end
-               local sslctx
-               if sslcfg then
-                       if not ssl then
-                               debug "fatal error: luasec not found"
-                               return nil, "luasec not found"
-                       end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "error while creating new ssl context for server socket:", err )
-                               return nil, err
-                       end
-               end
                local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl )  -- new server handler
                debug( "new server created with id:", tostring(interface))
                return interface
        end
-end )( )
 
-local addclient, wrapclient
-do
-       function wrapclient( client, ip, port, listeners, pattern, sslctx )
+local function wrapclient( client, ip, port, listeners, pattern, sslctx )
                local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
                interface:_start_connection(sslctx)
                return interface, client
                --function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
        end
 
-       function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )
-               local client, err = socket.tcp()  -- creating new socket
-               if not client then
-                       debug( "cannot create socket:", err )
-                       return nil, err
+local function addclient( addr, serverport, listener, pattern, sslctx, typ )
+       if sslctx and not has_luasec then
+               debug "need luasec, but not available"
+               return nil, "luasec not found"
                end
-               client:settimeout( 0 )  -- set nonblocking
-               if localaddr then
-                       local res, err = client:bind( localaddr, localport, -1 )
-                       if not res then
-                               debug( "cannot bind client:", err )
-                               return nil, err
+       if not typ then
+               local addrinfo, err = getaddrinfo(addr)
+               if not addrinfo then return nil, err end
+               if addrinfo[1] and addrinfo[1].family == "inet6" then
+                       typ = "tcp6"
+               else
+                       typ = "tcp"
                        end
                end
-               local sslctx
-               if sslcfg then  -- handle ssl/new context
-                       if not ssl then
-                               debug "need luasec, but not available"
-                               return nil, "luasec not found"
+       local create = socket[typ]
+       if type( create ) ~= "function"  then
+               return nil, "invalid socket type"
                        end
-                       sslctx, err = sslcfg
-                       if err then
-                               debug( "cannot create new ssl context:", err )
+       local client, err = create()  -- creating new socket
+       if not client then
+               debug( "cannot create socket:", err )
                                return nil, err
                        end
-               end
+       client:settimeout( 0 )  -- set nonblocking
                local res, err = client:connect( addr, serverport )  -- connect
                if res or ( err == "timeout" ) then
                        local ip, port = client:getsockname( )
-                       local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl )
-                       interface:_start_connection( startssl )
+               local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
+               interface:_start_connection( sslctx )
                        debug( "new connection id:", interface.id )
                        return interface, err
                else
@@ -783,23 +734,18 @@ do
                        return nil, err
                end
        end
-end
 
-
-local loop = function( )  -- starts the event loop
+local function loop( )  -- starts the event loop
        base:loop( )
        return "quitting";
 end
 
-local newevent = ( function( )
-       local add = base.addevent
-       return function( ... )
-               return add( base, ... )
+local function newevent( ... )
+       return addevent( base, ... )
        end
-end )( )
 
-local closeallservers = function( arg )
-       for _, item in ipairs( interfacelist( ) ) do
+local function closeallservers ( arg )
+       for item in pairs( interfacelist ) do
                if item.type == "server" then
                        item:close( arg )
                end
@@ -822,7 +768,7 @@ end
 -- being garbage-collected
 local signal_events = {}; -- [signal_num] -> event object
 local function hook_signal(signal_num, handler)
-       local function _handler(event)
+       local function _handler()
                local ret = handler();
                if ret ~= false then -- Continue handling this signal?
                        return EV_SIGNAL; -- Yes
@@ -850,15 +796,15 @@ local function link(sender, receiver, buffersize)
                        sender:pause();
                end
        end
+       sender:set_mode("*a");
 end
 
 return {
-
        cfg = cfg,
        base = base,
        loop = loop,
        link = link,
-       event = event,
+       event = levent,
        event_base = base,
        addevent = newevent,
        addserver = addserver,