Commit 94739085 authored by Arnaud Fontaine's avatar Arnaud Fontaine

py3: _mysql.string_literal() returns bytes().

And _mysql/mysqldb API (_mysql.connection.query()) converts the query string to
bytes() (additionally, cursor.execute(QUERY, ARGS) calls query() after
converting everything to bytes() too).
parent e8300329
This diff is collapsed.
...@@ -27,6 +27,8 @@ from __future__ import absolute_import ...@@ -27,6 +27,8 @@ from __future__ import absolute_import
# #
############################################################################## ##############################################################################
from Products.ERP5Type.Utils import str2bytes
from Shared.DC.ZRDB.Results import Results from Shared.DC.ZRDB.Results import Results
from Products.CMFActivity.ActivityTool import Message from Products.CMFActivity.ActivityTool import Message
import sys import sys
...@@ -85,7 +87,7 @@ class SQLDict(SQLBase): ...@@ -85,7 +87,7 @@ class SQLDict(SQLBase):
uid = line.uid uid = line.uid
original_uid = path_and_method_id_dict.get(key) original_uid = path_and_method_id_dict.get(key)
if original_uid is None: if original_uid is None:
sql_method_id = " AND method_id = %s AND group_method_id = %s" % ( sql_method_id = b" AND method_id = %s AND group_method_id = %s" % (
quote(method_id), quote(line.group_method_id)) quote(method_id), quote(line.group_method_id))
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
merge_parent = m.activity_kw.get('merge_parent') merge_parent = m.activity_kw.get('merge_parent')
...@@ -102,11 +104,11 @@ class SQLDict(SQLBase): ...@@ -102,11 +104,11 @@ class SQLDict(SQLBase):
uid_list = [] uid_list = []
if path_list: if path_list:
# Select parent messages. # Select parent messages.
result = Results(db.query("SELECT * FROM message" result = Results(db.query(b"SELECT * FROM message"
" WHERE processing_node IN (0, %s) AND path IN (%s)%s" b" WHERE processing_node IN (0, %d) AND path IN (%s)%s"
" ORDER BY path LIMIT 1 FOR UPDATE" % ( b" ORDER BY path LIMIT 1 FOR UPDATE" % (
processing_node, processing_node,
','.join(map(quote, path_list)), b','.join(map(quote, path_list)),
sql_method_id, sql_method_id,
), 0)) ), 0))
if result: # found a parent if result: # found a parent
...@@ -119,11 +121,11 @@ class SQLDict(SQLBase): ...@@ -119,11 +121,11 @@ class SQLDict(SQLBase):
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
# return unreserved similar children # return unreserved similar children
path = line.path path = line.path
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)" b" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)"
"%s FOR UPDATE" % ( b"%s FOR UPDATE" % (
quote(path), quote(path.replace('_', r'\_') + '/%'), quote(path), quote(path.replace('_', r'\_') + '/%'),
sql_method_id, str2bytes(sql_method_id),
), 0)[1] ), 0)[1]
reserve_uid_list = [x for x, in result] reserve_uid_list = [x for x, in result]
uid_list += reserve_uid_list uid_list += reserve_uid_list
...@@ -132,8 +134,8 @@ class SQLDict(SQLBase): ...@@ -132,8 +134,8 @@ class SQLDict(SQLBase):
reserve_uid_list.append(uid) reserve_uid_list.append(uid)
else: else:
# Select duplicates. # Select duplicates.
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % ( b" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % (
quote(path), sql_method_id, quote(path), sql_method_id,
), 0)[1] ), 0)[1]
reserve_uid_list = uid_list = [x for x, in result] reserve_uid_list = uid_list = [x for x, in result]
......
...@@ -1413,7 +1413,7 @@ class ActivityTool (BaseTool): ...@@ -1413,7 +1413,7 @@ class ActivityTool (BaseTool):
path = None if obj is None else '/'.join(obj.getPhysicalPath()) path = None if obj is None else '/'.join(obj.getPhysicalPath())
db = self.getSQLConnection() db = self.getSQLConnection()
quote = db.string_literal quote = db.string_literal
return bool(db.query("(%s)" % ") UNION ALL (".join( return bool(db.query(b"(%s)" % b") UNION ALL (".join(
activity.hasActivitySQL(quote, path=path, **kw) activity.hasActivitySQL(quote, path=path, **kw)
for activity in six.itervalues(activity_dict)))[1]) for activity in six.itervalues(activity_dict)))[1])
......
...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM ...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM
from DateTime import DateTime from DateTime import DateTime
from zLOG import LOG, ERROR, WARNING from zLOG import LOG, ERROR, WARNING
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.ERP5Type.Utils import str2bytes
hosed_connection = ( hosed_connection = (
CR.SERVER_GONE_ERROR, CR.SERVER_GONE_ERROR,
...@@ -203,7 +204,7 @@ def ord_or_None(s): ...@@ -203,7 +204,7 @@ def ord_or_None(s):
return ord(s) return ord(s)
match_select = re.compile( match_select = re.compile(
r'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)', br'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)',
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
).match ).match
...@@ -418,12 +419,14 @@ class DB(TM): ...@@ -418,12 +419,14 @@ class DB(TM):
"""Execute 'query_string' and return at most 'max_rows'.""" """Execute 'query_string' and return at most 'max_rows'."""
self._use_TM and self._register() self._use_TM and self._register()
desc = None desc = None
if not isinstance(query_string, bytes):
query_string = str2bytes(query_string)
# XXX deal with a typical mistake that the user appends # XXX deal with a typical mistake that the user appends
# an unnecessary and rather harmful semicolon at the end. # an unnecessary and rather harmful semicolon at the end.
# Unfortunately, MySQLdb does not want to be graceful. # Unfortunately, MySQLdb does not want to be graceful.
if query_string[-1:] == ';': if query_string[-1:] == b';':
query_string = query_string[:-1] query_string = query_string[:-1]
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
select_match = match_select(qs) select_match = match_select(qs)
...@@ -432,12 +435,12 @@ class DB(TM): ...@@ -432,12 +435,12 @@ class DB(TM):
if query_timeout is not None: if query_timeout is not None:
statement, select = select_match.groups() statement, select = select_match.groups()
if statement: if statement:
statement += ", max_statement_time=%f" % query_timeout statement += b", max_statement_time=%f" % query_timeout
else: else:
statement = "max_statement_time=%f" % query_timeout statement = b"max_statement_time=%f" % query_timeout
qs = "SET STATEMENT %s FOR SELECT %s" % (statement, select) qs = b"SET STATEMENT %s FOR SELECT %s" % (statement, select)
if max_rows: if max_rows:
qs = "%s LIMIT %d" % (qs, max_rows) qs = b"%s LIMIT %d" % (qs, max_rows)
c = self._query(qs) c = self._query(qs)
if c: if c:
if desc is not None is not c.describe(): if desc is not None is not c.describe():
...@@ -640,7 +643,7 @@ class DeferredDB(DB): ...@@ -640,7 +643,7 @@ class DeferredDB(DB):
def query(self, query_string, max_rows=1000): def query(self, query_string, max_rows=1000):
self._register() self._register()
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
if match_select(qs): if match_select(qs):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment