Commit 76c0f321 authored by Jason Madden's avatar Jason Madden

Unify the blocking implementation of Semaphore.wait and acquire.

parent ccbea2e3
......@@ -10,6 +10,7 @@ cdef class Semaphore:
cpdef unlink(self, object callback)
cpdef _start_notify(self)
cdef _notify_links(self)
cdef _do_wait(self, object timeout)
cpdef int wait(self, object timeout=*) except -1000
cpdef bint acquire(self, int blocking=*, object timeout=*) except -1000
cpdef __enter__(self)
......
......@@ -143,6 +143,32 @@ class Semaphore(object):
self._links = None
# TODO: Cancel a notifier if there are no links?
def _do_wait(self, timeout):
"""
Wait for up to *timeout* seconds to expire. If timeout
elapses, return the exception. Otherwise, return None.
Raises timeout if a different timer expires.
"""
switch = getcurrent().switch
self.rawlink(switch)
try:
# As a tiny efficiency optimization, avoid allocating a timer
# if not needed.
timer = Timeout.start_new(timeout) if timeout is not None else None
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.wait/acquire(): %r' % (result, )
except Timeout as ex:
if ex is not timer:
raise
return ex
finally:
if timer is not None:
timer.cancel()
finally:
self.unlink(switch)
def wait(self, timeout=None):
"""
wait(timeout=None) -> int
......@@ -161,22 +187,7 @@ class Semaphore(object):
if self.counter > 0:
return self.counter
switch = getcurrent().switch
self.rawlink(switch)
try:
timer = Timeout.start_new(timeout)
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.wait(): %r' % (result, )
except Timeout:
ex = sys.exc_info()[1]
if ex is not timer:
raise
finally:
timer.cancel()
finally:
self.unlink(switch)
self._do_wait(timeout) # return value irrelevant, whether we got it or got a timeout
return self.counter
def acquire(self, blocking=True, timeout=None):
......@@ -197,7 +208,8 @@ class Semaphore(object):
If ``blocking`` is True and ``timeout`` is None (the default), then
(so long as this semaphore was initialized with a size greater than 0)
this will always return True. If a timeout was given, and it expired before
the semaphore was acquired, False will be returned.
the semaphore was acquired, False will be returned. (Note that this can still
raise a ``Timeout`` exception, if some other caller had already started a timer.)
"""
if self.counter > 0:
self.counter -= 1
......@@ -206,25 +218,13 @@ class Semaphore(object):
if not blocking:
return False
switch = getcurrent().switch
self.rawlink(switch)
try:
# As a tiny efficiency optimization, avoid allocating a timer
# if not needed.
timer = Timeout.start_new(timeout) if timeout is not None else None
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.acquire(): %r' % (result, )
except Timeout as ex:
if ex is timer:
return False
raise
finally:
if timer is not None:
timer.cancel()
finally:
self.unlink(switch)
timeout = self._do_wait(timeout)
if timeout is not None:
# Our timer expired.
return False
# Neither our timer no another one expired, so we blocked until
# awoke. Therefore, the counter is ours
self.counter -= 1
assert self.counter >= 0
return True
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment