util.sql: Return failure if set_encoding() fails
[prosody.git] / util / sql.lua
index 4e63bed76dff477748a936c389a4a1a60e318111..d0da930279fcf7ef1a17029d12c373100a98d0b6 100644 (file)
@@ -25,34 +25,9 @@ function is_column(x) return getmetatable(x)==column_mt; end
 function is_index(x) return getmetatable(x)==index_mt; end
 function is_table(x) return getmetatable(x)==table_mt; end
 function is_query(x) return getmetatable(x)==query_mt; end
---function is_op(x) return getmetatable(x)==op_mt; end
---function expr(...) return setmetatable({...}, op_mt); end
 function Integer(n) return "Integer()" end
 function String(n) return "String()" end
 
---[[local ops = {
-       __add = function(a, b) return "("..a.."+"..b..")" end;
-       __sub = function(a, b) return "("..a.."-"..b..")" end;
-       __mul = function(a, b) return "("..a.."*"..b..")" end;
-       __div = function(a, b) return "("..a.."/"..b..")" end;
-       __mod = function(a, b) return "("..a.."%"..b..")" end;
-       __pow = function(a, b) return "POW("..a..","..b..")" end;
-       __unm = function(a) return "NOT("..a..")" end;
-       __len = function(a) return "COUNT("..a..")" end;
-       __eq = function(a, b) return "("..a.."=="..b..")" end;
-       __lt = function(a, b) return "("..a.."<"..b..")" end;
-       __le = function(a, b) return "("..a.."<="..b..")" end;
-};
-
-local functions = {
-
-};
-
-local cmap = {
-       [Integer] = Integer();
-       [String] = String();
-};]]
-
 function Column(definition)
        return setmetatable(definition, column_mt);
 end
@@ -94,7 +69,6 @@ function index_mt:__tostring()
        return s..' }';
 --     return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
 end
---
 
 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
 local function parse_url(url)
@@ -121,26 +95,6 @@ local function parse_url(url)
        };
 end
 
---[[local session = {};
-
-function session.query(...)
-       local rets = {...};
-       local query = setmetatable({ __rets = rets, __filters }, query_mt);
-       return query;
-end
---
-
-local function db2uri(params)
-       return build_url{
-               scheme = params.driver,
-               user = params.username,
-               password = params.password,
-               host = params.host,
-               port = params.port,
-               path = params.database,
-       };
-end]]
-
 local engine = {};
 function engine:connect()
        if self.conn then return true; end
@@ -156,8 +110,19 @@ function engine:connect()
        dbh:autocommit(false); -- don't commit automatically
        self.conn = dbh;
        self.prepared = {};
+       local ok, err = self:set_encoding();
+       if not ok then
+               return ok, err;
+       end
+       local ok, err = self:onconnect();
+       if ok == false then
+               return ok, err;
+       end
        return true;
 end
+function engine:onconnect()
+       -- Override from create_engine()
+end
 function engine:execute(sql, ...)
        local success, err = self:connect();
        if not success then return success, err; end
@@ -208,8 +173,8 @@ engine.delete = engine.execute_update;
 engine.update = engine.execute_update;
 function engine:_transaction(func, ...)
        if not self.conn then
-               local a,b = self:connect();
-               if not a then return a,b; end
+               local ok, err = self:connect();
+               if not ok then return ok, err; end
        end
        --assert(not self.__transaction, "Recursive transactions not allowed");
        local args, n_args = {...}, select("#", ...);
@@ -229,15 +194,15 @@ function engine:_transaction(func, ...)
        end
 end
 function engine:transaction(...)
-       local a,b = self:_transaction(...);
-       if not a then
+       local ok, ret = self:_transaction(...);
+       if not ok then
                local conn = self.conn;
                if not conn or not conn:ping() then
                        self.conn = nil;
-                       a,b = self:_transaction(...);
+                       ok, ret = self:_transaction(...);
                end
        end
-       return a,b;
+       return ok, ret;
 end
 function engine:_create_index(index)
        local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
@@ -260,13 +225,18 @@ end
 function engine:_create_table(table)
        local sql = "CREATE TABLE `"..table.name.."` (";
        for i,col in ipairs(table.c) do
-               sql = sql.."`"..col.name.."` "..col.type;
+               local col_type = col.type;
+               if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
+                       col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
+               end
+               if col.auto_increment == true and self.params.driver == "PostgreSQL" then
+                       col_type = "BIGSERIAL";
+               end
+               sql = sql.."`"..col.name.."` "..col_type;
                if col.nullable == false then sql = sql.." NOT NULL"; end
                if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
                if col.auto_increment == true then
-                       if self.params.driver == "PostgreSQL" then
-                               sql = sql.." SERIAL";
-                       elseif self.params.driver == "MySQL" then
+                       if self.params.driver == "MySQL" then
                                sql = sql.." AUTO_INCREMENT";
                        elseif self.params.driver == "SQLite3" then
                                sql = sql.." AUTOINCREMENT";
@@ -278,7 +248,7 @@ function engine:_create_table(table)
        if self.params.driver == "PostgreSQL" then
                sql = sql:gsub("`", "\"");
        elseif self.params.driver == "MySQL" then
-               sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
+               sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
        end
        local success,err = self:execute(sql);
        if not success then return success,err; end
@@ -290,24 +260,30 @@ function engine:_create_table(table)
        return success;
 end
 function engine:set_encoding() -- to UTF-8
-       if self.params.driver == "SQLite3" then return end
        local driver = self.params.driver;
+       if driver == "SQLite3" then
+               return self:transaction(function()
+                       if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
+                               self.charset = "utf8";
+                       end
+               end);
+       end
        local set_names_query = "SET NAMES '%s';"
        local charset = "utf8";
        if driver == "MySQL" then
-               set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';");
                local ok, charsets = self:transaction(function()
-                       return self:select"SELECT `CHARACTER_SET_NAME` FROM `CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
+                       return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
                end);
                local row = ok and charsets();
                charset = row and row[1] or charset;
+               set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
        end
        self.charset = charset;
-       return self:transaction(function() return engine:execute(set_names_query:format(charset)); end);
+       return self:transaction(function() return self:execute(set_names_query:format(charset)); end);
 end
 local engine_mt = { __index = engine };
 
-local function db2uri(params)
+function db2uri(params)
        return build_url{
                scheme = params.driver,
                user = params.username,
@@ -317,55 +293,9 @@ local function db2uri(params)
                path = params.database,
        };
 end
-local engine_cache = {}; -- TODO make weak valued
-function create_engine(self, params)
-       local url = db2uri(params);
-       if not engine_cache[url] then
-               local engine = setmetatable({ url = url, params = params }, engine_mt);
-               engine_cache[url] = engine;
-       end
-       return engine_cache[url];
-end
-
-
---[[Users = Table {
-       name="users";
-       Column { name="user_id", type=String(), primary_key=true };
-};
-print(Users)
-print(Users.c.user_id)]]
-
---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
-
-local i = 0;
-for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
-       i = i+1;
-       print(i);
-       for k,v in pairs(row) do
-               print("",k,v);
-       end
-end
-print("---")
-
-Prosody = Table {
-       name="prosody";
-       Column { name="host", type="TEXT", nullable=false };
-       Column { name="user", type="TEXT", nullable=false };
-       Column { name="store", type="TEXT", nullable=false };
-       Column { name="key", type="TEXT", nullable=false };
-       Column { name="type", type="TEXT", nullable=false };
-       Column { name="value", type="TEXT", nullable=false };
-       Index { name="prosody_index", "host", "user", "store", "key" };
-};
---print(Prosody);
-assert(engine:transaction(function()
-       assert(Prosody:create(engine));
-end));
 
-for row in assert(engine:execute("select user from prosody")):rows(true) do
-       print("username:", row['username'])
+function create_engine(self, params, onconnect)
+       return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
 end
---result.close();]]
 
 return _M;