Commit 7771e204 authored by Robert Bradshaw's avatar Robert Bradshaw

cython.locals(...) decorator for pure python type declarations.

parent 5249c536
...@@ -713,6 +713,15 @@ class StringNode(ConstNode): ...@@ -713,6 +713,15 @@ class StringNode(ConstNode):
def analyse_types(self, env): def analyse_types(self, env):
self.entry = env.add_string_const(self.value) self.entry = env.add_string_const(self.value)
def analyse_as_type(self, env):
from TreeFragment import TreeFragment
pos = (self.pos[0], self.pos[1], self.pos[2]-7)
declaration = TreeFragment(u"sizeof(%s)" % self.value, name=pos[0].filename, initial_pos=pos)
sizeof_node = declaration.root.stats[0].expr
sizeof_node.analyse_types(env)
if isinstance(sizeof_node, SizeofTypeNode):
return sizeof_node.arg_type
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type.is_int: if dst_type.is_int:
...@@ -886,6 +895,8 @@ class NameNode(AtomicExprNode): ...@@ -886,6 +895,8 @@ class NameNode(AtomicExprNode):
return None return None
def analyse_as_type(self, env): def analyse_as_type(self, env):
if self.name in PyrexTypes.rank_to_type_name:
return PyrexTypes.simple_c_type(1, 0, self.name)
entry = self.entry entry = self.entry
if not entry: if not entry:
entry = env.lookup(self.name) entry = env.lookup(self.name)
...@@ -2767,6 +2778,9 @@ class DictItemNode(ExprNode): ...@@ -2767,6 +2778,9 @@ class DictItemNode(ExprNode):
def generate_disposal_code(self, code): def generate_disposal_code(self, code):
self.key.generate_disposal_code(code) self.key.generate_disposal_code(code)
self.value.generate_disposal_code(code) self.value.generate_disposal_code(code)
def __iter__(self):
return iter([self.key, self.value])
class ClassNode(ExprNode): class ClassNode(ExprNode):
......
...@@ -1166,11 +1166,16 @@ class CFuncDefNode(FuncDefNode): ...@@ -1166,11 +1166,16 @@ class CFuncDefNode(FuncDefNode):
# overridable whether or not this is a cpdef function # overridable whether or not this is a cpdef function
child_attrs = ["base_type", "declarator", "body", "py_func"] child_attrs = ["base_type", "declarator", "body", "py_func"]
def unqualified_name(self): def unqualified_name(self):
return self.entry.name return self.entry.name
def analyse_declarations(self, env): def analyse_declarations(self, env):
if 'locals' in env.directives:
directive_locals = env.directives['locals']
else:
directive_locals = {}
self.directive_locals = directive_locals
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
# The 2 here is because we need both function and argument names. # The 2 here is because we need both function and argument names.
name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None)) name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
...@@ -1442,11 +1447,27 @@ class DefNode(FuncDefNode): ...@@ -1442,11 +1447,27 @@ class DefNode(FuncDefNode):
entry = None entry = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
if 'locals' in env.directives:
directive_locals = env.directives['locals']
else:
directive_locals = {}
self.directive_locals = directive_locals
for arg in self.args: for arg in self.args:
base_type = arg.base_type.analyse(env) base_type = arg.base_type.analyse(env)
name_declarator, type = \ name_declarator, type = \
arg.declarator.analyse(base_type, env) arg.declarator.analyse(base_type, env)
arg.name = name_declarator.name arg.name = name_declarator.name
if arg.name in directive_locals:
type_node = directive_locals[arg.name]
other_type = type_node.analyse_as_type(env)
if other_type is None:
error(type_node.pos, "Not a type")
elif (type is not PyrexTypes.py_object_type
and not type.same_as(other_type)):
error(arg.base_type.pos, "Signature does not agree with previous declaration")
error(type_node.pos, "Previous declaration here")
else:
type = other_type
if name_declarator.cname: if name_declarator.cname:
error(self.pos, error(self.pos,
"Python function argument cannot have C name specification") "Python function argument cannot have C name specification")
......
...@@ -58,13 +58,15 @@ c_line_in_traceback = 1 ...@@ -58,13 +58,15 @@ c_line_in_traceback = 1
option_types = { option_types = {
'boundscheck' : bool, 'boundscheck' : bool,
'nonecheck' : bool, 'nonecheck' : bool,
'embedsignature' : bool 'embedsignature' : bool,
'locals' : dict,
} }
option_defaults = { option_defaults = {
'boundscheck' : True, 'boundscheck' : True,
'nonecheck' : False, 'nonecheck' : False,
'embedsignature' : False, 'embedsignature' : False,
'locals' : {}
} }
def parse_option_value(name, value): def parse_option_value(name, value):
......
...@@ -308,6 +308,14 @@ class InterpretCompilerDirectives(CythonTransform): ...@@ -308,6 +308,14 @@ class InterpretCompilerDirectives(CythonTransform):
newimp.append((pos, name, as_name, kind)) newimp.append((pos, name, as_name, kind))
node.imported_names = newimpo node.imported_names = newimpo
return node return node
def visit_SingleAssignmentNode(self, node):
if (isinstance(node.rhs, ImportNode) and
node.rhs.module_name.value == u'cython'):
self.cython_module_names.add(node.lhs.name)
else:
self.visitchildren(node)
return node
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
...@@ -318,7 +326,7 @@ class InterpretCompilerDirectives(CythonTransform): ...@@ -318,7 +326,7 @@ class InterpretCompilerDirectives(CythonTransform):
# decorator), returns (optionname, value). # decorator), returns (optionname, value).
# Otherwise, returns None # Otherwise, returns None
optname = None optname = None
if isinstance(node, SimpleCallNode): if isinstance(node, CallNode):
if (isinstance(node.function, AttributeNode) and if (isinstance(node.function, AttributeNode) and
isinstance(node.function.obj, NameNode) and isinstance(node.function.obj, NameNode) and
node.function.obj.name in self.cython_module_names): node.function.obj.name in self.cython_module_names):
...@@ -330,12 +338,25 @@ class InterpretCompilerDirectives(CythonTransform): ...@@ -330,12 +338,25 @@ class InterpretCompilerDirectives(CythonTransform):
if optname: if optname:
optiontype = Options.option_types.get(optname) optiontype = Options.option_types.get(optname)
if optiontype: if optiontype:
args = node.args if isinstance(node, SimpleCallNode):
args = node.args
kwds = None
else:
if node.starstar_arg or not isinstance(node.positional_args, TupleNode):
raise PostParseError(dec.function.pos,
'Compile-time keyword arguments must be explicit.' % optname)
args = node.positional_args.args
kwds = node.keyword_args
if optiontype is bool: if optiontype is bool:
if len(args) != 1 or not isinstance(args[0], BoolNode): if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(dec.function.pos, raise PostParseError(dec.function.pos,
'The %s option takes one compile-time boolean argument' % optname) 'The %s option takes one compile-time boolean argument' % optname)
return (optname, args[0].value) return (optname, args[0].value)
elif optiontype is dict:
if len(args) != 0:
raise PostParseError(dec.function.pos,
'The %s option takes no prepositional arguments' % optname)
return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
else: else:
assert False assert False
...@@ -367,7 +388,7 @@ class InterpretCompilerDirectives(CythonTransform): ...@@ -367,7 +388,7 @@ class InterpretCompilerDirectives(CythonTransform):
else: else:
realdecs.append(dec) realdecs.append(dec)
node.decorators = realdecs node.decorators = realdecs
if options: if options:
optdict = {} optdict = {}
options.reverse() # Decorators coming first take precedence options.reverse() # Decorators coming first take precedence
...@@ -499,12 +520,19 @@ property NAME: ...@@ -499,12 +520,19 @@ property NAME:
lenv = node.create_local_scope(self.env_stack[-1]) lenv = node.create_local_scope(self.env_stack[-1])
node.body.analyse_control_flow(lenv) # this will be totally refactored node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv) node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
node.body.analyse_declarations(lenv) node.body.analyse_declarations(lenv)
self.env_stack.append(lenv) self.env_stack.append(lenv)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
return node return node
# Some nodes are no longer needed after declaration # Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed # analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the # on these nodes in a seperate recursive process from the
......
...@@ -1168,6 +1168,7 @@ modifiers_and_name_to_type = { ...@@ -1168,6 +1168,7 @@ modifiers_and_name_to_type = {
(1, 0, "int"): c_int_type, (1, 0, "int"): c_int_type,
(1, 1, "int"): c_long_type, (1, 1, "int"): c_long_type,
(1, 2, "int"): c_longlong_type, (1, 2, "int"): c_longlong_type,
(1, 0, "long"): c_long_type,
(1, 0, "Py_ssize_t"): c_py_ssize_t_type, (1, 0, "Py_ssize_t"): c_py_ssize_t_type,
(1, 0, "float"): c_float_type, (1, 0, "float"): c_float_type,
(1, 0, "double"): c_double_type, (1, 0, "double"): c_double_type,
...@@ -1216,6 +1217,19 @@ def c_ptr_type(base_type): ...@@ -1216,6 +1217,19 @@ def c_ptr_type(base_type):
return c_char_ptr_type return c_char_ptr_type
else: else:
return CPtrType(base_type) return CPtrType(base_type)
def Node_to_type(node, env):
from ExprNodes import NameNode, AttributeNode, StringNode, error
if isinstance(node, StringNode):
node = NameNode(node.pos, name=node.value)
if isinstance(node, NameNode) and node.name in rank_to_type_name:
return simple_c_type(1, 0, node.name)
elif isinstance(node, (AttributeNode, NameNode)):
node.analyze_types(env)
if not node.entry.is_type:
pass
else:
error(node.pos, "Bad type")
def public_decl(base, dll_linkage): def public_decl(base, dll_linkage):
if dll_linkage: if dll_linkage:
......
...@@ -289,8 +289,8 @@ class PyrexScanner(Scanner): ...@@ -289,8 +289,8 @@ class PyrexScanner(Scanner):
resword_dict = build_resword_dict() resword_dict = build_resword_dict()
def __init__(self, file, filename, parent_scanner = None, def __init__(self, file, filename, parent_scanner = None,
scope = None, context = None, source_encoding=None, parse_comments=True): scope = None, context = None, source_encoding=None, parse_comments=True, initial_pos=None):
Scanner.__init__(self, get_lexicon(), file, filename) Scanner.__init__(self, get_lexicon(), file, filename, initial_pos)
if parent_scanner: if parent_scanner:
self.context = parent_scanner.context self.context = parent_scanner.context
self.included_files = parent_scanner.included_files self.included_files = parent_scanner.included_files
......
...@@ -29,7 +29,7 @@ class StringParseContext(Main.Context): ...@@ -29,7 +29,7 @@ class StringParseContext(Main.Context):
raise AssertionError("Not yet supporting any cimports/includes from string code snippets") raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
return ModuleScope(module_name, parent_module = None, context = self) return ModuleScope(module_name, parent_module = None, context = self)
def parse_from_strings(name, code, pxds={}, level=None): def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
""" """
Utility method to parse a (unicode) string of code. This is mostly Utility method to parse a (unicode) string of code. This is mostly
used for internal Cython compiler purposes (creating code snippets used for internal Cython compiler purposes (creating code snippets
...@@ -47,7 +47,8 @@ def parse_from_strings(name, code, pxds={}, level=None): ...@@ -47,7 +47,8 @@ def parse_from_strings(name, code, pxds={}, level=None):
encoding = "UTF-8" encoding = "UTF-8"
module_name = name module_name = name
initial_pos = (name, 1, 0) if initial_pos is None:
initial_pos = (name, 1, 0)
code_source = StringSourceDescriptor(name, code) code_source = StringSourceDescriptor(name, code)
context = StringParseContext([], name) context = StringParseContext([], name)
...@@ -56,7 +57,7 @@ def parse_from_strings(name, code, pxds={}, level=None): ...@@ -56,7 +57,7 @@ def parse_from_strings(name, code, pxds={}, level=None):
buf = StringIO(code.encode(encoding)) buf = StringIO(code.encode(encoding))
scanner = PyrexScanner(buf, code_source, source_encoding = encoding, scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
scope = scope, context = context) scope = scope, context = context, initial_pos = initial_pos)
if level is None: if level is None:
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
else: else:
...@@ -181,7 +182,7 @@ def strip_common_indent(lines): ...@@ -181,7 +182,7 @@ def strip_common_indent(lines):
return lines return lines
class TreeFragment(object): class TreeFragment(object):
def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None): def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
if isinstance(code, unicode): if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
...@@ -189,8 +190,7 @@ class TreeFragment(object): ...@@ -189,8 +190,7 @@ class TreeFragment(object):
fmt_pxds = {} fmt_pxds = {}
for key, value in pxds.iteritems(): for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value) fmt_pxds[key] = fmt(value)
mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level)
if level is None: if level is None:
t = t.body # Make sure a StatListNode is at the top t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode): if not isinstance(t, StatListNode):
......
...@@ -60,7 +60,7 @@ class Scanner: ...@@ -60,7 +60,7 @@ class Scanner:
queue = None # list of tokens to be returned queue = None # list of tokens to be returned
trace = 0 trace = 0
def __init__(self, lexicon, stream, name = ''): def __init__(self, lexicon, stream, name = '', initial_pos = None):
""" """
Scanner(lexicon, stream, name = '') Scanner(lexicon, stream, name = '')
...@@ -84,6 +84,8 @@ class Scanner: ...@@ -84,6 +84,8 @@ class Scanner:
self.cur_line_start = 0 self.cur_line_start = 0
self.cur_char = BOL self.cur_char = BOL
self.input_state = 1 self.input_state = 1
if initial_pos is not None:
self.cur_line, self.cur_line_start = initial_pos[1], -initial_pos[2]
def read(self): def read(self):
""" """
......
def empty_decorator(x):
return x
def locals(**arg_types):
return empty_decorator
def cast(type, arg):
# can/should we emulate anything here?
return arg
py_int = int
py_long = long
py_float = float
# They just have to exist...
int = long = char = bint = uint = ulong = longlong = ulonglong = Py_ssize_t = float = double = None
# Void cython.* directives (for case insensitive operating systems).
from Shadow import *
...@@ -2,5 +2,11 @@ ...@@ -2,5 +2,11 @@
# Cython -- Main Program, generic # Cython -- Main Program, generic
# #
from Cython.Compiler.Main import main if __name__ == '__main__':
main(command_line = 1)
from Cython.Compiler.Main import main
main(command_line = 1)
else:
# Void cython.* directives.
from Cython.Shadow import *
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment