Merge 0.7->0.8
[prosody.git] / tools / migration / migrator / prosody_sql.lua
1
2 local assert = assert;
3 local have_DBI, DBI = pcall(require,"DBI");
4 local print = print;
5 local type = type;
6 local next = next;
7 local pairs = pairs;
8 local t_sort = table.sort;
9 local json = require "util.json";
10 local mtools = require "migrator.mtools";
11 local tostring = tostring;
12 local tonumber = tonumber;
13
14 if not have_DBI then
15         error("LuaDBI (required for SQL support) was not found, please see http://prosody.im/doc/depends#luadbi", 0);
16 end
17
18 module "prosody_sql"
19
20 local function create_table(connection, params)
21         local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);";
22         if params.driver == "PostgreSQL" then
23                 create_sql = create_sql:gsub("`", "\"");
24         end
25         
26         local stmt = connection:prepare(create_sql);
27         if stmt then
28                 local ok = stmt:execute();
29                 local commit_ok = connection:commit();
30                 if ok and commit_ok then
31                         local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)";
32                         if params.driver == "PostgreSQL" then
33                                 index_sql = index_sql:gsub("`", "\"");
34                         elseif params.driver == "MySQL" then
35                                 index_sql = index_sql:gsub("`([,)])", "`(20)%1");
36                         end
37                         local stmt, err = connection:prepare(index_sql);
38                         local ok, commit_ok, commit_err;
39                         if stmt then
40                                 ok, err = assert(stmt:execute());
41                                 commit_ok, commit_err = assert(connection:commit());
42                         end
43                 end
44         end
45 end
46
47 local function serialize(value)
48         local t = type(value);
49         if t == "string" or t == "boolean" or t == "number" then
50                 return t, tostring(value);
51         elseif t == "table" then
52                 local value,err = json.encode(value);
53                 if value then return "json", value; end
54                 return nil, err;
55         end
56         return nil, "Unhandled value type: "..t;
57 end
58 local function deserialize(t, value)
59         if t == "string" then return value;
60         elseif t == "boolean" then
61                 if value == "true" then return true;
62                 elseif value == "false" then return false; end
63         elseif t == "number" then return tonumber(value);
64         elseif t == "json" then
65                 return json.decode(value);
66         end
67 end
68
69 local function decode_user(item)
70         local userdata = {
71                 user = item[1][1].user;
72                 host = item[1][1].host;
73                 stores = {};
74         };
75         for i=1,#item do -- loop over stores
76                 local result = {};
77                 local store = item[i];
78                 for i=1,#store do -- loop over store data
79                         local row = store[i];
80                         local k = row.key;
81                         local v = deserialize(row.type, row.value);
82                         if k and v then
83                                 if k ~= "" then result[k] = v; elseif type(v) == "table" then
84                                         for a,b in pairs(v) do
85                                                 result[a] = b;
86                                         end
87                                 end
88                         end
89                         userdata.stores[store[1].store] = result;
90                 end
91         end
92         return userdata;
93 end
94
95 function reader(input)
96         local dbh = assert(DBI.Connect(
97                 assert(input.driver, "no input.driver specified"),
98                 assert(input.database, "no input.database specified"),
99                 input.username, input.password,
100                 input.host, input.port
101         ));
102         assert(dbh:ping());
103         local stmt = assert(dbh:prepare("SELECT * FROM prosody"));
104         assert(stmt:execute());
105         local keys = {"host", "user", "store", "key", "type", "value"};
106         local f,s,val = stmt:rows(true);
107         -- get SQL rows, sorted
108         local iter = mtools.sorted {
109                 reader = function() val = f(s, val); return val; end;
110                 filter = function(x)
111                         for i=1,#keys do
112                                 if not x[keys[i]] then return false; end -- TODO log error, missing field
113                         end
114                         if x.host  == "" then x.host  = nil; end
115                         if x.user  == "" then x.user  = nil; end
116                         if x.store == "" then x.store = nil; end
117                         return x;
118                 end;
119                 sorter = function(a, b)
120                         local a_host, a_user, a_store = a.host or "", a.user or "", a.store or "";
121                         local b_host, b_user, b_store = b.host or "", b.user or "", b.store or "";
122                         return a_host > b_host or (a_host==b_host and a_user > b_user) or (a_host==b_host and a_user==b_user and a_store > b_store);
123                 end;
124         };
125         -- merge rows to get stores
126         iter = mtools.merged(iter, function(a, b)
127                 return (a.host == b.host and a.user == b.user and a.store == b.store);
128         end);
129         -- merge stores to get users
130         iter = mtools.merged(iter, function(a, b)
131                 return (a[1].host == b[1].host and a[1].user == b[1].user);
132         end);
133         return function()
134                 local x = iter();
135                 return x and decode_user(x);
136         end;
137 end
138
139 function writer(output, iter)
140         local dbh = assert(DBI.Connect(
141                 assert(output.driver, "no output.driver specified"),
142                 assert(output.database, "no output.database specified"),
143                 output.username, output.password,
144                 output.host, output.port
145         ));
146         assert(dbh:ping());
147         create_table(dbh, output);
148         local stmt = assert(dbh:prepare("SELECT * FROM prosody"));
149         assert(stmt:execute());
150         local stmt = assert(dbh:prepare("DELETE FROM prosody"));
151         assert(stmt:execute());
152         local insert_sql = "INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)";
153         if output.driver == "PostgreSQL" then
154                 insert_sql = insert_sql:gsub("`", "\"");
155         end
156         local insert = assert(dbh:prepare(insert_sql));
157
158         return function(item)
159                 if not item then assert(dbh:commit()) return dbh:close(); end -- end of input
160                 local host = item.host or "";
161                 local user = item.user or "";
162                 for store, data in pairs(item.stores) do
163                         -- TODO transactions
164                         local extradata = {};
165                         for key, value in pairs(data) do
166                                 if type(key) == "string" and key ~= "" then
167                                         local t, value = assert(serialize(value));
168                                         local ok, err = assert(insert:execute(host, user, store, key, t, value));
169                                 else
170                                         extradata[key] = value;
171                                 end
172                         end
173                         if next(extradata) ~= nil then
174                                 local t, extradata = assert(serialize(extradata));
175                                 local ok, err = assert(insert:execute(host, user, store, "", t, extradata));
176                         end
177                 end
178         end;
179 end
180
181
182 return _M;