Commit 02218626 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Replace filename strings with more generic source descriptors.

This facilitates using the parser and compiler with runtime sources (such as
strings), while still being able to provide context for error messages/C debugging comments.
parent a927b5a7
...@@ -8,6 +8,7 @@ import Options ...@@ -8,6 +8,7 @@ import Options
from Cython.Utils import open_new_file, open_source_file from Cython.Utils import open_new_file, open_source_file
from PyrexTypes import py_object_type, typecast from PyrexTypes import py_object_type, typecast
from TypeSlots import method_coexist from TypeSlots import method_coexist
from Scanning import SourceDescriptor
class CCodeWriter: class CCodeWriter:
# f file output file # f file output file
...@@ -89,21 +90,22 @@ class CCodeWriter: ...@@ -89,21 +90,22 @@ class CCodeWriter:
def get_py_version_hex(self, pyversion): def get_py_version_hex(self, pyversion):
return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4] return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
def file_contents(self, file): def file_contents(self, source_desc):
try: try:
return self.input_file_contents[file] return self.input_file_contents[source_desc]
except KeyError: except KeyError:
F = [line.encode('ASCII', 'replace').replace( F = [line.encode('ASCII', 'replace').replace(
'*/', '*[inserted by cython to avoid comment closer]/') '*/', '*[inserted by cython to avoid comment closer]/')
for line in open_source_file(file)] for line in source_desc.get_lines(decode=True)]
self.input_file_contents[file] = F self.input_file_contents[source_desc] = F
return F return F
def mark_pos(self, pos): def mark_pos(self, pos):
if pos is None: if pos is None:
return return
filename, line, col = pos source_desc, line, col = pos
contents = self.file_contents(filename) assert isinstance(source_desc, SourceDescriptor)
contents = self.file_contents(source_desc)
context = '' context = ''
for i in range(max(0,line-3), min(line+2, len(contents))): for i in range(max(0,line-3), min(line+2, len(contents))):
...@@ -112,7 +114,7 @@ class CCodeWriter: ...@@ -112,7 +114,7 @@ class CCodeWriter:
s = s.rstrip() + ' # <<<<<<<<<<<<<< ' + '\n' s = s.rstrip() + ' # <<<<<<<<<<<<<< ' + '\n'
context += " * " + s context += " * " + s
marker = '"%s":%d\n%s' % (filename.encode('ASCII', 'replace'), line, context) marker = '"%s":%d\n%s' % (str(source_desc).encode('ASCII', 'replace'), line, context)
if self.last_marker != marker: if self.last_marker != marker:
self.marker = marker self.marker = marker
......
...@@ -12,13 +12,17 @@ class PyrexError(Exception): ...@@ -12,13 +12,17 @@ class PyrexError(Exception):
class PyrexWarning(Exception): class PyrexWarning(Exception):
pass pass
def context(position): def context(position):
F = open(position[0]).readlines() source = position[0]
s = ''.join(F[position[1]-6:position[1]]) assert not (isinstance(source, unicode) or isinstance(source, str)), (
"Please replace filename strings with Scanning.FileSourceDescriptor instances %r" % source)
F = list(source.get_lines())
s = ''.join(F[min(0, position[1]-6):position[1]])
s += ' '*(position[2]-1) + '^' s += ' '*(position[2]-1) + '^'
s = '-'*60 + '\n...\n' + s + '\n' + '-'*60 + '\n' s = '-'*60 + '\n...\n' + s + '\n' + '-'*60 + '\n'
return s return s
class CompileError(PyrexError): class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
......
...@@ -9,7 +9,7 @@ if sys.version_info[:2] < (2, 2): ...@@ -9,7 +9,7 @@ if sys.version_info[:2] < (2, 2):
from time import time from time import time
import Version import Version
from Scanning import PyrexScanner from Scanning import PyrexScanner, FileSourceDescriptor
import Errors import Errors
from Errors import PyrexError, CompileError, error from Errors import PyrexError, CompileError, error
import Parsing import Parsing
...@@ -85,7 +85,8 @@ class Context: ...@@ -85,7 +85,8 @@ class Context:
try: try:
if debug_find_module: if debug_find_module:
print("Context.find_module: Parsing %s" % pxd_pathname) print("Context.find_module: Parsing %s" % pxd_pathname)
pxd_tree = self.parse(pxd_pathname, scope.type_names, pxd = 1, source_desc = FileSourceDescriptor(pxd_pathname)
pxd_tree = self.parse(source_desc, scope.type_names, pxd = 1,
full_module_name = module_name) full_module_name = module_name)
pxd_tree.analyse_declarations(scope) pxd_tree.analyse_declarations(scope)
except CompileError: except CompileError:
...@@ -116,7 +117,10 @@ class Context: ...@@ -116,7 +117,10 @@ class Context:
# None if not found, but does not report an error. # None if not found, but does not report an error.
dirs = self.include_directories dirs = self.include_directories
if pos: if pos:
here_dir = os.path.dirname(pos[0]) file_desc = pos[0]
if not isinstance(file_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
here_dir = os.path.dirname(file_desc.filename)
dirs = [here_dir] + dirs dirs = [here_dir] + dirs
for dir in dirs: for dir in dirs:
path = os.path.join(dir, filename) path = os.path.join(dir, filename)
...@@ -137,19 +141,21 @@ class Context: ...@@ -137,19 +141,21 @@ class Context:
self.modules[name] = scope self.modules[name] = scope
return scope return scope
def parse(self, source_filename, type_names, pxd, full_module_name): def parse(self, source_desc, type_names, pxd, full_module_name):
name = Utils.encode_filename(source_filename) if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
source_filename = Utils.encode_filename(source_desc.filename)
# Parse the given source file and return a parse tree. # Parse the given source file and return a parse tree.
try: try:
f = Utils.open_source_file(source_filename, "rU") f = Utils.open_source_file(source_filename, "rU")
try: try:
s = PyrexScanner(f, name, source_encoding = f.encoding, s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
type_names = type_names, context = self) type_names = type_names, context = self)
tree = Parsing.p_module(s, pxd, full_module_name) tree = Parsing.p_module(s, pxd, full_module_name)
finally: finally:
f.close() f.close()
except UnicodeDecodeError, msg: except UnicodeDecodeError, msg:
error((name, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg) error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
if Errors.num_errors > 0: if Errors.num_errors > 0:
raise CompileError raise CompileError
return tree return tree
...@@ -197,6 +203,7 @@ class Context: ...@@ -197,6 +203,7 @@ class Context:
except EnvironmentError: except EnvironmentError:
pass pass
module_name = full_module_name # self.extract_module_name(source, options) module_name = full_module_name # self.extract_module_name(source, options)
source = FileSourceDescriptor(source)
initial_pos = (source, 1, 0) initial_pos = (source, 1, 0)
scope = self.find_module(module_name, pos = initial_pos, need_pxd = 0) scope = self.find_module(module_name, pos = initial_pos, need_pxd = 0)
errors_occurred = False errors_occurred = False
...@@ -339,6 +346,8 @@ def main(command_line = 0): ...@@ -339,6 +346,8 @@ def main(command_line = 0):
if any_failures: if any_failures:
sys.exit(1) sys.exit(1)
#------------------------------------------------------------------------ #------------------------------------------------------------------------
# #
# Set the default options depending on the platform # Set the default options depending on the platform
......
...@@ -427,8 +427,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -427,8 +427,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("") code.putln("")
code.putln("static char *%s[] = {" % Naming.filenames_cname) code.putln("static char *%s[] = {" % Naming.filenames_cname)
if code.filename_list: if code.filename_list:
for filename in code.filename_list: for source_desc in code.filename_list:
filename = os.path.basename(filename) filename = os.path.basename(str(source_desc))
escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"') escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"')
code.putln('"%s",' % code.putln('"%s",' %
escaped_filename) escaped_filename)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os, re import os, re
from string import join, replace from string import join, replace
from types import ListType, TupleType from types import ListType, TupleType
from Scanning import PyrexScanner from Scanning import PyrexScanner, FileSourceDescriptor
import Nodes import Nodes
import ExprNodes import ExprNodes
from ModuleNode import ModuleNode from ModuleNode import ModuleNode
...@@ -1182,7 +1182,8 @@ def p_include_statement(s, level): ...@@ -1182,7 +1182,8 @@ def p_include_statement(s, level):
include_file_path = s.context.find_include_file(include_file_name, pos) include_file_path = s.context.find_include_file(include_file_name, pos)
if include_file_path: if include_file_path:
f = Utils.open_source_file(include_file_path, mode="rU") f = Utils.open_source_file(include_file_path, mode="rU")
s2 = PyrexScanner(f, include_file_path, s, source_encoding=f.encoding) source_desc = FileSourceDescriptor(include_file_path)
s2 = PyrexScanner(f, source_desc, s, source_encoding=f.encoding)
try: try:
tree = p_statement_list(s2, level) tree = p_statement_list(s2, level)
finally: finally:
......
...@@ -17,6 +17,8 @@ from Cython.Plex.Errors import UnrecognizedInput ...@@ -17,6 +17,8 @@ from Cython.Plex.Errors import UnrecognizedInput
from Errors import CompileError, error from Errors import CompileError, error
from Lexicon import string_prefixes, make_lexicon from Lexicon import string_prefixes, make_lexicon
from Cython import Utils
plex_version = getattr(Plex, '_version', None) plex_version = getattr(Plex, '_version', None)
#print "Plex version:", plex_version ### #print "Plex version:", plex_version ###
...@@ -203,6 +205,57 @@ def initial_compile_time_env(): ...@@ -203,6 +205,57 @@ def initial_compile_time_env():
#------------------------------------------------------------------ #------------------------------------------------------------------
class SourceDescriptor:
pass
class FileSourceDescriptor(SourceDescriptor):
"""
Represents a code source. A code source is a more generic abstraction
for a "filename" (as sometimes the code doesn't come from a file).
Instances of code sources are passed to Scanner.__init__ as the
optional name argument and will be passed back when asking for
the position()-tuple.
"""
def __init__(self, filename):
self.filename = filename
def get_lines(self, decode=False):
# decode is True when called from Code.py (which reserializes in a standard way to ASCII),
# while decode is False when called from Errors.py.
#
# Note that if changing Errors.py in this respect, raising errors over wrong encoding
# will no longer be able to produce the line where the encoding problem occurs ...
if decode:
return Utils.open_source_file(self.filename)
else:
return open(self.filename)
def __str__(self):
return self.filename
def __repr__(self):
return "<FileSourceDescriptor:%s>" % self
class StringSourceDescriptor(SourceDescriptor):
"""
Instances of this class can be used instead of a filenames if the
code originates from a string object.
"""
def __init__(self, name, code):
self.name = name
self.codelines = [x + "\n" for x in code.split("\n")]
def get_lines(self, decode=False):
return self.codelines
def __str__(self):
return self.name
def __repr__(self):
return "<StringSourceDescriptor:%s>" % self
#------------------------------------------------------------------
class PyrexScanner(Scanner): class PyrexScanner(Scanner):
# context Context Compilation context # context Context Compilation context
# type_names set Identifiers to be treated as type names # type_names set Identifiers to be treated as type names
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment