util.sql: Import pcall (fixes #677)
[prosody.git] / util / sql.lua
index 2d5e17743a6a90f448499804330741924b5caad5..843e7ddac74714911e048badf3479e1c7e3f331d 100644 (file)
@@ -1,8 +1,8 @@
 
 local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, unpack, select;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
 local tonumber, tostring = tonumber, tostring;
-local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
 local t_concat = table.concat;
 local s_char = string.char;
 local log = require "util.logger".init("sql");
@@ -25,8 +25,8 @@ local function is_column(x) return getmetatable(x)==column_mt; end
 local function is_index(x) return getmetatable(x)==index_mt; end
 local function is_table(x) return getmetatable(x)==table_mt; end
 local function is_query(x) return getmetatable(x)==query_mt; end
-local function Integer(n) return "Integer()" end
-local function String(n) return "String()" end
+local function Integer() return "Integer()" end
+local function String() return "String()" end
 
 local function Column(definition)
        return setmetatable(definition, column_mt);
@@ -102,11 +102,12 @@ function engine:connect()
        local params = self.params;
        assert(params.driver, "no driver")
        log("debug", "Connecting to [%s] %s...", params.driver, params.database);
-       local dbh, err = DBI.Connect(
+       local ok, dbh, err = pcall(DBI.Connect,
                params.driver, params.database,
                params.username, params.password,
                params.host, params.port
        );
+       if not ok then return ok, dbh; end
        if not dbh then return nil, err; end
        dbh:autocommit(false); -- don't commit automatically
        self.conn = dbh;
@@ -124,6 +125,14 @@ end
 function engine:onconnect()
        -- Override from create_engine()
 end
+
+function engine:prepquery(sql)
+       if self.params.driver == "PostgreSQL" then
+               sql = sql:gsub("`", "\"");
+       end
+       return sql;
+end
+
 function engine:execute(sql, ...)
        local success, err = self:connect();
        if not success then return success, err; end
@@ -147,18 +156,19 @@ local result_mt = { __index = {
        rowcount = function(self) return self.__stmt:rowcount(); end;
 } };
 
+local function debugquery(where, sql, ...)
+       local i = 0; local a = {...}
+       log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
+end
+
 function engine:execute_query(sql, ...)
-       if self.params.driver == "PostgreSQL" then
-               sql = sql:gsub("`", "\"");
-       end
+       sql = self:prepquery(sql);
        local stmt = assert(self.conn:prepare(sql));
        assert(stmt:execute(...));
        return stmt:rows();
 end
 function engine:execute_update(sql, ...)
-       if self.params.driver == "PostgreSQL" then
-               sql = sql:gsub("`", "\"");
-       end
+       sql = self:prepquery(sql);
        local prepared = self.prepared;
        local stmt = prepared[sql];
        if not stmt then
@@ -172,6 +182,30 @@ engine.insert = engine.execute_update;
 engine.select = engine.execute_query;
 engine.delete = engine.execute_update;
 engine.update = engine.execute_update;
+local function debugwrap(name, f)
+       return function (self, sql, ...)
+               debugquery(name, sql, ...)
+               return f(self, sql, ...)
+       end
+end
+function engine:debug(enable)
+       self._debug = enable;
+       if enable then
+               engine.insert = debugwrap("insert", engine.execute_update);
+               engine.select = debugwrap("select", engine.execute_query);
+               engine.delete = debugwrap("delete", engine.execute_update);
+               engine.update = debugwrap("update", engine.execute_update);
+       else
+               engine.insert = engine.execute_update;
+               engine.select = engine.execute_query;
+               engine.delete = engine.execute_update;
+               engine.update = engine.execute_update;
+       end
+end
+local function handleerr(err)
+       log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
+       return err;
+end
 function engine:_transaction(func, ...)
        if not self.conn then
                local ok, err = self:connect();
@@ -182,7 +216,7 @@ function engine:_transaction(func, ...)
        local function f() return func(unpack(args, 1, n_args)); end
        log("debug", "SQL transaction begin [%s]", tostring(func));
        self.__transaction = true;
-       local success, a, b, c = xpcall(f, debug_traceback);
+       local success, a, b, c = xpcall(f, handleerr);
        self.__transaction = nil;
        if success then
                log("debug", "SQL transaction success [%s]", tostring(func));
@@ -221,7 +255,9 @@ function engine:_create_index(index)
        if index.unique then
                sql = sql:gsub("^CREATE", "CREATE UNIQUE");
        end
-       --print(sql);
+       if self._debug then
+               debugquery("create", sql);
+       end
        return self:execute(sql);
 end
 function engine:_create_table(table)
@@ -252,6 +288,9 @@ function engine:_create_table(table)
        elseif self.params.driver == "MySQL" then
                sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
        end
+       if self._debug then
+               debugquery("create", sql);
+       end
        local success,err = self:execute(sql);
        if not success then return success,err; end
        for i,v in ipairs(table.__table__) do
@@ -265,19 +304,21 @@ function engine:set_encoding() -- to UTF-8
        local driver = self.params.driver;
        if driver == "SQLite3" then
                return self:transaction(function()
-                       if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
-                               self.charset = "utf8";
+                       for encoding in self:select"PRAGMA encoding;" do
+                               if encoding[1] == "UTF-8" then
+                                       self.charset = "utf8";
+                               end
                        end
                end);
        end
        local set_names_query = "SET NAMES '%s';"
        local charset = "utf8";
        if driver == "MySQL" then
-               local ok, charsets = self:transaction(function()
-                       return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
+               self:transaction(function()
+                       for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+                               charset = row and row[1] or charset;
+                       end
                end);
-               local row = ok and charsets();
-               charset = row and row[1] or charset;
                set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
        end
        self.charset = charset;
@@ -291,12 +332,16 @@ function engine:set_encoding() -- to UTF-8
                local ok, actual_charset = self:transaction(function ()
                        return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
                end);
+               local charset_ok = true;
                for row in actual_charset do
                        if row[2] ~= charset then
                                log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
-                               return false, "Failed to set connection encoding";
+                               charset_ok = false;
                        end
                end
+               if not charset_ok then
+                       return false, "Failed to set connection encoding";
+               end
        end
 
        return true;