Commit b0339316 authored by Stefan Behnel's avatar Stefan Behnel

fix compile time constants in array size declarations (e.g. int a[enum_val+1])

parent 02a1df31
...@@ -230,15 +230,23 @@ class ExprNode(Node): ...@@ -230,15 +230,23 @@ class ExprNode(Node):
# C type of the result_code expression). # C type of the result_code expression).
return self.result_ctype or self.type return self.result_ctype or self.type
def get_constant_result_code(self): def get_constant_c_result_code(self):
# Return the constant value of this node as a result code # Return the constant value of this node as a result code
# string, or None if the node is not constant. # string, or None if the node is not constant. This method
# can be called when the constant result code is required
# before the code generation phase.
#
# The return value is a string that can represent a simple C
# value, a constant C name or a constant C expression. If the
# node type depends on Python code, this must return None.
return None return None
def calculate_constant_result(self): def calculate_constant_result(self):
# Calculate the constant result of this expression and store # Calculate the constant compile time result value of this
# it in ``self.constant_result``. Does nothing by default, # expression and store it in ``self.constant_result``. Does
# thus leaving ``self.constant_result`` unknown. # nothing by default, thus leaving ``self.constant_result``
# unknown. If valid, the result can be an arbitrary Python
# value.
# #
# This must only be called when it is assured that all # This must only be called when it is assured that all
# sub-expressions have a valid constant_result value. The # sub-expressions have a valid constant_result value. The
...@@ -619,7 +627,7 @@ class ConstNode(AtomicExprNode): ...@@ -619,7 +627,7 @@ class ConstNode(AtomicExprNode):
def check_const(self): def check_const(self):
pass pass
def get_constant_result_code(self): def get_constant_c_result_code(self):
return self.calculate_result_code() return self.calculate_result_code()
def calculate_result_code(self): def calculate_result_code(self):
...@@ -648,7 +656,7 @@ class NullNode(ConstNode): ...@@ -648,7 +656,7 @@ class NullNode(ConstNode):
value = "NULL" value = "NULL"
constant_result = 0 constant_result = 0
def get_constant_result_code(self): def get_constant_c_result_code(self):
return self.value return self.value
...@@ -695,9 +703,9 @@ class IntNode(ConstNode): ...@@ -695,9 +703,9 @@ class IntNode(ConstNode):
if self.type.is_pyobject: if self.type.is_pyobject:
self.result_code = code.get_py_num(self.value, self.longness) self.result_code = code.get_py_num(self.value, self.longness)
else: else:
self.result_code = self.get_constant_result_code() self.result_code = self.get_constant_c_result_code()
def get_constant_result_code(self): def get_constant_c_result_code(self):
return str(self.value) + self.unsigned + self.longness return str(self.value) + self.unsigned + self.longness
def calculate_result_code(self): def calculate_result_code(self):
...@@ -784,7 +792,7 @@ class StringNode(ConstNode): ...@@ -784,7 +792,7 @@ class StringNode(ConstNode):
else: else:
self.result_code = code.get_string_const(self.value) self.result_code = code.get_string_const(self.value)
def get_constant_result_code(self): def get_constant_c_result_code(self):
return None # FIXME return None # FIXME
def calculate_result_code(self): def calculate_result_code(self):
...@@ -825,7 +833,7 @@ class IdentifierStringNode(ConstNode): ...@@ -825,7 +833,7 @@ class IdentifierStringNode(ConstNode):
else: else:
self.result_code = code.get_string_const(self.value) self.result_code = code.get_string_const(self.value)
def get_constant_result_code(self): def get_constant_c_result_code(self):
return None return None
def calculate_result_code(self): def calculate_result_code(self):
...@@ -915,6 +923,11 @@ class NameNode(AtomicExprNode): ...@@ -915,6 +923,11 @@ class NameNode(AtomicExprNode):
return denv.lookup(self.name) return denv.lookup(self.name)
except KeyError: except KeyError:
error(self.pos, "Compile-time name '%s' not defined" % self.name) error(self.pos, "Compile-time name '%s' not defined" % self.name)
def get_constant_c_result_code(self):
if not self.entry or self.entry.type.is_pyobject:
return None
return self.entry.cname
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
# If coercing to a generic pyobject and this is a builtin # If coercing to a generic pyobject and this is a builtin
...@@ -4164,6 +4177,14 @@ class NumBinopNode(BinopNode): ...@@ -4164,6 +4177,14 @@ class NumBinopNode(BinopNode):
return PyrexTypes.widest_numeric_type(type1, type2) return PyrexTypes.widest_numeric_type(type1, type2)
else: else:
return None return None
def get_constant_c_result_code(self):
value1 = self.operand1.get_constant_c_result_code()
value2 = self.operand2.get_constant_c_result_code()
if value1 and value2:
return "(%s %s %s)" % (value1, self.operator, value2)
else:
return None
def c_types_okay(self, type1, type2): def c_types_okay(self, type1, type2):
#print "NumBinopNode.c_types_okay:", type1, type2 ### #print "NumBinopNode.c_types_okay:", type1, type2 ###
......
...@@ -455,12 +455,13 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -455,12 +455,13 @@ class CArrayDeclaratorNode(CDeclaratorNode):
self.dimension.analyse_const_expression(env) self.dimension.analyse_const_expression(env)
if not self.dimension.type.is_int: if not self.dimension.type.is_int:
error(self.dimension.pos, "Array dimension not integer") error(self.dimension.pos, "Array dimension not integer")
size = self.dimension.get_constant_result_code() size = self.dimension.get_constant_c_result_code()
try: if size is not None:
size = int(size) try:
except (ValueError, TypeError): size = int(size)
# runtime constant? except ValueError:
pass # runtime constant?
pass
else: else:
size = None size = None
if not base_type.is_complete(): if not base_type.is_complete():
...@@ -541,7 +542,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -541,7 +542,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
else: else:
if self.exception_value: if self.exception_value:
self.exception_value.analyse_const_expression(env) self.exception_value.analyse_const_expression(env)
exc_val = self.exception_value.get_constant_result_code() exc_val = self.exception_value.get_constant_c_result_code()
if self.exception_check == '+': if self.exception_check == '+':
exc_val_type = self.exception_value.type exc_val_type = self.exception_value.type
if not exc_val_type.is_error and \ if not exc_val_type.is_error and \
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment