Commit d1ad6dc6 authored by Denis Bilenko's avatar Denis Bilenko

threadpool.py: forward unhandled errors to hub's handle_error; better manage threads

spawn manager greenlet if necessary to finish off excess threads
parent daa93507
...@@ -22,10 +22,10 @@ class ThreadPool(object): ...@@ -22,10 +22,10 @@ class ThreadPool(object):
hub = get_hub() hub = get_hub()
self.hub = hub self.hub = hub
self._maxsize = 0 self._maxsize = 0
self._init(maxsize) self.manager = None
self.pid = os.getpid() self.pid = os.getpid()
self.fork_watcher = hub.loop.fork(ref=False) self.fork_watcher = hub.loop.fork(ref=False)
self.fork_watcher.start(self._on_fork) self._init(maxsize)
def _set_maxsize(self, maxsize): def _set_maxsize(self, maxsize):
if not isinstance(maxsize, integer_types): if not isinstance(maxsize, integer_types):
...@@ -35,8 +35,7 @@ class ThreadPool(object): ...@@ -35,8 +35,7 @@ class ThreadPool(object):
difference = maxsize - self._maxsize difference = maxsize - self._maxsize
self._semaphore.counter += difference self._semaphore.counter += difference
self._maxsize = maxsize self._maxsize = maxsize
self._remove_threads() self.adjust()
self._add_threads()
# make sure all currently blocking spawn() start unlocking if maxsize increased # make sure all currently blocking spawn() start unlocking if maxsize increased
self._semaphore._start_notify() self._semaphore._start_notify()
...@@ -82,25 +81,39 @@ class ThreadPool(object): ...@@ -82,25 +81,39 @@ class ThreadPool(object):
delay = min(delay * 2, .05) delay = min(delay * 2, .05)
def kill(self): def kill(self):
delay = 0.0005 if self.manager:
while self._size > 0: self.manager.kill()
self._remove_threads(0) self._manage(0)
sleep(delay)
delay = min(delay * 2, .05)
def _add_threads(self): def _adjust(self, maxsize):
while self.task_queue.unfinished_tasks > self._size: if maxsize is None:
if self._size >= self.maxsize: maxsize = self._maxsize
break while self.task_queue.unfinished_tasks > self._size and self._size < maxsize:
self._add_thread() self._add_thread()
while self._size - maxsize > self.task_queue.unfinished_tasks:
self.task_queue.put(None)
if self._size:
self.fork_watcher.start(self._on_fork)
else:
self.fork_watcher.stop()
def _remove_threads(self, maxsize=None): def _manage(self, maxsize=None):
if maxsize is None: if maxsize is None:
maxsize = self._maxsize maxsize = self._maxsize
excess = self._size - maxsize delay = 0.0001
if excess > 0: while True:
while excess > self.task_queue.qsize(): self._adjust(maxsize)
self.task_queue.put(None) if self._size <= maxsize:
return
sleep(delay)
delay = min(delay * 2, .05)
def adjust(self):
if self.manager:
return
if self._adjust(self.maxsize):
return
self.manager = Greenlet.spawn(self._manage)
def _add_thread(self): def _add_thread(self):
with self._lock: with self._lock:
...@@ -121,23 +134,37 @@ class ThreadPool(object): ...@@ -121,23 +134,37 @@ class ThreadPool(object):
try: try:
task_queue = self.task_queue task_queue = self.task_queue
result = AsyncResult() result = AsyncResult()
tr = ThreadResult(result, hub=self.hub) thread_result = ThreadResult(result, hub=self.hub)
self._remove_threads() task_queue.put((func, args, kwargs, thread_result))
task_queue.put((func, args, kwargs, tr)) self.adjust()
self._add_threads() # rawlink() must be the last call
result.rawlink(lambda *args: self._semaphore.release()) result.rawlink(lambda *args: self._semaphore.release())
except: except:
semaphore.release() semaphore.release()
raise raise
return result return result
def _decrease_size(self):
if sys is None:
return
_lock = getattr(self, '_lock', None)
if _lock is not None:
with _lock:
self._size -= 1
def _worker(self): def _worker(self):
need_decrease = True
try: try:
while True: while True:
task_queue = self.task_queue task_queue = self.task_queue
task = task_queue.get() task = task_queue.get()
try: try:
if task is None: if task is None:
need_decrease = False
self._decrease_size()
# we want first to decrease size, then decrease unfinished_tasks
# otherwise, _adjust might think there's one more idle thread that
# needs to be killed
return return
func, args, kwargs, result = task func, args, kwargs, result = task
try: try:
...@@ -146,7 +173,7 @@ class ThreadPool(object): ...@@ -146,7 +173,7 @@ class ThreadPool(object):
exc_info = getattr(sys, 'exc_info', None) exc_info = getattr(sys, 'exc_info', None)
if exc_info is None: if exc_info is None:
return return
result.set_exception(exc_info()[1]) result.handle_error(func, exc_info())
else: else:
if sys is None: if sys is None:
return return
...@@ -156,12 +183,8 @@ class ThreadPool(object): ...@@ -156,12 +183,8 @@ class ThreadPool(object):
return return
task_queue.task_done() task_queue.task_done()
finally: finally:
if sys is None: if need_decrease:
return self._decrease_size()
_lock = getattr(self, '_lock', None)
if _lock is not None:
with _lock:
self._size -= 1
def apply(self, func, args=None, kwds=None): def apply(self, func, args=None, kwds=None):
"""Equivalent of the apply() builtin function. It blocks till the result is ready.""" """Equivalent of the apply() builtin function. It blocks till the result is ready."""
...@@ -221,25 +244,40 @@ class ThreadResult(object): ...@@ -221,25 +244,40 @@ class ThreadResult(object):
def __init__(self, receiver, hub=None): def __init__(self, receiver, hub=None):
if hub is None: if hub is None:
hub = get_hub() hub = get_hub()
self.value = None
self.exception = None
self.receiver = receiver self.receiver = receiver
self.hub = hub
self.value = None
self.context = None
self.exc_info = None
self.async = hub.loop.async() self.async = hub.loop.async()
self.async.start(self._on_async) self.async.start(self._on_async)
def _on_async(self): def _on_async(self):
self.async.stop() self.async.stop()
try:
if self.exc_info is not None:
try:
self.hub.handle_error(self.context, *self.exc_info)
finally:
self.exc_info = None
self.context = None
self.async = None
self.hub = None
if self.receiver is not None: if self.receiver is not None:
self.receiver(self) self.receiver(self)
finally:
self.receiver = None self.receiver = None
self.value = None
def successful(self):
return self.exception is None
def set(self, value): def set(self, value):
self.value = value self.value = value
self.async.send() self.async.send()
def set_exception(self, value): def handle_error(self, context, exc_info):
self.exception = value self.context = context
self.exc_info = exc_info
self.async.send() self.async.send()
# link protocol:
def successful(self):
return True
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment