util.async: Fix logic bug that prevented error watcher being called for runners
[prosody.git] / util / async.lua
1 local log = require "util.logger".init("util.async");
2
3 local function runner_continue(thread)
4         -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
5         if coroutine.status(thread) ~= "suspended" then -- This should suffice
6                 return false;
7         end
8         local ok, state, runner = coroutine.resume(thread);
9         if not ok then
10                 local level = 0;
11                 while debug.getinfo(thread, level, "") do level = level + 1; end
12                 ok, runner = debug.getlocal(thread, level-1, 1);
13                 local error_handler = runner.watchers.error;
14                 if error_handler then error_handler(runner, debug.traceback(thread, state)); end
15         elseif state == "ready" then
16                 -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
17                 -- We also have to :run(), because the queue might have further items that will not be
18                 -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
19                 runner.state = "ready";
20                 runner:run();
21         end
22         return true;
23 end
24
25 local function waiter(num)
26         local thread = coroutine.running();
27         if not thread then
28                 error("Not running in an async context, see http://prosody.im/doc/developers/async");
29         end
30         num = num or 1;
31         local waiting;
32         return function ()
33                 if num == 0 then return; end -- already done
34                 waiting = true;
35                 coroutine.yield("wait");
36         end, function ()
37                 num = num - 1;
38                 if num == 0 and waiting then
39                         runner_continue(thread);
40                 elseif num < 0 then
41                         error("done() called too many times");
42                 end
43         end;
44 end
45
46 local runner_mt = {};
47 runner_mt.__index = runner_mt;
48
49 local function runner_create_thread(func, self)
50         local thread = coroutine.create(function (self)
51                 while true do
52                         func(coroutine.yield("ready", self));
53                 end
54         end);
55         assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
56         return thread;
57 end
58
59 local empty_watchers = {};
60 local function runner(func, watchers, data)
61         return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
62                 queue = {}, watchers = watchers or empty_watchers, data = data }
63         , runner_mt);
64 end
65
66 function runner_mt:run(input)
67         if input ~= nil then
68                 table.insert(self.queue, input);
69         end
70         if self.state ~= "ready" then
71                 return true, self.state, #self.queue;
72         end
73
74         local q, thread = self.queue, self.thread;
75         if not thread or coroutine.status(thread) == "dead" then
76                 thread = runner_create_thread(self.func, self);
77                 self.thread = thread;
78         end
79
80         local n, state, err = #q, self.state, nil;
81         self.state = "running";
82         while n > 0 and state == "ready" do
83                 local consumed;
84                 for i = 1,n do
85                         local input = q[i];
86                         local ok, new_state = coroutine.resume(thread, input);
87                         if not ok then
88                                 consumed, state, err = i, "ready", debug.traceback(thread, new_state);
89                                 self.thread = nil;
90                                 break;
91                         elseif new_state == "wait" then
92                                 consumed, state = i, "waiting";
93                                 break;
94                         end
95                 end
96                 if not consumed then consumed = n; end
97                 if q[n+1] ~= nil then
98                         n = #q;
99                 end
100                 for i = 1, n do
101                         q[i] = q[consumed+i];
102                 end
103                 n = #q;
104         end
105         self.state = state;
106         if err or state ~= self.notified_state then
107                 if err then
108                         state = "error"
109                 else
110                         self.notified_state = state;
111                 end
112                 local handler = self.watchers[state];
113                 if handler then handler(self, err); end
114         end
115         return true, state, n;
116 end
117
118 function runner_mt:enqueue(input)
119         table.insert(self.queue, input);
120 end
121
122 return { waiter = waiter, runner = runner };