Commit 9e9b474a authored by Stefan Behnel's avatar Stefan Behnel

clean up and simplify code generation code for fused types

parent 60618f0a
...@@ -260,22 +260,17 @@ class FusedCFuncDefNode(StatListNode): ...@@ -260,22 +260,17 @@ class FusedCFuncDefNode(StatListNode):
Genereate Cython code for instance checks, matching an object to Genereate Cython code for instance checks, matching an object to
specialized types. specialized types.
""" """
if_ = 'if'
for specialized_type in normal_types: for specialized_type in normal_types:
# all_numeric = all_numeric and specialized_type.is_numeric # all_numeric = all_numeric and specialized_type.is_numeric
py_type_name = specialized_type.py_type_name() pyx_code.context.update(
specialized_type_name = specialized_type.specialization_string py_type_name=specialized_type.py_type_name(),
pyx_code.context.update(locals()) specialized_type_name=specialized_type.specialization_string,
)
pyx_code.put_chunk( pyx_code.put_chunk(
u""" u"""
{{if_}} isinstance(arg, {{py_type_name}}): elif isinstance(arg, {{py_type_name}}):
dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}' dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
""") """)
if_ = 'elif'
if not normal_types:
# we need an 'if' to match the following 'else'
pyx_code.putln("if 0: pass")
def _dtype_name(self, dtype): def _dtype_name(self, dtype):
if dtype.is_typedef: if dtype.is_typedef:
...@@ -398,46 +393,40 @@ class FusedCFuncDefNode(StatListNode): ...@@ -398,46 +393,40 @@ class FusedCFuncDefNode(StatListNode):
to each specialization, which obtains the buffer each time and tries to each specialization, which obtains the buffer each time and tries
to match the format string. to match the format string.
""" """
if buffer_types: # The first thing to find a match in this loop breaks out of the loop
if pyx_code.indenter(u"else:"): if pyx_code.indenter(u"while 1:"):
# The first thing to find a match in this loop breaks out of the loop pyx_code.put_chunk(
if pyx_code.indenter(u"while 1:"): u"""
pyx_code.put_chunk( if ndarray is not None:
u""" if isinstance(arg, ndarray):
if ndarray is not None: dtype = arg.dtype
if isinstance(arg, ndarray): elif __pyx_memoryview_check(arg):
dtype = arg.dtype arg_base = arg.base
elif __pyx_memoryview_check(arg): if isinstance(arg_base, ndarray):
arg_base = arg.base dtype = arg_base.dtype
if isinstance(arg_base, ndarray): else:
dtype = arg_base.dtype dtype = None
else: else:
dtype = None dtype = None
else:
dtype = None itemsize = -1
if dtype is not None:
itemsize = -1 itemsize = dtype.itemsize
if dtype is not None: kind = ord(dtype.kind)
itemsize = dtype.itemsize dtype_signed = kind == 'i'
kind = ord(dtype.kind) """)
dtype_signed = kind == 'i' pyx_code.indent(2)
""") pyx_code.named_insertion_point("numpy_dtype_checks")
pyx_code.indent(2) self._buffer_check_numpy_dtype(pyx_code, buffer_types)
pyx_code.named_insertion_point("numpy_dtype_checks") pyx_code.dedent(2)
self._buffer_check_numpy_dtype(pyx_code, buffer_types)
pyx_code.dedent(2)
for specialized_type in buffer_types:
self._buffer_parse_format_string_check(
pyx_code, decl_code, specialized_type, env)
pyx_code.putln(self.no_match) for specialized_type in buffer_types:
pyx_code.putln("break") self._buffer_parse_format_string_check(
pyx_code.dedent() pyx_code, decl_code, specialized_type, env)
pyx_code.dedent() pyx_code.putln(self.no_match)
else: pyx_code.putln("break")
pyx_code.putln("else: %s" % self.no_match) pyx_code.dedent()
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types): def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
""" """
...@@ -569,7 +558,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -569,7 +558,7 @@ class FusedCFuncDefNode(StatListNode):
# FIXME: use a typed signature - currently fails badly because # FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here! # default arguments inherit the types we specify here!
dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}] dest_sig = [None] * {{n_fused}}
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
...@@ -606,10 +595,19 @@ class FusedCFuncDefNode(StatListNode): ...@@ -606,10 +595,19 @@ class FusedCFuncDefNode(StatListNode):
normal_types, buffer_types = self._split_fused_types(arg) normal_types, buffer_types = self._split_fused_types(arg)
self._unpack_argument(pyx_code) self._unpack_argument(pyx_code)
self._fused_instance_checks(normal_types, pyx_code, env)
self._buffer_checks(buffer_types, pyx_code, decl_code, env)
fused_index += 1
# we need an 'if' to allow the following elif/else branches
pyx_code.putln("if 0: pass")
if normal_types:
self._fused_instance_checks(normal_types, pyx_code, env)
if pyx_code.indenter("else:"):
if buffer_types:
self._buffer_checks(buffer_types, pyx_code, decl_code, env)
else:
pyx_code.putln(self.no_match)
pyx_code.dedent()
fused_index += 1
all_buffer_types.update(buffer_types) all_buffer_types.update(buffer_types)
if arg.default: if arg.default:
...@@ -646,7 +644,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -646,7 +644,9 @@ class FusedCFuncDefNode(StatListNode):
fragment_code = pyx_code.getvalue() fragment_code = pyx_code.getvalue()
# print decl_code.getvalue() # print decl_code.getvalue()
# print fragment_code # print fragment_code
fragment = TreeFragment.TreeFragment(fragment_code, level='module') from .Optimize import ConstantFolding
fragment = TreeFragment.TreeFragment(
fragment_code, level='module', pipeline=[ConstantFolding()])
ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
UtilityCode.declare_declarations_in_scope( UtilityCode.declare_declarations_in_scope(
decl_code.getvalue(), env.global_scope()) decl_code.getvalue(), env.global_scope())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment