Commit 80cabad1 authored by Raymond Hettinger's avatar Raymond Hettinger

Improve namedtuple's _cast() method with a docstring, new name, and error-checking.

parent e3c12c91
...@@ -393,7 +393,13 @@ Example:: ...@@ -393,7 +393,13 @@ Example::
def __new__(cls, x, y): def __new__(cls, x, y):
return tuple.__new__(cls, (x, y)) return tuple.__new__(cls, (x, y))
_cast = classmethod(tuple.__new__) @classmethod
def _make(cls, iterable):
'Make a new Point object from a sequence or iterable'
result = tuple.__new__(cls, iterable)
if len(result) != 2:
raise TypeError('Expected 2 arguments, got %d' % len(result))
return result
def __repr__(self): def __repr__(self):
return 'Point(x=%r, y=%r)' % self return 'Point(x=%r, y=%r)' % self
...@@ -404,7 +410,7 @@ Example:: ...@@ -404,7 +410,7 @@ Example::
def _replace(self, **kwds): def _replace(self, **kwds):
'Return a new Point object replacing specified fields with new values' 'Return a new Point object replacing specified fields with new values'
return self.__class__._cast(map(kwds.get, ('x', 'y'), self)) return self.__class__._make(map(kwds.get, ('x', 'y'), self))
x = property(itemgetter(0)) x = property(itemgetter(0))
y = property(itemgetter(1)) y = property(itemgetter(1))
...@@ -426,34 +432,28 @@ by the :mod:`csv` or :mod:`sqlite3` modules:: ...@@ -426,34 +432,28 @@ by the :mod:`csv` or :mod:`sqlite3` modules::
EmployeeRecord = namedtuple('EmployeeRecord', 'name, age, title, department, paygrade') EmployeeRecord = namedtuple('EmployeeRecord', 'name, age, title, department, paygrade')
import csv import csv
for emp in map(EmployeeRecord._cast, csv.reader(open("employees.csv", "rb"))): for emp in map(EmployeeRecord._make, csv.reader(open("employees.csv", "rb"))):
print emp.name, emp.title print emp.name, emp.title
import sqlite3 import sqlite3
conn = sqlite3.connect('/companydata') conn = sqlite3.connect('/companydata')
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute('SELECT name, age, title, department, paygrade FROM employees') cursor.execute('SELECT name, age, title, department, paygrade FROM employees')
for emp in map(EmployeeRecord._cast, cursor.fetchall()): for emp in map(EmployeeRecord._make, cursor.fetchall()):
print emp.name, emp.title print emp.name, emp.title
In addition to the methods inherited from tuples, named tuples support In addition to the methods inherited from tuples, named tuples support
three additonal methods and one attribute. three additional methods and one attribute.
.. method:: namedtuple._cast(iterable) .. method:: namedtuple._make(iterable)
Class method returning a new instance taking the positional arguments from the Class method that makes a new instance from an existing sequence or iterable.
*iterable*. Useful for casting existing sequences and iterables to named tuples.
This fast constructor does not check the length of the inputs. To achieve the
same effect with length checking, use the star-operator instead.
:: ::
>>> t = [11, 22] >>> t = [11, 22]
>>> Point._cast(t) # fast conversion >>> Point._make(t)
Point(x=11, y=22) Point(x=11, y=22)
>>> Point(*t) # slow conversion with length checking
Point(x=11, y=22)
.. method:: somenamedtuple._asdict() .. method:: somenamedtuple._asdict()
......
...@@ -54,6 +54,7 @@ def namedtuple(typename, field_names, verbose=False): ...@@ -54,6 +54,7 @@ def namedtuple(typename, field_names, verbose=False):
seen_names.add(name) seen_names.add(name)
# Create and fill-in the class template # Create and fill-in the class template
numfields = len(field_names)
argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
reprtxt = ', '.join('%s=%%r' % name for name in field_names) reprtxt = ', '.join('%s=%%r' % name for name in field_names)
dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names)) dicttxt = ', '.join('%r: t[%d]' % (name, pos) for pos, name in enumerate(field_names))
...@@ -63,7 +64,13 @@ def namedtuple(typename, field_names, verbose=False): ...@@ -63,7 +64,13 @@ def namedtuple(typename, field_names, verbose=False):
_fields = %(field_names)r \n _fields = %(field_names)r \n
def __new__(cls, %(argtxt)s): def __new__(cls, %(argtxt)s):
return tuple.__new__(cls, (%(argtxt)s)) \n return tuple.__new__(cls, (%(argtxt)s)) \n
_cast = classmethod(tuple.__new__) \n @classmethod
def _make(cls, iterable):
'Make a new %(typename)s object from a sequence or iterable'
result = tuple.__new__(cls, iterable)
if len(result) != %(numfields)d:
raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
return result \n
def __repr__(self): def __repr__(self):
return '%(typename)s(%(reprtxt)s)' %% self \n return '%(typename)s(%(reprtxt)s)' %% self \n
def _asdict(t): def _asdict(t):
...@@ -71,7 +78,7 @@ def namedtuple(typename, field_names, verbose=False): ...@@ -71,7 +78,7 @@ def namedtuple(typename, field_names, verbose=False):
return {%(dicttxt)s} \n return {%(dicttxt)s} \n
def _replace(self, **kwds): def _replace(self, **kwds):
'Return a new %(typename)s object replacing specified fields with new values' 'Return a new %(typename)s object replacing specified fields with new values'
return self.__class__._cast(map(kwds.get, %(field_names)r, self)) \n\n''' % locals() return self.__class__._make(map(kwds.get, %(field_names)r, self)) \n\n''' % locals()
for i, name in enumerate(field_names): for i, name in enumerate(field_names):
template += ' %s = property(itemgetter(%d))\n' % (name, i) template += ' %s = property(itemgetter(%d))\n' % (name, i)
if verbose: if verbose:
......
...@@ -32,6 +32,9 @@ class TestNamedTuple(unittest.TestCase): ...@@ -32,6 +32,9 @@ class TestNamedTuple(unittest.TestCase):
namedtuple('Point0', 'x1 y2') # Verify that numbers are allowed in names namedtuple('Point0', 'x1 y2') # Verify that numbers are allowed in names
namedtuple('_', 'a b c') # Test leading underscores in a typename namedtuple('_', 'a b c') # Test leading underscores in a typename
self.assertRaises(TypeError, Point._make, [11]) # catch too few args
self.assertRaises(TypeError, Point._make, [11, 22, 33]) # catch too many args
def test_instance(self): def test_instance(self):
Point = namedtuple('Point', 'x y') Point = namedtuple('Point', 'x y')
p = Point(11, 22) p = Point(11, 22)
...@@ -47,7 +50,7 @@ class TestNamedTuple(unittest.TestCase): ...@@ -47,7 +50,7 @@ class TestNamedTuple(unittest.TestCase):
self.assertEqual(repr(p), 'Point(x=11, y=22)') self.assertEqual(repr(p), 'Point(x=11, y=22)')
self.assert_('__dict__' not in dir(p)) # verify instance has no dict self.assert_('__dict__' not in dir(p)) # verify instance has no dict
self.assert_('__weakref__' not in dir(p)) self.assert_('__weakref__' not in dir(p))
self.assertEqual(p, Point._cast([11, 22])) # test _cast classmethod self.assertEqual(p, Point._make([11, 22])) # test _make classmethod
self.assertEqual(p._fields, ('x', 'y')) # test _fields attribute self.assertEqual(p._fields, ('x', 'y')) # test _fields attribute
self.assertEqual(p._replace(x=1), (1, 22)) # test _replace method self.assertEqual(p._replace(x=1), (1, 22)) # test _replace method
self.assertEqual(p._asdict(), dict(x=11, y=22)) # test _asdict method self.assertEqual(p._asdict(), dict(x=11, y=22)) # test _asdict method
...@@ -84,14 +87,14 @@ class TestNamedTuple(unittest.TestCase): ...@@ -84,14 +87,14 @@ class TestNamedTuple(unittest.TestCase):
def test_odd_sizes(self): def test_odd_sizes(self):
Zero = namedtuple('Zero', '') Zero = namedtuple('Zero', '')
self.assertEqual(Zero(), ()) self.assertEqual(Zero(), ())
self.assertEqual(Zero._cast([]), ()) self.assertEqual(Zero._make([]), ())
self.assertEqual(repr(Zero()), 'Zero()') self.assertEqual(repr(Zero()), 'Zero()')
self.assertEqual(Zero()._asdict(), {}) self.assertEqual(Zero()._asdict(), {})
self.assertEqual(Zero()._fields, ()) self.assertEqual(Zero()._fields, ())
Dot = namedtuple('Dot', 'd') Dot = namedtuple('Dot', 'd')
self.assertEqual(Dot(1), (1,)) self.assertEqual(Dot(1), (1,))
self.assertEqual(Dot._cast([1]), (1,)) self.assertEqual(Dot._make([1]), (1,))
self.assertEqual(Dot(1).d, 1) self.assertEqual(Dot(1).d, 1)
self.assertEqual(repr(Dot(1)), 'Dot(d=1)') self.assertEqual(repr(Dot(1)), 'Dot(d=1)')
self.assertEqual(Dot(1)._asdict(), {'d':1}) self.assertEqual(Dot(1)._asdict(), {'d':1})
...@@ -104,7 +107,7 @@ class TestNamedTuple(unittest.TestCase): ...@@ -104,7 +107,7 @@ class TestNamedTuple(unittest.TestCase):
Big = namedtuple('Big', names) Big = namedtuple('Big', names)
b = Big(*range(n)) b = Big(*range(n))
self.assertEqual(b, tuple(range(n))) self.assertEqual(b, tuple(range(n)))
self.assertEqual(Big._cast(range(n)), tuple(range(n))) self.assertEqual(Big._make(range(n)), tuple(range(n)))
for pos, name in enumerate(names): for pos, name in enumerate(names):
self.assertEqual(getattr(b, name), pos) self.assertEqual(getattr(b, name), pos)
repr(b) # make sure repr() doesn't blow-up repr(b) # make sure repr() doesn't blow-up
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment