Commit c0b49eab authored by Robert Bradshaw's avatar Robert Bradshaw

Support for [] operator.

parent 15903d3b
...@@ -1914,11 +1914,6 @@ class IndexNode(ExprNode): ...@@ -1914,11 +1914,6 @@ class IndexNode(ExprNode):
else: else:
if self.base.type.is_ptr or self.base.type.is_array: if self.base.type.is_ptr or self.base.type.is_array:
self.type = self.base.type.base_type self.type = self.base.type.base_type
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
self.base.type)
self.type = PyrexTypes.error_type
if self.index.type.is_pyobject: if self.index.type.is_pyobject:
self.index = self.index.coerce_to( self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env) PyrexTypes.c_py_ssize_t_type, env)
...@@ -1926,6 +1921,30 @@ class IndexNode(ExprNode): ...@@ -1926,6 +1921,30 @@ class IndexNode(ExprNode):
error(self.pos, error(self.pos,
"Invalid index type '%s'" % "Invalid index type '%s'" %
self.index.type) self.index.type)
elif self.base.type.is_cpp_class:
function = self.base.type.scope.lookup("operator[]")
if function is None:
error(self.pos, "Indexing '%s' not supported" % self.base.type)
else:
function = PyrexTypes.best_match([self.index], function.all_alternatives(), self.pos)
if function is None:
error(self.pos, "Invalid index type '%s'" % self.index.type)
if function is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type
if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference '%s'" % self.type)
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
self.base.type)
self.type = PyrexTypes.error_type
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
def nogil_check(self, env): def nogil_check(self, env):
...@@ -4738,6 +4757,7 @@ class NumBinopNode(BinopNode): ...@@ -4738,6 +4757,7 @@ class NumBinopNode(BinopNode):
if type2.is_ptr: if type2.is_ptr:
type2 = type2.base_type type2 = type2.base_type
entry = env.lookup(type1.name) entry = env.lookup(type1.name)
# Shouldn't this be type1.scope?
function = entry.type.scope.lookup("operator%s" % self.operator) function = entry.type.scope.lookup("operator%s" % self.operator)
if function is not None: if function is not None:
operands = [self.operand2] operands = [self.operand2]
......
...@@ -2061,6 +2061,7 @@ supported_overloaded_operators = set([ ...@@ -2061,6 +2061,7 @@ supported_overloaded_operators = set([
'+', '-', '*', '/', '%', '+', '-', '*', '/', '%',
'++', '--', '~', '|', '&', '^', '<<', '>>', '++', '--', '~', '|', '&', '^', '<<', '>>',
'==', '!=', '>=', '>', '<=', '<', '==', '!=', '>=', '>', '<=', '<',
'[]',
]) ])
def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
...@@ -2108,6 +2109,7 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, ...@@ -2108,6 +2109,7 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
cname = ctx.namespace + "::" + name cname = ctx.namespace + "::" + name
if name == 'operator' and ctx.visibility == 'extern': if name == 'operator' and ctx.visibility == 'extern':
op = s.sy op = s.sy
s.next()
# Handle diphthong operators. # Handle diphthong operators.
if op == '(': if op == '(':
s.expect(')') s.expect(')')
...@@ -2115,7 +2117,6 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag, ...@@ -2115,7 +2117,6 @@ def p_c_simple_declarator(s, ctx, empty, is_type, cmethod_flag,
elif op == '[': elif op == '[':
s.expect(']') s.expect(']')
op = '[]' op = '[]'
s.next()
if op in ['-', '+', '|', '&'] and s.sy == op: if op in ['-', '+', '|', '&'] and s.sy == op:
op = op*2 op = op*2
s.next() s.next()
......
...@@ -3,6 +3,7 @@ cdef extern from "<vector>" namespace std: ...@@ -3,6 +3,7 @@ cdef extern from "<vector>" namespace std:
cdef cppclass vector[T]: cdef cppclass vector[T]:
void push_back(T) void push_back(T)
size_t size() size_t size()
T operator[](size_t)
def simple_test(double x): def simple_test(double x):
""" """
...@@ -37,3 +38,19 @@ def list_test(L): ...@@ -37,3 +38,19 @@ def list_test(L):
return len(L), v.size() return len(L), v.size()
finally: finally:
del v del v
def index_test(L):
"""
>>> index_test([1,2,4,8])
(1.0, 8.0)
>>> index_test([1.25])
(1.25, 1.25)
"""
cdef vector[double] *v
try:
v = new vector[double]()
for a in L:
v.push_back(a)
return v[0][0], v[0][len(L)-1]
finally:
del v
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment