Commit 76c0f321 authored by Jason Madden's avatar Jason Madden

Unify the blocking implementation of Semaphore.wait and acquire.

parent ccbea2e3
...@@ -10,6 +10,7 @@ cdef class Semaphore: ...@@ -10,6 +10,7 @@ cdef class Semaphore:
cpdef unlink(self, object callback) cpdef unlink(self, object callback)
cpdef _start_notify(self) cpdef _start_notify(self)
cdef _notify_links(self) cdef _notify_links(self)
cdef _do_wait(self, object timeout)
cpdef int wait(self, object timeout=*) except -1000 cpdef int wait(self, object timeout=*) except -1000
cpdef bint acquire(self, int blocking=*, object timeout=*) except -1000 cpdef bint acquire(self, int blocking=*, object timeout=*) except -1000
cpdef __enter__(self) cpdef __enter__(self)
......
...@@ -143,6 +143,32 @@ class Semaphore(object): ...@@ -143,6 +143,32 @@ class Semaphore(object):
self._links = None self._links = None
# TODO: Cancel a notifier if there are no links? # TODO: Cancel a notifier if there are no links?
def _do_wait(self, timeout):
"""
Wait for up to *timeout* seconds to expire. If timeout
elapses, return the exception. Otherwise, return None.
Raises timeout if a different timer expires.
"""
switch = getcurrent().switch
self.rawlink(switch)
try:
# As a tiny efficiency optimization, avoid allocating a timer
# if not needed.
timer = Timeout.start_new(timeout) if timeout is not None else None
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.wait/acquire(): %r' % (result, )
except Timeout as ex:
if ex is not timer:
raise
return ex
finally:
if timer is not None:
timer.cancel()
finally:
self.unlink(switch)
def wait(self, timeout=None): def wait(self, timeout=None):
""" """
wait(timeout=None) -> int wait(timeout=None) -> int
...@@ -161,22 +187,7 @@ class Semaphore(object): ...@@ -161,22 +187,7 @@ class Semaphore(object):
if self.counter > 0: if self.counter > 0:
return self.counter return self.counter
switch = getcurrent().switch self._do_wait(timeout) # return value irrelevant, whether we got it or got a timeout
self.rawlink(switch)
try:
timer = Timeout.start_new(timeout)
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.wait(): %r' % (result, )
except Timeout:
ex = sys.exc_info()[1]
if ex is not timer:
raise
finally:
timer.cancel()
finally:
self.unlink(switch)
return self.counter return self.counter
def acquire(self, blocking=True, timeout=None): def acquire(self, blocking=True, timeout=None):
...@@ -197,7 +208,8 @@ class Semaphore(object): ...@@ -197,7 +208,8 @@ class Semaphore(object):
If ``blocking`` is True and ``timeout`` is None (the default), then If ``blocking`` is True and ``timeout`` is None (the default), then
(so long as this semaphore was initialized with a size greater than 0) (so long as this semaphore was initialized with a size greater than 0)
this will always return True. If a timeout was given, and it expired before this will always return True. If a timeout was given, and it expired before
the semaphore was acquired, False will be returned. the semaphore was acquired, False will be returned. (Note that this can still
raise a ``Timeout`` exception, if some other caller had already started a timer.)
""" """
if self.counter > 0: if self.counter > 0:
self.counter -= 1 self.counter -= 1
...@@ -206,25 +218,13 @@ class Semaphore(object): ...@@ -206,25 +218,13 @@ class Semaphore(object):
if not blocking: if not blocking:
return False return False
switch = getcurrent().switch timeout = self._do_wait(timeout)
self.rawlink(switch) if timeout is not None:
try: # Our timer expired.
# As a tiny efficiency optimization, avoid allocating a timer
# if not needed.
timer = Timeout.start_new(timeout) if timeout is not None else None
try:
try:
result = get_hub().switch()
assert result is self, 'Invalid switch into Semaphore.acquire(): %r' % (result, )
except Timeout as ex:
if ex is timer:
return False return False
raise
finally: # Neither our timer no another one expired, so we blocked until
if timer is not None: # awoke. Therefore, the counter is ours
timer.cancel()
finally:
self.unlink(switch)
self.counter -= 1 self.counter -= 1
assert self.counter >= 0 assert self.counter >= 0
return True return True
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment