Commit 9e9b474a authored by Stefan Behnel's avatar Stefan Behnel

clean up and simplify code generation code for fused types

parent 60618f0a
......@@ -260,22 +260,17 @@ class FusedCFuncDefNode(StatListNode):
Genereate Cython code for instance checks, matching an object to
specialized types.
"""
if_ = 'if'
for specialized_type in normal_types:
# all_numeric = all_numeric and specialized_type.is_numeric
py_type_name = specialized_type.py_type_name()
specialized_type_name = specialized_type.specialization_string
pyx_code.context.update(locals())
pyx_code.context.update(
py_type_name=specialized_type.py_type_name(),
specialized_type_name=specialized_type.specialization_string,
)
pyx_code.put_chunk(
u"""
{{if_}} isinstance(arg, {{py_type_name}}):
elif isinstance(arg, {{py_type_name}}):
dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
""")
if_ = 'elif'
if not normal_types:
# we need an 'if' to match the following 'else'
pyx_code.putln("if 0: pass")
def _dtype_name(self, dtype):
if dtype.is_typedef:
......@@ -398,46 +393,40 @@ class FusedCFuncDefNode(StatListNode):
to each specialization, which obtains the buffer each time and tries
to match the format string.
"""
if buffer_types:
if pyx_code.indenter(u"else:"):
# The first thing to find a match in this loop breaks out of the loop
if pyx_code.indenter(u"while 1:"):
pyx_code.put_chunk(
u"""
if ndarray is not None:
if isinstance(arg, ndarray):
dtype = arg.dtype
elif __pyx_memoryview_check(arg):
arg_base = arg.base
if isinstance(arg_base, ndarray):
dtype = arg_base.dtype
else:
dtype = None
else:
dtype = None
itemsize = -1
if dtype is not None:
itemsize = dtype.itemsize
kind = ord(dtype.kind)
dtype_signed = kind == 'i'
""")
pyx_code.indent(2)
pyx_code.named_insertion_point("numpy_dtype_checks")
self._buffer_check_numpy_dtype(pyx_code, buffer_types)
pyx_code.dedent(2)
for specialized_type in buffer_types:
self._buffer_parse_format_string_check(
pyx_code, decl_code, specialized_type, env)
# The first thing to find a match in this loop breaks out of the loop
if pyx_code.indenter(u"while 1:"):
pyx_code.put_chunk(
u"""
if ndarray is not None:
if isinstance(arg, ndarray):
dtype = arg.dtype
elif __pyx_memoryview_check(arg):
arg_base = arg.base
if isinstance(arg_base, ndarray):
dtype = arg_base.dtype
else:
dtype = None
else:
dtype = None
itemsize = -1
if dtype is not None:
itemsize = dtype.itemsize
kind = ord(dtype.kind)
dtype_signed = kind == 'i'
""")
pyx_code.indent(2)
pyx_code.named_insertion_point("numpy_dtype_checks")
self._buffer_check_numpy_dtype(pyx_code, buffer_types)
pyx_code.dedent(2)
pyx_code.putln(self.no_match)
pyx_code.putln("break")
pyx_code.dedent()
for specialized_type in buffer_types:
self._buffer_parse_format_string_check(
pyx_code, decl_code, specialized_type, env)
pyx_code.dedent()
else:
pyx_code.putln("else: %s" % self.no_match)
pyx_code.putln(self.no_match)
pyx_code.putln("break")
pyx_code.dedent()
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
"""
......@@ -569,7 +558,7 @@ class FusedCFuncDefNode(StatListNode):
# FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here!
dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}]
dest_sig = [None] * {{n_fused}}
if kwargs is None:
kwargs = {}
......@@ -606,10 +595,19 @@ class FusedCFuncDefNode(StatListNode):
normal_types, buffer_types = self._split_fused_types(arg)
self._unpack_argument(pyx_code)
self._fused_instance_checks(normal_types, pyx_code, env)
self._buffer_checks(buffer_types, pyx_code, decl_code, env)
fused_index += 1
# we need an 'if' to allow the following elif/else branches
pyx_code.putln("if 0: pass")
if normal_types:
self._fused_instance_checks(normal_types, pyx_code, env)
if pyx_code.indenter("else:"):
if buffer_types:
self._buffer_checks(buffer_types, pyx_code, decl_code, env)
else:
pyx_code.putln(self.no_match)
pyx_code.dedent()
fused_index += 1
all_buffer_types.update(buffer_types)
if arg.default:
......@@ -646,7 +644,9 @@ class FusedCFuncDefNode(StatListNode):
fragment_code = pyx_code.getvalue()
# print decl_code.getvalue()
# print fragment_code
fragment = TreeFragment.TreeFragment(fragment_code, level='module')
from .Optimize import ConstantFolding
fragment = TreeFragment.TreeFragment(
fragment_code, level='module', pipeline=[ConstantFolding()])
ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
UtilityCode.declare_declarations_in_scope(
decl_code.getvalue(), env.global_scope())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment