Merge 0.10->trunk
[prosody.git] / util / sql.lua
index 037dbc76b94ed90fd63fd3d9b4a0b4c5575776ee..9981ac3cbfd464f2ff0e193dc5c6cef82f6d2556 100644 (file)
@@ -1,6 +1,6 @@
 
 local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, unpack, select;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
 local tonumber, tostring = tonumber, tostring;
 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
 local t_concat = table.concat;
@@ -13,7 +13,7 @@ local DBI = require "DBI";
 DBI.Drivers();
 local build_url = require "socket.url".build;
 
-module("sql")
+local _ENV = nil;
 
 local column_mt = {};
 local table_mt = {};
@@ -21,17 +21,17 @@ local query_mt = {};
 --local op_mt = {};
 local index_mt = {};
 
-function is_column(x) return getmetatable(x)==column_mt; end
-function is_index(x) return getmetatable(x)==index_mt; end
-function is_table(x) return getmetatable(x)==table_mt; end
-function is_query(x) return getmetatable(x)==query_mt; end
-function Integer(n) return "Integer()" end
-function String(n) return "String()" end
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer(n) return "Integer()" end
+local function String(n) return "String()" end
 
-function Column(definition)
+local function Column(definition)
        return setmetatable(definition, column_mt);
 end
-function Table(definition)
+local function Table(definition)
        local c = {}
        for i,col in ipairs(definition) do
                if is_column(col) then
@@ -42,7 +42,7 @@ function Table(definition)
        end
        return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
 end
-function Index(definition)
+local function Index(definition)
        return setmetatable(definition, index_mt);
 end
 
@@ -101,7 +101,7 @@ function engine:connect()
 
        local params = self.params;
        assert(params.driver, "no driver")
-       log("error", "Connecting to [%s] %s...", params.driver, params.database);
+       log("debug", "Connecting to [%s] %s...", params.driver, params.database);
        local dbh, err = DBI.Connect(
                params.driver, params.database,
                params.username, params.password,
@@ -147,6 +147,11 @@ local result_mt = { __index = {
        rowcount = function(self) return self.__stmt:rowcount(); end;
 } };
 
+local function debugquery(where, sql, ...)
+       local i = 0; local a = {...}
+       log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
+end
+
 function engine:execute_query(sql, ...)
        if self.params.driver == "PostgreSQL" then
                sql = sql:gsub("`", "\"");
@@ -172,6 +177,26 @@ engine.insert = engine.execute_update;
 engine.select = engine.execute_query;
 engine.delete = engine.execute_update;
 engine.update = engine.execute_update;
+local function debugwrap(name, f)
+       return function (self, sql, ...)
+               debugquery(name, sql, ...)
+               return f(self, sql, ...)
+       end
+end
+function engine:debug(enable)
+       self._debug = enable;
+       if enable then
+               engine.insert = debugwrap("insert", engine.execute_update);
+               engine.select = debugwrap("select", engine.execute_query);
+               engine.delete = debugwrap("delete", engine.execute_update);
+               engine.update = debugwrap("update", engine.execute_update);
+       else
+               engine.insert = engine.execute_update;
+               engine.select = engine.execute_query;
+               engine.delete = engine.execute_update;
+               engine.update = engine.execute_update;
+       end
+end
 function engine:_transaction(func, ...)
        if not self.conn then
                local ok, err = self:connect();
@@ -180,6 +205,7 @@ function engine:_transaction(func, ...)
        --assert(not self.__transaction, "Recursive transactions not allowed");
        local args, n_args = {...}, select("#", ...);
        local function f() return func(unpack(args, 1, n_args)); end
+       log("debug", "SQL transaction begin [%s]", tostring(func));
        self.__transaction = true;
        local success, a, b, c = xpcall(f, debug_traceback);
        self.__transaction = nil;
@@ -220,7 +246,9 @@ function engine:_create_index(index)
        if index.unique then
                sql = sql:gsub("^CREATE", "CREATE UNIQUE");
        end
-       --print(sql);
+       if self._debug then
+               debugquery("create", sql);
+       end
        return self:execute(sql);
 end
 function engine:_create_table(table)
@@ -251,6 +279,9 @@ function engine:_create_table(table)
        elseif self.params.driver == "MySQL" then
                sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
        end
+       if self._debug then
+               debugquery("create", sql);
+       end
        local success,err = self:execute(sql);
        if not success then return success,err; end
        for i,v in ipairs(table.__table__) do
@@ -285,12 +316,24 @@ function engine:set_encoding() -- to UTF-8
        if not ok then
                return ok, err;
        end
-       
+
+       if driver == "MySQL" then
+               local ok, actual_charset = self:transaction(function ()
+                       return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
+               end);
+               for row in actual_charset do
+                       if row[2] ~= charset then
+                               log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
+                               return false, "Failed to set connection encoding";
+                       end
+               end
+       end
+
        return true;
 end
 local engine_mt = { __index = engine };
 
-function db2uri(params)
+local function db2uri(params)
        return build_url{
                scheme = params.driver,
                user = params.username,
@@ -301,8 +344,20 @@ function db2uri(params)
        };
 end
 
-function create_engine(self, params, onconnect)
+local function create_engine(self, params, onconnect)
        return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
 end
 
-return _M;
+return {
+       is_column = is_column;
+       is_index = is_index;
+       is_table = is_table;
+       is_query = is_query;
+       Integer = Integer;
+       String = String;
+       Column = Column;
+       Table = Table;
+       Index = Index;
+       create_engine = create_engine;
+       db2uri = db2uri;
+};