from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * from Cython.Compiler.TreeFragment import TreeFragment from Cython.Utils import EncodedString from Cython.Compiler.Errors import CompileError import PyrexTypes from sets import Set as set class PureCFuncNode(Node): def __init__(self, pos, cname, type, c_code, visibility='private'): self.pos = pos self.cname = cname self.type = type self.c_code = c_code self.visibility = visibility def analyse_types(self, env): self.entry = env.declare_cfunction( "<pure c function:%s>" % self.cname, self.type, self.pos, cname=self.cname, defining=True, visibility=self.visibility) def generate_function_definitions(self, env, code, transforms): assert self.type.optional_arg_count == 0 visibility = self.entry.visibility if visibility != 'private': storage_class = "%s " % Naming.extern_c_macro else: storage_class = "static " arg_decls = [arg.declaration_code() for arg in self.type.args] sig = self.type.return_type.declaration_code( self.type.function_header_code(self.cname, ", ".join(arg_decls))) code.putln("") code.putln("%s%s {" % (storage_class, sig)) code.put(self.c_code) code.putln("}") def generate_execution_code(self, code): pass tschecker_functype = PyrexTypes.CFuncType( PyrexTypes.c_char_ptr_type, [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type, (0, 0, None), cname="ts")], exception_value = "NULL" ) tsprefix = "__Pyx_tsc" class BufferTransform(CythonTransform): """ Run after type analysis. Takes care of the buffer functionality. Expects to be run on the full module. If you need to process a fragment one should look into refactoring this transform. """ # Abbreviations: # "ts" means typestring and/or typestring checking stuff scope = None # # Entry point # def __call__(self, node): assert isinstance(node, ModuleNode) try: cymod = self.context.modules[u'__cython__'] except KeyError: # No buffer fun for this module return node self.bufstruct_type = cymod.entries[u'Py_buffer'].type self.tscheckers = {} self.ts_funcs = [] self.ts_item_checkers = {} self.module_scope = node.scope self.module_pos = node.pos result = super(BufferTransform, self).__call__(node) # Register ts stuff if "endian.h" not in node.scope.include_files: node.scope.include_files.append("endian.h") result.body.stats += self.ts_funcs return result # # Basic operations for transforms # def handle_scope(self, node, scope): # For all buffers, insert extra variables in the scope. # The variables are also accessible from the buffer_info # on the buffer entry bufvars = [(name, entry) for name, entry in scope.entries.iteritems() if entry.type.buffer_options is not None] for name, entry in bufvars: bufopts = entry.type.buffer_options # Get or make a type string checker tschecker = self.tschecker(bufopts.dtype) # Declare auxiliary vars bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name), self.bufstruct_type, node.pos) temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name), entry.type, node.pos) stridevars = [] shapevars = [] for idx in range(bufopts.ndim): # stride varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx)) var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True) stridevars.append(var) # shape varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx)) var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True) shapevars.append(var) entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker) entry.buffer_aux.temp_var = temp_var self.scope = scope # Notes: The cast to <char*> gets around Cython not supporting const types acquire_buffer_fragment = TreeFragment(u""" TMP = LHS if TMP is not None: __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) TMP = RHS if TMP is not None: __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) TSCHECKER(<char*>BUFINFO.format) ASSIGN_AUX LHS = TMP """) fetch_strides = TreeFragment(u""" TARGET = BUFINFO.strides[IDX] """) fetch_shape = TreeFragment(u""" TARGET = BUFINFO.shape[IDX] """) def reacquire_buffer(self, node): bufaux = node.lhs.entry.buffer_aux auxass = [] for idx, entry in enumerate(bufaux.stridevars): entry.used = True ass = self.fetch_strides.substitute({ u"TARGET": NameNode(node.pos, name=entry.name), u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"IDX": IntNode(node.pos, value=EncodedString(idx)), }) auxass.append(ass) for idx, entry in enumerate(bufaux.shapevars): entry.used = True ass = self.fetch_shape.substitute({ u"TARGET": NameNode(node.pos, name=entry.name), u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"IDX": IntNode(node.pos, value=EncodedString(idx)) }) auxass.append(ass) bufaux.buffer_info_var.used = True acq = self.acquire_buffer_fragment.substitute({ u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), u"LHS" : node.lhs, u"RHS": node.rhs, u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name), u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name) }, pos=node.pos) # Note: The below should probably be refactored into something # like fragment.substitute(..., context=self.context), with # TreeFragment getting context.pipeline_until_now() and # applying it on the fragment. acq.analyse_declarations(self.scope) acq.analyse_expressions(self.scope) stats = acq.stats return stats def assign_into_buffer(self, node): result = SingleAssignmentNode(node.pos, rhs=self.visit(node.rhs), lhs=self.buffer_index(node.lhs)) result.analyse_expressions(self.scope) return result def buffer_index(self, node): pos = node.pos bufaux = node.base.entry.buffer_aux assert bufaux is not None # indices * strides... to_sum = [ IntBinopNode(pos, operator='*', operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index), operand2=NameNode(node.pos, name=stride.name)) for index, stride in zip(node.indices, bufaux.stridevars)] # then sum them with the buffer pointer expr = AttributeNode(pos, obj=NameNode(pos, name=bufaux.buffer_info_var.name), attribute=EncodedString("buf")) for next in to_sum: expr = AddNode(pos, operator='+', operand1=expr, operand2=next) casted = TypecastNode(pos, operand=expr, type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype)) result = IndexNode(pos, base=casted, index=IntNode(pos, value='0')) return result # # Transforms # def visit_ModuleNode(self, node): self.handle_scope(node, node.scope) self.visitchildren(node) return node def visit_FuncDefNode(self, node): self.handle_scope(node, node.local_scope) self.visitchildren(node) return node def visit_SingleAssignmentNode(self, node): # On assignments, two buffer-related things can happen: # a) A buffer variable is assigned to (reacquisition) # b) Buffer access assignment: arr[...] = ... # Since we don't allow nested buffers, these don't overlap. self.visitchildren(node) # Only acquire buffers on vars (not attributes) for now. if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: # Is buffer variable return self.reacquire_buffer(node) elif (isinstance(node.lhs, IndexNode) and isinstance(node.lhs.base, NameNode) and node.lhs.base.entry.buffer_aux is not None): return self.assign_into_buffer(node) else: return node def visit_IndexNode(self, node): # Only occurs when the IndexNode is an rvalue if node.is_buffer_access: assert node.index is None assert node.indices is not None result = self.buffer_index(node) result.analyse_expressions(self.scope) return result else: return node # # Utils for creating type string checkers # def new_ts_func(self, name, code): cname = "%s_%s" % (tsprefix, name) funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code) funcnode.analyse_types(self.module_scope) self.ts_funcs.append(funcnode) return funcnode def mangle_dtype_name(self, dtype): # Use prefixes to seperate user defined types from builtins # (consider "typedef float unsigned_int") return dtype.declaration_code("").replace(" ", "_") def get_ts_check_item(self, dtype): # See if we can consume one (unnamed) dtype as next item funcnode = self.ts_item_checkers.get(dtype) if funcnode is None: char = dtype.typestring if char is not None and len(char) > 1: # Can use direct comparison funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\ if (*ts != '%s') { PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts); return NULL; } else return ts + 1; """ % char) else: # Must deduce sign and length; rely on int vs. float to be correctly declared ctype = dtype.declaration_code("") code = """\ int ok; switch (*ts) {""" if dtype.is_int: types = [ ('b', 'char'), ('h', 'short'), ('i', 'int'), ('l', 'long'), ('q', 'long long') ] code += "".join(["""\ case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break; case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break; """ % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for char, against in types]) code += """\ default: ok = 0; } if (!ok) { PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts); return NULL; } else return ts + 1; """ funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code) self.ts_item_checkers[dtype] = funcnode return funcnode.entry.cname ts_consume_whitespace_cname = None ts_check_endian_cname = None def ensure_ts_utils(self): # Makes sure that the typechecker utils are in scope # (and constructs them if not) if self.ts_consume_whitespace_cname is None: self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\ while (1) { switch (*ts) { case 10: case 13: case ' ': ++ts; default: return ts; } } """).entry.cname if self.ts_check_endian_cname is None: self.ts_check_endian_cname = self.new_ts_func("check_endian", """\ int ok = 1; switch (*ts) { case '@': case '=': ++ts; break; case '<': if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts; else ok = 0; break; case '>': case '!': if (__BYTE_ORDER == __BIG_ENDIAN) ++ts; else ok = 0; break; } if (!ok) { PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts); return NULL; } return ts; """).entry.cname def create_ts_check_simple(self, dtype): # Check whole string for single unnamed item consume_whitespace = self.ts_consume_whitespace_cname check_endian = self.ts_check_endian_cname check_item = self.get_ts_check_item(dtype) return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\ ts = %(consume_whitespace)s(ts); ts = %(check_endian)s(ts); if (!ts) return NULL; ts = %(consume_whitespace)s(ts); ts = %(check_item)s(ts); if (!ts) return NULL; ts = %(consume_whitespace)s(ts); if (*ts != 0) { PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts); return NULL; } return ts; """ % locals()) def tschecker(self, dtype): # Creates a type string checker function for the given type. # Each checker is created as a function entry in the module scope # and a PureCNode and put in the self.ts_checkers dict. # Also the entry is returned. # # TODO: __eq__ and __hash__ for types self.ensure_ts_utils() funcnode = self.tscheckers.get(dtype) if funcnode is None: if dtype.is_struct_or_union: assert False elif dtype.is_int or dtype.is_float: # This includes simple typedef-ed types funcnode = self.create_ts_check_simple(dtype) else: assert False self.tscheckers[dtype] = funcnode return funcnode.entry # TODO: # - buf must be NULL before getting new buffer