Commit 8b43f516 authored by Jérome Perrin's avatar Jérome Perrin

upgradeSchema bug fixes

Add some tests and fix bugs:
* https://nexedi.erp5.net/bug_module/20170426-A3962E
* another bug that columns names were not escaped ( in a project we have a custom table with a column named `use` and this breaks `upgradeSchema`  )

/reviewed-on !854
parents dacf12bb 0016855f
......@@ -42,6 +42,7 @@ from Acquisition import aq_base, aq_inner, aq_parent, ImplicitAcquisitionWrapper
from Products.CMFActivity.ActiveObject import ActiveObject
from Products.CMFActivity.ActivityTool import GroupedMessage
from Products.ERP5Type.TransactionalVariable import getTransactionalVariable
from Products.ZMySQLDA.DA import DeferredConnection
from AccessControl.PermissionRole import rolesForPermissionOn
......@@ -1361,20 +1362,45 @@ class CatalogTool (UniqueObject, ZCatalog, CMFCoreCatalogTool, ActiveObject):
security.declareProtected(Permissions.ManagePortal, 'upgradeSchema')
def upgradeSchema(self, sql_catalog_id=None, src__=0):
"""Upgrade all catalog tables, with ALTER or CREATE queries"""
portal = self.getPortalObject()
catalog = self.getSQLCatalog(sql_catalog_id)
connection_id = catalog.z_create_catalog.connection_id
src = []
db = self.getPortalObject()[connection_id]()
with db.lock():
for clear_method in catalog.sql_clear_catalog:
r = catalog[clear_method]._upgradeSchema(
connection_id, create_if_not_exists=1, src__=1)
if r:
src.append(r)
if not src__:
for r in src:
db.query(r)
return src
# group methods by connection
method_list_by_connection_id = defaultdict(list)
for method_id in catalog.sql_clear_catalog:
method = catalog[method_id]
method_list_by_connection_id[method.connection_id].append(method)
# Because we cannot select on deferred connections, _upgradeSchema
# cannot be used on SQL methods using a deferred connection.
# We try to find a "non deferred" connection using the same connection
# string and we'll use it instead.
connection_by_connection_id = {}
for connection_id in method_list_by_connection_id:
connection = portal[connection_id]
connection_string = connection.connection_string
connection_by_connection_id[connection_id] = connection
if isinstance(connection, DeferredConnection):
for other_connection in portal.objectValues(
spec=('Z MySQL Database Connection',)):
if connection_string == other_connection.connection_string:
connection_by_connection_id[connection_id] = other_connection
break
queries_by_connection_id = defaultdict(list)
for connection_id, method_list in method_list_by_connection_id.items():
connection = connection_by_connection_id[connection_id]
db = connection()
with db.lock():
for method in method_list:
query = method._upgradeSchema(connection.getId(), create_if_not_exists=1, src__=1)
if query:
queries_by_connection_id[connection_id].append(query)
if not src__:
for query in queries_by_connection_id[connection_id]:
db.query(query)
return sum(queries_by_connection_id.values(), [])
security.declarePublic('getDocumentValueList')
def getDocumentValueList(self, sql_catalog_id=None,
......
......@@ -34,6 +34,7 @@ import httplib
from AccessControl import getSecurityManager
from AccessControl.SecurityManagement import newSecurityManager
from DateTime import DateTime
from _mysql_exceptions import ProgrammingError
from OFS.ObjectManager import ObjectManager
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
from Products.ERP5Type.tests.utils import LogInterceptor, createZODBPythonScript, todo_erp5, getExtraSqlConnectionStringList
......@@ -3830,7 +3831,135 @@ VALUES
# but a proper page
self.assertIn('<title>Catalog Tool - portal_catalog', ret.getBody())
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestERP5Catalog))
return suite
class CatalogToolUpgradeSchemaTestCase(ERP5TypeTestCase):
"""Tests for "upgrade schema" feature of ERP5 Catalog.
"""
def getBusinessTemplateList(self):
return ("erp5_full_text_mroonga_catalog",)
def afterSetUp(self):
# Add two connections
db1, db2 = getExtraSqlConnectionStringList()[:2]
addConnection = self.portal.manage_addProduct[
"ZMySQLDA"].manage_addZMySQLConnection
addConnection("erp5_test_connection_1", "", db1)
addConnection("erp5_test_connection_2", "", db2)
addConnection("erp5_test_connection_deferred_2", "", db2, deferred=True)
self.catalog_tool = self.portal.portal_catalog
self.catalog = self.catalog_tool.newContent(portal_type="Catalog")
self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id="z_create_catalog",
src="CREATE TABLE dummy_catalog (uid int)")
# These will be cleaned up at tear down
self._db1_table_list = ["dummy_catalog"]
self._db2_table_list = []
def beforeTearDown(self):
for table in self._db1_table_list:
self.query_connection_1("DROP TABLE IF EXISTS `%s`" % table)
for table in self._db2_table_list:
self.query_connection_2("DROP TABLE IF EXISTS `%s`" % table)
self.portal.manage_delObjects([
"erp5_test_connection_1",
"erp5_test_connection_2",
"erp5_test_connection_deferred_2"])
self.commit()
def query_connection_1(self, q):
return self.portal.erp5_test_connection_1().query(q)
def query_connection_2(self, q):
return self.portal.erp5_test_connection_2().query(q)
def upgradeSchema(self):
self.assertTrue(
self.catalog_tool.upgradeSchema(
sql_catalog_id=self.catalog.getId(), src__=True))
self.catalog_tool.upgradeSchema(sql_catalog_id=self.catalog.getId())
self.assertFalse(
self.catalog_tool.upgradeSchema(
sql_catalog_id=self.catalog.getId(), src__=True))
def test_upgradeSchema_add_table(self):
self._db1_table_list.append("add_table")
method = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id=self.id(),
src="CREATE TABLE add_table (a int)")
self.catalog.setSqlClearCatalogList([method.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT a from add_table")
def test_upgradeSchema_alter_table(self):
self._db1_table_list.append("altered_table")
self.query_connection_1("CREATE TABLE altered_table (a int)")
self.commit()
method = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
id=self.id(),
src="CREATE TABLE altered_table (a int, b int)")
self.catalog.setSqlClearCatalogList([method.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT b from altered_table")
def test_upgradeSchema_multi_connections(self):
# Check that we can upgrade tables on more than one connection,
# like when using an external datawarehouse. This is a reproduction
# for https://nexedi.erp5.net/bug_module/20170426-A3962E
# In this test we use both "normal" and deferred connections,
# which is what happens in default erp5 catalog.
self._db1_table_list.append("table1")
self.query_connection_1("CREATE TABLE table1 (a int)")
self._db2_table_list.extend(("table2", "table_deferred2"))
self.query_connection_2("CREATE TABLE table2 (a int)")
self.query_connection_2("CREATE TABLE table_deferred2 (a int)")
self.commit()
method1 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_1",
src="CREATE TABLE table1 (a int, b int)")
method2 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_2",
src="CREATE TABLE table2 (a int, b int)")
method_deferred2 = self.catalog.newContent(
portal_type="SQL Method",
connection_id="erp5_test_connection_deferred_2",
src="CREATE TABLE table_deferred2 (a int, b int)")
self.catalog.setSqlClearCatalogList(
[method1.getId(),
method2.getId(),
method_deferred2.getId()])
self.commit()
self.upgradeSchema()
self.commit()
self.query_connection_1("SELECT b from table1")
self.query_connection_2("SELECT b from table2")
self.query_connection_2("SELECT b from table_deferred2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table2' doesn't exist"):
self.query_connection_1("SELECT b from table2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table_deferred2' doesn't exist"):
self.query_connection_1("SELECT b from table_deferred2")
with self.assertRaisesRegexp(ProgrammingError,
r"Table '.*\.table1' doesn't exist"):
self.query_connection_2("SELECT b from table1")
......@@ -530,7 +530,7 @@ class DB(TM):
# already done it (in case that it plans to execute the returned query).
with (nested if src__ else self.lock)():
try:
old_list, old_set, old_default = self._getTableSchema(name)
old_list, old_set, old_default = self._getTableSchema("`%s`" % name)
except ProgrammingError, e:
if e[0] != ER.NO_SUCH_TABLE or not create_if_not_exists:
raise
......@@ -538,7 +538,7 @@ class DB(TM):
self.query(create_sql)
return create_sql
name_new = '_%s_new' % name
name_new = '`_%s_new`' % name
self.query('CREATE TEMPORARY TABLE %s %s'
% (name_new, create_sql[m.end():]))
try:
......@@ -559,7 +559,7 @@ class DB(TM):
old_dict[column] = pos, spec
pos += 1
else:
q("DROP COLUMN " + column)
q("DROP COLUMN `%s`" % column)
for key in old_set - new_set:
if "PRIMARY" in key:
......@@ -574,26 +574,26 @@ class DB(TM):
try:
old = old_dict[column]
except KeyError:
q("ADD COLUMN %s %s %s" % (column, spec, where))
q("ADD COLUMN `%s` %s %s" % (column, spec, where))
column_list.append(column)
else:
if old != (pos, spec):
q("MODIFY COLUMN %s %s %s" % (column, spec, where))
q("MODIFY COLUMN `%s` %s %s" % (column, spec, where))
if old[1] != spec:
column_list.append(column)
pos += 1
where = "AFTER " + column
where = "AFTER `%s`" % column
for key in new_set - old_set:
q("ADD " + key)
if src:
src = "ALTER TABLE %s%s" % (name, ','.join("\n " + q
src = "ALTER TABLE `%s`%s" % (name, ','.join("\n " + q
for q in src))
if not src__:
self.query(src)
if column_list and initialize and self.query(
"SELECT 1 FROM " + name, 1)[1]:
"SELECT 1 FROM `%s`" % name, 1)[1]:
initialize(self, column_list)
return src
......
##############################################################################
# coding: utf-8
# Copyright (c) 2019 Nexedi SA and Contributors. All Rights Reserved.
# Jérome Perrin <jerome@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsability of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# garantees and support are strongly adviced to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
##############################################################################
from textwrap import dedent
from _mysql_exceptions import OperationalError
from Shared.DC.ZRDB.DA import DA
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
class TestTableStructureMigrationTestCase(ERP5TypeTestCase):
def getBusinessTemplateList(self):
return 'erp5_full_text_mroonga_catalog',
def beforeTearDown(self):
self.portal.erp5_sql_connection().query('DROP table if exists X')
self.portal.erp5_sql_connection().query('DROP table if exists `table`')
self.commit()
def query(self, q):
return self.portal.erp5_sql_connection().query(q)
def check_upgrade_schema(self, previous_schema, new_schema, table_name='X'):
self.query(previous_schema)
da = DA(
id=self.id(),
title=self.id(),
connection_id=self.portal.erp5_sql_connection.getId(),
arguments=(),
template=new_schema).__of__(self.portal)
self.assertTrue(da._upgradeSchema(src__=True))
da._upgradeSchema()
self.assertFalse(da._upgradeSchema(src__=True))
self.assertEqual(
new_schema,
self.query('SHOW CREATE TABLE `%s`' % table_name)[1][0][1])
def test_add_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT a, b FROM X")
def test_remove_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT b FROM X")
with self.assertRaisesRegexp(OperationalError,
"Unknown column 'a' in 'field list'"):
self.query("SELECT a FROM X")
def test_rename_column(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`b` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT b FROM X")
with self.assertRaisesRegexp(OperationalError,
"Unknown column 'a' in 'field list'"):
self.query("SELECT a FROM X")
def test_change_column_type(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` varchar(10) COLLATE utf8_unicode_ci DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
# insterting 1 will be casted as string
self.query("INSERT INTO X VALUES (1)")
self.assertEqual(('1',), self.query("SELECT a FROM X")[1][0])
def test_change_column_default(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT 123
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("INSERT INTO X VALUES ()")
self.assertEqual((123,), self.query("SELECT a FROM X")[1][0])
def test_add_index(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
KEY `idx_a` (`a`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
self.query("SELECT * FROM X USE INDEX (`idx_a`)")
def test_remove_index(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL,
KEY `idx_a` (`a`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `X` (
`a` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""))
with self.assertRaisesRegexp(OperationalError,
"Key 'idx_a' doesn't exist in table 'X'"):
self.query("SELECT * FROM X USE INDEX (`idx_a`)")
def test_escape(self):
self.check_upgrade_schema(
dedent(
"""\
CREATE TABLE `table` (
`drop` int(11) DEFAULT NULL,
`alter` int(11) DEFAULT NULL,
KEY `CASE` (`drop`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
dedent(
"""\
CREATE TABLE `table` (
`and` int(11) DEFAULT NULL,
`alter` varchar(255) COLLATE utf8_unicode_ci DEFAULT 'BETWEEN',
KEY `use` (`alter`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci"""),
table_name='table')
self.query(
"SELECT `alter`, `and` FROM `table` USE INDEX (`use`)")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment