Merge 0.9->0.10
[prosody.git] / util / sql.lua
1
2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, unpack, select;
4 local tonumber, tostring = tonumber, tostring;
5 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
6 local t_concat = table.concat;
7 local s_char = string.char;
8 local log = require "util.logger".init("sql");
9
10 local DBI = require "DBI";
11 -- This loads all available drivers while globals are unlocked
12 -- LuaDBI should be fixed to not set globals.
13 DBI.Drivers();
14 local build_url = require "socket.url".build;
15
16 module("sql")
17
18 local column_mt = {};
19 local table_mt = {};
20 local query_mt = {};
21 --local op_mt = {};
22 local index_mt = {};
23
24 function is_column(x) return getmetatable(x)==column_mt; end
25 function is_index(x) return getmetatable(x)==index_mt; end
26 function is_table(x) return getmetatable(x)==table_mt; end
27 function is_query(x) return getmetatable(x)==query_mt; end
28 --function is_op(x) return getmetatable(x)==op_mt; end
29 --function expr(...) return setmetatable({...}, op_mt); end
30 function Integer(n) return "Integer()" end
31 function String(n) return "String()" end
32
33 --[[local ops = {
34         __add = function(a, b) return "("..a.."+"..b..")" end;
35         __sub = function(a, b) return "("..a.."-"..b..")" end;
36         __mul = function(a, b) return "("..a.."*"..b..")" end;
37         __div = function(a, b) return "("..a.."/"..b..")" end;
38         __mod = function(a, b) return "("..a.."%"..b..")" end;
39         __pow = function(a, b) return "POW("..a..","..b..")" end;
40         __unm = function(a) return "NOT("..a..")" end;
41         __len = function(a) return "COUNT("..a..")" end;
42         __eq = function(a, b) return "("..a.."=="..b..")" end;
43         __lt = function(a, b) return "("..a.."<"..b..")" end;
44         __le = function(a, b) return "("..a.."<="..b..")" end;
45 };
46
47 local functions = {
48
49 };
50
51 local cmap = {
52         [Integer] = Integer();
53         [String] = String();
54 };]]
55
56 function Column(definition)
57         return setmetatable(definition, column_mt);
58 end
59 function Table(definition)
60         local c = {}
61         for i,col in ipairs(definition) do
62                 if is_column(col) then
63                         c[i], c[col.name] = col, col;
64                 elseif is_index(col) then
65                         col.table = definition.name;
66                 end
67         end
68         return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
69 end
70 function Index(definition)
71         return setmetatable(definition, index_mt);
72 end
73
74 function table_mt:__tostring()
75         local s = { 'name="'..self.__table__.name..'"' }
76         for i,col in ipairs(self.__table__) do
77                 s[#s+1] = tostring(col);
78         end
79         return 'Table{ '..t_concat(s, ", ")..' }'
80 end
81 table_mt.__index = {};
82 function table_mt.__index:create(engine)
83         return engine:_create_table(self);
84 end
85 function table_mt:__call(...)
86         -- TODO
87 end
88 function column_mt:__tostring()
89         return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
90 end
91 function index_mt:__tostring()
92         local s = 'Index{ name="'..self.name..'"';
93         for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
94         return s..' }';
95 --      return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
96 end
97 --
98
99 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
100 local function parse_url(url)
101         local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
102         assert(scheme, "Invalid URL format");
103         local username, password, host, port;
104         local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
105         if not authpart then hostpart = secondpart; end
106         if authpart then
107                 username, password = authpart:match("([^:]*):(.*)");
108                 username = username or authpart;
109                 password = password and urldecode(password);
110         end
111         if hostpart then
112                 host, port = hostpart:match("([^:]*):(.*)");
113                 host = host or hostpart;
114                 port = port and assert(tonumber(port), "Invalid URL format");
115         end
116         return {
117                 scheme = scheme:lower();
118                 username = username; password = password;
119                 host = host; port = port;
120                 database = #database > 0 and database or nil;
121         };
122 end
123
124 --[[local session = {};
125
126 function session.query(...)
127         local rets = {...};
128         local query = setmetatable({ __rets = rets, __filters }, query_mt);
129         return query;
130 end
131 --
132
133 local function db2uri(params)
134         return build_url{
135                 scheme = params.driver,
136                 user = params.username,
137                 password = params.password,
138                 host = params.host,
139                 port = params.port,
140                 path = params.database,
141         };
142 end]]
143
144 local engine = {};
145 function engine:connect()
146         if self.conn then return true; end
147
148         local params = self.params;
149         assert(params.driver, "no driver")
150         local dbh, err = DBI.Connect(
151                 params.driver, params.database,
152                 params.username, params.password,
153                 params.host, params.port
154         );
155         if not dbh then return nil, err; end
156         dbh:autocommit(false); -- don't commit automatically
157         self.conn = dbh;
158         self.prepared = {};
159         self:set_encoding();
160         return true;
161 end
162 function engine:execute(sql, ...)
163         local success, err = self:connect();
164         if not success then return success, err; end
165         local prepared = self.prepared;
166
167         local stmt = prepared[sql];
168         if not stmt then
169                 local err;
170                 stmt, err = self.conn:prepare(sql);
171                 if not stmt then return stmt, err; end
172                 prepared[sql] = stmt;
173         end
174
175         local success, err = stmt:execute(...);
176         if not success then return success, err; end
177         return stmt;
178 end
179
180 local result_mt = { __index = {
181         affected = function(self) return self.__stmt:affected(); end;
182         rowcount = function(self) return self.__stmt:rowcount(); end;
183 } };
184
185 function engine:execute_query(sql, ...)
186         if self.params.driver == "PostgreSQL" then
187                 sql = sql:gsub("`", "\"");
188         end
189         local stmt = assert(self.conn:prepare(sql));
190         assert(stmt:execute(...));
191         return stmt:rows();
192 end
193 function engine:execute_update(sql, ...)
194         if self.params.driver == "PostgreSQL" then
195                 sql = sql:gsub("`", "\"");
196         end
197         local prepared = self.prepared;
198         local stmt = prepared[sql];
199         if not stmt then
200                 stmt = assert(self.conn:prepare(sql));
201                 prepared[sql] = stmt;
202         end
203         assert(stmt:execute(...));
204         return setmetatable({ __stmt = stmt }, result_mt);
205 end
206 engine.insert = engine.execute_update;
207 engine.select = engine.execute_query;
208 engine.delete = engine.execute_update;
209 engine.update = engine.execute_update;
210 function engine:_transaction(func, ...)
211         if not self.conn then
212                 local a,b = self:connect();
213                 if not a then return a,b; end
214         end
215         --assert(not self.__transaction, "Recursive transactions not allowed");
216         local args, n_args = {...}, select("#", ...);
217         local function f() return func(unpack(args, 1, n_args)); end
218         self.__transaction = true;
219         local success, a, b, c = xpcall(f, debug_traceback);
220         self.__transaction = nil;
221         if success then
222                 log("debug", "SQL transaction success [%s]", tostring(func));
223                 local ok, err = self.conn:commit();
224                 if not ok then return ok, err; end -- commit failed
225                 return success, a, b, c;
226         else
227                 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
228                 if self.conn then self.conn:rollback(); end
229                 return success, a;
230         end
231 end
232 function engine:transaction(...)
233         local a,b = self:_transaction(...);
234         if not a then
235                 local conn = self.conn;
236                 if not conn or not conn:ping() then
237                         self.conn = nil;
238                         a,b = self:_transaction(...);
239                 end
240         end
241         return a,b;
242 end
243 function engine:_create_index(index)
244         local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
245         for i=1,#index do
246                 sql = sql.."`"..index[i].."`";
247                 if i ~= #index then sql = sql..", "; end
248         end
249         sql = sql..");"
250         if self.params.driver == "PostgreSQL" then
251                 sql = sql:gsub("`", "\"");
252         elseif self.params.driver == "MySQL" then
253                 sql = sql:gsub("`([,)])", "`(20)%1");
254         end
255         if index.unique then
256                 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
257         end
258         --print(sql);
259         return self:execute(sql);
260 end
261 function engine:_create_table(table)
262         local sql = "CREATE TABLE `"..table.name.."` (";
263         for i,col in ipairs(table.c) do
264                 local col_type = col.type;
265                 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
266                         col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
267                 end
268                 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
269                         col_type = "BIGSERIAL";
270                 end
271                 sql = sql.."`"..col.name.."` "..col_type;
272                 if col.nullable == false then sql = sql.." NOT NULL"; end
273                 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
274                 if col.auto_increment == true then
275                         if self.params.driver == "MySQL" then
276                                 sql = sql.." AUTO_INCREMENT";
277                         elseif self.params.driver == "SQLite3" then
278                                 sql = sql.." AUTOINCREMENT";
279                         end
280                 end
281                 if i ~= #table.c then sql = sql..", "; end
282         end
283         sql = sql.. ");"
284         if self.params.driver == "PostgreSQL" then
285                 sql = sql:gsub("`", "\"");
286         elseif self.params.driver == "MySQL" then
287                 sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
288         end
289         local success,err = self:execute(sql);
290         if not success then return success,err; end
291         for i,v in ipairs(table.__table__) do
292                 if is_index(v) then
293                         self:_create_index(v);
294                 end
295         end
296         return success;
297 end
298 function engine:set_encoding() -- to UTF-8
299         local driver = self.params.driver;
300         if driver == "SQLite3" then
301                 return self:transaction(function()
302                         if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
303                                 self.charset = "utf8";
304                         end
305                 end);
306         end
307         local set_names_query = "SET NAMES '%s';"
308         local charset = "utf8";
309         if driver == "MySQL" then
310                 set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';");
311                 local ok, charsets = self:transaction(function()
312                         return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
313                 end);
314                 local row = ok and charsets();
315                 charset = row and row[1] or charset;
316         end
317         self.charset = charset;
318         return self:transaction(function() return self:execute(set_names_query:format(charset)); end);
319 end
320 local engine_mt = { __index = engine };
321
322 local function db2uri(params)
323         return build_url{
324                 scheme = params.driver,
325                 user = params.username,
326                 password = params.password,
327                 host = params.host,
328                 port = params.port,
329                 path = params.database,
330         };
331 end
332 local engine_cache = {}; -- TODO make weak valued
333 function create_engine(self, params)
334         local url = db2uri(params);
335         if not engine_cache[url] then
336                 local engine = setmetatable({ url = url, params = params }, engine_mt);
337                 engine_cache[url] = engine;
338         end
339         return engine_cache[url];
340 end
341
342
343 --[[Users = Table {
344         name="users";
345         Column { name="user_id", type=String(), primary_key=true };
346 };
347 print(Users)
348 print(Users.c.user_id)]]
349
350 --local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
351 --[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
352
353 local i = 0;
354 for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
355         i = i+1;
356         print(i);
357         for k,v in pairs(row) do
358                 print("",k,v);
359         end
360 end
361 print("---")
362
363 Prosody = Table {
364         name="prosody";
365         Column { name="host", type="TEXT", nullable=false };
366         Column { name="user", type="TEXT", nullable=false };
367         Column { name="store", type="TEXT", nullable=false };
368         Column { name="key", type="TEXT", nullable=false };
369         Column { name="type", type="TEXT", nullable=false };
370         Column { name="value", type="TEXT", nullable=false };
371         Index { name="prosody_index", "host", "user", "store", "key" };
372 };
373 --print(Prosody);
374 assert(engine:transaction(function()
375         assert(Prosody:create(engine));
376 end));
377
378 for row in assert(engine:execute("select user from prosody")):rows(true) do
379         print("username:", row['username'])
380 end
381 --result.close();]]
382
383 return _M;