Commit 7771e204 authored by Robert Bradshaw's avatar Robert Bradshaw

cython.locals(...) decorator for pure python type declarations.

parent 5249c536
......@@ -713,6 +713,15 @@ class StringNode(ConstNode):
def analyse_types(self, env):
self.entry = env.add_string_const(self.value)
def analyse_as_type(self, env):
from TreeFragment import TreeFragment
pos = (self.pos[0], self.pos[1], self.pos[2]-7)
declaration = TreeFragment(u"sizeof(%s)" % self.value, name=pos[0].filename, initial_pos=pos)
sizeof_node = declaration.root.stats[0].expr
sizeof_node.analyse_types(env)
if isinstance(sizeof_node, SizeofTypeNode):
return sizeof_node.arg_type
def coerce_to(self, dst_type, env):
if dst_type.is_int:
......@@ -886,6 +895,8 @@ class NameNode(AtomicExprNode):
return None
def analyse_as_type(self, env):
if self.name in PyrexTypes.rank_to_type_name:
return PyrexTypes.simple_c_type(1, 0, self.name)
entry = self.entry
if not entry:
entry = env.lookup(self.name)
......@@ -2767,6 +2778,9 @@ class DictItemNode(ExprNode):
def generate_disposal_code(self, code):
self.key.generate_disposal_code(code)
self.value.generate_disposal_code(code)
def __iter__(self):
return iter([self.key, self.value])
class ClassNode(ExprNode):
......
......@@ -1166,11 +1166,16 @@ class CFuncDefNode(FuncDefNode):
# overridable whether or not this is a cpdef function
child_attrs = ["base_type", "declarator", "body", "py_func"]
def unqualified_name(self):
return self.entry.name
def analyse_declarations(self, env):
if 'locals' in env.directives:
directive_locals = env.directives['locals']
else:
directive_locals = {}
self.directive_locals = directive_locals
base_type = self.base_type.analyse(env)
# The 2 here is because we need both function and argument names.
name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
......@@ -1442,11 +1447,27 @@ class DefNode(FuncDefNode):
entry = None
def analyse_declarations(self, env):
if 'locals' in env.directives:
directive_locals = env.directives['locals']
else:
directive_locals = {}
self.directive_locals = directive_locals
for arg in self.args:
base_type = arg.base_type.analyse(env)
name_declarator, type = \
arg.declarator.analyse(base_type, env)
arg.name = name_declarator.name
if arg.name in directive_locals:
type_node = directive_locals[arg.name]
other_type = type_node.analyse_as_type(env)
if other_type is None:
error(type_node.pos, "Not a type")
elif (type is not PyrexTypes.py_object_type
and not type.same_as(other_type)):
error(arg.base_type.pos, "Signature does not agree with previous declaration")
error(type_node.pos, "Previous declaration here")
else:
type = other_type
if name_declarator.cname:
error(self.pos,
"Python function argument cannot have C name specification")
......
......@@ -58,13 +58,15 @@ c_line_in_traceback = 1
option_types = {
'boundscheck' : bool,
'nonecheck' : bool,
'embedsignature' : bool
'embedsignature' : bool,
'locals' : dict,
}
option_defaults = {
'boundscheck' : True,
'nonecheck' : False,
'embedsignature' : False,
'locals' : {}
}
def parse_option_value(name, value):
......
......@@ -308,6 +308,14 @@ class InterpretCompilerDirectives(CythonTransform):
newimp.append((pos, name, as_name, kind))
node.imported_names = newimpo
return node
def visit_SingleAssignmentNode(self, node):
if (isinstance(node.rhs, ImportNode) and
node.rhs.module_name.value == u'cython'):
self.cython_module_names.add(node.lhs.name)
else:
self.visitchildren(node)
return node
def visit_Node(self, node):
self.visitchildren(node)
......@@ -318,7 +326,7 @@ class InterpretCompilerDirectives(CythonTransform):
# decorator), returns (optionname, value).
# Otherwise, returns None
optname = None
if isinstance(node, SimpleCallNode):
if isinstance(node, CallNode):
if (isinstance(node.function, AttributeNode) and
isinstance(node.function.obj, NameNode) and
node.function.obj.name in self.cython_module_names):
......@@ -330,12 +338,25 @@ class InterpretCompilerDirectives(CythonTransform):
if optname:
optiontype = Options.option_types.get(optname)
if optiontype:
args = node.args
if isinstance(node, SimpleCallNode):
args = node.args
kwds = None
else:
if node.starstar_arg or not isinstance(node.positional_args, TupleNode):
raise PostParseError(dec.function.pos,
'Compile-time keyword arguments must be explicit.' % optname)
args = node.positional_args.args
kwds = node.keyword_args
if optiontype is bool:
if len(args) != 1 or not isinstance(args[0], BoolNode):
if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(dec.function.pos,
'The %s option takes one compile-time boolean argument' % optname)
return (optname, args[0].value)
elif optiontype is dict:
if len(args) != 0:
raise PostParseError(dec.function.pos,
'The %s option takes no prepositional arguments' % optname)
return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
else:
assert False
......@@ -367,7 +388,7 @@ class InterpretCompilerDirectives(CythonTransform):
else:
realdecs.append(dec)
node.decorators = realdecs
if options:
optdict = {}
options.reverse() # Decorators coming first take precedence
......@@ -499,12 +520,19 @@ property NAME:
lenv = node.create_local_scope(self.env_stack[-1])
node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
if type:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
return node
# Some nodes are no longer needed after declaration
# analysis and can be dropped. The analysis was performed
# on these nodes in a seperate recursive process from the
......
......@@ -1168,6 +1168,7 @@ modifiers_and_name_to_type = {
(1, 0, "int"): c_int_type,
(1, 1, "int"): c_long_type,
(1, 2, "int"): c_longlong_type,
(1, 0, "long"): c_long_type,
(1, 0, "Py_ssize_t"): c_py_ssize_t_type,
(1, 0, "float"): c_float_type,
(1, 0, "double"): c_double_type,
......@@ -1216,6 +1217,19 @@ def c_ptr_type(base_type):
return c_char_ptr_type
else:
return CPtrType(base_type)
def Node_to_type(node, env):
from ExprNodes import NameNode, AttributeNode, StringNode, error
if isinstance(node, StringNode):
node = NameNode(node.pos, name=node.value)
if isinstance(node, NameNode) and node.name in rank_to_type_name:
return simple_c_type(1, 0, node.name)
elif isinstance(node, (AttributeNode, NameNode)):
node.analyze_types(env)
if not node.entry.is_type:
pass
else:
error(node.pos, "Bad type")
def public_decl(base, dll_linkage):
if dll_linkage:
......
......@@ -289,8 +289,8 @@ class PyrexScanner(Scanner):
resword_dict = build_resword_dict()
def __init__(self, file, filename, parent_scanner = None,
scope = None, context = None, source_encoding=None, parse_comments=True):
Scanner.__init__(self, get_lexicon(), file, filename)
scope = None, context = None, source_encoding=None, parse_comments=True, initial_pos=None):
Scanner.__init__(self, get_lexicon(), file, filename, initial_pos)
if parent_scanner:
self.context = parent_scanner.context
self.included_files = parent_scanner.included_files
......
......@@ -29,7 +29,7 @@ class StringParseContext(Main.Context):
raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
return ModuleScope(module_name, parent_module = None, context = self)
def parse_from_strings(name, code, pxds={}, level=None):
def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
"""
Utility method to parse a (unicode) string of code. This is mostly
used for internal Cython compiler purposes (creating code snippets
......@@ -47,7 +47,8 @@ def parse_from_strings(name, code, pxds={}, level=None):
encoding = "UTF-8"
module_name = name
initial_pos = (name, 1, 0)
if initial_pos is None:
initial_pos = (name, 1, 0)
code_source = StringSourceDescriptor(name, code)
context = StringParseContext([], name)
......@@ -56,7 +57,7 @@ def parse_from_strings(name, code, pxds={}, level=None):
buf = StringIO(code.encode(encoding))
scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
scope = scope, context = context)
scope = scope, context = context, initial_pos = initial_pos)
if level is None:
tree = Parsing.p_module(scanner, 0, module_name)
else:
......@@ -181,7 +182,7 @@ def strip_common_indent(lines):
return lines
class TreeFragment(object):
def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None):
def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
if isinstance(code, unicode):
def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
......@@ -189,8 +190,7 @@ class TreeFragment(object):
fmt_pxds = {}
for key, value in pxds.iteritems():
fmt_pxds[key] = fmt(value)
mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level)
mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
if level is None:
t = t.body # Make sure a StatListNode is at the top
if not isinstance(t, StatListNode):
......
......@@ -60,7 +60,7 @@ class Scanner:
queue = None # list of tokens to be returned
trace = 0
def __init__(self, lexicon, stream, name = ''):
def __init__(self, lexicon, stream, name = '', initial_pos = None):
"""
Scanner(lexicon, stream, name = '')
......@@ -84,6 +84,8 @@ class Scanner:
self.cur_line_start = 0
self.cur_char = BOL
self.input_state = 1
if initial_pos is not None:
self.cur_line, self.cur_line_start = initial_pos[1], -initial_pos[2]
def read(self):
"""
......
def empty_decorator(x):
return x
def locals(**arg_types):
return empty_decorator
def cast(type, arg):
# can/should we emulate anything here?
return arg
py_int = int
py_long = long
py_float = float
# They just have to exist...
int = long = char = bint = uint = ulong = longlong = ulonglong = Py_ssize_t = float = double = None
# Void cython.* directives (for case insensitive operating systems).
from Shadow import *
......@@ -2,5 +2,11 @@
# Cython -- Main Program, generic
#
from Cython.Compiler.Main import main
main(command_line = 1)
if __name__ == '__main__':
from Cython.Compiler.Main import main
main(command_line = 1)
else:
# Void cython.* directives.
from Cython.Shadow import *
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment