Commit 7b4f9991 authored by Arnaud Fontaine's avatar Arnaud Fontaine

py3: _mysql.string_literal() returns bytes().

And _mysql/mysqldb API (_mysql.connection.query()) converts the query string to
bytes() (additionally, cursor.execute(QUERY, ARGS) calls query() after
converting everything to bytes() too).

TODO: Instead of doing this ourself, should we use cursor.execute(QUERY, ARGS)
      which does conversion to bytes()? Performances?
parent 1af5baf2
...@@ -125,23 +125,23 @@ def sqltest_dict(): ...@@ -125,23 +125,23 @@ def sqltest_dict():
def _(name, column=None, op="="): def _(name, column=None, op="="):
if column is None: if column is None:
column = name column = name
column_op = "%s %s " % (column, op) column_op = ("%s %s " % (column, op)).encode()
def render(value, render_string): def render(value, render_string):
if isinstance(value, _SQLTEST_NO_QUOTE_TYPE_SET): if isinstance(value, _SQLTEST_NO_QUOTE_TYPE_SET):
return column_op + str(value) return column_op + str(value).encode()
if isinstance(value, DateTime): if isinstance(value, DateTime):
value = render_datetime(value) value = render_datetime(value)
if isinstance(value, basestring): if isinstance(value, basestring):
return column_op + render_string(value) return column_op + render_string(value)
assert op == "=", value assert op == "=", value
if value is None: # XXX: see comment in SQLBase._getMessageList if value is None: # XXX: see comment in SQLBase._getMessageList
return column + " IS NULL" return column + b" IS NULL"
for x in value: for x in value:
return "%s IN (%s)" % (column, ', '.join(map( return b"%s IN (%s)" % (column, ', '.join(map(
str if isinstance(x, _SQLTEST_NO_QUOTE_TYPE_SET) else str if isinstance(x, _SQLTEST_NO_QUOTE_TYPE_SET) else
render_datetime if isinstance(x, DateTime) else render_datetime if isinstance(x, DateTime) else
render_string, value))) render_string, value)).encode())
return "0" return b"0"
sqltest_dict[name] = render sqltest_dict[name] = render
_('active_process_uid') _('active_process_uid')
_('group_method_id') _('group_method_id')
...@@ -290,18 +290,18 @@ CREATE TABLE %s ( ...@@ -290,18 +290,18 @@ CREATE TABLE %s (
% (self.sql_table, src)) % (self.sql_table, src))
self._insert_max_payload = (db.getMaxAllowedPacket() self._insert_max_payload = (db.getMaxAllowedPacket()
+ len(self._insert_separator) + len(self._insert_separator)
- len(self._insert_template % (self.sql_table, ''))) - len(self._insert_template % (self.sql_table.encode(), b'')))
def _initialize(self, db, column_list): def _initialize(self, db, column_list):
LOG('CMFActivity', ERROR, "Non-empty %r table upgraded." LOG('CMFActivity', ERROR, "Non-empty %r table upgraded."
" The following added columns could not be initialized: %s" " The following added columns could not be initialized: %s"
% (self.sql_table, ", ".join(column_list))) % (self.sql_table, ", ".join(column_list)))
_insert_template = ("INSERT INTO %s (uid," _insert_template = (b"INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node," b" path, active_process_uid, date, method_id, processing_node,"
" priority, node, group_method_id, tag, serialization_tag," b" priority, node, group_method_id, tag, serialization_tag,"
" message) VALUES\n(%s)") b" message) VALUES\n(%s)")
_insert_separator = "),\n(" _insert_separator = b"),\n("
def _hasDependency(self, message): def _hasDependency(self, message):
get = message.activity_kw.get get = message.activity_kw.get
...@@ -320,9 +320,9 @@ CREATE TABLE %s ( ...@@ -320,9 +320,9 @@ CREATE TABLE %s (
if reset_uid: if reset_uid:
reset_uid = False reset_uid = False
# Overflow will result into IntegrityError. # Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE)) db.query(b"SET @uid := %d" % getrandbits(UID_SAFE_BITSIZE))
try: try:
db.query(self._insert_template % (self.sql_table, values)) db.query(self._insert_template % (self.sql_table.encode(), values))
except MySQLdb.IntegrityError as xxx_todo_changeme: except MySQLdb.IntegrityError as xxx_todo_changeme:
(code, _) = xxx_todo_changeme.args (code, _) = xxx_todo_changeme.args
if code != DUP_ENTRY: if code != DUP_ENTRY:
...@@ -342,18 +342,18 @@ CREATE TABLE %s ( ...@@ -342,18 +342,18 @@ CREATE TABLE %s (
if m.is_registered: if m.is_registered:
active_process_uid = m.active_process_uid active_process_uid = m.active_process_uid
date = m.activity_kw.get('at_date') date = m.activity_kw.get('at_date')
row = ','.join(( row = b','.join((
'@uid+%s' % i, b'@uid+%d' % i,
quote('/'.join(m.object_path)), quote('/'.join(m.object_path)),
'NULL' if active_process_uid is None else str(active_process_uid), b'NULL' if active_process_uid is None else str(active_process_uid).encode(),
"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)), b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id), quote(m.method_id),
'-1' if hasDependency(m) else '0', b'-1' if hasDependency(m) else b'0',
str(m.activity_kw.get('priority', 1)), str(m.activity_kw.get('priority', 1)).encode(),
str(m.activity_kw.get('node', 0)), str(m.activity_kw.get('node', 0)).encode(),
quote(m.getGroupId()), quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')), quote(m.activity_kw.get('tag', b'')),
quote(m.activity_kw.get('serialization_tag', '')), quote(m.activity_kw.get('serialization_tag', b'')),
quote(Message.dump(m)))) quote(Message.dump(m))))
i += 1 i += 1
n = sep_len + len(row) n = sep_len + len(row)
...@@ -374,11 +374,11 @@ CREATE TABLE %s ( ...@@ -374,11 +374,11 @@ CREATE TABLE %s (
# value should be ignored, instead of trying to render them # value should be ignored, instead of trying to render them
# (with comparisons with NULL). # (with comparisons with NULL).
q = db.string_literal q = db.string_literal
sql = '\n AND '.join(sqltest_dict[k](v, q) for k, v in kw.items()) sql = b'\n AND '.join(sqltest_dict[k](v, q) for k, v in kw.items())
sql = "SELECT * FROM %s%s\nORDER BY priority, date, uid%s" % ( sql = b"SELECT * FROM %s%s\nORDER BY priority, date, uid%s" % (
self.sql_table, self.sql_table.encode(),
sql and '\nWHERE ' + sql, sql and b'\nWHERE ' + sql,
'' if count is None else '\nLIMIT %d' % count, b'' if count is None else b'\nLIMIT %d' % count,
) )
return sql if src__ else Results(db.query(sql, max_rows=0)) return sql if src__ else Results(db.query(sql, max_rows=0))
......
...@@ -203,7 +203,7 @@ def ord_or_None(s): ...@@ -203,7 +203,7 @@ def ord_or_None(s):
return ord(s) return ord(s)
match_select = re.compile( match_select = re.compile(
r'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)', rb'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)',
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
).match ).match
...@@ -418,12 +418,14 @@ class DB(TM): ...@@ -418,12 +418,14 @@ class DB(TM):
"""Execute 'query_string' and return at most 'max_rows'.""" """Execute 'query_string' and return at most 'max_rows'."""
self._use_TM and self._register() self._use_TM and self._register()
desc = None desc = None
if not isinstance(query_string, bytes):
query_string = query_string.encode()
# XXX deal with a typical mistake that the user appends # XXX deal with a typical mistake that the user appends
# an unnecessary and rather harmful semicolon at the end. # an unnecessary and rather harmful semicolon at the end.
# Unfortunately, MySQLdb does not want to be graceful. # Unfortunately, MySQLdb does not want to be graceful.
if query_string[-1:] == ';': if query_string[-1:] == b';':
query_string = query_string[:-1] query_string = query_string[:-1]
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
select_match = match_select(qs) select_match = match_select(qs)
...@@ -432,12 +434,12 @@ class DB(TM): ...@@ -432,12 +434,12 @@ class DB(TM):
if query_timeout is not None: if query_timeout is not None:
statement, select = select_match.groups() statement, select = select_match.groups()
if statement: if statement:
statement += ", max_statement_time=%f" % query_timeout statement += b", max_statement_time=%f" % query_timeout
else: else:
statement = "max_statement_time=%f" % query_timeout statement = b"max_statement_time=%f" % query_timeout
qs = "SET STATEMENT %s FOR SELECT %s" % (statement, select) qs = b"SET STATEMENT %s FOR SELECT %s" % (statement, select)
if max_rows: if max_rows:
qs = "%s LIMIT %d" % (qs, max_rows) qs = b"%s LIMIT %d" % (qs, max_rows)
c = self._query(qs) c = self._query(qs)
if c: if c:
if desc is not None is not c.describe(): if desc is not None is not c.describe():
...@@ -457,7 +459,7 @@ class DB(TM): ...@@ -457,7 +459,7 @@ class DB(TM):
return items, result return items, result
def string_literal(self, s): def string_literal(self, s):
return self.db.string_literal(s).decode('utf-8') return self.db.string_literal(s)
def _begin(self, *ignored): def _begin(self, *ignored):
"""Begin a transaction (when TM is enabled).""" """Begin a transaction (when TM is enabled)."""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment