Commit 8bc85858 authored by Giampaolo Rodola's avatar Giampaolo Rodola

provide a common method to check for RETR_DATA validity, first checking the...

provide a common method to check for RETR_DATA validity, first checking the expected len and then the actual data content; this way we get a failure on len mismatch rather than content mismatch (which is very long and unreadable)
parent 0c5e52f0
...@@ -461,6 +461,10 @@ class TestFTPClass(TestCase): ...@@ -461,6 +461,10 @@ class TestFTPClass(TestCase):
self.client.close() self.client.close()
self.server.stop() self.server.stop()
def check_data(self, received, expected):
self.assertEqual(len(received), len(expected))
self.assertEqual(received, expected)
def test_getwelcome(self): def test_getwelcome(self):
self.assertEqual(self.client.getwelcome(), '220 welcome') self.assertEqual(self.client.getwelcome(), '220 welcome')
...@@ -542,7 +546,7 @@ class TestFTPClass(TestCase): ...@@ -542,7 +546,7 @@ class TestFTPClass(TestCase):
received.append(data.decode('ascii')) received.append(data.decode('ascii'))
received = [] received = []
self.client.retrbinary('retr', callback) self.client.retrbinary('retr', callback)
self.assertEqual(''.join(received), RETR_DATA) self.check_data(''.join(received), RETR_DATA)
def test_retrbinary_rest(self): def test_retrbinary_rest(self):
def callback(data): def callback(data):
...@@ -550,20 +554,17 @@ class TestFTPClass(TestCase): ...@@ -550,20 +554,17 @@ class TestFTPClass(TestCase):
for rest in (0, 10, 20): for rest in (0, 10, 20):
received = [] received = []
self.client.retrbinary('retr', callback, rest=rest) self.client.retrbinary('retr', callback, rest=rest)
self.assertEqual(''.join(received), RETR_DATA[rest:], self.check_data(''.join(received), RETR_DATA[rest:])
msg='rest test case %d %d %d' % (rest,
len(''.join(received)),
len(RETR_DATA[rest:])))
def test_retrlines(self): def test_retrlines(self):
received = [] received = []
self.client.retrlines('retr', received.append) self.client.retrlines('retr', received.append)
self.assertEqual(''.join(received), RETR_DATA.replace('\r\n', '')) self.check_data(''.join(received), RETR_DATA.replace('\r\n', ''))
def test_storbinary(self): def test_storbinary(self):
f = io.BytesIO(RETR_DATA.encode('ascii')) f = io.BytesIO(RETR_DATA.encode('ascii'))
self.client.storbinary('stor', f) self.client.storbinary('stor', f)
self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) self.check_data(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg # test new callback arg
flag = [] flag = []
f.seek(0) f.seek(0)
...@@ -580,7 +581,7 @@ class TestFTPClass(TestCase): ...@@ -580,7 +581,7 @@ class TestFTPClass(TestCase):
def test_storlines(self): def test_storlines(self):
f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii'))
self.client.storlines('stor', f) self.client.storlines('stor', f)
self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) self.check_data(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg # test new callback arg
flag = [] flag = []
f.seek(0) f.seek(0)
...@@ -781,6 +782,7 @@ class TestIPv6Environment(TestCase): ...@@ -781,6 +782,7 @@ class TestIPv6Environment(TestCase):
received.append(data.decode('ascii')) received.append(data.decode('ascii'))
received = [] received = []
self.client.retrbinary('retr', callback) self.client.retrbinary('retr', callback)
self.assertEqual(len(''.join(received)), len(RETR_DATA))
self.assertEqual(''.join(received), RETR_DATA) self.assertEqual(''.join(received), RETR_DATA)
self.client.set_pasv(True) self.client.set_pasv(True)
retr() retr()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment