Commit a4d15b91 authored by Yoshinori Okuji's avatar Yoshinori Okuji

Fix arguments to handleAnswerTIDs. Change the MySQL database to use BIGINT...

Fix arguments to handleAnswerTIDs. Change the MySQL database to use BIGINT rather than BINARY for oid, serial and tid.

git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@165 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 99d73db1
...@@ -238,7 +238,7 @@ class EventHandler(object): ...@@ -238,7 +238,7 @@ class EventHandler(object):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
def handleAnswerTIDs(self, conn, packet, tid_list): def handleAnswerTIDs(self, conn, packet, tid_list):
self.handleUnexpectedPacket(conn, conn, packet, packet) self.handleUnexpectedPacket(conn, packet)
def handleAskTransactionInformation(self, conn, packet, tid): def handleAskTransactionInformation(self, conn, packet, tid):
self.handleUnexpectedPacket(conn, packet) self.handleUnexpectedPacket(conn, packet)
......
...@@ -4,12 +4,19 @@ from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST ...@@ -4,12 +4,19 @@ from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
import logging import logging
from array import array from array import array
import string import string
from struct import pack, unpack
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.util import dump from neo.util import dump
from neo.protocol import DISCARDED_STATE from neo.protocol import DISCARDED_STATE
def p64(n):
return pack('!Q', n)
def u64(s):
return unpack('!Q', s)[0]
class MySQLDatabaseManager(DatabaseManager): class MySQLDatabaseManager(DatabaseManager):
"""This class manages a database on MySQL.""" """This class manages a database on MySQL."""
...@@ -54,7 +61,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -54,7 +61,7 @@ class MySQLDatabaseManager(DatabaseManager):
conn = self.conn conn = self.conn
try: try:
printable_char_list = [] printable_char_list = []
for c in query.split('\n', 1)[0]: for c in query.split('\n', 1)[0][:70]:
if c not in string.printable or c in '\t\x0b\x0c\r': if c not in string.printable or c in '\t\x0b\x0c\r':
c = '\\x%02x' % ord(c) c = '\\x%02x' % ord(c)
printable_char_list.append(c) printable_char_list.append(c)
...@@ -109,7 +116,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -109,7 +116,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "trans" stores information on committed transactions. # The table "trans" stores information on committed transactions.
q("""CREATE TABLE IF NOT EXISTS trans ( q("""CREATE TABLE IF NOT EXISTS trans (
tid BINARY(8) NOT NULL PRIMARY KEY, tid BIGINT UNSIGNED NOT NULL PRIMARY KEY,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
description BLOB NOT NULL, description BLOB NOT NULL,
...@@ -118,8 +125,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -118,8 +125,8 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "obj" stores committed object data. # The table "obj" stores committed object data.
q("""CREATE TABLE IF NOT EXISTS obj ( q("""CREATE TABLE IF NOT EXISTS obj (
oid BINARY(8) NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BINARY(8) NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL,
checksum INT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL,
value MEDIUMBLOB NOT NULL, value MEDIUMBLOB NOT NULL,
...@@ -128,7 +135,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -128,7 +135,7 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans ( q("""CREATE TABLE IF NOT EXISTS ttrans (
tid BINARY(8) NOT NULL, tid BIGINT UNSIGNED NOT NULL,
oids MEDIUMBLOB NOT NULL, oids MEDIUMBLOB NOT NULL,
user BLOB NOT NULL, user BLOB NOT NULL,
description BLOB NOT NULL, description BLOB NOT NULL,
...@@ -137,8 +144,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -137,8 +144,8 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "tobj" stores uncommitted object data. # The table "tobj" stores uncommitted object data.
q("""CREATE TABLE IF NOT EXISTS tobj ( q("""CREATE TABLE IF NOT EXISTS tobj (
oid BINARY(8) NOT NULL, oid BIGINT UNSIGNED NOT NULL,
serial BINARY(8) NOT NULL, serial BIGINT UNSIGNED NOT NULL,
compression TINYINT UNSIGNED NOT NULL, compression TINYINT UNSIGNED NOT NULL,
checksum INT UNSIGNED NOT NULL, checksum INT UNSIGNED NOT NULL,
value MEDIUMBLOB NOT NULL value MEDIUMBLOB NOT NULL
...@@ -224,6 +231,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -224,6 +231,8 @@ class MySQLDatabaseManager(DatabaseManager):
if loid is None or (tmp_loid is not None and loid < tmp_loid): if loid is None or (tmp_loid is not None and loid < tmp_loid):
loid = tmp_loid loid = tmp_loid
self.commit() self.commit()
if loid is not None:
loid = p64(loid)
return loid return loid
def getLastTID(self, all = True): def getLastTID(self, all = True):
...@@ -247,6 +256,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -247,6 +256,8 @@ class MySQLDatabaseManager(DatabaseManager):
if ltid is None or (tmp_serial is not None and ltid < tmp_serial): if ltid is None or (tmp_serial is not None and ltid < tmp_serial):
ltid = tmp_serial ltid = tmp_serial
self.commit() self.commit()
if ltid is not None:
ltid = p64(ltid)
return ltid return ltid
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
...@@ -254,22 +265,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -254,22 +265,21 @@ class MySQLDatabaseManager(DatabaseManager):
tid_set = set() tid_set = set()
self.begin() self.begin()
r = q("""SELECT tid FROM ttrans""") r = q("""SELECT tid FROM ttrans""")
tid_set.update((t[0] for t in r)) tid_set.update((p64(t[0]) for t in r))
r = q("""SELECT serial FROM tobj""") r = q("""SELECT serial FROM tobj""")
self.commit() self.commit()
tid_set.update((t[0] for t in r)) tid_set.update((p64(t[0]) for t in r))
return list(tid_set) return list(tid_set)
def objectPresent(self, oid, tid, all = True): def objectPresent(self, oid, tid, all = True):
q = self.query q = self.query
e = self.escape oid = u64(oid)
oid = e(oid) tid = u64(tid)
tid = e(tid)
self.begin() self.begin()
r = q("""SELECT oid FROM obj WHERE oid = '%s' AND serial = '%s'""" \ r = q("""SELECT oid FROM obj WHERE oid = %d AND serial = %d""" \
% (oid, tid)) % (oid, tid))
if not r and all: if not r and all:
r = q("""SELECT oid FROM tobj WHERE oid = '%s' AND serial = '%s'""" \ r = q("""SELECT oid FROM tobj WHERE oid = %d AND serial = %d""" \
% (oid, tid)) % (oid, tid))
self.commit() self.commit()
if r: if r:
...@@ -278,12 +288,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -278,12 +288,11 @@ class MySQLDatabaseManager(DatabaseManager):
def getObject(self, oid, tid = None, before_tid = None): def getObject(self, oid, tid = None, before_tid = None):
q = self.query q = self.query
e = self.escape oid = u64(oid)
oid = e(oid)
if tid is not None: if tid is not None:
tid = e(tid) tid = u64(tid)
r = q("""SELECT serial, compression, checksum, value FROM obj r = q("""SELECT serial, compression, checksum, value FROM obj
WHERE oid = '%s' AND serial = '%s'""" \ WHERE oid = %d AND serial = %d""" \
% (oid, tid)) % (oid, tid))
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data = r[0]
...@@ -291,15 +300,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -291,15 +300,15 @@ class MySQLDatabaseManager(DatabaseManager):
except IndexError: except IndexError:
return None return None
elif before_tid is not None: elif before_tid is not None:
before_tid = e(before_tid) before_tid = u64(before_tid)
r = q("""SELECT serial, compression, checksum, value FROM obj r = q("""SELECT serial, compression, checksum, value FROM obj
WHERE oid = '%s' AND serial < '%s' WHERE oid = %d AND serial < %d
ORDER BY serial DESC LIMIT 1""" \ ORDER BY serial DESC LIMIT 1""" \
% (oid, before_tid)) % (oid, before_tid))
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data = r[0]
r = q("""SELECT serial FROM obj r = q("""SELECT serial FROM obj
WHERE oid = '%s' AND serial > '%s' WHERE oid = %d AND serial > %d
ORDER BY serial LIMIT 1""" \ ORDER BY serial LIMIT 1""" \
% (oid, before_tid)) % (oid, before_tid))
try: try:
...@@ -312,7 +321,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -312,7 +321,7 @@ class MySQLDatabaseManager(DatabaseManager):
# XXX I want to express "HAVING serial = MAX(serial)", but # XXX I want to express "HAVING serial = MAX(serial)", but
# MySQL does not use an index for a HAVING clause! # MySQL does not use an index for a HAVING clause!
r = q("""SELECT serial, compression, checksum, value FROM obj r = q("""SELECT serial, compression, checksum, value FROM obj
WHERE oid = '%s' ORDER BY serial DESC LIMIT 1""" \ WHERE oid = %d ORDER BY serial DESC LIMIT 1""" \
% oid) % oid)
try: try:
serial, compression, checksum, data = r[0] serial, compression, checksum, data = r[0]
...@@ -320,6 +329,10 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -320,6 +329,10 @@ class MySQLDatabaseManager(DatabaseManager):
except IndexError: except IndexError:
return None return None
if serial is not None:
serial = p64(serial)
if next_serial is not None:
next_serial = p64(next_serial)
return serial, next_serial, compression, checksum, data return serial, next_serial, compression, checksum, data
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
...@@ -365,7 +378,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -365,7 +378,7 @@ class MySQLDatabaseManager(DatabaseManager):
def storeTransaction(self, tid, object_list, transaction): def storeTransaction(self, tid, object_list, transaction):
q = self.query q = self.query
e = self.escape e = self.escape
tid = e(tid) tid = u64(tid)
self.begin() self.begin()
try: try:
# XXX it might be more efficient to insert multiple objects # XXX it might be more efficient to insert multiple objects
...@@ -375,9 +388,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -375,9 +388,9 @@ class MySQLDatabaseManager(DatabaseManager):
# tobj has no index, so inserting one by one should not be # tobj has no index, so inserting one by one should not be
# significantly different from inserting many at a time. # significantly different from inserting many at a time.
for oid, compression, checksum, data in object_list: for oid, compression, checksum, data in object_list:
oid = e(oid) oid = u64(oid)
data = e(data) data = e(data)
q("""INSERT INTO tobj VALUES ('%s', '%s', %d, %d, '%s')""" \ q("""INSERT INTO tobj VALUES (%d, %d, %d, %d, '%s')""" \
% (oid, tid, compression, checksum, data)) % (oid, tid, compression, checksum, data))
if transaction is not None: if transaction is not None:
oid_list, user, desc, ext = transaction oid_list, user, desc, ext = transaction
...@@ -385,7 +398,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -385,7 +398,7 @@ class MySQLDatabaseManager(DatabaseManager):
user = e(user) user = e(user)
desc = e(desc) desc = e(desc)
ext = e(ext) ext = e(ext)
q("""INSERT INTO ttrans VALUES ('%s', '%s', '%s', '%s', '%s')""" \ q("""INSERT INTO ttrans VALUES (%d, '%s', '%s', '%s', '%s')""" \
% (tid, oids, user, desc, ext)) % (tid, oids, user, desc, ext))
except: except:
self.rollback() self.rollback()
...@@ -394,16 +407,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -394,16 +407,15 @@ class MySQLDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
e = self.escape tid = u64(tid)
tid = e(tid)
self.begin() self.begin()
try: try:
q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = '%s'""" \ q("""INSERT INTO obj SELECT * FROM tobj WHERE tobj.serial = %d""" \
% tid) % tid)
q("""DELETE FROM tobj WHERE serial = '%s'""" % tid) q("""DELETE FROM tobj WHERE serial = %d""" % tid)
q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = '%s'""" \ q("""INSERT INTO trans SELECT * FROM ttrans WHERE ttrans.tid = %d""" \
% tid) % tid)
q("""DELETE FROM ttrans WHERE tid = '%s'""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
except: except:
self.rollback() self.rollback()
raise raise
...@@ -411,16 +423,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -411,16 +423,15 @@ class MySQLDatabaseManager(DatabaseManager):
def deleteTransaction(self, tid, all = False): def deleteTransaction(self, tid, all = False):
q = self.query q = self.query
e = self.escape tid = u64(tid)
tid = e(tid)
self.begin() self.begin()
try: try:
q("""DELETE FROM tobj WHERE serial = '%s'""" % tid) q("""DELETE FROM tobj WHERE serial = %d""" % tid)
q("""DELETE FROM ttrans WHERE tid = '%s'""" % tid) q("""DELETE FROM ttrans WHERE tid = %d""" % tid)
if all: if all:
# Note that this can be very slow. # Note that this can be very slow.
q("""DELETE FROM obj WHERE serial = '%s'""" % tid) q("""DELETE FROM obj WHERE serial = %d""" % tid)
q("""DELETE FROM trans WHERE tid = '%s'""" % tid) q("""DELETE FROM trans WHERE tid = %d""" % tid)
except: except:
self.rollback() self.rollback()
raise raise
...@@ -428,21 +439,20 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -428,21 +439,20 @@ class MySQLDatabaseManager(DatabaseManager):
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
q = self.query q = self.query
e = self.escape tid = u64(tid)
tid = e(tid)
self.begin() self.begin()
r = q("""SELECT oids, user, description, ext FROM trans r = q("""SELECT oids, user, description, ext FROM trans
WHERE tid = '%s'""" \ WHERE tid = %d""" \
% tid) % tid)
if not r and all: if not r and all:
r = q("""SELECT oids, user, description, ext FROM ttrans r = q("""SELECT oids, user, description, ext FROM ttrans
WHERE tid = '%s'""" \ WHERE tid = %d""" \
% tid) % tid)
self.commit() self.commit()
if r: if r:
oids, user, desc, ext = r[0] oids, user, desc, ext = r[0]
if (len(oids) % 8) != 0 or len(oids) == 0: if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids for tid %s' % dump(tid)) raise DatabaseFailure('invalid oids for tid %x' % tid)
oid_list = [] oid_list = []
for i in xrange(0, len(oids), 8): for i in xrange(0, len(oids), 8):
oid_list.append(oids[i:i+8]) oid_list.append(oids[i:i+8])
...@@ -451,19 +461,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -451,19 +461,19 @@ class MySQLDatabaseManager(DatabaseManager):
def getObjectHistory(self, oid, length = 1): def getObjectHistory(self, oid, length = 1):
q = self.query q = self.query
e = self.escape oid = u64(oid)
oid = e(oid) r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = %d
r = q("""SELECT serial, LENGTH(value) FROM obj WHERE oid = '%s'
ORDER BY serial DESC LIMIT %d""" \ ORDER BY serial DESC LIMIT %d""" \
% (oid, length)) % (oid, length))
if r: if r:
return r return [(p64(serial), length) for serial, length in r]
return None return None
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query q = self.query
e = self.escape
r = q("""SELECT tid FROM trans WHERE MOD(tid,%d) in (%s) r = q("""SELECT tid FROM trans WHERE MOD(tid,%d) in (%s)
ORDER BY tid DESC LIMIT %d""" \ ORDER BY tid DESC LIMIT %d,%d""" \
% (offset, num_partitions, ','.join(partition_list), length)) % (num_partitions,
return [t[0] for t in r] ','.join([str(p) for p in partition_list]),
offset, length))
return [p64(t[0]) for t in r]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment