summaryrefslogtreecommitdiff
path: root/src/backend/parser/parse_agg.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/backend/parser/parse_agg.c')
-rw-r--r--src/backend/parser/parse_agg.c723
1 files changed, 680 insertions, 43 deletions
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 7b0e66807d4..1e3f2e0ffa2 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -42,7 +42,9 @@ typedef struct
{
ParseState *pstate;
Query *qry;
+ PlannerInfo *root;
List *groupClauses;
+ List *groupClauseCommonVars;
bool have_non_var_grouping;
List **func_grouped_rels;
int sublevels_up;
@@ -56,11 +58,18 @@ static int check_agg_arguments(ParseState *pstate,
static bool check_agg_arguments_walker(Node *node,
check_agg_arguments_context *context);
static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
- List *groupClauses, bool have_non_var_grouping,
+ List *groupClauses, List *groupClauseVars,
+ bool have_non_var_grouping,
List **func_grouped_rels);
static bool check_ungrouped_columns_walker(Node *node,
check_ungrouped_columns_context *context);
-
+static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
+ List *groupClauses, PlannerInfo *root,
+ bool have_non_var_grouping);
+static bool finalize_grouping_exprs_walker(Node *node,
+ check_ungrouped_columns_context *context);
+static void check_agglevels_and_constraints(ParseState *pstate,Node *expr);
+static List *expand_groupingset_node(GroupingSet *gs);
/*
* transformAggregateCall -
@@ -96,10 +105,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
List *tdistinct = NIL;
AttrNumber attno = 1;
int save_next_resno;
- int min_varlevel;
ListCell *lc;
- const char *err;
- bool errkind;
if (AGGKIND_IS_ORDERED_SET(agg->aggkind))
{
@@ -214,15 +220,97 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
agg->aggorder = torder;
agg->aggdistinct = tdistinct;
+ check_agglevels_and_constraints(pstate, (Node *) agg);
+}
+
+/*
+ * transformGroupingFunc
+ * Transform a GROUPING expression
+ *
+ * GROUPING() behaves very like an aggregate. Processing of levels and nesting
+ * is done as for aggregates. We set p_hasAggs for these expressions too.
+ */
+Node *
+transformGroupingFunc(ParseState *pstate, GroupingFunc *p)
+{
+ ListCell *lc;
+ List *args = p->args;
+ List *result_list = NIL;
+ GroupingFunc *result = makeNode(GroupingFunc);
+
+ if (list_length(args) > 31)
+ ereport(ERROR,
+ (errcode(ERRCODE_TOO_MANY_ARGUMENTS),
+ errmsg("GROUPING must have fewer than 32 arguments"),
+ parser_errposition(pstate, p->location)));
+
+ foreach(lc, args)
+ {
+ Node *current_result;
+
+ current_result = transformExpr(pstate, (Node*) lfirst(lc), pstate->p_expr_kind);
+
+ /* acceptability of expressions is checked later */
+
+ result_list = lappend(result_list, current_result);
+ }
+
+ result->args = result_list;
+ result->location = p->location;
+
+ check_agglevels_and_constraints(pstate, (Node *) result);
+
+ return (Node *) result;
+}
+
+/*
+ * Aggregate functions and grouping operations (which are combined in the spec
+ * as <set function specification>) are very similar with regard to level and
+ * nesting restrictions (though we allow a lot more things than the spec does).
+ * Centralise those restrictions here.
+ */
+static void
+check_agglevels_and_constraints(ParseState *pstate, Node *expr)
+{
+ List *directargs = NIL;
+ List *args = NIL;
+ Expr *filter = NULL;
+ int min_varlevel;
+ int location = -1;
+ Index *p_levelsup;
+ const char *err;
+ bool errkind;
+ bool isAgg = IsA(expr, Aggref);
+
+ if (isAgg)
+ {
+ Aggref *agg = (Aggref *) expr;
+
+ directargs = agg->aggdirectargs;
+ args = agg->args;
+ filter = agg->aggfilter;
+ location = agg->location;
+ p_levelsup = &agg->agglevelsup;
+ }
+ else
+ {
+ GroupingFunc *grp = (GroupingFunc *) expr;
+
+ args = grp->args;
+ location = grp->location;
+ p_levelsup = &grp->agglevelsup;
+ }
+
/*
* Check the arguments to compute the aggregate's level and detect
* improper nesting.
*/
min_varlevel = check_agg_arguments(pstate,
- agg->aggdirectargs,
- agg->args,
- agg->aggfilter);
- agg->agglevelsup = min_varlevel;
+ directargs,
+ args,
+ filter);
+
+ *p_levelsup = min_varlevel;
/* Mark the correct pstate level as having aggregates */
while (min_varlevel-- > 0)
@@ -247,20 +335,32 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
Assert(false); /* can't happen */
break;
case EXPR_KIND_OTHER:
- /* Accept aggregate here; caller must throw error if wanted */
+ /* Accept aggregate/grouping here; caller must throw error if wanted */
break;
case EXPR_KIND_JOIN_ON:
case EXPR_KIND_JOIN_USING:
- err = _("aggregate functions are not allowed in JOIN conditions");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in JOIN conditions");
+ else
+ err = _("grouping operations are not allowed in JOIN conditions");
+
break;
case EXPR_KIND_FROM_SUBSELECT:
/* Should only be possible in a LATERAL subquery */
Assert(pstate->p_lateral_active);
- /* Aggregate scope rules make it worth being explicit here */
- err = _("aggregate functions are not allowed in FROM clause of their own query level");
+ /* Aggregate/grouping scope rules make it worth being explicit here */
+ if (isAgg)
+ err = _("aggregate functions are not allowed in FROM clause of their own query level");
+ else
+ err = _("grouping operations are not allowed in FROM clause of their own query level");
+
break;
case EXPR_KIND_FROM_FUNCTION:
- err = _("aggregate functions are not allowed in functions in FROM");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in functions in FROM");
+ else
+ err = _("grouping operations are not allowed in functions in FROM");
+
break;
case EXPR_KIND_WHERE:
errkind = true;
@@ -278,10 +378,18 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
/* okay */
break;
case EXPR_KIND_WINDOW_FRAME_RANGE:
- err = _("aggregate functions are not allowed in window RANGE");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in window RANGE");
+ else
+ err = _("grouping operations are not allowed in window RANGE");
+
break;
case EXPR_KIND_WINDOW_FRAME_ROWS:
- err = _("aggregate functions are not allowed in window ROWS");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in window ROWS");
+ else
+ err = _("grouping operations are not allowed in window ROWS");
+
break;
case EXPR_KIND_SELECT_TARGET:
/* okay */
@@ -312,26 +420,55 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
break;
case EXPR_KIND_CHECK_CONSTRAINT:
case EXPR_KIND_DOMAIN_CHECK:
- err = _("aggregate functions are not allowed in check constraints");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in check constraints");
+ else
+ err = _("grouping operations are not allowed in check constraints");
+
break;
case EXPR_KIND_COLUMN_DEFAULT:
case EXPR_KIND_FUNCTION_DEFAULT:
- err = _("aggregate functions are not allowed in DEFAULT expressions");
+
+ if (isAgg)
+ err = _("aggregate functions are not allowed in DEFAULT expressions");
+ else
+ err = _("grouping operations are not allowed in DEFAULT expressions");
+
break;
case EXPR_KIND_INDEX_EXPRESSION:
- err = _("aggregate functions are not allowed in index expressions");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in index expressions");
+ else
+ err = _("grouping operations are not allowed in index expressions");
+
break;
case EXPR_KIND_INDEX_PREDICATE:
- err = _("aggregate functions are not allowed in index predicates");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in index predicates");
+ else
+ err = _("grouping operations are not allowed in index predicates");
+
break;
case EXPR_KIND_ALTER_COL_TRANSFORM:
- err = _("aggregate functions are not allowed in transform expressions");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in transform expressions");
+ else
+ err = _("grouping operations are not allowed in transform expressions");
+
break;
case EXPR_KIND_EXECUTE_PARAMETER:
- err = _("aggregate functions are not allowed in EXECUTE parameters");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in EXECUTE parameters");
+ else
+ err = _("grouping operations are not allowed in EXECUTE parameters");
+
break;
case EXPR_KIND_TRIGGER_WHEN:
- err = _("aggregate functions are not allowed in trigger WHEN conditions");
+ if (isAgg)
+ err = _("aggregate functions are not allowed in trigger WHEN conditions");
+ else
+ err = _("grouping operations are not allowed in trigger WHEN conditions");
+
break;
/*
@@ -342,18 +479,28 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
* which is sane anyway.
*/
}
+
if (err)
ereport(ERROR,
(errcode(ERRCODE_GROUPING_ERROR),
errmsg_internal("%s", err),
- parser_errposition(pstate, agg->location)));
+ parser_errposition(pstate, location)));
+
if (errkind)
+ {
+ if (isAgg)
+ /* translator: %s is name of a SQL construct, eg GROUP BY */
+ err = _("aggregate functions are not allowed in %s");
+ else
+ /* translator: %s is name of a SQL construct, eg GROUP BY */
+ err = _("grouping operations are not allowed in %s");
+
ereport(ERROR,
(errcode(ERRCODE_GROUPING_ERROR),
- /* translator: %s is name of a SQL construct, eg GROUP BY */
- errmsg("aggregate functions are not allowed in %s",
- ParseExprKindName(pstate->p_expr_kind)),
- parser_errposition(pstate, agg->location)));
+ errmsg_internal(err,
+ ParseExprKindName(pstate->p_expr_kind)),
+ parser_errposition(pstate, location)));
+ }
}
/*
@@ -466,7 +613,6 @@ check_agg_arguments(ParseState *pstate,
locate_agg_of_level((Node *) directargs,
context.min_agglevel))));
}
-
return agglevel;
}
@@ -507,6 +653,21 @@ check_agg_arguments_walker(Node *node,
/* no need to examine args of the inner aggregate */
return false;
}
+ if (IsA(node, GroupingFunc))
+ {
+ int agglevelsup = ((GroupingFunc *) node)->agglevelsup;
+
+ /* convert levelsup to frame of reference of original query */
+ agglevelsup -= context->sublevels_up;
+ /* ignore local aggs of subqueries */
+ if (agglevelsup >= 0)
+ {
+ if (context->min_agglevel < 0 ||
+ context->min_agglevel > agglevelsup)
+ context->min_agglevel = agglevelsup;
+ }
+ /* Continue and descend into subtree */
+ }
/* We can throw error on sight for a window function */
if (IsA(node, WindowFunc))
ereport(ERROR,
@@ -527,6 +688,7 @@ check_agg_arguments_walker(Node *node,
context->sublevels_up--;
return result;
}
+
return expression_tree_walker(node,
check_agg_arguments_walker,
(void *) context);
@@ -770,17 +932,66 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
void
parseCheckAggregates(ParseState *pstate, Query *qry)
{
+ List *gset_common = NIL;
List *groupClauses = NIL;
+ List *groupClauseCommonVars = NIL;
bool have_non_var_grouping;
List *func_grouped_rels = NIL;
ListCell *l;
bool hasJoinRTEs;
bool hasSelfRefRTEs;
- PlannerInfo *root;
+ PlannerInfo *root = NULL;
Node *clause;
/* This should only be called if we found aggregates or grouping */
- Assert(pstate->p_hasAggs || qry->groupClause || qry->havingQual);
+ Assert(pstate->p_hasAggs || qry->groupClause || qry->havingQual || qry->groupingSets);
+
+ /*
+ * If we have grouping sets, expand them and find the intersection of all
+ * sets.
+ */
+ if (qry->groupingSets)
+ {
+ /*
+ * The limit of 4096 is arbitrary and exists simply to avoid resource
+ * issues from pathological constructs.
+ */
+ List *gsets = expand_grouping_sets(qry->groupingSets, 4096);
+
+ if (!gsets)
+ ereport(ERROR,
+ (errcode(ERRCODE_STATEMENT_TOO_COMPLEX),
+ errmsg("too many grouping sets present (max 4096)"),
+ parser_errposition(pstate,
+ qry->groupClause
+ ? exprLocation((Node *) qry->groupClause)
+ : exprLocation((Node *) qry->groupingSets))));
+
+ /*
+ * The intersection will often be empty, so help things along by
+ * seeding the intersect with the smallest set.
+ */
+ gset_common = linitial(gsets);
+
+ if (gset_common)
+ {
+ for_each_cell(l, lnext(list_head(gsets)))
+ {
+ gset_common = list_intersection_int(gset_common, lfirst(l));
+ if (!gset_common)
+ break;
+ }
+ }
+
+ /*
+ * If there was only one grouping set in the expansion, AND if the
+ * groupClause is non-empty (meaning that the grouping set is not empty
+ * either), then we can ditch the grouping set and pretend we just had
+ * a normal GROUP BY.
+ */
+ if (list_length(gsets) == 1 && qry->groupClause)
+ qry->groupingSets = NIL;
+ }
/*
* Scan the range table to see if there are JOIN or self-reference CTE
@@ -800,15 +1011,19 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
/*
* Build a list of the acceptable GROUP BY expressions for use by
* check_ungrouped_columns().
+ *
+ * We get the TLE, not just the expr, because GROUPING wants to know
+ * the sortgroupref.
*/
foreach(l, qry->groupClause)
{
SortGroupClause *grpcl = (SortGroupClause *) lfirst(l);
- Node *expr;
+ TargetEntry *expr;
- expr = get_sortgroupclause_expr(grpcl, qry->targetList);
+ expr = get_sortgroupclause_tle(grpcl, qry->targetList);
if (expr == NULL)
continue; /* probably cannot happen */
+
groupClauses = lcons(expr, groupClauses);
}
@@ -830,21 +1045,28 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
groupClauses = (List *) flatten_join_alias_vars(root,
(Node *) groupClauses);
}
- else
- root = NULL; /* keep compiler quiet */
/*
* Detect whether any of the grouping expressions aren't simple Vars; if
* they're all Vars then we don't have to work so hard in the recursive
* scans. (Note we have to flatten aliases before this.)
+ *
+ * Track Vars that are included in all grouping sets separately in
+ * groupClauseCommonVars, since these are the only ones we can use to check
+ * for functional dependencies.
*/
have_non_var_grouping = false;
foreach(l, groupClauses)
{
- if (!IsA((Node *) lfirst(l), Var))
+ TargetEntry *tle = lfirst(l);
+ if (!IsA(tle->expr, Var))
{
have_non_var_grouping = true;
- break;
+ }
+ else if (!qry->groupingSets ||
+ list_member_int(gset_common, tle->ressortgroupref))
+ {
+ groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
}
}
@@ -855,19 +1077,30 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
* this will also find ungrouped variables that came from ORDER BY and
* WINDOW clauses. For that matter, it's also going to examine the
* grouping expressions themselves --- but they'll all pass the test ...
+ *
+ * We also finalize GROUPING expressions, but for that we need to traverse
+ * the original (unflattened) clause in order to modify nodes.
*/
clause = (Node *) qry->targetList;
+ finalize_grouping_exprs(clause, pstate, qry,
+ groupClauses, root,
+ have_non_var_grouping);
if (hasJoinRTEs)
clause = flatten_join_alias_vars(root, clause);
check_ungrouped_columns(clause, pstate, qry,
- groupClauses, have_non_var_grouping,
+ groupClauses, groupClauseCommonVars,
+ have_non_var_grouping,
&func_grouped_rels);
clause = (Node *) qry->havingQual;
+ finalize_grouping_exprs(clause, pstate, qry,
+ groupClauses, root,
+ have_non_var_grouping);
if (hasJoinRTEs)
clause = flatten_join_alias_vars(root, clause);
check_ungrouped_columns(clause, pstate, qry,
- groupClauses, have_non_var_grouping,
+ groupClauses, groupClauseCommonVars,
+ have_non_var_grouping,
&func_grouped_rels);
/*
@@ -904,14 +1137,17 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
*/
static void
check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
- List *groupClauses, bool have_non_var_grouping,
+ List *groupClauses, List *groupClauseCommonVars,
+ bool have_non_var_grouping,
List **func_grouped_rels)
{
check_ungrouped_columns_context context;
context.pstate = pstate;
context.qry = qry;
+ context.root = NULL;
context.groupClauses = groupClauses;
+ context.groupClauseCommonVars = groupClauseCommonVars;
context.have_non_var_grouping = have_non_var_grouping;
context.func_grouped_rels = func_grouped_rels;
context.sublevels_up = 0;
@@ -965,6 +1201,16 @@ check_ungrouped_columns_walker(Node *node,
return false;
}
+ if (IsA(node, GroupingFunc))
+ {
+ GroupingFunc *grp = (GroupingFunc *) node;
+
+ /* handled GroupingFunc separately, no need to recheck at this level */
+
+ if ((int) grp->agglevelsup >= context->sublevels_up)
+ return false;
+ }
+
/*
* If we have any GROUP BY items that are not simple Vars, check to see if
* subexpression as a whole matches any GROUP BY item. We need to do this
@@ -976,7 +1222,9 @@ check_ungrouped_columns_walker(Node *node,
{
foreach(gl, context->groupClauses)
{
- if (equal(node, lfirst(gl)))
+ TargetEntry *tle = lfirst(gl);
+
+ if (equal(node, tle->expr))
return false; /* acceptable, do not descend more */
}
}
@@ -1003,7 +1251,7 @@ check_ungrouped_columns_walker(Node *node,
{
foreach(gl, context->groupClauses)
{
- Var *gvar = (Var *) lfirst(gl);
+ Var *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
if (IsA(gvar, Var) &&
gvar->varno == var->varno &&
@@ -1040,7 +1288,7 @@ check_ungrouped_columns_walker(Node *node,
if (check_functional_grouping(rte->relid,
var->varno,
0,
- context->groupClauses,
+ context->groupClauseCommonVars,
&context->qry->constraintDeps))
{
*context->func_grouped_rels =
@@ -1085,6 +1333,395 @@ check_ungrouped_columns_walker(Node *node,
}
/*
+ * finalize_grouping_exprs -
+ * Scan the given expression tree for GROUPING() and related calls,
+ * and validate and process their arguments.
+ *
+ * This is split out from check_ungrouped_columns above because it needs
+ * to modify the nodes (which it does in-place, not via a mutator) while
+ * check_ungrouped_columns may see only a copy of the original thanks to
+ * flattening of join alias vars. So here, we flatten each individual
+ * GROUPING argument as we see it before comparing it.
+ */
+static void
+finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
+ List *groupClauses, PlannerInfo *root,
+ bool have_non_var_grouping)
+{
+ check_ungrouped_columns_context context;
+
+ context.pstate = pstate;
+ context.qry = qry;
+ context.root = root;
+ context.groupClauses = groupClauses;
+ context.groupClauseCommonVars = NIL;
+ context.have_non_var_grouping = have_non_var_grouping;
+ context.func_grouped_rels = NULL;
+ context.sublevels_up = 0;
+ context.in_agg_direct_args = false;
+ finalize_grouping_exprs_walker(node, &context);
+}
+
+static bool
+finalize_grouping_exprs_walker(Node *node,
+ check_ungrouped_columns_context *context)
+{
+ ListCell *gl;
+
+ if (node == NULL)
+ return false;
+ if (IsA(node, Const) ||
+ IsA(node, Param))
+ return false; /* constants are always acceptable */
+
+ if (IsA(node, Aggref))
+ {
+ Aggref *agg = (Aggref *) node;
+
+ if ((int) agg->agglevelsup == context->sublevels_up)
+ {
+ /*
+ * If we find an aggregate call of the original level, do not
+ * recurse into its normal arguments, ORDER BY arguments, or
+ * filter; GROUPING exprs of this level are not allowed there. But
+ * check direct arguments as though they weren't in an aggregate.
+ */
+ bool result;
+
+ Assert(!context->in_agg_direct_args);
+ context->in_agg_direct_args = true;
+ result = finalize_grouping_exprs_walker((Node *) agg->aggdirectargs,
+ context);
+ context->in_agg_direct_args = false;
+ return result;
+ }
+
+ /*
+ * We can skip recursing into aggregates of higher levels altogether,
+ * since they could not possibly contain exprs of concern to us (see
+ * transformAggregateCall). We do need to look at aggregates of lower
+ * levels, however.
+ */
+ if ((int) agg->agglevelsup > context->sublevels_up)
+ return false;
+ }
+
+ if (IsA(node, GroupingFunc))
+ {
+ GroupingFunc *grp = (GroupingFunc *) node;
+
+ /*
+ * We only need to check GroupingFunc nodes at the exact level to which
+ * they belong, since they cannot mix levels in arguments.
+ */
+
+ if ((int) grp->agglevelsup == context->sublevels_up)
+ {
+ ListCell *lc;
+ List *ref_list = NIL;
+
+ foreach(lc, grp->args)
+ {
+ Node *expr = lfirst(lc);
+ Index ref = 0;
+
+ if (context->root)
+ expr = flatten_join_alias_vars(context->root, expr);
+
+ /*
+ * Each expression must match a grouping entry at the current
+ * query level. Unlike the general expression case, we don't
+ * allow functional dependencies or outer references.
+ */
+
+ if (IsA(expr, Var))
+ {
+ Var *var = (Var *) expr;
+
+ if (var->varlevelsup == context->sublevels_up)
+ {
+ foreach(gl, context->groupClauses)
+ {
+ TargetEntry *tle = lfirst(gl);
+ Var *gvar = (Var *) tle->expr;
+
+ if (IsA(gvar, Var) &&
+ gvar->varno == var->varno &&
+ gvar->varattno == var->varattno &&
+ gvar->varlevelsup == 0)
+ {
+ ref = tle->ressortgroupref;
+ break;
+ }
+ }
+ }
+ }
+ else if (context->have_non_var_grouping &&
+ context->sublevels_up == 0)
+ {
+ foreach(gl, context->groupClauses)
+ {
+ TargetEntry *tle = lfirst(gl);
+
+ if (equal(expr, tle->expr))
+ {
+ ref = tle->ressortgroupref;
+ break;
+ }
+ }
+ }
+
+ if (ref == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_GROUPING_ERROR),
+ errmsg("arguments to GROUPING must be grouping expressions of the associated query level"),
+ parser_errposition(context->pstate,
+ exprLocation(expr))));
+
+ ref_list = lappend_int(ref_list, ref);
+ }
+
+ grp->refs = ref_list;
+ }
+
+ if ((int) grp->agglevelsup > context->sublevels_up)
+ return false;
+ }
+
+ if (IsA(node, Query))
+ {
+ /* Recurse into subselects */
+ bool result;
+
+ context->sublevels_up++;
+ result = query_tree_walker((Query *) node,
+ finalize_grouping_exprs_walker,
+ (void *) context,
+ 0);
+ context->sublevels_up--;
+ return result;
+ }
+ return expression_tree_walker(node, finalize_grouping_exprs_walker,
+ (void *) context);
+}
+
+
+/*
+ * Given a GroupingSet node, expand it and return a list of lists.
+ *
+ * For EMPTY nodes, return a list of one empty list.
+ *
+ * For SIMPLE nodes, return a list of one list, which is the node content.
+ *
+ * For CUBE and ROLLUP nodes, return a list of the expansions.
+ *
+ * For SET nodes, recursively expand contained CUBE and ROLLUP.
+ */
+static List*
+expand_groupingset_node(GroupingSet *gs)
+{
+ List * result = NIL;
+
+ switch (gs->kind)
+ {
+ case GROUPING_SET_EMPTY:
+ result = list_make1(NIL);
+ break;
+
+ case GROUPING_SET_SIMPLE:
+ result = list_make1(gs->content);
+ break;
+
+ case GROUPING_SET_ROLLUP:
+ {
+ List *rollup_val = gs->content;
+ ListCell *lc;
+ int curgroup_size = list_length(gs->content);
+
+ while (curgroup_size > 0)
+ {
+ List *current_result = NIL;
+ int i = curgroup_size;
+
+ foreach(lc, rollup_val)
+ {
+ GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
+
+ Assert(gs_current->kind == GROUPING_SET_SIMPLE);
+
+ current_result
+ = list_concat(current_result,
+ list_copy(gs_current->content));
+
+ /* If we are done with making the current group, break */
+ if (--i == 0)
+ break;
+ }
+
+ result = lappend(result, current_result);
+ --curgroup_size;
+ }
+
+ result = lappend(result, NIL);
+ }
+ break;
+
+ case GROUPING_SET_CUBE:
+ {
+ List *cube_list = gs->content;
+ int number_bits = list_length(cube_list);
+ uint32 num_sets;
+ uint32 i;
+
+ /* parser should cap this much lower */
+ Assert(number_bits < 31);
+
+ num_sets = (1U << number_bits);
+
+ for (i = 0; i < num_sets; i++)
+ {
+ List *current_result = NIL;
+ ListCell *lc;
+ uint32 mask = 1U;
+
+ foreach(lc, cube_list)
+ {
+ GroupingSet *gs_current = (GroupingSet *) lfirst(lc);
+
+ Assert(gs_current->kind == GROUPING_SET_SIMPLE);
+
+ if (mask & i)
+ {
+ current_result
+ = list_concat(current_result,
+ list_copy(gs_current->content));
+ }
+
+ mask <<= 1;
+ }
+
+ result = lappend(result, current_result);
+ }
+ }
+ break;
+
+ case GROUPING_SET_SETS:
+ {
+ ListCell *lc;
+
+ foreach(lc, gs->content)
+ {
+ List *current_result = expand_groupingset_node(lfirst(lc));
+
+ result = list_concat(result, current_result);
+ }
+ }
+ break;
+ }
+
+ return result;
+}
+
+static int
+cmp_list_len_asc(const void *a, const void *b)
+{
+ int la = list_length(*(List*const*)a);
+ int lb = list_length(*(List*const*)b);
+ return (la > lb) ? 1 : (la == lb) ? 0 : -1;
+}
+
+/*
+ * Expand a groupingSets clause to a flat list of grouping sets.
+ * The returned list is sorted by length, shortest sets first.
+ *
+ * This is mainly for the planner, but we use it here too to do
+ * some consistency checks.
+ */
+List *
+expand_grouping_sets(List *groupingSets, int limit)
+{
+ List *expanded_groups = NIL;
+ List *result = NIL;
+ double numsets = 1;
+ ListCell *lc;
+
+ if (groupingSets == NIL)
+ return NIL;
+
+ foreach(lc, groupingSets)
+ {
+ List *current_result = NIL;
+ GroupingSet *gs = lfirst(lc);
+
+ current_result = expand_groupingset_node(gs);
+
+ Assert(current_result != NIL);
+
+ numsets *= list_length(current_result);
+
+ if (limit >= 0 && numsets > limit)
+ return NIL;
+
+ expanded_groups = lappend(expanded_groups, current_result);
+ }
+
+ /*
+ * Do cartesian product between sublists of expanded_groups.
+ * While at it, remove any duplicate elements from individual
+ * grouping sets (we must NOT change the number of sets though)
+ */
+
+ foreach(lc, (List *) linitial(expanded_groups))
+ {
+ result = lappend(result, list_union_int(NIL, (List *) lfirst(lc)));
+ }
+
+ for_each_cell(lc, lnext(list_head(expanded_groups)))
+ {
+ List *p = lfirst(lc);
+ List *new_result = NIL;
+ ListCell *lc2;
+
+ foreach(lc2, result)
+ {
+ List *q = lfirst(lc2);
+ ListCell *lc3;
+
+ foreach(lc3, p)
+ {
+ new_result = lappend(new_result,
+ list_union_int(q, (List *) lfirst(lc3)));
+ }
+ }
+ result = new_result;
+ }
+
+ if (list_length(result) > 1)
+ {
+ int result_len = list_length(result);
+ List **buf = palloc(sizeof(List*) * result_len);
+ List **ptr = buf;
+
+ foreach(lc, result)
+ {
+ *ptr++ = lfirst(lc);
+ }
+
+ qsort(buf, result_len, sizeof(List*), cmp_list_len_asc);
+
+ result = NIL;
+ ptr = buf;
+
+ while (result_len-- > 0)
+ result = lappend(result, *ptr++);
+
+ pfree(buf);
+ }
+
+ return result;
+}
+
+/*
* get_aggregate_argtypes
* Identify the specific datatypes passed to an aggregate call.
*