Commit 3f1a75bd authored by Stefan Behnel's avatar Stefan Behnel

fix 'self' argument type in generic builtin method overrides, test for None at...

fix 'self' argument type in generic builtin method overrides, test for None at call time (ticket #571)
parent 73e6ab4a
...@@ -112,20 +112,20 @@ builtin_types_table = [ ...@@ -112,20 +112,20 @@ builtin_types_table = [
("tuple", "PyTuple_Type", []), ("tuple", "PyTuple_Type", []),
("list", "PyList_Type", [("insert", "OzO", "i", "PyList_Insert")]), ("list", "PyList_Type", [("insert", "TzO", "i", "PyList_Insert")]),
("dict", "PyDict_Type", [("items", "O", "O", "PyDict_Items"), ("dict", "PyDict_Type", [("items", "T", "O", "PyDict_Items"),
("keys", "O", "O", "PyDict_Keys"), ("keys", "T", "O", "PyDict_Keys"),
("values","O", "O", "PyDict_Values"), ("values","T", "O", "PyDict_Values"),
("copy", "O", "O", "PyDict_Copy")]), ("copy", "T", "O", "PyDict_Copy")]),
("slice", "PySlice_Type", []), ("slice", "PySlice_Type", []),
# ("file", "PyFile_Type", []), # not in Py3 # ("file", "PyFile_Type", []), # not in Py3
("set", "PySet_Type", [("clear", "O", "i", "PySet_Clear"), ("set", "PySet_Type", [("clear", "T", "i", "PySet_Clear"),
("discard", "OO", "i", "PySet_Discard"), ("discard", "TO", "i", "PySet_Discard"),
("add", "OO", "i", "PySet_Add"), ("add", "TO", "i", "PySet_Add"),
("pop", "O", "O", "PySet_Pop")]), ("pop", "T", "O", "PySet_Pop")]),
("frozenset", "PyFrozenSet_Type", []), ("frozenset", "PyFrozenSet_Type", []),
] ]
...@@ -470,7 +470,11 @@ def init_builtin_types(): ...@@ -470,7 +470,11 @@ def init_builtin_types():
builtin_types[name] = the_type builtin_types[name] = the_type
for name, args, ret, cname in funcs: for name, args, ret, cname in funcs:
sig = Signature(args, ret) sig = Signature(args, ret)
the_type.scope.declare_cfunction(name, sig.function_type(), None, cname) # override 'self' type (first argument)
self_arg = PyrexTypes.CFuncTypeArg("", the_type, None)
self_arg.not_none = True
method_type = sig.function_type(self_arg)
the_type.scope.declare_cfunction(name, method_type, None, cname)
def init_builtin_structs(): def init_builtin_structs():
for name, cname, attribute_types in builtin_structs_table: for name, cname, attribute_types in builtin_structs_table:
......
...@@ -2836,7 +2836,11 @@ class SimpleCallNode(CallNode): ...@@ -2836,7 +2836,11 @@ class SimpleCallNode(CallNode):
arg.analyse_types(env) arg.analyse_types(env)
if self.self and func_type.args: if self.self and func_type.args:
# Coerce 'self' to the type expected by the method. # Coerce 'self' to the type expected by the method.
expected_type = func_type.args[0].type self_arg = func_type.args[0]
if self_arg.not_none: # C methods must do the None test for self at *call* time
self.self = self.self.as_none_safe_node(
"'NoneType' object has no attribute '%s'" % self.function.entry.name)
expected_type = self_arg.type
self.coerced_self = CloneNode(self.self).coerce_to( self.coerced_self = CloneNode(self.self).coerce_to(
expected_type, env) expected_type, env)
# Insert coerced 'self' argument into argument list. # Insert coerced 'self' argument into argument list.
......
...@@ -100,10 +100,14 @@ class Signature(object): ...@@ -100,10 +100,14 @@ class Signature(object):
def exception_value(self): def exception_value(self):
return self.error_value_map.get(self.ret_format) return self.error_value_map.get(self.ret_format)
def function_type(self): def function_type(self, self_arg_override=None):
# Construct a C function type descriptor for this signature # Construct a C function type descriptor for this signature
args = [] args = []
for i in xrange(self.num_fixed_args()): for i in xrange(self.num_fixed_args()):
if self_arg_override is not None and self.is_self_arg(i):
assert isinstance(self_arg_override, PyrexTypes.CFuncTypeArg)
args.append(self_arg_override)
else:
arg_type = self.fixed_arg_type(i) arg_type = self.fixed_arg_type(i)
args.append(PyrexTypes.CFuncTypeArg("", arg_type, None)) args.append(PyrexTypes.CFuncTypeArg("", arg_type, None))
ret_type = self.return_type() ret_type = self.return_type()
......
...@@ -56,6 +56,13 @@ def test_set_clear(): ...@@ -56,6 +56,13 @@ def test_set_clear():
s1.clear() s1.clear()
return s1 return s1
def test_set_clear_None():
"""
>>> test_set_clear_None()
"""
cdef set s1 = None
s1.clear()
def test_set_list_comp(): def test_set_list_comp():
""" """
>>> type(test_set_list_comp()) is _set >>> type(test_set_list_comp()) is _set
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment