From bdcdaca803f59341b5c8e8a0d9e8e02772af40c3 Mon Sep 17 00:00:00 2001
From: Kazuhiko Shiozaki <kazuhiko@nexedi.com>
Date: Thu, 1 Oct 2015 12:35:04 +0200
Subject: [PATCH] Add auto_extend_select_list argument in buildSQLQuery() and
 use alias in group_by_expression and order_by_expression.

If True, select_list is automatically extended to have columns used in
group_by_list and order_by_list. It is useful when use
select_expression in inner query and use group_by_expression or
order_by_expression in outer query.
---
 product/ERP5/tests/testI18NSearch.py          | 17 ++++---
 product/ZSQLCatalog/ColumnMap.py              | 10 ++++-
 .../Operator/ComparisonOperator.py            |  7 ++-
 product/ZSQLCatalog/Query/EntireQuery.py      |  3 ++
 product/ZSQLCatalog/SQLCatalog.py             |  8 +++-
 product/ZSQLCatalog/SQLExpression.py          | 20 +++++++--
 .../ZSQLCatalog/interfaces/entire_query.py    |  6 +++
 .../ZSQLCatalog/interfaces/query_catalog.py   | 11 ++++-
 .../ZSQLCatalog/interfaces/sql_expression.py  |  6 +++
 product/ZSQLCatalog/tests/testSQLCatalog.py   | 44 +++++++++++++++++--
 10 files changed, 108 insertions(+), 24 deletions(-)

diff --git a/product/ERP5/tests/testI18NSearch.py b/product/ERP5/tests/testI18NSearch.py
index eed405f588..376a50033c 100644
--- a/product/ERP5/tests/testI18NSearch.py
+++ b/product/ERP5/tests/testI18NSearch.py
@@ -80,10 +80,10 @@ class TestI18NSearch(ERP5TypeTestCase):
     self.assertEqual(result[0].getPath(), self.person1.getPath())
 
     # check sort on fulltext column
-    self.assertFalse('ORDER BY\n  MATCH' in self.portal.portal_catalog(SearchableText='Faure', sort_on=(('SearchableText', 'ascending'),), src__=1))
+    self.assertTrue('ORDER BY\n  `full_text`.`SearchableText` ASC' in self.portal.portal_catalog(SearchableText='Faure', sort_on=(('SearchableText', 'ascending'),), src__=1))
 
     # check sort on fulltext search score
-    self.assertTrue('ORDER BY\n  MATCH' in self.portal.portal_catalog(SearchableText='Faure', sort_on=(('SearchableText__score__', 'ascending'),), src__=1))
+    self.assertTrue('ORDER BY\n  full_text_SearchableText__score__ ASC' in self.portal.portal_catalog(SearchableText='Faure', sort_on=(('SearchableText__score__', 'ascending'),), src__=1))
 
   def test_catalog_full_text_title(self):
     # check if 'é' == 'e' collation works
@@ -108,18 +108,23 @@ class TestI18NSearch(ERP5TypeTestCase):
       self.assertEqual(result[0].getPath(), self.person3.getPath())
 
     # check sort on fulltext column
