7b2eae54b07eca616e8b03cc13a3ce8fbab54db9
[prosody.git] / util / async.lua
1 local log = require "util.logger".init("util.async");
2
3 local function runner_continue(thread)
4         -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
5         if coroutine.status(thread) ~= "suspended" then -- This should suffice
6                 return false;
7         end
8         local ok, state, runner = coroutine.resume(thread);
9         if not ok then
10                 local level = 0;
11                 while debug.getinfo(thread, level, "") do level = level + 1; end
12                 ok, runner = debug.getlocal(thread, level-1, 1);
13                 local error_handler = runner.watchers.error;
14                 if error_handler then error_handler(runner, debug.traceback(thread, state)); end
15         elseif state == "ready" then
16                 -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
17                 -- We also have to :run(), because the queue might have further items that will not be
18                 -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
19                 runner.state = "ready";
20                 runner:run();
21         end
22         return true;
23 end
24
25 local function waiter(num)
26         local thread = coroutine.running();
27         if not thread then
28                 error("Not running in an async context, see https://prosody.im/doc/developers/async");
29         end
30         num = num or 1;
31         local waiting;
32         return function ()
33                 if num == 0 then return; end -- already done
34                 waiting = true;
35                 coroutine.yield("wait");
36         end, function ()
37                 num = num - 1;
38                 if num == 0 and waiting then
39                         runner_continue(thread);
40                 elseif num < 0 then
41                         error("done() called too many times");
42                 end
43         end;
44 end
45
46 local function guarder()
47         local guards = {};
48         return function (id, func)
49                 local thread = coroutine.running();
50                 if not thread then
51                         error("Not running in an async context, see https://prosody.im/doc/developers/async");
52                 end
53                 local guard = guards[id];
54                 if not guard then
55                         guard = {};
56                         guards[id] = guard;
57                         log("debug", "New guard!");
58                 else
59                         table.insert(guard, thread);
60                         log("debug", "Guarded. %d threads waiting.", #guard)
61                         coroutine.yield("wait");
62                 end
63                 local function exit()
64                         local next_waiting = table.remove(guard, 1);
65                         if next_waiting then
66                                 log("debug", "guard: Executing next waiting thread (%d left)", #guard)
67                                 runner_continue(next_waiting);
68                         else
69                                 log("debug", "Guard off duty.")
70                                 guards[id] = nil;
71                         end
72                 end
73                 if func then
74                         func();
75                         exit();
76                         return;
77                 end
78                 return exit;
79         end;
80 end
81
82 local runner_mt = {};
83 runner_mt.__index = runner_mt;
84
85 local function runner_create_thread(func, self)
86         local thread = coroutine.create(function (self)
87                 while true do
88                         func(coroutine.yield("ready", self));
89                 end
90         end);
91         assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
92         return thread;
93 end
94
95 local empty_watchers = {};
96 local function runner(func, watchers, data)
97         return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
98                 queue = {}, watchers = watchers or empty_watchers, data = data }
99         , runner_mt);
100 end
101
102 function runner_mt:run(input)
103         if input ~= nil then
104                 table.insert(self.queue, input);
105         end
106         if self.state ~= "ready" then
107                 return true, self.state, #self.queue;
108         end
109
110         local q, thread = self.queue, self.thread;
111         if not thread or coroutine.status(thread) == "dead" then
112                 thread = runner_create_thread(self.func, self);
113                 self.thread = thread;
114         end
115
116         local n, state, err = #q, self.state, nil;
117         self.state = "running";
118         while n > 0 and state == "ready" do
119                 local consumed;
120                 for i = 1,n do
121                         local input = q[i];
122                         local ok, new_state = coroutine.resume(thread, input);
123                         if not ok then
124                                 consumed, state, err = i, "ready", debug.traceback(thread, new_state);
125                                 self.thread = nil;
126                                 break;
127                         elseif new_state == "wait" then
128                                 consumed, state = i, "waiting";
129                                 break;
130                         end
131                 end
132                 if not consumed then consumed = n; end
133                 if q[n+1] ~= nil then
134                         n = #q;
135                 end
136                 for i = 1, n do
137                         q[i] = q[consumed+i];
138                 end
139                 n = #q;
140         end
141         self.state = state;
142         if err or state ~= self.notified_state then
143                 if err then
144                         state = "error"
145                 else
146                         self.notified_state = state;
147                 end
148                 local handler = self.watchers[state];
149                 if handler then handler(self, err); end
150         end
151         return true, state, n;
152 end
153
154 function runner_mt:enqueue(input)
155         table.insert(self.queue, input);
156 end
157
158 return { waiter = waiter, guarder = guarder, runner = runner };