X-Git-Url: https://git.enpas.org/?a=blobdiff_plain;f=util%2Fsql.lua;h=6fed1373ae9f027af12aa6cb92ce901dfcbb674e;hb=670c0bb01015bf63176eafc5c605c800c7053e1f;hp=fc1191f96a65f6ce106134e76cbd44f0333616a8;hpb=e00936f32b1958e52e8f8cadd84e5ec60df648eb;p=prosody.git diff --git a/util/sql.lua b/util/sql.lua index fc1191f9..6fed1373 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -1,8 +1,9 @@ local setmetatable, getmetatable = setmetatable, getmetatable; -local ipairs, unpack, select = ipairs, unpack, select; +local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 local tonumber, tostring = tonumber, tostring; -local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local type = type; +local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback; local t_concat = table.concat; local s_char = string.char; local log = require "util.logger".init("sql"); @@ -13,7 +14,7 @@ local DBI = require "DBI"; DBI.Drivers(); local build_url = require "socket.url".build; -module("sql") +local _ENV = nil; local column_mt = {}; local table_mt = {}; @@ -21,17 +22,17 @@ local query_mt = {}; --local op_mt = {}; local index_mt = {}; -function is_column(x) return getmetatable(x)==column_mt; end -function is_index(x) return getmetatable(x)==index_mt; end -function is_table(x) return getmetatable(x)==table_mt; end -function is_query(x) return getmetatable(x)==query_mt; end -function Integer(n) return "Integer()" end -function String(n) return "String()" end +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end +local function Integer() return "Integer()" end +local function String() return "String()" end -function Column(definition) +local function Column(definition) return setmetatable(definition, column_mt); end -function Table(definition) +local function Table(definition) local c = {} for i,col in ipairs(definition) do if is_column(col) then @@ -42,7 +43,7 @@ function Table(definition) end return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); end -function Index(definition) +local function Index(definition) return setmetatable(definition, index_mt); end @@ -101,16 +102,21 @@ function engine:connect() local params = self.params; assert(params.driver, "no driver") - local dbh, err = DBI.Connect( + log("debug", "Connecting to [%s] %s...", params.driver, params.database); + local ok, dbh, err = pcall(DBI.Connect, params.driver, params.database, params.username, params.password, params.host, params.port ); + if not ok then return ok, dbh; end if not dbh then return nil, err; end dbh:autocommit(false); -- don't commit automatically self.conn = dbh; self.prepared = {}; - self:set_encoding(); + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end local ok, err = self:onconnect(); if ok == false then return ok, err; @@ -120,6 +126,14 @@ end function engine:onconnect() -- Override from create_engine() end + +function engine:prepquery(sql) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + return sql; +end + function engine:execute(sql, ...) local success, err = self:connect(); if not success then return success, err; end @@ -143,18 +157,19 @@ local result_mt = { __index = { rowcount = function(self) return self.__stmt:rowcount(); end; } }; +local function debugquery(where, sql, ...) + local i = 0; local a = {...} + log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end)); +end + function engine:execute_query(sql, ...) - if self.params.driver == "PostgreSQL" then - sql = sql:gsub("`", "\""); - end + sql = self:prepquery(sql); local stmt = assert(self.conn:prepare(sql)); assert(stmt:execute(...)); return stmt:rows(); end function engine:execute_update(sql, ...) - if self.params.driver == "PostgreSQL" then - sql = sql:gsub("`", "\""); - end + sql = self:prepquery(sql); local prepared = self.prepared; local stmt = prepared[sql]; if not stmt then @@ -168,6 +183,30 @@ engine.insert = engine.execute_update; engine.select = engine.execute_query; engine.delete = engine.execute_update; engine.update = engine.execute_update; +local function debugwrap(name, f) + return function (self, sql, ...) + debugquery(name, sql, ...) + return f(self, sql, ...) + end +end +function engine:debug(enable) + self._debug = enable; + if enable then + engine.insert = debugwrap("insert", engine.execute_update); + engine.select = debugwrap("select", engine.execute_query); + engine.delete = debugwrap("delete", engine.execute_update); + engine.update = debugwrap("update", engine.execute_update); + else + engine.insert = engine.execute_update; + engine.select = engine.execute_query; + engine.delete = engine.execute_update; + engine.update = engine.execute_update; + end +end +local function handleerr(err) + log("error", "Error in SQL transaction: %s", debug_traceback(err, 3)); + return err; +end function engine:_transaction(func, ...) if not self.conn then local ok, err = self:connect(); @@ -176,8 +215,9 @@ function engine:_transaction(func, ...) --assert(not self.__transaction, "Recursive transactions not allowed"); local args, n_args = {...}, select("#", ...); local function f() return func(unpack(args, 1, n_args)); end + log("debug", "SQL transaction begin [%s]", tostring(func)); self.__transaction = true; - local success, a, b, c = xpcall(f, debug_traceback); + local success, a, b, c = xpcall(f, handleerr); self.__transaction = nil; if success then log("debug", "SQL transaction success [%s]", tostring(func)); @@ -216,7 +256,9 @@ function engine:_create_index(index) if index.unique then sql = sql:gsub("^CREATE", "CREATE UNIQUE"); end - --print(sql); + if self._debug then + debugquery("create", sql); + end return self:execute(sql); end function engine:_create_table(table) @@ -247,6 +289,9 @@ function engine:_create_table(table) elseif self.params.driver == "MySQL" then sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset)); end + if self._debug then + debugquery("create", sql); + end local success,err = self:execute(sql); if not success then return success,err; end for i,v in ipairs(table.__table__) do @@ -260,27 +305,51 @@ function engine:set_encoding() -- to UTF-8 local driver = self.params.driver; if driver == "SQLite3" then return self:transaction(function() - if self:select"PRAGMA encoding;"()[1] == "UTF-8" then - self.charset = "utf8"; + for encoding in self:select"PRAGMA encoding;" do + if encoding[1] == "UTF-8" then + self.charset = "utf8"; + end end end); end local set_names_query = "SET NAMES '%s';" local charset = "utf8"; if driver == "MySQL" then - local ok, charsets = self:transaction(function() - return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;"; + self:transaction(function() + for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do + charset = row and row[1] or charset; + end end); - local row = ok and charsets(); - charset = row and row[1] or charset; set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin")); end self.charset = charset; - return self:transaction(function() return self:execute(set_names_query:format(charset)); end); + log("debug", "Using encoding '%s' for database connection", charset); + local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end); + if not ok then + return ok, err; + end + + if driver == "MySQL" then + local ok, actual_charset = self:transaction(function () + return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'"; + end); + local charset_ok = true; + for row in actual_charset do + if row[2] ~= charset then + log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset); + charset_ok = false; + end + end + if not charset_ok then + return false, "Failed to set connection encoding"; + end + end + + return true; end local engine_mt = { __index = engine }; -function db2uri(params) +local function db2uri(params) return build_url{ scheme = params.driver, user = params.username, @@ -291,8 +360,20 @@ function db2uri(params) }; end -function create_engine(self, params, onconnect) +local function create_engine(self, params, onconnect) return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end -return _M; +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Integer = Integer; + String = String; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; +};