-    self.assertFalse('ORDER BY\n  MATCH' in self.portal.portal_catalog(**{
+    self.assertFalse('ORDER BY\n  catalog_full_text_title__score__ ASC' in self.portal.portal_catalog(**{
       'catalog_full_text.title':'Faure',
       'sort_on':(('catalog_full_text.title', 'ascending'),),
       'src__':1
       }))
 
     # check sort on fulltext search score
-    self.assertFalse('ORDER BY\n  MATCH' in self.portal.portal_catalog(**{
+    self.assertTrue('ORDER BY\n  catalog_full_text_title__score__' in self.portal.portal_catalog(**{
       'catalog_full_text.title':'Faure',
       'sort_on':(('catalog_full_text.title__score__', 'ascending'),),
       'src__':1
       }))
+    self.assertTrue('ORDER BY\n  catalog_full_text_title__score__' in self.portal.portal_catalog(**{
+      'catalog_full_text.title':'Faure',
+      'sort_on':(('title__score__', 'ascending'),),
+      'src__':1
+      }))
 
   @expectedFailure
   def test_full_text_title(self):
@@ -148,10 +153,10 @@ class TestI18NSearch(ERP5TypeTestCase):
     self.assertTrue('MATCH' in self.portal.portal_catalog(destination_title='Faure', src__=1))
 
     # check sort on fulltext column
-    self.assertFalse('ORDER BY\n  MATCH' in self.portal.portal_catalog(title='Faure', sort_on=(('title', 'ascending'),), src__=1))
+    self.assertTrue('ORDER BY\n  `catalog`.`title` ASC' in self.portal.portal_catalog(title='Faure', sort_on=(('title', 'ascending'),), src__=1))
 
     # check sort on fulltext search score
-    self.assertTrue('ORDER BY\n  MATCH' in self.portal.portal_catalog(title='Faure', sort_on=(('title__score__', 'ascending'),), src__=1))
+    self.assertTrue('ORDER BY\n  catalog_full_text_title__score__' in self.portal.portal_catalog(title='Faure', sort_on=(('title__score__', 'ascending'),), src__=1))
 
 def test_suite():
   suite = unittest.TestSuite()
diff --git a/product/ZSQLCatalog/ColumnMap.py b/product/ZSQLCatalog/ColumnMap.py
index e8a2e49252..bc789406ad 100644
--- a/product/ZSQLCatalog/ColumnMap.py
+++ b/product/ZSQLCatalog/ColumnMap.py
@@ -453,7 +453,10 @@ class ColumnMap(object):
   def asSQLColumn(self, raw_column, group=DEFAULT_GROUP_ID):
     if self.catalog_table_name is None or raw_column in self.column_ignore_set or \
        '.' in raw_column or '*' in raw_column:
-      result = raw_column
+      if raw_column.endswith('__score__'):
+        result = raw_column.replace('.', '_')
+      else:
+        result = raw_column
     else:
       if raw_column.endswith('__score__'):
         raw_column = raw_column[:-9]
@@ -464,7 +467,10 @@ class ColumnMap(object):
       if group is DEFAULT_GROUP_ID:
         group, column = self.related_key_dict.get(column, (group, raw_column))
       alias = self.table_alias_dict[(group, self.column_map[(group, column)])]
-      result = '`%s`.`%s%s`' % (alias, column, column_suffix)
+      if column_suffix:
+        result = '%s_%s%s' % (alias, column, column_suffix)
+      else:
+        result = '`%s`.`%s`' % (alias, column)
       if function is not None:
         result = '%s(%s)' % (function, result)
     return result
diff --git a/product/ZSQLCatalog/Operator/ComparisonOperator.py b/product/ZSQLCatalog/Operator/ComparisonOperator.py
index 49e29cca3c..bdaa6bb1ea 100644
--- a/product/ZSQLCatalog/Operator/ComparisonOperator.py
+++ b/product/ZSQLCatalog/Operator/ComparisonOperator.py
@@ -116,16 +116,15 @@ class MatchComparisonOperator(MonovaluedComparisonOperator):
     }
     select_dict = {}
     if not only_group_columns:
-      select_dict['%s__score__' % column.replace('`', '').rsplit('.', 1)[-1]] = match_string
-    # Support sort on the relevance by using (column)__score__ key.
+      select_dict['%s__score__' % column.replace('`', '').replace('.', '_')] = match_string
     order_by_dict = {
-      '`%s__score__`' % '`.`'.join([x.strip('`') for x in column.split('.')]): match_string,
+      '%s__score__' % column.replace('`', '').replace('.', '_'): match_string
     }
     return SQLExpression(
       self,
       select_dict=select_dict,
-      where_expression=match_string,
       order_by_dict=order_by_dict,
+      where_expression=match_string,
       can_merge_select_dict=True,
     )
 
diff --git a/product/ZSQLCatalog/Query/EntireQuery.py b/product/ZSQLCatalog/Query/EntireQuery.py
index 4baf088f8a..474d55c718 100644
--- a/product/ZSQLCatalog/Query/EntireQuery.py
+++ b/product/ZSQLCatalog/Query/EntireQuery.py
@@ -61,6 +61,7 @@ class EntireQuery(object):
                left_join_list=(),
                limit=None,
                catalog_table_name=None,
+               auto_extend_select_list=False,
                extra_column_list=(),
                from_expression=None,
                order_by_override_list=None,
@@ -76,6 +77,7 @@ class EntireQuery(object):
     self.extra_column_list = list(extra_column_list)
     self.from_expression = from_expression
     self.implicit_join = implicit_join
+    self.auto_extend_select_list = auto_extend_select_list
 
   def asSearchTextExpression(self, sql_catalog):
     return self.query.asSearchTextExpression(sql_catalog)
@@ -211,6 +213,7 @@ class EntireQuery(object):
       order_by_list=self.order_by_list,
       group_by_list=self.group_by_list,
       select_dict=self.final_select_dict,
+      auto_extend_select_list=self.auto_extend_select_list,
       limit=self.limit,
       where_expression_operator='and',
       sql_expression_list=self.sql_expression_list)
diff --git a/product/ZSQLCatalog/SQLCatalog.py b/product/ZSQLCatalog/SQLCatalog.py
index aabc00daf3..516e7d79b9 100644
--- a/product/ZSQLCatalog/SQLCatalog.py
+++ b/product/ZSQLCatalog/SQLCatalog.py
@@ -2350,7 +2350,8 @@ class Catalog(Folder,
     return order_by_list
 
   def buildEntireQuery(self, kw, query_table='catalog', ignore_empty_string=1,
-                       limit=None, extra_column_list=()):
+                       limit=None, auto_extend_select_list=False,
+                       extra_column_list=()):
     group_by_list = kw.pop('group_by_list', kw.pop('group_by', kw.pop('group_by_expression', ())))
     if isinstance(group_by_list, basestring):
       group_by_list = [x.strip() for x in group_by_list.split(',')]
@@ -2417,18 +2418,21 @@ class Catalog(Folder,
       implicit_join=implicit_join,
       limit=limit,
       catalog_table_name=query_table,
+      auto_extend_select_list=auto_extend_select_list,
       extra_column_list=extra_column_list,
       from_expression=from_expression)
 
   def buildSQLQuery(self, query_table='catalog', REQUEST=None,
                           ignore_empty_string=1, only_group_columns=False,
-                          limit=None, extra_column_list=(),
+                          limit=None, auto_extend_select_list=False,
+                          extra_column_list=(),
                           **kw):
     return self.buildEntireQuery(
       kw,
       query_table=query_table,
       ignore_empty_string=ignore_empty_string,
       limit=limit,
+      auto_extend_select_list=auto_extend_select_list,
       extra_column_list=extra_column_list,
     ).asSQLExpression(
       self,
diff --git a/product/ZSQLCatalog/SQLExpression.py b/product/ZSQLCatalog/SQLExpression.py
index fcfff1b141..233b02f12e 100644
--- a/product/ZSQLCatalog/SQLExpression.py
+++ b/product/ZSQLCatalog/SQLExpression.py
@@ -94,6 +94,7 @@ class SQLExpression(object):
                where_expression_operator=None,
                sql_expression_list=(),
                select_dict=None,
+               auto_extend_select_list=False,
                limit=None,
                from_expression=None,
                can_merge_select_dict=False):
@@ -120,6 +121,7 @@ class SQLExpression(object):
       sql_expression_list = [x for x in sql_expression_list if x is not None]
     self.sql_expression_list = list(sql_expression_list)
     self.select_dict = defaultDict(select_dict)
+    self.auto_extend_select_list = auto_extend_select_list
     if limit is None:
       self.limit = ()
     elif isinstance(limit, (list, tuple)):
@@ -133,6 +135,17 @@ class SQLExpression(object):
       warnings.warn("Providing a 'from_expression' is deprecated.",
                     DeprecationWarning)
     self.from_expression = from_expression
+    self._select_dict = self._getSelectDict()[0]
+    if self.auto_extend_select_list:
+      select_column_set = {y for x, y in self._select_dict.iteritems()}
+      extend_column_set = set(self.group_by_list).union(
+        {x[0] for x in self.order_by_list})
+      for i in extend_column_set.difference(select_column_set):
+        # '__score__' suffix alias is already added in select_dict by
+        # MatchComparisonOperator.
+        if '__score__' not in i:
+          self._select_dict['%s__ext__' % i.replace('`', '').replace('.', '_')] = i
+    self._reversed_select_dict = {y: x for x, y in self._select_dict.iteritems()}
 
   def getTableAliasDict(self):
     """
@@ -236,9 +249,8 @@ class SQLExpression(object):
     append = result.append
     order_by_dict = self._getOrderByDict()
     for (column, direction, cast) in self.getOrderByList():
-      if column.endswith('__score__') and column not in order_by_dict:
-        continue
       expression = conflictSafeGet(order_by_dict, column, str(column))
+      expression = self._reversed_select_dict.get(expression, expression)
       if cast not in (None, ''):
         expression = 'CAST(%s AS %s)' % (expression, cast)
       if direction is not None:
@@ -304,7 +316,7 @@ class SQLExpression(object):
       If there are nested SQLExpression, it merges (union of sets) them with
       local value.
     """
-    result = set(self.group_by_list)
+    result = {self._reversed_select_dict.get(x, x) for x in self.group_by_list}
     for sql_expression in self.sql_expression_list:
       result.update(sql_expression.getGroupByset())
     return result
@@ -363,7 +375,7 @@ class SQLExpression(object):
       checks that they don't alias different columns with the same name. If
       they do, it raises a ValueError.
     """
-    return self._getSelectDict()[0]
+    return self._select_dict
 
   def getSelectExpression(self):
     """
diff --git a/product/ZSQLCatalog/interfaces/entire_query.py b/product/ZSQLCatalog/interfaces/entire_query.py
index a413d1e0b2..c6ac087cfb 100644
--- a/product/ZSQLCatalog/interfaces/entire_query.py
+++ b/product/ZSQLCatalog/interfaces/entire_query.py
@@ -45,6 +45,7 @@ class IEntireQuery(Interface):
 
   def __init__(query, order_by_list=None, group_by_list=None,
     select_dict=None, limit=None, catalog_table_name=None,
+    auto_extend_select_list=False,
     extra_column_list=None, from_expression=None,
     order_by_override_list=None):
     """
@@ -67,6 +68,11 @@ class IEntireQuery(Interface):
         See SQLExpression.
       catalog_table_name (string)
         Name of the table to use as a catalog.
+      auto_extend_select_list (boolean)
+        If True, select_list is automatically extended to have columns
+        used in group_by_list and order_by_list. It is useful when use
+        select_expression in inner query and use group_by_expression or
+        order_by_expression in outer query.
 
       Deprecated parameters.
       extra_column_list (list of string)
diff --git a/product/ZSQLCatalog/interfaces/query_catalog.py b/product/ZSQLCatalog/interfaces/query_catalog.py
index eb37194195..2a16420efd 100644
--- a/product/ZSQLCatalog/interfaces/query_catalog.py
+++ b/product/ZSQLCatalog/interfaces/query_catalog.py
@@ -64,7 +64,8 @@ class ISearchKeyCatalog(Interface):
     """
 
   def buildEntireQuery(kw, query_table='catalog', ignore_empty_string=1,
-                       limit=None, extra_column_list=None):
+                       limit=None, auto_extend_select_list=False,
+                       extra_column_list=None):
     """
       Construct and return an instance of EntireQuery class from given
       parameters by calling buildQuery.
@@ -95,6 +96,11 @@ class ISearchKeyCatalog(Interface):
          - type cast (see SQL documentation of 'CAST')
         Sort will happen on given parameter name (its column if it's a column
         name, corresponding virtual column otherwise - as for related keys).
+      auto_extend_select_list (boolean)
+        If True, select_list is automatically extended to have columns
+        used in group_by_list and order_by_list. It is useful when use
+        select_expression in inner query and use group_by_expression or
+        order_by_expression in outer query.
       Extra parameters are passed through to buildQuery.
 
       Backward compatibility parameters:
@@ -140,7 +146,8 @@ class ISearchKeyCatalog(Interface):
 
   def buildSQLQuery(query_table='catalog', REQUEST=None,
                     ignore_empty_string=1, only_group_columns=False,
-                    limit=None, extra_column_list=None,
+                    limit=None, auto_extend_select_list=False,
+                    extra_column_list=(),
                     **kw):
     """
       Return an SQLExpression-generated dictionary (see
diff --git a/product/ZSQLCatalog/interfaces/sql_expression.py b/product/ZSQLCatalog/interfaces/sql_expression.py
index 4099cba2dc..5b74b2af12 100644
--- a/product/ZSQLCatalog/interfaces/sql_expression.py
+++ b/product/ZSQLCatalog/interfaces/sql_expression.py
@@ -61,6 +61,7 @@ class ISQLExpression(Interface):
                where_expression_operator=None,
                sql_expression_list=None,
                select_dict=None,
+               auto_extend_select_list=False,
                limit=None,
                from_expression=None):
     """
@@ -100,6 +101,11 @@ class ISQLExpression(Interface):
         Key is column alias.
         Value is column name, or Null. If it is Null, the alias will also be
         used as column name.
+      auto_extend_select_list (boolean)
+        If True, select_list is automatically extended to have columns
+        used in group_by_list and order_by_list. It is useful when use
+        select_expression in inner query and use group_by_expression or
+        order_by_expression in outer query.
       limit (1-tuple, 2-tuple, other)
         First item is the number of lines expected, second one if given is the
         offset of limited result list within the unlimited result list.
diff --git a/product/ZSQLCatalog/tests/testSQLCatalog.py b/product/ZSQLCatalog/tests/testSQLCatalog.py
index 41a5b4c3d5..38a56449cf 100644
--- a/product/ZSQLCatalog/tests/testSQLCatalog.py
+++ b/product/ZSQLCatalog/tests/testSQLCatalog.py
@@ -728,27 +728,27 @@ class TestSQLCatalog(ERP5TypeTestCase):
     order_by_expression = sql_expression.getOrderByExpression()
     self.assertNotEqual(order_by_expression, '')
     # ... and not sort by relevance
-    self.assertFalse('MATCH' in order_by_expression, order_by_expression)
+    self.assertEqual('`foo`.`fulltext`', order_by_expression)
     # order_by_list on fulltext column + '__score__, resulting "ORDER BY" must be non-empty.
     sql_expression = self.asSQLExpression({'fulltext': 'foo',
       'order_by_list': [('fulltext__score__', ), ]})
     order_by_expression = sql_expression.getOrderByExpression()
     self.assertNotEqual(order_by_expression, '')
     # ... and must sort by relevance
-    self.assertTrue('MATCH' in order_by_expression, order_by_expression)
+    self.assertEqual('foo_fulltext__score__', order_by_expression)
     # ordering on fulltext column with sort order specified must preserve
     # sorting by relevance.
     for direction in ('ASC', 'DESC'):
       sql_expression = self.asSQLExpression({'fulltext': 'foo',
         'order_by_list': [('fulltext__score__', direction), ]})
       order_by_expression = sql_expression.getOrderByExpression()
-      self.assertTrue('MATCH' in order_by_expression, (order_by_expression, direction))
+      self.assertEqual('foo_fulltext__score__ %s' % direction, order_by_expression)
     # Providing a None cast should work too
     for direction in ('ASC', 'DESC'):
       sql_expression = self.asSQLExpression({'fulltext': 'foo',
         'order_by_list': [('fulltext__score__', direction, None), ]})
       order_by_expression = sql_expression.getOrderByExpression()
-      self.assertTrue('MATCH' in order_by_expression, (order_by_expression, direction))
+      self.assertEqual('foo_fulltext__score__ %s' % direction, order_by_expression)
 
   def test_logicalOperators(self):
     self.catalog(ReferenceQuery(ReferenceQuery(operator='=', default='AN ORB'),
@@ -759,6 +759,42 @@ class TestSQLCatalog(ERP5TypeTestCase):
         operator='and'),
       {'default': 'AN OR ORB'})
 
+  def test_auto_extend_select_list(self):
+    # by default select_list is not automatically extended by
+    # order_by_list or group_by_list.
+    sql_expression = self.asSQLExpression({
+      'order_by_list': [('default',),]})
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual({}, select_dict)
+    sql_expression = self.asSQLExpression({
+      'group_by_list': ['default',]})
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual({}, select_dict)
+    # select_list is extended if auto_extend_select_list is enabled.
+    sql_expression = self.asSQLExpression({
+      'order_by_list': [('default',),]},
+      auto_extend_select_list=True)
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual({'foo_default__ext__': '`foo`.`default`'}, select_dict)
+    sql_expression = self.asSQLExpression({
+      'group_by_list': ['default',]},
+      auto_extend_select_list=True)
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual({'foo_default__ext__': '`foo`.`default`'}, select_dict)
+    # fulltext score is automatically added in select_dict even if
+    # auto_extend_select_list is not enabled.
+    sql_expression = self.asSQLExpression({
+      'fulltext': 'foo',
+      'order_by_list': [('fulltext__score__',),]})
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual(['foo_fulltext__score__'], select_dict.keys())
+    sql_expression = self.asSQLExpression({
+      'fulltext': 'foo',
+      'order_by_list': [('fulltext__score__',),]},
+      auto_extend_select_list=True)
+    select_dict = sql_expression.getSelectDict()
+    self.assertEqual(['foo_fulltext__score__'], select_dict.keys())
+
   def _searchTextInDictQuery(self, column):
     self.catalog(ReferenceQuery(ReferenceQuery(
         ReferenceQuery(operator='>=', date=DateTime('2001/08/11')),
-- 
2.30.9