util.sql: Remove built-in engine caching. This is the wrong layer to do this, and...
[prosody.git] / util / sql.lua
1
2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, unpack, select;
4 local tonumber, tostring = tonumber, tostring;
5 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
6 local t_concat = table.concat;
7 local s_char = string.char;
8 local log = require "util.logger".init("sql");
9
10 local DBI = require "DBI";
11 -- This loads all available drivers while globals are unlocked
12 -- LuaDBI should be fixed to not set globals.
13 DBI.Drivers();
14 local build_url = require "socket.url".build;
15
16 module("sql")
17
18 local column_mt = {};
19 local table_mt = {};
20 local query_mt = {};
21 --local op_mt = {};
22 local index_mt = {};
23
24 function is_column(x) return getmetatable(x)==column_mt; end
25 function is_index(x) return getmetatable(x)==index_mt; end
26 function is_table(x) return getmetatable(x)==table_mt; end
27 function is_query(x) return getmetatable(x)==query_mt; end
28 function Integer(n) return "Integer()" end
29 function String(n) return "String()" end
30
31 function Column(definition)
32         return setmetatable(definition, column_mt);
33 end
34 function Table(definition)
35         local c = {}
36         for i,col in ipairs(definition) do
37                 if is_column(col) then
38                         c[i], c[col.name] = col, col;
39                 elseif is_index(col) then
40                         col.table = definition.name;
41                 end
42         end
43         return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
44 end
45 function Index(definition)
46         return setmetatable(definition, index_mt);
47 end
48
49 function table_mt:__tostring()
50         local s = { 'name="'..self.__table__.name..'"' }
51         for i,col in ipairs(self.__table__) do
52                 s[#s+1] = tostring(col);
53         end
54         return 'Table{ '..t_concat(s, ", ")..' }'
55 end
56 table_mt.__index = {};
57 function table_mt.__index:create(engine)
58         return engine:_create_table(self);
59 end
60 function table_mt:__call(...)
61         -- TODO
62 end
63 function column_mt:__tostring()
64         return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
65 end
66 function index_mt:__tostring()
67         local s = 'Index{ name="'..self.name..'"';
68         for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
69         return s..' }';
70 --      return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
71 end
72
73 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
74 local function parse_url(url)
75         local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
76         assert(scheme, "Invalid URL format");
77         local username, password, host, port;
78         local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
79         if not authpart then hostpart = secondpart; end
80         if authpart then
81                 username, password = authpart:match("([^:]*):(.*)");
82                 username = username or authpart;
83                 password = password and urldecode(password);
84         end
85         if hostpart then
86                 host, port = hostpart:match("([^:]*):(.*)");
87                 host = host or hostpart;
88                 port = port and assert(tonumber(port), "Invalid URL format");
89         end
90         return {
91                 scheme = scheme:lower();
92                 username = username; password = password;
93                 host = host; port = port;
94                 database = #database > 0 and database or nil;
95         };
96 end
97
98 local engine = {};
99 function engine:connect()
100         if self.conn then return true; end
101
102         local params = self.params;
103         assert(params.driver, "no driver")
104         local dbh, err = DBI.Connect(
105                 params.driver, params.database,
106                 params.username, params.password,
107                 params.host, params.port
108         );
109         if not dbh then return nil, err; end
110         dbh:autocommit(false); -- don't commit automatically
111         self.conn = dbh;
112         self.prepared = {};
113         self:set_encoding();
114         return true;
115 end
116 function engine:execute(sql, ...)
117         local success, err = self:connect();
118         if not success then return success, err; end
119         local prepared = self.prepared;
120
121         local stmt = prepared[sql];
122         if not stmt then
123                 local err;
124                 stmt, err = self.conn:prepare(sql);
125                 if not stmt then return stmt, err; end
126                 prepared[sql] = stmt;
127         end
128
129         local success, err = stmt:execute(...);
130         if not success then return success, err; end
131         return stmt;
132 end
133
134 local result_mt = { __index = {
135         affected = function(self) return self.__stmt:affected(); end;
136         rowcount = function(self) return self.__stmt:rowcount(); end;
137 } };
138
139 function engine:execute_query(sql, ...)
140         if self.params.driver == "PostgreSQL" then
141                 sql = sql:gsub("`", "\"");
142         end
143         local stmt = assert(self.conn:prepare(sql));
144         assert(stmt:execute(...));
145         return stmt:rows();
146 end
147 function engine:execute_update(sql, ...)
148         if self.params.driver == "PostgreSQL" then
149                 sql = sql:gsub("`", "\"");
150         end
151         local prepared = self.prepared;
152         local stmt = prepared[sql];
153         if not stmt then
154                 stmt = assert(self.conn:prepare(sql));
155                 prepared[sql] = stmt;
156         end
157         assert(stmt:execute(...));
158         return setmetatable({ __stmt = stmt }, result_mt);
159 end
160 engine.insert = engine.execute_update;
161 engine.select = engine.execute_query;
162 engine.delete = engine.execute_update;
163 engine.update = engine.execute_update;
164 function engine:_transaction(func, ...)
165         if not self.conn then
166                 local ok, err = self:connect();
167                 if not ok then return ok, err; end
168         end
169         --assert(not self.__transaction, "Recursive transactions not allowed");
170         local args, n_args = {...}, select("#", ...);
171         local function f() return func(unpack(args, 1, n_args)); end
172         self.__transaction = true;
173         local success, a, b, c = xpcall(f, debug_traceback);
174         self.__transaction = nil;
175         if success then
176                 log("debug", "SQL transaction success [%s]", tostring(func));
177                 local ok, err = self.conn:commit();
178                 if not ok then return ok, err; end -- commit failed
179                 return success, a, b, c;
180         else
181                 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
182                 if self.conn then self.conn:rollback(); end
183                 return success, a;
184         end
185 end
186 function engine:transaction(...)
187         local ok, ret = self:_transaction(...);
188         if not ok then
189                 local conn = self.conn;
190                 if not conn or not conn:ping() then
191                         self.conn = nil;
192                         ok, ret = self:_transaction(...);
193                 end
194         end
195         return ok, ret;
196 end
197 function engine:_create_index(index)
198         local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
199         for i=1,#index do
200                 sql = sql.."`"..index[i].."`";
201                 if i ~= #index then sql = sql..", "; end
202         end
203         sql = sql..");"
204         if self.params.driver == "PostgreSQL" then
205                 sql = sql:gsub("`", "\"");
206         elseif self.params.driver == "MySQL" then
207                 sql = sql:gsub("`([,)])", "`(20)%1");
208         end
209         if index.unique then
210                 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
211         end
212         --print(sql);
213         return self:execute(sql);
214 end
215 function engine:_create_table(table)
216         local sql = "CREATE TABLE `"..table.name.."` (";
217         for i,col in ipairs(table.c) do
218                 local col_type = col.type;
219                 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
220                         col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
221                 end
222                 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
223                         col_type = "BIGSERIAL";
224                 end
225                 sql = sql.."`"..col.name.."` "..col_type;
226                 if col.nullable == false then sql = sql.." NOT NULL"; end
227                 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
228                 if col.auto_increment == true then
229                         if self.params.driver == "MySQL" then
230                                 sql = sql.." AUTO_INCREMENT";
231                         elseif self.params.driver == "SQLite3" then
232                                 sql = sql.." AUTOINCREMENT";
233                         end
234                 end
235                 if i ~= #table.c then sql = sql..", "; end
236         end
237         sql = sql.. ");"
238         if self.params.driver == "PostgreSQL" then
239                 sql = sql:gsub("`", "\"");
240         elseif self.params.driver == "MySQL" then
241                 sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
242         end
243         local success,err = self:execute(sql);
244         if not success then return success,err; end
245         for i,v in ipairs(table.__table__) do
246                 if is_index(v) then
247                         self:_create_index(v);
248                 end
249         end
250         return success;
251 end
252 function engine:set_encoding() -- to UTF-8
253         local driver = self.params.driver;
254         if driver == "SQLite3" then
255                 return self:transaction(function()
256                         if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
257                                 self.charset = "utf8";
258                         end
259                 end);
260         end
261         local set_names_query = "SET NAMES '%s';"
262         local charset = "utf8";
263         if driver == "MySQL" then
264                 set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';");
265                 local ok, charsets = self:transaction(function()
266                         return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
267                 end);
268                 local row = ok and charsets();
269                 charset = row and row[1] or charset;
270         end
271         self.charset = charset;
272         return self:transaction(function() return self:execute(set_names_query:format(charset)); end);
273 end
274 local engine_mt = { __index = engine };
275
276 function db2uri(params)
277         return build_url{
278                 scheme = params.driver,
279                 user = params.username,
280                 password = params.password,
281                 host = params.host,
282                 port = params.port,
283                 path = params.database,
284         };
285 end
286
287 function create_engine(self, params, onconnect)
288         return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
289 end
290
291 return _M;