Commit d019bc83 authored by Oren Milman's avatar Oren Milman Committed by INADA Naoki

bpo-31787: Prevent refleaks when calling __init__() more than once (GH-3995)

parent aec7532e
...@@ -2373,6 +2373,20 @@ class CTask_CFuture_Tests(BaseTaskTests, SetMethodsTest, ...@@ -2373,6 +2373,20 @@ class CTask_CFuture_Tests(BaseTaskTests, SetMethodsTest,
Task = getattr(tasks, '_CTask', None) Task = getattr(tasks, '_CTask', None)
Future = getattr(futures, '_CFuture', None) Future = getattr(futures, '_CFuture', None)
@support.refcount_test
def test_refleaks_in_task___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
@asyncio.coroutine
def coro():
pass
task = self.new_task(self.loop, coro())
self.loop.run_until_complete(task)
refs_before = gettotalrefcount()
for i in range(100):
task.__init__(coro(), loop=self.loop)
self.loop.run_until_complete(task)
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
@unittest.skipUnless(hasattr(futures, '_CFuture') and @unittest.skipUnless(hasattr(futures, '_CFuture') and
hasattr(tasks, '_CTask'), hasattr(tasks, '_CTask'),
......
...@@ -13,6 +13,7 @@ import subprocess ...@@ -13,6 +13,7 @@ import subprocess
import threading import threading
from test.support import unlink from test.support import unlink
import _compression import _compression
import sys
# Skip tests if the bz2 module doesn't exist. # Skip tests if the bz2 module doesn't exist.
...@@ -816,6 +817,16 @@ class BZ2DecompressorTest(BaseTest): ...@@ -816,6 +817,16 @@ class BZ2DecompressorTest(BaseTest):
# Previously, a second call could crash due to internal inconsistency # Previously, a second call could crash due to internal inconsistency
self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30) self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30)
@support.refcount_test
def test_refleaks_in___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
bzd = BZ2Decompressor()
refs_before = gettotalrefcount()
for i in range(100):
bzd.__init__()
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
class CompressDecompressTest(BaseTest): class CompressDecompressTest(BaseTest):
def testCompress(self): def testCompress(self):
data = bz2.compress(self.TEXT) data = bz2.compress(self.TEXT)
......
...@@ -1559,6 +1559,15 @@ order (MRO) for bases """ ...@@ -1559,6 +1559,15 @@ order (MRO) for bases """
del cm.x del cm.x
self.assertNotHasAttr(cm, "x") self.assertNotHasAttr(cm, "x")
@support.refcount_test
def test_refleaks_in_classmethod___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
cm = classmethod(None)
refs_before = gettotalrefcount()
for i in range(100):
cm.__init__(None)
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
@support.impl_detail("the module 'xxsubtype' is internal") @support.impl_detail("the module 'xxsubtype' is internal")
def test_classmethods_in_c(self): def test_classmethods_in_c(self):
# Testing C-based class methods... # Testing C-based class methods...
...@@ -1614,6 +1623,15 @@ order (MRO) for bases """ ...@@ -1614,6 +1623,15 @@ order (MRO) for bases """
del sm.x del sm.x
self.assertNotHasAttr(sm, "x") self.assertNotHasAttr(sm, "x")
@support.refcount_test
def test_refleaks_in_staticmethod___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
sm = staticmethod(None)
refs_before = gettotalrefcount()
for i in range(100):
sm.__init__(None)
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
@support.impl_detail("the module 'xxsubtype' is internal") @support.impl_detail("the module 'xxsubtype' is internal")
def test_staticmethods_in_c(self): def test_staticmethods_in_c(self):
# Testing C-based static methods... # Testing C-based static methods...
......
...@@ -162,6 +162,15 @@ class HashLibTestCase(unittest.TestCase): ...@@ -162,6 +162,15 @@ class HashLibTestCase(unittest.TestCase):
constructors = self.constructors_to_test.values() constructors = self.constructors_to_test.values()
return itertools.chain.from_iterable(constructors) return itertools.chain.from_iterable(constructors)
@support.refcount_test
def test_refleaks_in_hash___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
sha1_hash = c_hashlib.new('sha1')
refs_before = gettotalrefcount()
for i in range(100):
sha1_hash.__init__('sha1')
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
def test_hash_array(self): def test_hash_array(self):
a = array.array("b", range(10)) a = array.array("b", range(10))
for cons in self.hash_constructors: for cons in self.hash_constructors:
......
...@@ -4,6 +4,8 @@ import os ...@@ -4,6 +4,8 @@ import os
import pathlib import pathlib
import pickle import pickle
import random import random
import sys
from test import support
import unittest import unittest
from test.support import ( from test.support import (
...@@ -364,6 +366,15 @@ class CompressorDecompressorTestCase(unittest.TestCase): ...@@ -364,6 +366,15 @@ class CompressorDecompressorTestCase(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
pickle.dumps(LZMADecompressor(), proto) pickle.dumps(LZMADecompressor(), proto)
@support.refcount_test
def test_refleaks_in_decompressor___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
lzd = LZMADecompressor()
refs_before = gettotalrefcount()
for i in range(100):
lzd.__init__()
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
class CompressDecompressFunctionTestCase(unittest.TestCase): class CompressDecompressFunctionTestCase(unittest.TestCase):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import sys import sys
import unittest import unittest
from test import support
class PropertyBase(Exception): class PropertyBase(Exception):
pass pass
...@@ -173,6 +174,16 @@ class PropertyTests(unittest.TestCase): ...@@ -173,6 +174,16 @@ class PropertyTests(unittest.TestCase):
sub.__class__.spam.__doc__ = 'Spam' sub.__class__.spam.__doc__ = 'Spam'
self.assertEqual(sub.__class__.spam.__doc__, 'Spam') self.assertEqual(sub.__class__.spam.__doc__, 'Spam')
@support.refcount_test
def test_refleaks_in___init__(self):
gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount')
fake_prop = property('fget', 'fset', 'fdel', 'doc')
refs_before = gettotalrefcount()
for i in range(100):
fake_prop.__init__('fget', 'fset', 'fdel', 'doc')
self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10)
# Issue 5890: subclasses of property do not preserve method __doc__ strings # Issue 5890: subclasses of property do not preserve method __doc__ strings
class PropertySub(property): class PropertySub(property):
"""This is a subclass of property""" """This is a subclass of property"""
......
Fixed refleaks of ``__init__()`` methods in various modules.
(Contributed by Oren Milman)
...@@ -458,6 +458,7 @@ future_schedule_callbacks(FutureObj *fut) ...@@ -458,6 +458,7 @@ future_schedule_callbacks(FutureObj *fut)
return 0; return 0;
} }
static int static int
future_init(FutureObj *fut, PyObject *loop) future_init(FutureObj *fut, PyObject *loop)
{ {
...@@ -465,6 +466,19 @@ future_init(FutureObj *fut, PyObject *loop) ...@@ -465,6 +466,19 @@ future_init(FutureObj *fut, PyObject *loop)
int is_true; int is_true;
_Py_IDENTIFIER(get_debug); _Py_IDENTIFIER(get_debug);
// Same to FutureObj_clear() but not clearing fut->dict
Py_CLEAR(fut->fut_loop);
Py_CLEAR(fut->fut_callback0);
Py_CLEAR(fut->fut_context0);
Py_CLEAR(fut->fut_callbacks);
Py_CLEAR(fut->fut_result);
Py_CLEAR(fut->fut_exception);
Py_CLEAR(fut->fut_source_tb);
fut->fut_state = STATE_PENDING;
fut->fut_log_tb = 0;
fut->fut_blocking = 0;
if (loop == Py_None) { if (loop == Py_None) {
loop = get_event_loop(); loop = get_event_loop();
if (loop == NULL) { if (loop == NULL) {
...@@ -474,7 +488,7 @@ future_init(FutureObj *fut, PyObject *loop) ...@@ -474,7 +488,7 @@ future_init(FutureObj *fut, PyObject *loop)
else { else {
Py_INCREF(loop); Py_INCREF(loop);
} }
Py_XSETREF(fut->fut_loop, loop); fut->fut_loop = loop;
res = _PyObject_CallMethodId(fut->fut_loop, &PyId_get_debug, NULL); res = _PyObject_CallMethodId(fut->fut_loop, &PyId_get_debug, NULL);
if (res == NULL) { if (res == NULL) {
...@@ -486,16 +500,12 @@ future_init(FutureObj *fut, PyObject *loop) ...@@ -486,16 +500,12 @@ future_init(FutureObj *fut, PyObject *loop)
return -1; return -1;
} }
if (is_true) { if (is_true) {
Py_XSETREF(fut->fut_source_tb, _PyObject_CallNoArg(traceback_extract_stack)); fut->fut_source_tb = _PyObject_CallNoArg(traceback_extract_stack);
if (fut->fut_source_tb == NULL) { if (fut->fut_source_tb == NULL) {
return -1; return -1;
} }
} }
fut->fut_callback0 = NULL;
fut->fut_context0 = NULL;
fut->fut_callbacks = NULL;
return 0; return 0;
} }
...@@ -1938,16 +1948,16 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop) ...@@ -1938,16 +1948,16 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop)
return -1; return -1;
} }
self->task_context = PyContext_CopyCurrent(); Py_XSETREF(self->task_context, PyContext_CopyCurrent());
if (self->task_context == NULL) { if (self->task_context == NULL) {
return -1; return -1;
} }
self->task_fut_waiter = NULL; Py_CLEAR(self->task_fut_waiter);
self->task_must_cancel = 0; self->task_must_cancel = 0;
self->task_log_destroy_pending = 1; self->task_log_destroy_pending = 1;
Py_INCREF(coro); Py_INCREF(coro);
self->task_coro = coro; Py_XSETREF(self->task_coro, coro);
if (task_call_step_soon(self, NULL)) { if (task_call_step_soon(self, NULL)) {
return -1; return -1;
......
...@@ -644,7 +644,7 @@ _bz2_BZ2Decompressor___init___impl(BZ2Decompressor *self) ...@@ -644,7 +644,7 @@ _bz2_BZ2Decompressor___init___impl(BZ2Decompressor *self)
self->bzs_avail_in_real = 0; self->bzs_avail_in_real = 0;
self->input_buffer = NULL; self->input_buffer = NULL;
self->input_buffer_size = 0; self->input_buffer_size = 0;
self->unused_data = PyBytes_FromStringAndSize(NULL, 0); Py_XSETREF(self->unused_data, PyBytes_FromStringAndSize(NULL, 0));
if (self->unused_data == NULL) if (self->unused_data == NULL)
goto error; goto error;
......
...@@ -369,8 +369,8 @@ EVP_tp_init(EVPobject *self, PyObject *args, PyObject *kwds) ...@@ -369,8 +369,8 @@ EVP_tp_init(EVPobject *self, PyObject *args, PyObject *kwds)
return -1; return -1;
} }
self->name = name_obj; Py_INCREF(name_obj);
Py_INCREF(self->name); Py_XSETREF(self->name, name_obj);
if (data_obj) { if (data_obj) {
if (view.len >= HASHLIB_GIL_MINSIZE) { if (view.len >= HASHLIB_GIL_MINSIZE) {
......
...@@ -1173,7 +1173,7 @@ _lzma_LZMADecompressor___init___impl(Decompressor *self, int format, ...@@ -1173,7 +1173,7 @@ _lzma_LZMADecompressor___init___impl(Decompressor *self, int format,
self->needs_input = 1; self->needs_input = 1;
self->input_buffer = NULL; self->input_buffer = NULL;
self->input_buffer_size = 0; self->input_buffer_size = 0;
self->unused_data = PyBytes_FromStringAndSize(NULL, 0); Py_XSETREF(self->unused_data, PyBytes_FromStringAndSize(NULL, 0));
if (self->unused_data == NULL) if (self->unused_data == NULL)
goto error; goto error;
......
...@@ -1490,10 +1490,10 @@ property_init_impl(propertyobject *self, PyObject *fget, PyObject *fset, ...@@ -1490,10 +1490,10 @@ property_init_impl(propertyobject *self, PyObject *fget, PyObject *fset,
Py_XINCREF(fdel); Py_XINCREF(fdel);
Py_XINCREF(doc); Py_XINCREF(doc);
self->prop_get = fget; Py_XSETREF(self->prop_get, fget);
self->prop_set = fset; Py_XSETREF(self->prop_set, fset);
self->prop_del = fdel; Py_XSETREF(self->prop_del, fdel);
self->prop_doc = doc; Py_XSETREF(self->prop_doc, doc);
self->getter_doc = 0; self->getter_doc = 0;
/* if no docstring given and the getter has one, use that one */ /* if no docstring given and the getter has one, use that one */
......
...@@ -709,7 +709,7 @@ cm_init(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -709,7 +709,7 @@ cm_init(PyObject *self, PyObject *args, PyObject *kwds)
if (!PyArg_UnpackTuple(args, "classmethod", 1, 1, &callable)) if (!PyArg_UnpackTuple(args, "classmethod", 1, 1, &callable))
return -1; return -1;
Py_INCREF(callable); Py_INCREF(callable);
cm->cm_callable = callable; Py_XSETREF(cm->cm_callable, callable);
return 0; return 0;
} }
...@@ -890,7 +890,7 @@ sm_init(PyObject *self, PyObject *args, PyObject *kwds) ...@@ -890,7 +890,7 @@ sm_init(PyObject *self, PyObject *args, PyObject *kwds)
if (!PyArg_UnpackTuple(args, "staticmethod", 1, 1, &callable)) if (!PyArg_UnpackTuple(args, "staticmethod", 1, 1, &callable))
return -1; return -1;
Py_INCREF(callable); Py_INCREF(callable);
sm->sm_callable = callable; Py_XSETREF(sm->sm_callable, callable);
return 0; return 0;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment