Commit f0713aa3 authored by Julien Muchembled's avatar Julien Muchembled

storage: fix store of multiple values that only differ by the compression flag

This fixes the case of an application that would store 2 values X & Y
where NEO internally compresses X into a value identical to Y.
parent 8204e541
...@@ -195,9 +195,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -195,9 +195,10 @@ class MySQLDatabaseManager(DatabaseManager):
# but 'UNIQUE' constraint would not work as expected. # but 'UNIQUE' constraint would not work as expected.
q("""CREATE TABLE IF NOT EXISTS data ( q("""CREATE TABLE IF NOT EXISTS data (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
hash BINARY(20) NOT NULL UNIQUE, hash BINARY(20) NOT NULL,
compression TINYINT UNSIGNED NULL, compression TINYINT UNSIGNED NULL,
value LONGBLOB NULL value LONGBLOB NULL,
UNIQUE (hash, compression)
) ENGINE=""" + engine) ) ENGINE=""" + engine)
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
...@@ -443,9 +444,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -443,9 +444,10 @@ class MySQLDatabaseManager(DatabaseManager):
(checksum, compression, e(data))) (checksum, compression, e(data)))
except IntegrityError, (code, _): except IntegrityError, (code, _):
if code == DUP_ENTRY: if code == DUP_ENTRY:
(r, c, d), = self.query("SELECT id, compression, value" (r, d), = self.query("SELECT id, value FROM data"
" FROM data WHERE hash='%s'" % checksum) " WHERE hash='%s' AND compression=%s"
if c == compression and d == data: % (checksum, compression))
if d == data:
return r return r
raise raise
return self.conn.insert_id() return self.conn.insert_id()
......
...@@ -25,11 +25,15 @@ from neo.lib import logging, util ...@@ -25,11 +25,15 @@ from neo.lib import logging, util
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
def unique_constraint_message(table, column): def unique_constraint_message(table, *columns):
c = sqlite3.connect(":memory:") c = sqlite3.connect(":memory:")
c.execute("CREATE TABLE %s (%s UNIQUE)" % (table, column)) values = '?' * len(columns)
insert = "INSERT INTO %s VALUES(%s)" % (table, ', '.join(values))
x = "%s (%s)" % (table, ', '.join(columns))
c.execute("CREATE TABLE " + x)
c.execute("CREATE UNIQUE INDEX i ON " + x)
try: try:
c.executemany("INSERT INTO %s VALUES(?)" % table, 'xx') c.executemany(insert, (values, values))
except sqlite3.IntegrityError, e: except sqlite3.IntegrityError, e:
return e.args[0] return e.args[0]
assert False assert False
...@@ -155,9 +159,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -155,9 +159,12 @@ class SQLiteDatabaseManager(DatabaseManager):
# The table "data" stores object data. # The table "data" stores object data.
q("""CREATE TABLE IF NOT EXISTS data ( q("""CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
hash BLOB NOT NULL UNIQUE, hash BLOB NOT NULL,
compression INTEGER, compression INTEGER NOT NULL,
value BLOB) value BLOB NULL)
""")
q("""CREATE UNIQUE INDEX IF NOT EXISTS _data_i1 ON
data(hash, compression)
""") """)
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
...@@ -369,16 +376,17 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -369,16 +376,17 @@ class SQLiteDatabaseManager(DatabaseManager):
% ",".join(map(str, data_id_list))) % ",".join(map(str, data_id_list)))
def storeData(self, checksum, data, compression, def storeData(self, checksum, data, compression,
_dup_hash=unique_constraint_message("data", "hash")): _dup=unique_constraint_message("data", "hash", "compression")):
H = buffer(checksum) H = buffer(checksum)
try: try:
return self.query("INSERT INTO data VALUES (NULL,?,?,?)", return self.query("INSERT INTO data VALUES (NULL,?,?,?)",
(H, compression, buffer(data))).lastrowid (H, compression, buffer(data))).lastrowid
except sqlite3.IntegrityError, e: except sqlite3.IntegrityError, e:
if e.args[0] == _dup_hash: if e.args[0] == _dup:
(r, c, d), = self.query("SELECT id, compression, value" (r, d), = self.query("SELECT id, value FROM data"
" FROM data WHERE hash=?", (H,)) " WHERE hash=? AND compression=?",
if c == compression and str(d) == data: (H, compression))
if str(d) == data:
return r return r
raise raise
......
...@@ -309,10 +309,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -309,10 +309,10 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
def getDataLockInfo(self): def getDataLockInfo(self):
dm = self.dm dm = self.dm
checksum_dict = dict(dm.query("SELECT id, hash FROM data")) index = tuple(dm.query("SELECT id, hash, compression FROM data"))
assert set(dm._uncommitted_data).issubset(checksum_dict) assert set(dm._uncommitted_data).issubset(x[0] for x in index)
get = dm._uncommitted_data.get get = dm._uncommitted_data.get
return {str(v): get(k, 0) for k, v in checksum_dict.iteritems()} return {(str(h), c): get(i, 0) for i, h, c in index}
class ClientApplication(Node, neo.client.app.Application): class ClientApplication(Node, neo.client.app.Application):
......
...@@ -19,6 +19,7 @@ import threading ...@@ -19,6 +19,7 @@ import threading
import transaction import transaction
import unittest import unittest
from thread import get_ident from thread import get_ident
from zlib import compress
from persistent import Persistent from persistent import Persistent
from ZODB import POSException from ZODB import POSException
from neo.storage.transactions import TransactionManager, \ from neo.storage.transactions import TransactionManager, \
...@@ -47,17 +48,22 @@ class Test(NEOThreadedTest): ...@@ -47,17 +48,22 @@ class Test(NEOThreadedTest):
cluster.start() cluster.start()
storage = cluster.getZODBStorage() storage = cluster.getZODBStorage()
data_info = {} data_info = {}
for data in 'foo', '', 'foo': compressible = 'x' * 20
checksum = makeChecksum(data) compressed = compress(compressible)
for data in 'foo', '', 'foo', compressed, compressible:
if data is compressible:
key = makeChecksum(compressed), 1
else:
key = makeChecksum(data), 0
oid = storage.new_oid() oid = storage.new_oid()
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn) r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn) r2 = storage.tpc_vote(txn)
data_info[checksum] = 1 data_info[key] = 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
serial = storage.tpc_finish(txn) serial = storage.tpc_finish(txn)
data_info[checksum] = 0 data_info[key] = 0
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
self.assertEqual((data, serial), storage.load(oid, '')) self.assertEqual((data, serial), storage.load(oid, ''))
storage._cache.clear() storage._cache.clear()
...@@ -99,14 +105,14 @@ class Test(NEOThreadedTest): ...@@ -99,14 +105,14 @@ class Test(NEOThreadedTest):
data_info = {} data_info = {}
data = 'foo' data = 'foo'
checksum = makeChecksum(data) key = makeChecksum(data), 0
oid = storage.new_oid() oid = storage.new_oid()
txn = transaction.Transaction() txn = transaction.Transaction()
storage.tpc_begin(txn) storage.tpc_begin(txn)
r1 = storage.store(oid, None, data, '', txn) r1 = storage.store(oid, None, data, '', txn)
r2 = storage.tpc_vote(txn) r2 = storage.tpc_vote(txn)
tid = storage.tpc_finish(txn) tid = storage.tpc_finish(txn)
data_info[checksum] = 0 data_info[key] = 0
storage.sync() storage.sync()
txn = [transaction.Transaction() for x in xrange(3)] txn = [transaction.Transaction() for x in xrange(3)]
...@@ -117,21 +123,21 @@ class Test(NEOThreadedTest): ...@@ -117,21 +123,21 @@ class Test(NEOThreadedTest):
tid = None tid = None
for t in txn: for t in txn:
storage.tpc_vote(t) storage.tpc_vote(t)
data_info[checksum] = 3 data_info[key] = 3
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[1]) storage.tpc_abort(txn[1])
storage.sync() storage.sync()
data_info[checksum] -= 1 data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
tid1 = storage.tpc_finish(txn[2]) tid1 = storage.tpc_finish(txn[2])
data_info[checksum] -= 1 data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
storage.tpc_abort(txn[0]) storage.tpc_abort(txn[0])
storage.sync() storage.sync()
data_info[checksum] -= 1 data_info[key] -= 1
self.assertEqual(data_info, cluster.storage.getDataLockInfo()) self.assertEqual(data_info, cluster.storage.getDataLockInfo())
finally: finally:
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment