util.sql: Import type too (fix global access)
[prosody.git] / util / sql.lua
1
2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
4 local tonumber, tostring = tonumber, tostring;
5 local type = type;
6 local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
7 local t_concat = table.concat;
8 local s_char = string.char;
9 local log = require "util.logger".init("sql");
10
11 local DBI = require "DBI";
12 -- This loads all available drivers while globals are unlocked
13 -- LuaDBI should be fixed to not set globals.
14 DBI.Drivers();
15 local build_url = require "socket.url".build;
16
17 local _ENV = nil;
18
19 local column_mt = {};
20 local table_mt = {};
21 local query_mt = {};
22 --local op_mt = {};
23 local index_mt = {};
24
25 local function is_column(x) return getmetatable(x)==column_mt; end
26 local function is_index(x) return getmetatable(x)==index_mt; end
27 local function is_table(x) return getmetatable(x)==table_mt; end
28 local function is_query(x) return getmetatable(x)==query_mt; end
29 local function Integer() return "Integer()" end
30 local function String() return "String()" end
31
32 local function Column(definition)
33         return setmetatable(definition, column_mt);
34 end
35 local function Table(definition)
36         local c = {}
37         for i,col in ipairs(definition) do
38                 if is_column(col) then
39                         c[i], c[col.name] = col, col;
40                 elseif is_index(col) then
41                         col.table = definition.name;
42                 end
43         end
44         return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
45 end
46 local function Index(definition)
47         return setmetatable(definition, index_mt);
48 end
49
50 function table_mt:__tostring()
51         local s = { 'name="'..self.__table__.name..'"' }
52         for i,col in ipairs(self.__table__) do
53                 s[#s+1] = tostring(col);
54         end
55         return 'Table{ '..t_concat(s, ", ")..' }'
56 end
57 table_mt.__index = {};
58 function table_mt.__index:create(engine)
59         return engine:_create_table(self);
60 end
61 function table_mt:__call(...)
62         -- TODO
63 end
64 function column_mt:__tostring()
65         return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
66 end
67 function index_mt:__tostring()
68         local s = 'Index{ name="'..self.name..'"';
69         for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
70         return s..' }';
71 --      return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
72 end
73
74 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
75 local function parse_url(url)
76         local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
77         assert(scheme, "Invalid URL format");
78         local username, password, host, port;
79         local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
80         if not authpart then hostpart = secondpart; end
81         if authpart then
82                 username, password = authpart:match("([^:]*):(.*)");
83                 username = username or authpart;
84                 password = password and urldecode(password);
85         end
86         if hostpart then
87                 host, port = hostpart:match("([^:]*):(.*)");
88                 host = host or hostpart;
89                 port = port and assert(tonumber(port), "Invalid URL format");
90         end
91         return {
92                 scheme = scheme:lower();
93                 username = username; password = password;
94                 host = host; port = port;
95                 database = #database > 0 and database or nil;
96         };
97 end
98
99 local engine = {};
100 function engine:connect()
101         if self.conn then return true; end
102
103         local params = self.params;
104         assert(params.driver, "no driver")
105         log("debug", "Connecting to [%s] %s...", params.driver, params.database);
106         local ok, dbh, err = pcall(DBI.Connect,
107                 params.driver, params.database,
108                 params.username, params.password,
109                 params.host, params.port
110         );
111         if not ok then return ok, dbh; end
112         if not dbh then return nil, err; end
113         dbh:autocommit(false); -- don't commit automatically
114         self.conn = dbh;
115         self.prepared = {};
116         local ok, err = self:set_encoding();
117         if not ok then
118                 return ok, err;
119         end
120         local ok, err = self:onconnect();
121         if ok == false then
122                 return ok, err;
123         end
124         return true;
125 end
126 function engine:onconnect()
127         -- Override from create_engine()
128 end
129
130 function engine:prepquery(sql)
131         if self.params.driver == "PostgreSQL" then
132                 sql = sql:gsub("`", "\"");
133         end
134         return sql;
135 end
136
137 function engine:execute(sql, ...)
138         local success, err = self:connect();
139         if not success then return success, err; end
140         local prepared = self.prepared;
141
142         local stmt = prepared[sql];
143         if not stmt then
144                 local err;
145                 stmt, err = self.conn:prepare(sql);
146                 if not stmt then return stmt, err; end
147                 prepared[sql] = stmt;
148         end
149
150         local success, err = stmt:execute(...);
151         if not success then return success, err; end
152         return stmt;
153 end
154
155 local result_mt = { __index = {
156         affected = function(self) return self.__stmt:affected(); end;
157         rowcount = function(self) return self.__stmt:rowcount(); end;
158 } };
159
160 local function debugquery(where, sql, ...)
161         local i = 0; local a = {...}
162         log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
163 end
164
165 function engine:execute_query(sql, ...)
166         sql = self:prepquery(sql);
167         local stmt = assert(self.conn:prepare(sql));
168         assert(stmt:execute(...));
169         return stmt:rows();
170 end
171 function engine:execute_update(sql, ...)
172         sql = self:prepquery(sql);
173         local prepared = self.prepared;
174         local stmt = prepared[sql];
175         if not stmt then
176                 stmt = assert(self.conn:prepare(sql));
177                 prepared[sql] = stmt;
178         end
179         assert(stmt:execute(...));
180         return setmetatable({ __stmt = stmt }, result_mt);
181 end
182 engine.insert = engine.execute_update;
183 engine.select = engine.execute_query;
184 engine.delete = engine.execute_update;
185 engine.update = engine.execute_update;
186 local function debugwrap(name, f)
187         return function (self, sql, ...)
188                 debugquery(name, sql, ...)
189                 return f(self, sql, ...)
190         end
191 end
192 function engine:debug(enable)
193         self._debug = enable;
194         if enable then
195                 engine.insert = debugwrap("insert", engine.execute_update);
196                 engine.select = debugwrap("select", engine.execute_query);
197                 engine.delete = debugwrap("delete", engine.execute_update);
198                 engine.update = debugwrap("update", engine.execute_update);
199         else
200                 engine.insert = engine.execute_update;
201                 engine.select = engine.execute_query;
202                 engine.delete = engine.execute_update;
203                 engine.update = engine.execute_update;
204         end
205 end
206 local function handleerr(err)
207         log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
208         return err;
209 end
210 function engine:_transaction(func, ...)
211         if not self.conn then
212                 local ok, err = self:connect();
213                 if not ok then return ok, err; end
214         end
215         --assert(not self.__transaction, "Recursive transactions not allowed");
216         local args, n_args = {...}, select("#", ...);
217         local function f() return func(unpack(args, 1, n_args)); end
218         log("debug", "SQL transaction begin [%s]", tostring(func));
219         self.__transaction = true;
220         local success, a, b, c = xpcall(f, handleerr);
221         self.__transaction = nil;
222         if success then
223                 log("debug", "SQL transaction success [%s]", tostring(func));
224                 local ok, err = self.conn:commit();
225                 if not ok then return ok, err; end -- commit failed
226                 return success, a, b, c;
227         else
228                 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
229                 if self.conn then self.conn:rollback(); end
230                 return success, a;
231         end
232 end
233 function engine:transaction(...)
234         local ok, ret = self:_transaction(...);
235         if not ok then
236                 local conn = self.conn;
237                 if not conn or not conn:ping() then
238                         self.conn = nil;
239                         ok, ret = self:_transaction(...);
240                 end
241         end
242         return ok, ret;
243 end
244 function engine:_create_index(index)
245         local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
246         for i=1,#index do
247                 sql = sql.."`"..index[i].."`";
248                 if i ~= #index then sql = sql..", "; end
249         end
250         sql = sql..");"
251         if self.params.driver == "PostgreSQL" then
252                 sql = sql:gsub("`", "\"");
253         elseif self.params.driver == "MySQL" then
254                 sql = sql:gsub("`([,)])", "`(20)%1");
255         end
256         if index.unique then
257                 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
258         end
259         if self._debug then
260                 debugquery("create", sql);
261         end
262         return self:execute(sql);
263 end
264 function engine:_create_table(table)
265         local sql = "CREATE TABLE `"..table.name.."` (";
266         for i,col in ipairs(table.c) do
267                 local col_type = col.type;
268                 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
269                         col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
270                 end
271                 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
272                         col_type = "BIGSERIAL";
273                 end
274                 sql = sql.."`"..col.name.."` "..col_type;
275                 if col.nullable == false then sql = sql.." NOT NULL"; end
276                 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
277                 if col.auto_increment == true then
278                         if self.params.driver == "MySQL" then
279                                 sql = sql.." AUTO_INCREMENT";
280                         elseif self.params.driver == "SQLite3" then
281                                 sql = sql.." AUTOINCREMENT";
282                         end
283                 end
284                 if i ~= #table.c then sql = sql..", "; end
285         end
286         sql = sql.. ");"
287         if self.params.driver == "PostgreSQL" then
288                 sql = sql:gsub("`", "\"");
289         elseif self.params.driver == "MySQL" then
290                 sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
291         end
292         if self._debug then
293                 debugquery("create", sql);
294         end
295         local success,err = self:execute(sql);
296         if not success then return success,err; end
297         for i,v in ipairs(table.__table__) do
298                 if is_index(v) then
299                         self:_create_index(v);
300                 end
301         end
302         return success;
303 end
304 function engine:set_encoding() -- to UTF-8
305         local driver = self.params.driver;
306         if driver == "SQLite3" then
307                 return self:transaction(function()
308                         for encoding in self:select"PRAGMA encoding;" do
309                                 if encoding[1] == "UTF-8" then
310                                         self.charset = "utf8";
311                                 end
312                         end
313                 end);
314         end
315         local set_names_query = "SET NAMES '%s';"
316         local charset = "utf8";
317         if driver == "MySQL" then
318                 self:transaction(function()
319                         for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
320                                 charset = row and row[1] or charset;
321                         end
322                 end);
323                 set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
324         end
325         self.charset = charset;
326         log("debug", "Using encoding '%s' for database connection", charset);
327         local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
328         if not ok then
329                 return ok, err;
330         end
331
332         if driver == "MySQL" then
333                 local ok, actual_charset = self:transaction(function ()
334                         return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
335                 end);
336                 local charset_ok = true;
337                 for row in actual_charset do
338                         if row[2] ~= charset then
339                                 log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
340                                 charset_ok = false;
341                         end
342                 end
343                 if not charset_ok then
344                         return false, "Failed to set connection encoding";
345                 end
346         end
347
348         return true;
349 end
350 local engine_mt = { __index = engine };
351
352 local function db2uri(params)
353         return build_url{
354                 scheme = params.driver,
355                 user = params.username,
356                 password = params.password,
357                 host = params.host,
358                 port = params.port,
359                 path = params.database,
360         };
361 end
362
363 local function create_engine(self, params, onconnect)
364         return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
365 end
366
367 return {
368         is_column = is_column;
369         is_index = is_index;
370         is_table = is_table;
371         is_query = is_query;
372         Integer = Integer;
373         String = String;
374         Column = Column;
375         Table = Table;
376         Index = Index;
377         create_engine = create_engine;
378         db2uri = db2uri;
379 };