2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs, unpack, select = ipairs, unpack, select;
4 local tonumber, tostring = tonumber, tostring;
5 local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
6 local t_concat = table.concat;
7 local s_char = string.char;
8 local log = require "util.logger".init("sql");
10 local DBI = require "DBI";
11 -- This loads all available drivers while globals are unlocked
12 -- LuaDBI should be fixed to not set globals.
14 local build_url = require "socket.url".build;
24 function is_column(x) return getmetatable(x)==column_mt; end
25 function is_index(x) return getmetatable(x)==index_mt; end
26 function is_table(x) return getmetatable(x)==table_mt; end
27 function is_query(x) return getmetatable(x)==query_mt; end
28 function Integer(n) return "Integer()" end
29 function String(n) return "String()" end
31 function Column(definition)
32 return setmetatable(definition, column_mt);
34 function Table(definition)
36 for i,col in ipairs(definition) do
37 if is_column(col) then
38 c[i], c[col.name] = col, col;
39 elseif is_index(col) then
40 col.table = definition.name;
43 return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
45 function Index(definition)
46 return setmetatable(definition, index_mt);
49 function table_mt:__tostring()
50 local s = { 'name="'..self.__table__.name..'"' }
51 for i,col in ipairs(self.__table__) do
52 s[#s+1] = tostring(col);
54 return 'Table{ '..t_concat(s, ", ")..' }'
56 table_mt.__index = {};
57 function table_mt.__index:create(engine)
58 return engine:_create_table(self);
60 function table_mt:__call(...)
63 function column_mt:__tostring()
64 return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
66 function index_mt:__tostring()
67 local s = 'Index{ name="'..self.name..'"';
68 for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
70 -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
73 local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
74 local function parse_url(url)
75 local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
76 assert(scheme, "Invalid URL format");
77 local username, password, host, port;
78 local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
79 if not authpart then hostpart = secondpart; end
81 username, password = authpart:match("([^:]*):(.*)");
82 username = username or authpart;
83 password = password and urldecode(password);
86 host, port = hostpart:match("([^:]*):(.*)");
87 host = host or hostpart;
88 port = port and assert(tonumber(port), "Invalid URL format");
91 scheme = scheme:lower();
92 username = username; password = password;
93 host = host; port = port;
94 database = #database > 0 and database or nil;
99 function engine:connect()
100 if self.conn then return true; end
102 local params = self.params;
103 assert(params.driver, "no driver")
104 log("error", "Connecting to [%s] %s...", params.driver, params.database);
105 local dbh, err = DBI.Connect(
106 params.driver, params.database,
107 params.username, params.password,
108 params.host, params.port
110 if not dbh then return nil, err; end
111 dbh:autocommit(false); -- don't commit automatically
114 local ok, err = self:set_encoding();
118 local ok, err = self:onconnect();
124 function engine:onconnect()
125 -- Override from create_engine()
127 function engine:execute(sql, ...)
128 local success, err = self:connect();
129 if not success then return success, err; end
130 local prepared = self.prepared;
132 local stmt = prepared[sql];
135 stmt, err = self.conn:prepare(sql);
136 if not stmt then return stmt, err; end
137 prepared[sql] = stmt;
140 local success, err = stmt:execute(...);
141 if not success then return success, err; end
145 local result_mt = { __index = {
146 affected = function(self) return self.__stmt:affected(); end;
147 rowcount = function(self) return self.__stmt:rowcount(); end;
150 function engine:execute_query(sql, ...)
151 if self.params.driver == "PostgreSQL" then
152 sql = sql:gsub("`", "\"");
154 local stmt = assert(self.conn:prepare(sql));
155 assert(stmt:execute(...));
158 function engine:execute_update(sql, ...)
159 if self.params.driver == "PostgreSQL" then
160 sql = sql:gsub("`", "\"");
162 local prepared = self.prepared;
163 local stmt = prepared[sql];
165 stmt = assert(self.conn:prepare(sql));
166 prepared[sql] = stmt;
168 assert(stmt:execute(...));
169 return setmetatable({ __stmt = stmt }, result_mt);
171 engine.insert = engine.execute_update;
172 engine.select = engine.execute_query;
173 engine.delete = engine.execute_update;
174 engine.update = engine.execute_update;
175 function engine:_transaction(func, ...)
176 if not self.conn then
177 local ok, err = self:connect();
178 if not ok then return ok, err; end
180 --assert(not self.__transaction, "Recursive transactions not allowed");
181 local args, n_args = {...}, select("#", ...);
182 local function f() return func(unpack(args, 1, n_args)); end
183 self.__transaction = true;
184 local success, a, b, c = xpcall(f, debug_traceback);
185 self.__transaction = nil;
187 log("debug", "SQL transaction success [%s]", tostring(func));
188 local ok, err = self.conn:commit();
189 if not ok then return ok, err; end -- commit failed
190 return success, a, b, c;
192 log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
193 if self.conn then self.conn:rollback(); end
197 function engine:transaction(...)
198 local ok, ret = self:_transaction(...);
200 local conn = self.conn;
201 if not conn or not conn:ping() then
203 ok, ret = self:_transaction(...);
208 function engine:_create_index(index)
209 local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
211 sql = sql.."`"..index[i].."`";
212 if i ~= #index then sql = sql..", "; end
215 if self.params.driver == "PostgreSQL" then
216 sql = sql:gsub("`", "\"");
217 elseif self.params.driver == "MySQL" then
218 sql = sql:gsub("`([,)])", "`(20)%1");
221 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
224 return self:execute(sql);
226 function engine:_create_table(table)
227 local sql = "CREATE TABLE `"..table.name.."` (";
228 for i,col in ipairs(table.c) do
229 local col_type = col.type;
230 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
231 col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
233 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
234 col_type = "BIGSERIAL";
236 sql = sql.."`"..col.name.."` "..col_type;
237 if col.nullable == false then sql = sql.." NOT NULL"; end
238 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
239 if col.auto_increment == true then
240 if self.params.driver == "MySQL" then
241 sql = sql.." AUTO_INCREMENT";
242 elseif self.params.driver == "SQLite3" then
243 sql = sql.." AUTOINCREMENT";
246 if i ~= #table.c then sql = sql..", "; end
249 if self.params.driver == "PostgreSQL" then
250 sql = sql:gsub("`", "\"");
251 elseif self.params.driver == "MySQL" then
252 sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
254 local success,err = self:execute(sql);
255 if not success then return success,err; end
256 for i,v in ipairs(table.__table__) do
258 self:_create_index(v);
263 function engine:set_encoding() -- to UTF-8
264 local driver = self.params.driver;
265 if driver == "SQLite3" then
266 return self:transaction(function()
267 if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
268 self.charset = "utf8";
272 local set_names_query = "SET NAMES '%s';"
273 local charset = "utf8";
274 if driver == "MySQL" then
275 local ok, charsets = self:transaction(function()
276 return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
278 local row = ok and charsets();
279 charset = row and row[1] or charset;
280 set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
282 self.charset = charset;
283 log("debug", "Using encoding '%s' for database connection", charset);
284 local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
291 local engine_mt = { __index = engine };
293 function db2uri(params)
295 scheme = params.driver,
296 user = params.username,
297 password = params.password,
300 path = params.database,
304 function create_engine(self, params, onconnect)
305 return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);