acdf9c6390ceffea97c37b6cb08b315313ae9125
[prosody.git] / util / sql.lua
1
2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
4 local tonumber, tostring = tonumber, tostring;
5 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
6 local t_concat = table.concat;
7 local s_char = string.char;
8 local log = require "util.logger".init("sql");
9
10 local DBI = require "DBI";
11 -- This loads all available drivers while globals are unlocked
12 -- LuaDBI should be fixed to not set globals.
13 DBI.Drivers();
14 local build_url = require "socket.url".build;
15
16 local _ENV = nil;
17
18 local column_mt = {};
19 local table_mt = {};
20 local query_mt = {};
21 --local op_mt = {};
22 local index_mt = {};
23
24 local function is_column(x) return getmetatable(x)==column_mt; end
25 local function is_index(x) return getmetatable(x)==index_mt; end
26 local function is_table(x) return getmetatable(x)==table_mt; end
27 local function is_query(x) return getmetatable(x)==query_mt; end
28 local function Integer() return "Integer()" end
29 local function String() return "String()" end
30
31 local function Column(definition)
32         return setmetatable(definition, column_mt);
33 end
34 local function Table(definition)
35         local c = {}
36         for i,col in ipairs(definition) do
37                 if is_column(col) then
38                         c[i], c[col.name] = col, col;
39                 elseif is_index(col) then
40                         col.table = definition.name;
41                 end
42         end
43         return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
44 end
45 local function Index(definition)
46         return setmetatable(definition, index_mt);
47 end
48
49 function table_mt:__tostring()
50         local s = { 'name="'..self.__table__.name..'"' }
51         for i,col in ipairs(self.__table__) do
52                 s[#s+1] = tostring(col);
53         end
54         return 'Table{ '..t_concat(s, ", ")..' }'
55 end
56 table_mt.__index = {};
57 function table_mt.__index:create(engine)
58         return engine:_create_table(self);
59 end
60 function table_mt:__call(...)
61         -- TODO
62 end
63 function column_mt:__tostring()
64         return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
65 end
66 function index_mt:__tostring()
67         local s = 'Index{ name="'..self.name..'"';
68         for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
69         return s..' }';
70 --      return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
71 end
72
73 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
74 local function parse_url(url)
75         local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
76         assert(scheme, "Invalid URL format");
77         local username, password, host, port;
78         local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
79         if not authpart then hostpart = secondpart; end
80         if authpart then
81                 username, password = authpart:match("([^:]*):(.*)");
82                 username = username or authpart;
83                 password = password and urldecode(password);
84         end
85         if hostpart then
86                 host, port = hostpart:match("([^:]*):(.*)");
87                 host = host or hostpart;
88                 port = port and assert(tonumber(port), "Invalid URL format");
89         end
90         return {
91                 scheme = scheme:lower();
92                 username = username; password = password;
93                 host = host; port = port;
94                 database = #database > 0 and database or nil;
95         };
96 end
97
98 local engine = {};
99 function engine:connect()
100         if self.conn then return true; end
101
102         local params = self.params;
103         assert(params.driver, "no driver")
104         log("debug", "Connecting to [%s] %s...", params.driver, params.database);
105         local ok, dbh, err = pcall(DBI.Connect,
106                 params.driver, params.database,
107                 params.username, params.password,
108                 params.host, params.port
109         );
110         if not ok then return ok, dbh; end
111         if not dbh then return nil, err; end
112         dbh:autocommit(false); -- don't commit automatically
113         self.conn = dbh;
114         self.prepared = {};
115         local ok, err = self:set_encoding();
116         if not ok then
117                 return ok, err;
118         end
119         local ok, err = self:onconnect();
120         if ok == false then
121                 return ok, err;
122         end
123         return true;
124 end
125 function engine:onconnect()
126         -- Override from create_engine()
127 end
128
129 function engine:prepquery(sql)
130         if self.params.driver == "PostgreSQL" then
131                 sql = sql:gsub("`", "\"");
132         end
133         return sql;
134 end
135
136 function engine:execute(sql, ...)
137         local success, err = self:connect();
138         if not success then return success, err; end
139         local prepared = self.prepared;
140
141         local stmt = prepared[sql];
142         if not stmt then
143                 local err;
144                 stmt, err = self.conn:prepare(sql);
145                 if not stmt then return stmt, err; end
146                 prepared[sql] = stmt;
147         end
148
149         local success, err = stmt:execute(...);
150         if not success then return success, err; end
151         return stmt;
152 end
153
154 local result_mt = { __index = {
155         affected = function(self) return self.__stmt:affected(); end;
156         rowcount = function(self) return self.__stmt:rowcount(); end;
157 } };
158
159 local function debugquery(where, sql, ...)
160         local i = 0; local a = {...}
161         log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
162 end
163
164 function engine:execute_query(sql, ...)
165         sql = self:prepquery(sql);
166         local stmt = assert(self.conn:prepare(sql));
167         assert(stmt:execute(...));
168         return stmt:rows();
169 end
170 function engine:execute_update(sql, ...)
171         sql = self:prepquery(sql);
172         local prepared = self.prepared;
173         local stmt = prepared[sql];
174         if not stmt then
175                 stmt = assert(self.conn:prepare(sql));
176                 prepared[sql] = stmt;
177         end
178         assert(stmt:execute(...));
179         return setmetatable({ __stmt = stmt }, result_mt);
180 end
181 engine.insert = engine.execute_update;
182 engine.select = engine.execute_query;
183 engine.delete = engine.execute_update;
184 engine.update = engine.execute_update;
185 local function debugwrap(name, f)
186         return function (self, sql, ...)
187                 debugquery(name, sql, ...)
188                 return f(self, sql, ...)
189         end
190 end
191 function engine:debug(enable)
192         self._debug = enable;
193         if enable then
194                 engine.insert = debugwrap("insert", engine.execute_update);
195                 engine.select = debugwrap("select", engine.execute_query);
196                 engine.delete = debugwrap("delete", engine.execute_update);
197                 engine.update = debugwrap("update", engine.execute_update);
198         else
199                 engine.insert = engine.execute_update;
200                 engine.select = engine.execute_query;
201                 engine.delete = engine.execute_update;
202                 engine.update = engine.execute_update;
203         end
204 end
205 local function handleerr(err)
206         log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
207         return err;
208 end
209 function engine:_transaction(func, ...)
210         if not self.conn then
211                 local ok, err = self:connect();
212                 if not ok then return ok, err; end
213         end
214         --assert(not self.__transaction, "Recursive transactions not allowed");
215         local args, n_args = {...}, select("#", ...);
216         local function f() return func(unpack(args, 1, n_args)); end
217         log("debug", "SQL transaction begin [%s]", tostring(func));
218         self.__transaction = true;
219         local success, a, b, c = xpcall(f, handleerr);
220         self.__transaction = nil;
221         if success then
222                 log("debug", "SQL transaction success [%s]", tostring(func));
223                 local ok, err = self.conn:commit();
224                 if not ok then return ok, err; end -- commit failed
225                 return success, a, b, c;
226         else
227                 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
228                 if self.conn then self.conn:rollback(); end
229                 return success, a;
230         end
231 end
232 function engine:transaction(...)
233         local ok, ret = self:_transaction(...);
234         if not ok then
235                 local conn = self.conn;
236                 if not conn or not conn:ping() then
237                         self.conn = nil;
238                         ok, ret = self:_transaction(...);
239                 end
240         end
241         return ok, ret;
242 end
243 function engine:_create_index(index)
244         local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
245         for i=1,#index do
246                 sql = sql.."`"..index[i].."`";
247                 if i ~= #index then sql = sql..", "; end
248         end
249         sql = sql..");"
250         if self.params.driver == "PostgreSQL" then
251                 sql = sql:gsub("`", "\"");
252         elseif self.params.driver == "MySQL" then
253                 sql = sql:gsub("`([,)])", "`(20)%1");
254         end
255         if index.unique then
256                 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
257         end
258         if self._debug then
259                 debugquery("create", sql);
260         end
261         return self:execute(sql);
262 end
263 function engine:_create_table(table)
264         local sql = "CREATE TABLE `"..table.name.."` (";
265         for i,col in ipairs(table.c) do
266                 local col_type = col.type;
267                 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
268                         col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
269                 end
270                 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
271                         col_type = "BIGSERIAL";
272                 end
273                 sql = sql.."`"..col.name.."` "..col_type;
274                 if col.nullable == false then sql = sql.." NOT NULL"; end
275                 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
276                 if col.auto_increment == true then
277                         if self.params.driver == "MySQL" then
278                                 sql = sql.." AUTO_INCREMENT";
279                         elseif self.params.driver == "SQLite3" then
280                                 sql = sql.." AUTOINCREMENT";
281                         end
282                 end
283                 if i ~= #table.c then sql = sql..", "; end
284         end
285         sql = sql.. ");"
286         if self.params.driver == "PostgreSQL" then
287                 sql = sql:gsub("`", "\"");
288         elseif self.params.driver == "MySQL" then
289                 sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
290         end
291         if self._debug then
292                 debugquery("create", sql);
293         end
294         local success,err = self:execute(sql);
295         if not success then return success,err; end
296         for i,v in ipairs(table.__table__) do
297                 if is_index(v) then
298                         self:_create_index(v);
299                 end
300         end
301         return success;
302 end
303 function engine:set_encoding() -- to UTF-8
304         local driver = self.params.driver;
305         if driver == "SQLite3" then
306                 return self:transaction(function()
307                         for encoding in self:select"PRAGMA encoding;" do
308                                 if encoding[1] == "UTF-8" then
309                                         self.charset = "utf8";
310                                 end
311                         end
312                 end);
313         end
314         local set_names_query = "SET NAMES '%s';"
315         local charset = "utf8";
316         if driver == "MySQL" then
317                 self:transaction(function()
318                         for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
319                                 charset = row and row[1] or charset;
320                         end
321                 end);
322                 set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
323         end
324         self.charset = charset;
325         log("debug", "Using encoding '%s' for database connection", charset);
326         local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
327         if not ok then
328                 return ok, err;
329         end
330
331         if driver == "MySQL" then
332                 local ok, actual_charset = self:transaction(function ()
333                         return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
334                 end);
335                 local charset_ok = true;
336                 for row in actual_charset do
337                         if row[2] ~= charset then
338                                 log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
339                                 charset_ok = false;
340                         end
341                 end
342                 if not charset_ok then
343                         return false, "Failed to set connection encoding";
344                 end
345         end
346
347         return true;
348 end
349 local engine_mt = { __index = engine };
350
351 local function db2uri(params)
352         return build_url{
353                 scheme = params.driver,
354                 user = params.username,
355                 password = params.password,
356                 host = params.host,
357                 port = params.port,
358                 path = params.database,
359         };
360 end
361
362 local function create_engine(self, params, onconnect)
363         return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
364 end
365
366 return {
367         is_column = is_column;
368         is_index = is_index;
369         is_table = is_table;
370         is_query = is_query;
371         Integer = Integer;
372         String = String;
373         Column = Column;
374         Table = Table;
375         Index = Index;
376         create_engine = create_engine;
377         db2uri = db2uri;
378 };