Commit 64908ab7 authored by Stefan Behnel's avatar Stefan Behnel

NoneCheckNode to enforce runtime None checks for object references

parent aaeee9d0
...@@ -5091,7 +5091,44 @@ class PyTypeTestNode(CoercionNode): ...@@ -5091,7 +5091,44 @@ class PyTypeTestNode(CoercionNode):
def free_temps(self, code): def free_temps(self, code):
self.arg.free_temps(code) self.arg.free_temps(code)
class NoneCheckNode(CoercionNode):
# This node is used to check that a Python object is not None and
# raises an appropriate exception (as specified by the creating
# transform).
def __init__(self, arg, exception_type_cname, exception_message):
CoercionNode.__init__(self, arg)
self.type = arg.type
self.result_ctype = arg.ctype()
self.exception_type_cname = exception_type_cname
self.exception_message = exception_message
def analyse_types(self, env):
pass
def result_in_temp(self):
return self.arg.result_in_temp()
def calculate_result_code(self):
return self.arg.result()
def generate_result_code(self, code):
code.putln(
"if (unlikely(%s == Py_None)) {" % self.arg.result())
code.putln('PyErr_SetString(%s, "%s"); %s ' % (
self.exception_type_cname,
StringEncoding.escape_byte_string(self.exception_message),
code.error_goto(self.pos)))
code.putln("}")
def generate_post_assignment_code(self, code):
self.arg.generate_post_assignment_code(code)
def free_temps(self, code):
self.arg.free_temps(code)
class CoerceToPyTypeNode(CoercionNode): class CoerceToPyTypeNode(CoercionNode):
# This node is used to convert a C data type # This node is used to convert a C data type
......
...@@ -476,6 +476,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -476,6 +476,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
arg_list = arg_tuple.args arg_list = arg_tuple.args
self_arg = function.obj self_arg = function.obj
obj_type = self_arg.type obj_type = self_arg.type
is_unbound_method = False
if obj_type.is_builtin_type: if obj_type.is_builtin_type:
if obj_type is Builtin.type_type and arg_list and \ if obj_type is Builtin.type_type and arg_list and \
arg_list[0].type.is_pyobject: arg_list[0].type.is_pyobject:
...@@ -483,6 +484,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -483,6 +484,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
# (ignoring 'type.mro()' here ...) # (ignoring 'type.mro()' here ...)
type_name = function.obj.name type_name = function.obj.name
self_arg = None self_arg = None
is_unbound_method = True
else: else:
type_name = obj_type.name type_name = obj_type.name
else: else:
...@@ -494,9 +496,9 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -494,9 +496,9 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
if self_arg is not None: if self_arg is not None:
arg_list = [self_arg] + list(arg_list) arg_list = [self_arg] + list(arg_list)
if kwargs: if kwargs:
return method_handler(node, arg_list, kwargs) return method_handler(node, arg_list, kwargs, is_unbound_method)
else: else:
return method_handler(node, arg_list) return method_handler(node, arg_list, is_unbound_method)
else: else:
return node return node
...@@ -625,7 +627,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -625,7 +627,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
]) ])
def _handle_simple_method_object_append(self, node, args): def _handle_simple_method_object_append(self, node, args, is_unbound_method):
# X.append() is almost always referring to a list # X.append() is almost always referring to a list
if len(args) != 2: if len(args) != 2:
return node return node
...@@ -644,13 +646,14 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -644,13 +646,14 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
], ],
exception_value = "-1") exception_value = "-1")
def _handle_simple_method_list_append(self, node, args): def _handle_simple_method_list_append(self, node, args, is_unbound_method):
if len(args) != 2: if len(args) != 2:
error(node.pos, "list.append(x) called with wrong number of args, found %d" % error(node.pos, "list.append(x) called with wrong number of args, found %d" %
len(args)) len(args))
return node return node
return self._substitute_method_call( return self._substitute_method_call(
node, "PyList_Append", self.PyList_Append_func_type, args) node, "PyList_Append", self.PyList_Append_func_type,
'append', is_unbound_method, args)
single_param_func_type = PyrexTypes.CFuncType( single_param_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_int_type, [ PyrexTypes.c_int_type, [
...@@ -658,21 +661,37 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -658,21 +661,37 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
], ],
exception_value = "-1") exception_value = "-1")
def _handle_simple_method_list_sort(self, node, args): def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
if len(args) != 1: if len(args) != 1:
return node return node
return self._substitute_method_call( return self._substitute_method_call(
node, "PyList_Sort", self.single_param_func_type, args) node, "PyList_Sort", self.single_param_func_type,
'sort', is_unbound_method, args)
def _handle_simple_method_list_reverse(self, node, args): def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
if len(args) != 1: if len(args) != 1:
error(node.pos, "list.reverse(x) called with wrong number of args, found %d" % error(node.pos, "list.reverse(x) called with wrong number of args, found %d" %
len(args)) len(args))
return node
return self._substitute_method_call( return self._substitute_method_call(
node, "PyList_Reverse", self.single_param_func_type, args) node, "PyList_Reverse", self.single_param_func_type,
'reverse', is_unbound_method, args)
def _substitute_method_call(self, node, name, func_type, args=()): def _substitute_method_call(self, node, name, func_type,
attr_name, is_unbound_method, args=()):
args = list(args) args = list(args)
if args:
self_arg = args[0]
if is_unbound_method:
self_arg = ExprNodes.NoneCheckNode(
self_arg, "PyExc_TypeError",
"descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
attr_name, node.function.obj.name))
else:
self_arg = ExprNodes.NoneCheckNode(
self_arg, "PyExc_AttributeError",
"'NoneType' object has no attribute '%s'" % attr_name)
args[0] = self_arg
# FIXME: args[0] may need a runtime None check (ticket #166) # FIXME: args[0] may need a runtime None check (ticket #166)
return ExprNodes.PythonCapiCallNode( return ExprNodes.PythonCapiCallNode(
node.pos, name, func_type, node.pos, name, func_type,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment