-
Notifications
You must be signed in to change notification settings - Fork 3
Add QUANTILE function to PyDough #378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
afd45d4
6f60811
f6d8ad4
3f66a30
b7fef56
789b7c6
08fe280
757a7bd
454e24c
a8f3f7b
19451f6
6a9cda5
c5b91a4
55dcf45
d8e5641
c169db2
272492c
22c722e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,8 +87,9 @@ def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect): | |
# being derived as the subtree of the last element. | ||
self.stack: list[HybridTree] = [] | ||
# If True, rewrites MEDIAN calls into an average of the 1-2 median rows | ||
# via window functions, otherwise leaves as-is. | ||
self.rewrite_median: bool = dialect not in {DatabaseDialect.ANSI} | ||
# or rewrites QUANTILE calls to select the first qualifying row, | ||
# both derived from window functions, otherwise leaves as-is. | ||
self.rewrite_median_quantile: bool = dialect not in {DatabaseDialect.ANSI} | ||
|
||
@staticmethod | ||
def get_subcollection_join_keys( | ||
|
@@ -264,7 +265,7 @@ def eject_aggregate_inputs(self, hybrid: HybridTree) -> None: | |
rewritten: bool = False | ||
new_args: list[HybridExpr] = [] | ||
for arg in agg_call.args: | ||
if isinstance(arg, HybridRefExpr): | ||
if isinstance(arg, (HybridRefExpr, HybridLiteralExpr)): | ||
new_args.append(arg) | ||
else: | ||
rewritten = True | ||
|
@@ -282,7 +283,9 @@ def eject_aggregate_inputs(self, hybrid: HybridTree) -> None: | |
def run_rewrites(self, hybrid: HybridTree): | ||
""" | ||
Run any rewrite procedures that must occur after de-correlation, such | ||
as converting MEDIAN to an average of the 1-2 median rows. | ||
as converting MEDIAN to an average of the 1-2 median rows. Also converting | ||
the QUANTILE calls to the appropriate window function calls. | ||
|
||
|
||
Args: | ||
`hybrid`: the bottom of the hybrid tree to rewrite. | ||
|
@@ -294,15 +297,20 @@ def run_rewrites(self, hybrid: HybridTree): | |
self.run_rewrites(child.subtree) | ||
|
||
create_new_calc: bool = True | ||
# Rewrite any MEDIAN calls | ||
if self.rewrite_median: | ||
# Rewrite any MEDIAN and QUANTILE calls | ||
if self.rewrite_median_quantile: | ||
for child in hybrid.children: | ||
for agg_name, agg_call in child.aggs.items(): | ||
if agg_call.operator == pydop.MEDIAN: | ||
child.aggs[agg_name] = self.rewrite_median_call( | ||
child, agg_call, create_new_calc | ||
) | ||
create_new_calc = False | ||
if agg_call.operator == pydop.QUANTILE: | ||
child.aggs[agg_name] = self.rewrite_quantile_call( | ||
child, agg_call, create_new_calc | ||
) | ||
create_new_calc = False | ||
|
||
def qdag_expr_contains_window(self, expr: PyDoughExpressionQDAG) -> bool: | ||
""" | ||
|
@@ -760,6 +768,101 @@ def rewrite_median_call( | |
|
||
return avg_call | ||
|
||
def rewrite_quantile_call( | ||
self, | ||
child_connection: HybridConnection, | ||
expr: HybridFunctionExpr, | ||
create_new_calc: bool, | ||
) -> HybridFunctionExpr: | ||
""" | ||
Rewrites a QUANTILE aggregation call into an equivalent expression using | ||
window functions. | ||
This is typically used for dialects that do not natively support the | ||
PERCENTILE_DISC | ||
aggregate function. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The newlines here are a bit off. Also, wrap functions like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please address this before merging |
||
|
||
The rewritten expression selects the value at the specified quantile by: | ||
- Ranking the rows within each partition. | ||
- Calculating the number of rows (N) in each partition. | ||
- Keeping only those rows where the rank is greater than | ||
INTEGER((1.0 - p) * N), where p is the quantile argument. | ||
- Taking the maximum value among the kept rows. | ||
|
||
Args: | ||
child_connection: The HybridConnection containing the aggregate call | ||
to QUANTILE. | ||
expr: The HybridFunctionExpr representing the QUANTILE aggregation. | ||
create_new_calc: If True, injects new expressions into a new CALCULATE | ||
operation. | ||
|
||
Returns: | ||
A HybridFunctionExpr representing the rewritten aggregation using | ||
window functions. | ||
""" | ||
assert expr.operator == pydop.QUANTILE | ||
|
||
# Valid if the value of p is a number between 0 and 1 | ||
if ( | ||
not isinstance(expr.args[1], HybridLiteralExpr) | ||
or not isinstance(expr.args[1].typ, NumericType) | ||
or not isinstance(expr.args[1].literal.value, (int, float)) | ||
or not (0.0 <= float(expr.args[1].literal.value) <= 1.0) | ||
): | ||
raise ValueError( | ||
f"Expected second argument to QUANTILE to be a numeric literal between 0 and 1, instead received {expr.args[1]!r}" | ||
) | ||
|
||
assert len(expr.args) == 2 | ||
# The implementation | ||
# MAX(KEEP_IF(args[0], R > INTEGER((1.0-args[1]) * N))) | ||
data_expr: HybridExpr = expr.args[0] # Column | ||
|
||
assert child_connection.subtree.agg_keys is not None | ||
partition_args: list[HybridExpr] = child_connection.subtree.agg_keys | ||
order_args: list[HybridCollation] = [HybridCollation(data_expr, False, False)] | ||
|
||
# R | ||
rank: HybridExpr = HybridWindowExpr( | ||
pydop.RANKING, [], partition_args, order_args, NumericType(), {} | ||
) | ||
# N | ||
rows: HybridExpr = HybridWindowExpr( | ||
pydop.RELCOUNT, [data_expr], partition_args, [], NumericType(), {} | ||
) | ||
|
||
# (1.0-args[1]) | ||
sub: HybridExpr = HybridLiteralExpr( | ||
Literal(1.0 - float(expr.args[1].literal.value), NumericType()) | ||
) | ||
|
||
# (1.0-args[1]) * N | ||
product: HybridExpr = HybridFunctionExpr(pydop.MUL, [sub, rows], NumericType()) | ||
|
||
# INTEGER((1.0-args[1]) * N) | ||
cast_integer: HybridExpr = HybridFunctionExpr( | ||
pydop.INTEGER, [product], NumericType() | ||
) | ||
|
||
# R > INTEGER((1.0-args[1]) * N) | ||
greater: HybridExpr = HybridFunctionExpr( | ||
pydop.GRT, [rank, cast_integer], expr.typ | ||
) | ||
|
||
# KEEP_IF(args[0], R > INTEGER((1.0-args[1]) * N))) | ||
keep_largest: HybridExpr = HybridFunctionExpr( | ||
pydop.KEEP_IF, [data_expr, greater], data_expr.typ | ||
) | ||
|
||
# MAX | ||
max_input_arg = self.inject_expression( | ||
child_connection.subtree, keep_largest, create_new_calc | ||
) | ||
max_call: HybridFunctionExpr = HybridFunctionExpr( | ||
pydop.MAX, [max_input_arg], expr.typ | ||
) | ||
|
||
return max_call | ||
|
||
def make_hybrid_correl_expr( | ||
self, | ||
back_expr: BackReferenceExpression, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better phrasing: it is equivalent to the common
PERCENTILE_DISC
SQL aggregation function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done