Commit 033c5cdd authored by Mark's avatar Mark

Merge pull request #115 from markflorisson88/_fused_dispatch_rebased

fused runtime dispatch for buffers
parents 52138792 5e76c087
...@@ -14,6 +14,7 @@ import re ...@@ -14,6 +14,7 @@ import re
import sys import sys
from string import Template from string import Template
import operator import operator
import textwrap
import Naming import Naming
import Options import Options
...@@ -376,7 +377,7 @@ class UtilityCode(UtilityCodeBase): ...@@ -376,7 +377,7 @@ class UtilityCode(UtilityCodeBase):
self.cleanup(writer, output.module_pos) self.cleanup(writer, output.module_pos)
def sub_tempita(s, context, file, name): def sub_tempita(s, context, file=None, name=None):
"Run tempita on string s with given context." "Run tempita on string s with given context."
if not s: if not s:
return None return None
...@@ -1940,6 +1941,75 @@ class PyrexCodeWriter(object): ...@@ -1940,6 +1941,75 @@ class PyrexCodeWriter(object):
def dedent(self): def dedent(self):
self.level -= 1 self.level -= 1
class PyxCodeWriter(object):
"""
Can be used for writing out some Cython code. To use the indenter
functionality, the Cython.Compiler.Importer module will have to be used
to load the code to support python 2.4
"""
def __init__(self, buffer=None, indent_level=0, context=None, encoding='ascii'):
self.buffer = buffer or StringIOTree()
self.level = indent_level
self.context = context
self.encoding = encoding
def indent(self, levels=1):
self.level += levels
return True
def dedent(self, levels=1):
self.level -= levels
def indenter(self, line):
"""
Instead of
with pyx_code.indenter("for i in range(10):"):
pyx_code.putln("print i")
write
if pyx_code.indenter("for i in range(10);"):
pyx_code.putln("print i")
pyx_code.dedent()
"""
self.putln(line)
self.indent()
return True
def getvalue(self):
result = self.buffer.getvalue()
if not isinstance(result, unicode):
result = result.decode(self.encoding)
return result
def putln(self, line, context=None):
context = context or self.context
if context:
line = sub_tempita(line, context)
self._putln(line)
def _putln(self, line):
self.buffer.write("%s%s\n" % (self.level * " ", line))
def put_chunk(self, chunk, context=None):
context = context or self.context
if context:
chunk = sub_tempita(chunk, context)
chunk = textwrap.dedent(chunk)
for line in chunk.splitlines():
self._putln(line)
def insertion_point(self):
return PyxCodeWriter(self.buffer.insertion_point(), self.level,
self.context)
def named_insertion_point(self, name):
setattr(self, name, self.insertion_point())
class ClosureTempAllocator(object): class ClosureTempAllocator(object):
def __init__(self, klass): def __init__(self, klass):
......
...@@ -6245,7 +6245,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6245,7 +6245,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
else: else:
arg.default = DefaultLiteralArgNode(arg.pos, arg.default) arg.default = DefaultLiteralArgNode(arg.pos, arg.default)
default_args.append(arg) default_args.append(arg)
if nonliteral_objects or nonliteral_objects: if nonliteral_objects or nonliteral_other:
module_scope = env.global_scope() module_scope = env.global_scope()
cname = module_scope.next_id(Naming.defaults_struct_prefix) cname = module_scope.next_id(Naming.defaults_struct_prefix)
scope = Symtab.StructOrUnionScope(cname) scope = Symtab.StructOrUnionScope(cname)
......
This diff is collapsed.
This diff is collapsed.
...@@ -1488,7 +1488,8 @@ if VALUE is not None: ...@@ -1488,7 +1488,8 @@ if VALUE is not None:
return node return node
node = Nodes.FusedCFuncDefNode(node, env) from Cython.Compiler import FusedNode
node = FusedNode.FusedCFuncDefNode(node, env)
self.fused_function = node self.fused_function = node
self.visitchildren(node) self.visitchildren(node)
...@@ -1498,6 +1499,7 @@ if VALUE is not None: ...@@ -1498,6 +1499,7 @@ if VALUE is not None:
# Create PyCFunction nodes for each specialization # Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func) node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func) node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True) True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env)) pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
......
...@@ -790,10 +790,9 @@ class BufferType(BaseType): ...@@ -790,10 +790,9 @@ class BufferType(BaseType):
def __str__(self): def __str__(self):
# avoid ', ', as fused functions split the signature string on ', ' # avoid ', ', as fused functions split the signature string on ', '
cast_str = ''
if self.cast: if self.cast:
cast_str = ',cast=True' cast_str = ',cast=True'
else:
cast_str = ''
return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim, return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim,
cast_str) cast_str)
...@@ -2661,6 +2660,7 @@ def specialize_entry(entry, cname): ...@@ -2661,6 +2660,7 @@ def specialize_entry(entry, cname):
""" """
Specialize an entry of a copied fused function or method Specialize an entry of a copied fused function or method
""" """
entry.is_fused_specialized = True
entry.name = get_fused_cname(cname, entry.name) entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod: if entry.is_cmethod:
......
...@@ -120,6 +120,8 @@ class Entry(object): ...@@ -120,6 +120,8 @@ class Entry(object):
# error_on_uninitialized Have Control Flow issue an error when this entry is # error_on_uninitialized Have Control Flow issue an error when this entry is
# used uninitialized # used uninitialized
# cf_used boolean Entry is used # cf_used boolean Entry is used
# is_fused_specialized boolean Whether this entry of a cdef or def function
# is a specialization
# TODO: utility_code and utility_code_definition serves the same purpose... # TODO: utility_code and utility_code_definition serves the same purpose...
...@@ -179,6 +181,7 @@ class Entry(object): ...@@ -179,6 +181,7 @@ class Entry(object):
prev_entry = None prev_entry = None
might_overflow = 0 might_overflow = 0
fused_cfunction = None fused_cfunction = None
is_fused_specialized = False
utility_code_definition = None utility_code_definition = None
in_with_gil_block = 0 in_with_gil_block = 0
from_cython_utility_code = None from_cython_utility_code = None
...@@ -360,11 +363,11 @@ class Scope(object): ...@@ -360,11 +363,11 @@ class Scope(object):
# Return the module-level scope containing this scope. # Return the module-level scope containing this scope.
return self.outer_scope.builtin_scope() return self.outer_scope.builtin_scope()
def declare(self, name, cname, type, pos, visibility, shadow = 0): def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
# declared. # declared.
if type.is_buffer and not isinstance(self, LocalScope): if type.is_buffer and not isinstance(self, LocalScope): # and not is_type:
error(pos, 'Buffer types only allowed as function local variables') error(pos, 'Buffer types only allowed as function local variables')
if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname): if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
# See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names # See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
...@@ -415,7 +418,8 @@ class Scope(object): ...@@ -415,7 +418,8 @@ class Scope(object):
# Add an entry for a type definition. # Add an entry for a type definition.
if not cname: if not cname:
cname = name cname = name
entry = self.declare(name, cname, type, pos, visibility, shadow) entry = self.declare(name, cname, type, pos, visibility, shadow,
is_type=True)
entry.is_type = 1 entry.is_type = 1
entry.api = api entry.api = api
if defining: if defining:
......
...@@ -231,6 +231,12 @@ class TreeFragment(object): ...@@ -231,6 +231,12 @@ class TreeFragment(object):
substitutions = nodes, substitutions = nodes,
temps = self.temps + temps, pos = pos) temps = self.temps + temps, pos = pos)
class SetPosTransform(VisitorTransform):
def __init__(self, pos):
super(SetPosTransform, self).__init__()
self.pos = pos
def visit_Node(self, node):
node.pos = self.pos
self.visitchildren(node)
return node
...@@ -167,3 +167,11 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -167,3 +167,11 @@ class CythonUtilityCode(Code.UtilityCodeBase):
dep.declare_in_scope(dest_scope) dep.declare_in_scope(dest_scope)
return original_scope return original_scope
def declare_declarations_in_scope(declaration_string, env, private_type=True,
*args, **kwargs):
"""
Declare some declarations given as Cython code in declaration_string
in scope env.
"""
CythonUtilityCode(declaration_string, *args, **kwargs).declare_in_scope(env)
...@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw) ...@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw)
int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD && int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD &&
!((__pyx_FusedFunctionObject *) func)->__signatures__); !((__pyx_FusedFunctionObject *) func)->__signatures__);
//PyObject_Print(args, stdout, Py_PRINT_RAW);
if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) { if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) {
Py_ssize_t argc; Py_ssize_t argc;
PyObject *new_args; PyObject *new_args;
...@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw) ...@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw)
} }
if (binding_func->__signatures__) { if (binding_func->__signatures__) {
PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args, PyObject *tup = PyTuple_Pack(4, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw); kw == NULL ? Py_None : kw,
binding_func->func.defaults_tuple);
if (!tup) if (!tup)
goto __pyx_err; goto __pyx_err;
......
...@@ -978,7 +978,7 @@ cdef memoryview_fromslice({{memviewslice_name}} *memviewslice, ...@@ -978,7 +978,7 @@ cdef memoryview_fromslice({{memviewslice_name}} *memviewslice,
result.from_slice = memviewslice[0] result.from_slice = memviewslice[0]
__PYX_INC_MEMVIEW(memviewslice, 1) __PYX_INC_MEMVIEW(memviewslice, 1)
result.from_object = <object> memviewslice.memview.obj result.from_object = (<memoryview> memviewslice.memview).base
result.typeinfo = memviewslice.memview.typeinfo result.typeinfo = memviewslice.memview.typeinfo
result.view = memviewslice.memview.view result.view = memviewslice.memview.view
......
...@@ -374,6 +374,18 @@ class PyxArgs(object): ...@@ -374,6 +374,18 @@ class PyxArgs(object):
##pyxargs=None ##pyxargs=None
def _have_importers():
has_py_importer = False
has_pyx_importer = False
for importer in sys.meta_path:
if isinstance(importer, PyxImporter):
if isinstance(importer, PyImporter):
has_py_importer = True
else:
has_pyx_importer = True
return has_py_importer, has_pyx_importer
def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True, def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True,
setup_args={}, reload_support=False, setup_args={}, reload_support=False,
load_py_module_on_import_failure=False): load_py_module_on_import_failure=False):
...@@ -426,23 +438,32 @@ def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True, ...@@ -426,23 +438,32 @@ def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True,
pyxargs.reload_support = reload_support pyxargs.reload_support = reload_support
pyxargs.load_py_module_on_import_failure = load_py_module_on_import_failure pyxargs.load_py_module_on_import_failure = load_py_module_on_import_failure
has_py_importer = False has_py_importer, has_pyx_importer = _have_importers()
has_pyx_importer = False py_importer, pyx_importer = None, None
for importer in sys.meta_path:
if isinstance(importer, PyxImporter):
if isinstance(importer, PyImporter):
has_py_importer = True
else:
has_pyx_importer = True
if pyimport and not has_py_importer: if pyimport and not has_py_importer:
importer = PyImporter(pyxbuild_dir=build_dir) py_importer = PyImporter(pyxbuild_dir=build_dir)
sys.meta_path.insert(0, importer) sys.meta_path.insert(0, py_importer)
if pyximport and not has_pyx_importer: if pyximport and not has_pyx_importer:
importer = PyxImporter(pyxbuild_dir=build_dir) pyx_importer = PyxImporter(pyxbuild_dir=build_dir)
sys.meta_path.append(importer) sys.meta_path.append(pyx_importer)
return py_importer, pyx_importer
def uninstall(py_importer, pyx_importer):
"""
Uninstall an import hook.
"""
try:
sys.meta_path.remove(py_importer)
except ValueError:
pass
try:
sys.meta_path.remove(pyx_importer)
except ValueError:
pass
# MAIN # MAIN
......
...@@ -1251,6 +1251,8 @@ def refactor_for_py3(distdir, cy3_dir): ...@@ -1251,6 +1251,8 @@ def refactor_for_py3(distdir, cy3_dir):
recursive-include Cython *.py *.pyx *.pxd recursive-include Cython *.py *.pyx *.pxd
recursive-include Cython/Debugger/Tests * recursive-include Cython/Debugger/Tests *
recursive-include Cython/Utility * recursive-include Cython/Utility *
recursive-exclude pyximport test
include pyximport/*.py
include runtests.py include runtests.py
include cython.py include cython.py
''') ''')
......
...@@ -112,6 +112,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan ...@@ -112,6 +112,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan
"Cython.Compiler.FlowControl", "Cython.Compiler.FlowControl",
"Cython.Compiler.Code", "Cython.Compiler.Code",
"Cython.Runtime.refnanny", "Cython.Runtime.refnanny",
# "Cython.Compiler.FusedNode",
] ]
if compile_more: if compile_more:
compiled_modules.extend([ compiled_modules.extend([
......
...@@ -107,3 +107,27 @@ def test_defaults_fused(cython.floating arg1, cython.floating arg2 = counter2()) ...@@ -107,3 +107,27 @@ def test_defaults_fused(cython.floating arg1, cython.floating arg2 = counter2())
(2.0,) (2.0,)
""" """
print arg1, arg2 print arg1, arg2
funcs = []
for i in range(10):
def defaults_fused(cython.floating a, cython.floating b = i):
return a, b
funcs.append(defaults_fused)
def test_dynamic_defaults_fused():
"""
>>> test_dynamic_defaults_fused()
i 0 func result (1.0, 0.0) defaults (0,)
i 1 func result (1.0, 1.0) defaults (1,)
i 2 func result (1.0, 2.0) defaults (2,)
i 3 func result (1.0, 3.0) defaults (3,)
i 4 func result (1.0, 4.0) defaults (4,)
i 5 func result (1.0, 5.0) defaults (5,)
i 6 func result (1.0, 6.0) defaults (6,)
i 7 func result (1.0, 7.0) defaults (7,)
i 8 func result (1.0, 8.0) defaults (8,)
i 9 func result (1.0, 9.0) defaults (9,)
"""
for i, f in enumerate(funcs):
print "i", i, "func result", f(1.0), "defaults", get_defaults(f)
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment