Commit 9e9b474a authored by Stefan Behnel's avatar Stefan Behnel

clean up and simplify code generation code for fused types

parent 60618f0a
...@@ -260,22 +260,17 @@ class FusedCFuncDefNode(StatListNode): ...@@ -260,22 +260,17 @@ class FusedCFuncDefNode(StatListNode):
Genereate Cython code for instance checks, matching an object to Genereate Cython code for instance checks, matching an object to
specialized types. specialized types.
""" """
if_ = 'if'
for specialized_type in normal_types: for specialized_type in normal_types:
# all_numeric = all_numeric and specialized_type.is_numeric # all_numeric = all_numeric and specialized_type.is_numeric
py_type_name = specialized_type.py_type_name() pyx_code.context.update(
specialized_type_name = specialized_type.specialization_string py_type_name=specialized_type.py_type_name(),
pyx_code.context.update(locals()) specialized_type_name=specialized_type.specialization_string,
)
pyx_code.put_chunk( pyx_code.put_chunk(
u""" u"""
{{if_}} isinstance(arg, {{py_type_name}}): elif isinstance(arg, {{py_type_name}}):
dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}' dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
""") """)
if_ = 'elif'
if not normal_types:
# we need an 'if' to match the following 'else'
pyx_code.putln("if 0: pass")
def _dtype_name(self, dtype): def _dtype_name(self, dtype):
if dtype.is_typedef: if dtype.is_typedef:
...@@ -398,8 +393,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -398,8 +393,6 @@ class FusedCFuncDefNode(StatListNode):
to each specialization, which obtains the buffer each time and tries to each specialization, which obtains the buffer each time and tries
to match the format string. to match the format string.
""" """
if buffer_types:
if pyx_code.indenter(u"else:"):
# The first thing to find a match in this loop breaks out of the loop # The first thing to find a match in this loop breaks out of the loop
if pyx_code.indenter(u"while 1:"): if pyx_code.indenter(u"while 1:"):
pyx_code.put_chunk( pyx_code.put_chunk(
...@@ -435,10 +428,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -435,10 +428,6 @@ class FusedCFuncDefNode(StatListNode):
pyx_code.putln("break") pyx_code.putln("break")
pyx_code.dedent() pyx_code.dedent()
pyx_code.dedent()
else:
pyx_code.putln("else: %s" % self.no_match)
def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types): def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
""" """
If we have any buffer specializations, write out some variable If we have any buffer specializations, write out some variable
...@@ -569,7 +558,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -569,7 +558,7 @@ class FusedCFuncDefNode(StatListNode):
# FIXME: use a typed signature - currently fails badly because # FIXME: use a typed signature - currently fails badly because
# default arguments inherit the types we specify here! # default arguments inherit the types we specify here!
dest_sig = [{{for _ in range(n_fused)}}None,{{endfor}}] dest_sig = [None] * {{n_fused}}
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
...@@ -606,10 +595,19 @@ class FusedCFuncDefNode(StatListNode): ...@@ -606,10 +595,19 @@ class FusedCFuncDefNode(StatListNode):
normal_types, buffer_types = self._split_fused_types(arg) normal_types, buffer_types = self._split_fused_types(arg)
self._unpack_argument(pyx_code) self._unpack_argument(pyx_code)
# we need an 'if' to allow the following elif/else branches
pyx_code.putln("if 0: pass")
if normal_types:
self._fused_instance_checks(normal_types, pyx_code, env) self._fused_instance_checks(normal_types, pyx_code, env)
if pyx_code.indenter("else:"):
if buffer_types:
self._buffer_checks(buffer_types, pyx_code, decl_code, env) self._buffer_checks(buffer_types, pyx_code, decl_code, env)
fused_index += 1 else:
pyx_code.putln(self.no_match)
pyx_code.dedent()
fused_index += 1
all_buffer_types.update(buffer_types) all_buffer_types.update(buffer_types)
if arg.default: if arg.default:
...@@ -646,7 +644,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -646,7 +644,9 @@ class FusedCFuncDefNode(StatListNode):
fragment_code = pyx_code.getvalue() fragment_code = pyx_code.getvalue()
# print decl_code.getvalue() # print decl_code.getvalue()
# print fragment_code # print fragment_code
fragment = TreeFragment.TreeFragment(fragment_code, level='module') from .Optimize import ConstantFolding
fragment = TreeFragment.TreeFragment(
fragment_code, level='module', pipeline=[ConstantFolding()])
ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root) ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
UtilityCode.declare_declarations_in_scope( UtilityCode.declare_declarations_in_scope(
decl_code.getvalue(), env.global_scope()) decl_code.getvalue(), env.global_scope())
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment