Commit 033c5cdd authored by Mark's avatar Mark

Merge pull request #115 from markflorisson88/_fused_dispatch_rebased

fused runtime dispatch for buffers
parents 52138792 5e76c087
......@@ -14,6 +14,7 @@ import re
import sys
from string import Template
import operator
import textwrap
import Naming
import Options
......@@ -376,7 +377,7 @@ class UtilityCode(UtilityCodeBase):
self.cleanup(writer, output.module_pos)
def sub_tempita(s, context, file, name):
def sub_tempita(s, context, file=None, name=None):
"Run tempita on string s with given context."
if not s:
return None
......@@ -1940,6 +1941,75 @@ class PyrexCodeWriter(object):
def dedent(self):
self.level -= 1
class PyxCodeWriter(object):
"""
Can be used for writing out some Cython code. To use the indenter
functionality, the Cython.Compiler.Importer module will have to be used
to load the code to support python 2.4
"""
def __init__(self, buffer=None, indent_level=0, context=None, encoding='ascii'):
self.buffer = buffer or StringIOTree()
self.level = indent_level
self.context = context
self.encoding = encoding
def indent(self, levels=1):
self.level += levels
return True
def dedent(self, levels=1):
self.level -= levels
def indenter(self, line):
"""
Instead of
with pyx_code.indenter("for i in range(10):"):
pyx_code.putln("print i")
write
if pyx_code.indenter("for i in range(10);"):
pyx_code.putln("print i")
pyx_code.dedent()
"""
self.putln(line)
self.indent()
return True
def getvalue(self):
result = self.buffer.getvalue()
if not isinstance(result, unicode):
result = result.decode(self.encoding)
return result
def putln(self, line, context=None):
context = context or self.context
if context:
line = sub_tempita(line, context)
self._putln(line)
def _putln(self, line):
self.buffer.write("%s%s\n" % (self.level * " ", line))
def put_chunk(self, chunk, context=None):
context = context or self.context
if context:
chunk = sub_tempita(chunk, context)
chunk = textwrap.dedent(chunk)
for line in chunk.splitlines():
self._putln(line)
def insertion_point(self):
return PyxCodeWriter(self.buffer.insertion_point(), self.level,
self.context)
def named_insertion_point(self, name):
setattr(self, name, self.insertion_point())
class ClosureTempAllocator(object):
def __init__(self, klass):
......
......@@ -6222,7 +6222,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
env.use_utility_code(fused_function_utility_code)
else:
env.use_utility_code(binding_cfunc_utility_code)
self.analyse_default_args(env)
self.analyse_default_args(env)
#TODO(craig,haoyu) This should be moved to a better place
self.set_mod_name(env)
......@@ -6245,7 +6245,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
else:
arg.default = DefaultLiteralArgNode(arg.pos, arg.default)
default_args.append(arg)
if nonliteral_objects or nonliteral_objects:
if nonliteral_objects or nonliteral_other:
module_scope = env.global_scope()
cname = module_scope.next_id(Naming.defaults_struct_prefix)
scope = Symtab.StructOrUnionScope(cname)
......
This diff is collapsed.
This diff is collapsed.
......@@ -1488,7 +1488,8 @@ if VALUE is not None:
return node
node = Nodes.FusedCFuncDefNode(node, env)
from Cython.Compiler import FusedNode
node = FusedNode.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node)
......@@ -1498,6 +1499,7 @@ if VALUE is not None:
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
......
......@@ -790,10 +790,9 @@ class BufferType(BaseType):
def __str__(self):
# avoid ', ', as fused functions split the signature string on ', '
cast_str = ''
if self.cast:
cast_str = ',cast=True'
else:
cast_str = ''
return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim,
cast_str)
......@@ -2661,6 +2660,7 @@ def specialize_entry(entry, cname):
"""
Specialize an entry of a copied fused function or method
"""
entry.is_fused_specialized = True
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod:
......
......@@ -120,6 +120,8 @@ class Entry(object):
# error_on_uninitialized Have Control Flow issue an error when this entry is
# used uninitialized
# cf_used boolean Entry is used
# is_fused_specialized boolean Whether this entry of a cdef or def function
# is a specialization
# TODO: utility_code and utility_code_definition serves the same purpose...
......@@ -179,6 +181,7 @@ class Entry(object):
prev_entry = None
might_overflow = 0
fused_cfunction = None
is_fused_specialized = False
utility_code_definition = None
in_with_gil_block = 0
from_cython_utility_code = None
......@@ -360,11 +363,11 @@ class Scope(object):
# Return the module-level scope containing this scope.
return self.outer_scope.builtin_scope()
def declare(self, name, cname, type, pos, visibility, shadow = 0):
def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0):
# Create new entry, and add to dictionary if
# name is not None. Reports a warning if already
# declared.
if type.is_buffer and not isinstance(self, LocalScope):
if type.is_buffer and not isinstance(self, LocalScope): # and not is_type:
error(pos, 'Buffer types only allowed as function local variables')
if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
# See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
......@@ -415,7 +418,8 @@ class Scope(object):
# Add an entry for a type definition.
if not cname:
cname = name
entry = self.declare(name, cname, type, pos, visibility, shadow)
entry = self.declare(name, cname, type, pos, visibility, shadow,
is_type=True)
entry.is_type = 1
entry.api = api
if defining:
......
......@@ -231,6 +231,12 @@ class TreeFragment(object):
substitutions = nodes,
temps = self.temps + temps, pos = pos)
class SetPosTransform(VisitorTransform):
def __init__(self, pos):
super(SetPosTransform, self).__init__()
self.pos = pos
def visit_Node(self, node):
node.pos = self.pos
self.visitchildren(node)
return node
......@@ -167,3 +167,11 @@ class CythonUtilityCode(Code.UtilityCodeBase):
dep.declare_in_scope(dest_scope)
return original_scope
def declare_declarations_in_scope(declaration_string, env, private_type=True,
*args, **kwargs):
"""
Declare some declarations given as Cython code in declaration_string
in scope env.
"""
CythonUtilityCode(declaration_string, *args, **kwargs).declare_in_scope(env)
......@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw)
int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD &&
!((__pyx_FusedFunctionObject *) func)->__signatures__);
//PyObject_Print(args, stdout, Py_PRINT_RAW);
if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) {
Py_ssize_t argc;
PyObject *new_args;
......@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw)
}
if (binding_func->__signatures__) {
PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw);
PyObject *tup = PyTuple_Pack(4, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw,
binding_func->func.defaults_tuple);
if (!tup)
goto __pyx_err;
......
......@@ -978,7 +978,7 @@ cdef memoryview_fromslice({{memviewslice_name}} *memviewslice,
result.from_slice = memviewslice[0]
__PYX_INC_MEMVIEW(memviewslice, 1)
result.from_object = <object> memviewslice.memview.obj
result.from_object = (<memoryview> memviewslice.memview).base
result.typeinfo = memviewslice.memview.typeinfo
result.view = memviewslice.memview.view
......
......@@ -373,7 +373,19 @@ class PyxArgs(object):
setup_args={} #None
##pyxargs=None
def _have_importers():
has_py_importer = False
has_pyx_importer = False
for importer in sys.meta_path:
if isinstance(importer, PyxImporter):
if isinstance(importer, PyImporter):
has_py_importer = True
else:
has_pyx_importer = True
return has_py_importer, has_pyx_importer
def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True,
setup_args={}, reload_support=False,
load_py_module_on_import_failure=False):
......@@ -426,23 +438,32 @@ def install(pyximport=True, pyimport=False, build_dir=None, build_in_temp=True,
pyxargs.reload_support = reload_support
pyxargs.load_py_module_on_import_failure = load_py_module_on_import_failure
has_py_importer = False
has_pyx_importer = False
for importer in sys.meta_path:
if isinstance(importer, PyxImporter):
if isinstance(importer, PyImporter):
has_py_importer = True
else:
has_pyx_importer = True
has_py_importer, has_pyx_importer = _have_importers()
py_importer, pyx_importer = None, None
if pyimport and not has_py_importer:
importer = PyImporter(pyxbuild_dir=build_dir)
sys.meta_path.insert(0, importer)
py_importer = PyImporter(pyxbuild_dir=build_dir)
sys.meta_path.insert(0, py_importer)
if pyximport and not has_pyx_importer:
importer = PyxImporter(pyxbuild_dir=build_dir)
sys.meta_path.append(importer)
pyx_importer = PyxImporter(pyxbuild_dir=build_dir)
sys.meta_path.append(pyx_importer)
return py_importer, pyx_importer
def uninstall(py_importer, pyx_importer):
"""
Uninstall an import hook.
"""
try:
sys.meta_path.remove(py_importer)
except ValueError:
pass
try:
sys.meta_path.remove(pyx_importer)
except ValueError:
pass
# MAIN
......
......@@ -1251,6 +1251,8 @@ def refactor_for_py3(distdir, cy3_dir):
recursive-include Cython *.py *.pyx *.pxd
recursive-include Cython/Debugger/Tests *
recursive-include Cython/Utility *
recursive-exclude pyximport test
include pyximport/*.py
include runtests.py
include cython.py
''')
......
......@@ -112,6 +112,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan
"Cython.Compiler.FlowControl",
"Cython.Compiler.Code",
"Cython.Runtime.refnanny",
# "Cython.Compiler.FusedNode",
]
if compile_more:
compiled_modules.extend([
......
......@@ -107,3 +107,27 @@ def test_defaults_fused(cython.floating arg1, cython.floating arg2 = counter2())
(2.0,)
"""
print arg1, arg2
funcs = []
for i in range(10):
def defaults_fused(cython.floating a, cython.floating b = i):
return a, b
funcs.append(defaults_fused)
def test_dynamic_defaults_fused():
"""
>>> test_dynamic_defaults_fused()
i 0 func result (1.0, 0.0) defaults (0,)
i 1 func result (1.0, 1.0) defaults (1,)
i 2 func result (1.0, 2.0) defaults (2,)
i 3 func result (1.0, 3.0) defaults (3,)
i 4 func result (1.0, 4.0) defaults (4,)
i 5 func result (1.0, 5.0) defaults (5,)
i 6 func result (1.0, 6.0) defaults (6,)
i 7 func result (1.0, 7.0) defaults (7,)
i 8 func result (1.0, 8.0) defaults (8,)
i 9 func result (1.0, 9.0) defaults (9,)
"""
for i, f in enumerate(funcs):
print "i", i, "func result", f(1.0), "defaults", get_defaults(f)
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment