Commit 820f2980 authored by gsamain's avatar gsamain

Support typecast operators

parent a45374f2
...@@ -10444,12 +10444,14 @@ class TypecastNode(ExprNode): ...@@ -10444,12 +10444,14 @@ class TypecastNode(ExprNode):
# base_type CBaseTypeNode # base_type CBaseTypeNode
# declarator CDeclaratorNode # declarator CDeclaratorNode
# typecheck boolean # typecheck boolean
# overloaded boolean
# #
# If used from a transform, one can if wanted specify the attribute # If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None # "type" directly and leave base_type and declarator to None
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None base_type = declarator = type = None
overloaded = False
def type_dependencies(self, env): def type_dependencies(self, env):
return () return ()
...@@ -10515,6 +10517,10 @@ class TypecastNode(ExprNode): ...@@ -10515,6 +10517,10 @@ class TypecastNode(ExprNode):
elif self.operand.type.is_fused: elif self.operand.type.is_fused:
self.operand = self.operand.coerce_to(self.type, env) self.operand = self.operand.coerce_to(self.type, env)
#self.type = self.operand.type #self.type = self.operand.type
elif self.operand.type.is_cpp_class:
operator = 'operator ' + self.type.declaration_code('')
entry = self.operand.type.scope.lookup_here(operator)
self.overloaded = entry is not None
if self.type.is_ptr and self.type.base_type.is_cfunction and self.type.base_type.nogil: if self.type.is_ptr and self.type.base_type.is_cfunction and self.type.base_type.nogil:
op_type = self.operand.type op_type = self.operand.type
if op_type.is_ptr: if op_type.is_ptr:
...@@ -10561,6 +10567,8 @@ class TypecastNode(ExprNode): ...@@ -10561,6 +10567,8 @@ class TypecastNode(ExprNode):
real_part, real_part,
imag_part) imag_part)
else: else:
if self.overloaded and self.operand.type.is_cyp_class:
operand_result = '(*%s)' % operand_result
return self.type.cast_code(operand_result) return self.type.cast_code(operand_result)
def get_constant_c_result_code(self): def get_constant_c_result_code(self):
......
...@@ -953,6 +953,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -953,6 +953,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
elif attr.type.is_cfunction: elif attr.type.is_cfunction:
code.put("virtual ") code.put("virtual ")
has_virtual_methods = True has_virtual_methods = True
if 'operator ' in attr.name:
code.putln("%s();" % attr.cname)
continue
elif attr.type.is_cyp_class: elif attr.type.is_cyp_class:
cname = "%s = NULL" % cname cname = "%s = NULL" % cname
code.putln("%s;" % attr.type.declaration_code(cname)) code.putln("%s;" % attr.type.declaration_code(cname))
......
...@@ -2704,6 +2704,9 @@ class CFuncDefNode(FuncDefNode): ...@@ -2704,6 +2704,9 @@ class CFuncDefNode(FuncDefNode):
dll_linkage = None dll_linkage = None
modifiers = code.build_function_modifiers(self.entry.func_modifiers) modifiers = code.build_function_modifiers(self.entry.func_modifiers)
if 'operator ' in entity:
header = entity
else:
header = self.return_type.declaration_code(entity, dll_linkage=dll_linkage) header = self.return_type.declaration_code(entity, dll_linkage=dll_linkage)
#print (storage_class, modifiers, header) #print (storage_class, modifiers, header)
needs_proto = self.is_c_class_method needs_proto = self.is_c_class_method
......
...@@ -2550,6 +2550,41 @@ class CppClassScope(Scope): ...@@ -2550,6 +2550,41 @@ class CppClassScope(Scope):
return_type=type.return_type) return_type=type.return_type)
type.original_alloc_type = type.args[0] type.original_alloc_type = type.args[0]
else:
operator = self.operator_table.get(name, None)
if operator:
name = 'operator'+operator
elif name.startswith('__') and name.endswith('__'):
stripped_name = name[2:-2]
signed = 1
longness = 0
ctypename = None
exploded_name = stripped_name.split('_')
for index, token in enumerate(exploded_name):
# Basically, it is the same code than Parsing.p_sign_and_longness
if token == 'unsigned':
signed = 0
elif token == 'signed':
signed = 2
elif token == 'short':
longness = -1
elif token == 'long':
longness += 1
else:
ctypename = '_'.join(exploded_name[index:])
break
known_type = PyrexTypes.simple_c_type(signed, longness, ctypename)
if not known_type:
if stripped_name == "bool":
# This one is hardcoded because it is declared as an int
# in PyrexTypes
name = 'operator bool'
type.args = []
else:
known_type = self.lookup_type(stripped_name)
if known_type:
name = 'operator ' + known_type.declaration_code('')
type.args = []
if name in ('<init>', '<del>') and type.nogil: if name in ('<init>', '<del>') and type.nogil:
for base in self.type.base_classes: for base in self.type.base_classes:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment