2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
4 local tonumber, tostring = tonumber, tostring;
6 local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
7 local t_concat = table.concat;
8 local s_char = string.char;
9 local log = require "util.logger".init("sql");
11 local DBI = require "DBI";
12 -- This loads all available drivers while globals are unlocked
13 -- LuaDBI should be fixed to not set globals.
15 local build_url = require "socket.url".build;
25 local function is_column(x) return getmetatable(x)==column_mt; end
26 local function is_index(x) return getmetatable(x)==index_mt; end
27 local function is_table(x) return getmetatable(x)==table_mt; end
28 local function is_query(x) return getmetatable(x)==query_mt; end
29 local function Integer() return "Integer()" end
30 local function String() return "String()" end
32 local function Column(definition)
33 return setmetatable(definition, column_mt);
35 local function Table(definition)
37 for i,col in ipairs(definition) do
38 if is_column(col) then
39 c[i], c[col.name] = col, col;
40 elseif is_index(col) then
41 col.table = definition.name;
44 return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
46 local function Index(definition)
47 return setmetatable(definition, index_mt);
50 function table_mt:__tostring()
51 local s = { 'name="'..self.__table__.name..'"' }
52 for i,col in ipairs(self.__table__) do
53 s[#s+1] = tostring(col);
55 return 'Table{ '..t_concat(s, ", ")..' }'
57 table_mt.__index = {};
58 function table_mt.__index:create(engine)
59 return engine:_create_table(self);
61 function table_mt:__call(...)
64 function column_mt:__tostring()
65 return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
67 function index_mt:__tostring()
68 local s = 'Index{ name="'..self.name..'"';
69 for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
71 -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
74 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
75 local function parse_url(url)
76 local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
77 assert(scheme, "Invalid URL format");
78 local username, password, host, port;
79 local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
80 if not authpart then hostpart = secondpart; end
82 username, password = authpart:match("([^:]*):(.*)");
83 username = username or authpart;
84 password = password and urldecode(password);
87 host, port = hostpart:match("([^:]*):(.*)");
88 host = host or hostpart;
89 port = port and assert(tonumber(port), "Invalid URL format");
92 scheme = scheme:lower();
93 username = username; password = password;
94 host = host; port = port;
95 database = #database > 0 and database or nil;
100 function engine:connect()
101 if self.conn then return true; end
103 local params = self.params;
104 assert(params.driver, "no driver")
105 log("debug", "Connecting to [%s] %s...", params.driver, params.database);
106 local ok, dbh, err = pcall(DBI.Connect,
107 params.driver, params.database,
108 params.username, params.password,
109 params.host, params.port
111 if not ok then return ok, dbh; end
112 if not dbh then return nil, err; end
113 dbh:autocommit(false); -- don't commit automatically
116 local ok, err = self:set_encoding();
120 local ok, err = self:onconnect();
126 function engine:onconnect()
127 -- Override from create_engine()
130 function engine:prepquery(sql)
131 if self.params.driver == "PostgreSQL" then
132 sql = sql:gsub("`", "\"");
137 function engine:execute(sql, ...)
138 local success, err = self:connect();
139 if not success then return success, err; end
140 local prepared = self.prepared;
142 local stmt = prepared[sql];
145 stmt, err = self.conn:prepare(sql);
146 if not stmt then return stmt, err; end
147 prepared[sql] = stmt;
150 local success, err = stmt:execute(...);
151 if not success then return success, err; end
155 local result_mt = { __index = {
156 affected = function(self) return self.__stmt:affected(); end;
157 rowcount = function(self) return self.__stmt:rowcount(); end;
160 local function debugquery(where, sql, ...)
161 local i = 0; local a = {...}
162 log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
165 function engine:execute_query(sql, ...)
166 sql = self:prepquery(sql);
167 local stmt = assert(self.conn:prepare(sql));
168 assert(stmt:execute(...));
171 function engine:execute_update(sql, ...)
172 sql = self:prepquery(sql);
173 local prepared = self.prepared;
174 local stmt = prepared[sql];
176 stmt = assert(self.conn:prepare(sql));
177 prepared[sql] = stmt;
179 assert(stmt:execute(...));
180 return setmetatable({ __stmt = stmt }, result_mt);
182 engine.insert = engine.execute_update;
183 engine.select = engine.execute_query;
184 engine.delete = engine.execute_update;
185 engine.update = engine.execute_update;
186 local function debugwrap(name, f)
187 return function (self, sql, ...)
188 debugquery(name, sql, ...)
189 return f(self, sql, ...)
192 function engine:debug(enable)
193 self._debug = enable;
195 engine.insert = debugwrap("insert", engine.execute_update);
196 engine.select = debugwrap("select", engine.execute_query);
197 engine.delete = debugwrap("delete", engine.execute_update);
198 engine.update = debugwrap("update", engine.execute_update);
200 engine.insert = engine.execute_update;
201 engine.select = engine.execute_query;
202 engine.delete = engine.execute_update;
203 engine.update = engine.execute_update;
206 local function handleerr(err)
207 log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
210 function engine:_transaction(func, ...)
211 if not self.conn then
212 local ok, err = self:connect();
213 if not ok then return ok, err; end
215 --assert(not self.__transaction, "Recursive transactions not allowed");
216 local args, n_args = {...}, select("#", ...);
217 local function f() return func(unpack(args, 1, n_args)); end
218 log("debug", "SQL transaction begin [%s]", tostring(func));
219 self.__transaction = true;
220 local success, a, b, c = xpcall(f, handleerr);
221 self.__transaction = nil;
223 log("debug", "SQL transaction success [%s]", tostring(func));
224 local ok, err = self.conn:commit();
225 if not ok then return ok, err; end -- commit failed
226 return success, a, b, c;
228 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
229 if self.conn then self.conn:rollback(); end
233 function engine:transaction(...)
234 local ok, ret = self:_transaction(...);
236 local conn = self.conn;
237 if not conn or not conn:ping() then
239 ok, ret = self:_transaction(...);
244 function engine:_create_index(index)
245 local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
247 sql = sql.."`"..index[i].."`";
248 if i ~= #index then sql = sql..", "; end
251 if self.params.driver == "PostgreSQL" then
252 sql = sql:gsub("`", "\"");
253 elseif self.params.driver == "MySQL" then
254 sql = sql:gsub("`([,)])", "`(20)%1");
257 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
260 debugquery("create", sql);
262 return self:execute(sql);
264 function engine:_create_table(table)
265 local sql = "CREATE TABLE `"..table.name.."` (";
266 for i,col in ipairs(table.c) do
267 local col_type = col.type;
268 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
269 col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
271 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
272 col_type = "BIGSERIAL";
274 sql = sql.."`"..col.name.."` "..col_type;
275 if col.nullable == false then sql = sql.." NOT NULL"; end
276 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
277 if col.auto_increment == true then
278 if self.params.driver == "MySQL" then
279 sql = sql.." AUTO_INCREMENT";
280 elseif self.params.driver == "SQLite3" then
281 sql = sql.." AUTOINCREMENT";
284 if i ~= #table.c then sql = sql..", "; end
287 if self.params.driver == "PostgreSQL" then
288 sql = sql:gsub("`", "\"");
289 elseif self.params.driver == "MySQL" then
290 sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
293 debugquery("create", sql);
295 local success,err = self:execute(sql);
296 if not success then return success,err; end
297 for i,v in ipairs(table.__table__) do
299 self:_create_index(v);
304 function engine:set_encoding() -- to UTF-8
305 local driver = self.params.driver;
306 if driver == "SQLite3" then
307 return self:transaction(function()
308 for encoding in self:select"PRAGMA encoding;" do
309 if encoding[1] == "UTF-8" then
310 self.charset = "utf8";
315 local set_names_query = "SET NAMES '%s';"
316 local charset = "utf8";
317 if driver == "MySQL" then
318 self:transaction(function()
319 for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
320 charset = row and row[1] or charset;
323 set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
325 self.charset = charset;
326 log("debug", "Using encoding '%s' for database connection", charset);
327 local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
332 if driver == "MySQL" then
333 local ok, actual_charset = self:transaction(function ()
334 return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
336 local charset_ok = true;
337 for row in actual_charset do
338 if row[2] ~= charset then
339 log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
343 if not charset_ok then
344 return false, "Failed to set connection encoding";
350 local engine_mt = { __index = engine };
352 local function db2uri(params)
354 scheme = params.driver,
355 user = params.username,
356 password = params.password,
359 path = params.database,
363 local function create_engine(self, params, onconnect)
364 return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
368 is_column = is_column;
377 create_engine = create_engine;