Merge 0.10->trunk
[prosody.git] / tools / migration / migrator / prosody_sql.lua
1
2 local assert = assert;
3 local have_DBI, DBI = pcall(require,"DBI");
4 local print = print;
5 local type = type;
6 local next = next;
7 local pairs = pairs;
8 local t_sort = table.sort;
9 local json = require "util.json";
10 local mtools = require "migrator.mtools";
11 local tostring = tostring;
12 local tonumber = tonumber;
13
14 if not have_DBI then
15         error("LuaDBI (required for SQL support) was not found, please see https://prosody.im/doc/depends#luadbi", 0);
16 end
17
18 module "prosody_sql"
19
20 local function create_table(connection, params)
21         local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);";
22         if params.driver == "PostgreSQL" then
23                 create_sql = create_sql:gsub("`", "\"");
24         elseif params.driver == "MySQL" then
25                 create_sql = create_sql:gsub("`value` TEXT", "`value` MEDIUMTEXT");
26         end
27
28         local stmt = connection:prepare(create_sql);
29         if stmt then
30                 local ok = stmt:execute();
31                 local commit_ok = connection:commit();
32                 if ok and commit_ok then
33                         local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)";
34                         if params.driver == "PostgreSQL" then
35                                 index_sql = index_sql:gsub("`", "\"");
36                         elseif params.driver == "MySQL" then
37                                 index_sql = index_sql:gsub("`([,)])", "`(20)%1");
38                         end
39                         local stmt, err = connection:prepare(index_sql);
40                         local ok, commit_ok, commit_err;
41                         if stmt then
42                                 ok, err = assert(stmt:execute());
43                                 commit_ok, commit_err = assert(connection:commit());
44                         end
45                 elseif params.driver == "MySQL" then -- COMPAT: Upgrade tables from 0.8.0
46                         -- Failed to create, but check existing MySQL table here
47                         local stmt = connection:prepare("SHOW COLUMNS FROM prosody WHERE Field='value' and Type='text'");
48                         local ok = stmt:execute();
49                         local commit_ok = connection:commit();
50                         if ok and commit_ok then
51                                 if stmt:rowcount() > 0 then
52                                         local stmt = connection:prepare("ALTER TABLE prosody MODIFY COLUMN `value` MEDIUMTEXT");
53                                         local ok = stmt:execute();
54                                         local commit_ok = connection:commit();
55                                         if ok and commit_ok then
56                                                 print("Database table automatically upgraded");
57                                         end
58                                 end
59                                 repeat until not stmt:fetch();
60                         end
61                 end
62         end
63 end
64
65 local function serialize(value)
66         local t = type(value);
67         if t == "string" or t == "boolean" or t == "number" then
68                 return t, tostring(value);
69         elseif t == "table" then
70                 local value,err = json.encode(value);
71                 if value then return "json", value; end
72                 return nil, err;
73         end
74         return nil, "Unhandled value type: "..t;
75 end
76 local function deserialize(t, value)
77         if t == "string" then return value;
78         elseif t == "boolean" then
79                 if value == "true" then return true;
80                 elseif value == "false" then return false; end
81         elseif t == "number" then return tonumber(value);
82         elseif t == "json" then
83                 return json.decode(value);
84         end
85 end
86
87 local function decode_user(item)
88         local userdata = {
89                 user = item[1][1].user;
90                 host = item[1][1].host;
91                 stores = {};
92         };
93         for i=1,#item do -- loop over stores
94                 local result = {};
95                 local store = item[i];
96                 for i=1,#store do -- loop over store data
97                         local row = store[i];
98                         local k = row.key;
99                         local v = deserialize(row.type, row.value);
100                         if k and v then
101                                 if k ~= "" then result[k] = v; elseif type(v) == "table" then
102                                         for a,b in pairs(v) do
103                                                 result[a] = b;
104                                         end
105                                 end
106                         end
107                         userdata.stores[store[1].store] = result;
108                 end
109         end
110         return userdata;
111 end
112
113 function reader(input)
114         local dbh = assert(DBI.Connect(
115                 assert(input.driver, "no input.driver specified"),
116                 assert(input.database, "no input.database specified"),
117                 input.username, input.password,
118                 input.host, input.port
119         ));
120         assert(dbh:ping());
121         local stmt = assert(dbh:prepare("SELECT * FROM prosody"));
122         assert(stmt:execute());
123         local keys = {"host", "user", "store", "key", "type", "value"};
124         local f,s,val = stmt:rows(true);
125         -- get SQL rows, sorted
126         local iter = mtools.sorted {
127                 reader = function() val = f(s, val); return val; end;
128                 filter = function(x)
129                         for i=1,#keys do
130                                 if not x[keys[i]] then return false; end -- TODO log error, missing field
131                         end
132                         if x.host  == "" then x.host  = nil; end
133                         if x.user  == "" then x.user  = nil; end
134                         if x.store == "" then x.store = nil; end
135                         return x;
136                 end;
137                 sorter = function(a, b)
138                         local a_host, a_user, a_store = a.host or "", a.user or "", a.store or "";
139                         local b_host, b_user, b_store = b.host or "", b.user or "", b.store or "";
140                         return a_host > b_host or (a_host==b_host and a_user > b_user) or (a_host==b_host and a_user==b_user and a_store > b_store);
141                 end;
142         };
143         -- merge rows to get stores
144         iter = mtools.merged(iter, function(a, b)
145                 return (a.host == b.host and a.user == b.user and a.store == b.store);
146         end);
147         -- merge stores to get users
148         iter = mtools.merged(iter, function(a, b)
149                 return (a[1].host == b[1].host and a[1].user == b[1].user);
150         end);
151         return function()
152                 local x = iter();
153                 return x and decode_user(x);
154         end;
155 end
156
157 function writer(output, iter)
158         local dbh = assert(DBI.Connect(
159                 assert(output.driver, "no output.driver specified"),
160                 assert(output.database, "no output.database specified"),
161                 output.username, output.password,
162                 output.host, output.port
163         ));
164         assert(dbh:ping());
165         create_table(dbh, output);
166         local stmt = assert(dbh:prepare("SELECT * FROM prosody"));
167         assert(stmt:execute());
168         local stmt = assert(dbh:prepare("DELETE FROM prosody"));
169         assert(stmt:execute());
170         local insert_sql = "INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)";
171         if output.driver == "PostgreSQL" then
172                 insert_sql = insert_sql:gsub("`", "\"");
173         end
174         local insert = assert(dbh:prepare(insert_sql));
175
176         return function(item)
177                 if not item then assert(dbh:commit()) return dbh:close(); end -- end of input
178                 local host = item.host or "";
179                 local user = item.user or "";
180                 for store, data in pairs(item.stores) do
181                         -- TODO transactions
182                         local extradata = {};
183                         for key, value in pairs(data) do
184                                 if type(key) == "string" and key ~= "" then
185                                         local t, value = assert(serialize(value));
186                                         local ok, err = assert(insert:execute(host, user, store, key, t, value));
187                                 else
188                                         extradata[key] = value;
189                                 end
190                         end
191                         if next(extradata) ~= nil then
192                                 local t, extradata = assert(serialize(extradata));
193                                 local ok, err = assert(insert:execute(host, user, store, "", t, extradata));
194                         end
195                 end
196         end;
197 end
198
199
200 return _M;