Commit 814300ff authored by Xavier Thompson's avatar Xavier Thompson

Ensure callable cypclass objects are properly locked when called

parent 92b6c650
......@@ -693,7 +693,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
def visit_AttributeNode(self, node):
if node.obj.type and node.obj.type.is_cyp_class:
if node.is_called:
if node.is_called and node.type.is_cfunction:
if not node.type.is_static_method:
node.obj = self.lockcheck_written_or_read(node.obj, reading=node.type.is_const_method)
else:
......@@ -706,7 +706,13 @@ class CypclassLockTransform(Visitor.EnvTransform):
for i, arg in enumerate(node.args or ()): # provide an empty tuple fallback in case node.args is None
if arg.type.is_cyp_class:
node.args[i] = self.lockcheck_written_or_read(arg, reading=arg.type.is_const)
# TODO: lock callable objects
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_CoerceFromCallable(self, node):
if node.arg.type.is_cyp_class:
node.arg = self.lockcheck_written_or_read(node.arg, reading=node.type.is_const_method)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
......
......@@ -5962,6 +5962,7 @@ class SimpleCallNode(CallNode):
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
self.function = CoerceFromCallable(self.function)
elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry
elif self.function.is_subscript and self.function.is_fused_index:
......@@ -13918,6 +13919,37 @@ class CoerceToBooleanNode(CoercionNode):
return self
class CoerceFromCallable(CoercionNode):
# This node is used to wrap a callable cpp-typed node when it is called as part of a SimpleCallNode.
# The cpp type of the callable is replaced by the type of the operator() method during expression analysis;
# this node allows the underlying callable node to keep its original type.
def __init__(self, arg):
CoercionNode.__init__(self, arg)
self.type = arg.type
def generate_result_code(self, code):
self.arg.generate_result_code(code)
def result(self):
return self.arg.result()
def is_simple(self):
return self.arg.is_simple()
def may_be_none(self):
return self.arg.may_be_none()
def generate_evaluation_code(self, code):
self.arg.generate_evaluation_code(code)
def generate_disposal_code(self, code):
self.arg.generate_disposal_code(code)
def free_temps(self, code):
self.arg.free_temps(code)
class CoerceToComplexNode(CoercionNode):
def __init__(self, arg, dst_type, env):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment