Commit 671e6a53 authored by Robert Bradshaw's avatar Robert Bradshaw

Better unicode/str handling for user-supplied code.

parent 33aca7ce
...@@ -17,6 +17,16 @@ from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclaration ...@@ -17,6 +17,16 @@ from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclaration
from Cython.Compiler.TreeFragment import parse_from_strings from Cython.Compiler.TreeFragment import parse_from_strings
from Cython.Build.Dependencies import strip_string_literals, cythonize from Cython.Build.Dependencies import strip_string_literals, cythonize
# A utility function to convert user-supplied ASCII strings to unicode.
if sys.version_info[0] < 3:
def to_unicode(s):
if not isinstance(s, unicode):
return s.decode('ascii')
else:
return s
else:
to_unicode = lambda x: x
_code_cache = {} _code_cache = {}
...@@ -28,11 +38,10 @@ class AllSymbols(CythonTransform, SkipDeclarations): ...@@ -28,11 +38,10 @@ class AllSymbols(CythonTransform, SkipDeclarations):
self.names.add(node.name) self.names.add(node.name)
def unbound_symbols(code, context=None): def unbound_symbols(code, context=None):
code = to_unicode(code)
if context is None: if context is None:
context = Context([], default_options) context = Context([], default_options)
from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
if isinstance(code, str):
code = code.decode('ascii')
tree = parse_from_strings('(tree fragment)', code) tree = parse_from_strings('(tree fragment)', code)
for phase in context.create_pipeline(pxd=False): for phase in context.create_pipeline(pxd=False):
if phase is None: if phase is None:
...@@ -90,6 +99,7 @@ def cython_inline(code, ...@@ -90,6 +99,7 @@ def cython_inline(code,
**kwds): **kwds):
if get_type is None: if get_type is None:
get_type = lambda x: 'object' get_type = lambda x: 'object'
code = to_unicode(code)
code, literals = strip_string_literals(code) code, literals = strip_string_literals(code)
code = strip_common_indent(code) code = strip_common_indent(code)
ctx = Context(cython_include_dirs, default_options) ctx = Context(cython_include_dirs, default_options)
......
...@@ -12,7 +12,7 @@ test_kwds = dict(force=True, quiet=True) ...@@ -12,7 +12,7 @@ test_kwds = dict(force=True, quiet=True)
global_value = 100 global_value = 100
class TestStripLiterals(CythonTest): class TestInline(CythonTest):
def test_simple(self): def test_simple(self):
self.assertEquals(inline("return 1+2", **test_kwds), 3) self.assertEquals(inline("return 1+2", **test_kwds), 3)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment