Skip to content

Add QUANTILE function to PyDough #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions documentation/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Below is the list of every function/operator currently supported in PyDough as a
* [MEDIAN](#median)
* [MIN](#min)
* [MAX](#max)
* [QUANTILE](#quantile)
* [ANYTHING](#anything)
* [COUNT](#count)
* [NDISTINCT](#ndistinct)
Expand Down Expand Up @@ -891,6 +892,38 @@ The `MAX` function returns the largest value from the set of values it is called
Suppliers.CALCULATE(most_expensive_part_supplied = MAX(supply_records.supply_cost))
```

<!-- TOC --><a name="quantile"></a>

### QUANTILE

The `QUANTILE` function returns the value at a specified quantile from the set of values it is called on, using the `PERCENTILE_DISC` definition. Specifically:

- `QUANTILE(x, p)` returns the **smallest value of `x` such that at least `p` proportion of the non-null rows are less than or equal to it**. This matches the behavior of the SQL standard `PERCENTILE_DISC` aggregate function.
- The quantile value `p` must be a numeric literal between 0 and 1 (inclusive), where `0` returns the minimum, `1` returns the maximum, and `0.5` returns the 50th percentile.
- **NULL records are ignored** in the computation.

```py
# Returns the value at the 90th percentile of supply costs for each supplier
Suppliers.CALCULATE(ninetieth_percentile_cost = QUANTILE(supply_records.supply_cost, 0.9))

# Returns the median (50th percentile, discrete) supply cost for each supplier
Suppliers.CALCULATE(median_cost = QUANTILE(supply_records.supply_cost, 0.5))
```

- The first argument is the plural set of values to aggregate.
- The second argument is the quantile to compute, as a numeric literal between 0 and 1.
- If the quantile argument is not a valid number between 0 and 1, an error is raised.

> [!NOTE]
> `QUANTILE(X, P)` is equivalent to the `PERCENTILE_DISC` SQL Agreggation function.
> The implementation uses the SQL standard `PERCENTILE_DISC` aggregate function where available.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better phrasing: it is equivalent to the common PERCENTILE_DISC SQL aggregation function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


| **Input** | **Quantile** | **Output** |
|-----------|-------------|------------|
| `[1, 2, 3, 4, 5]` | `0.0` | `1` |
| `[1, 2, 3, 4, 5]` | `0.5` | `3` |
| `[1, 2, 3, 4, 5]` | `1.0` | `5` |

<!-- TOC --><a name="anything"></a>

### ANYTHING
Expand Down
115 changes: 109 additions & 6 deletions pydough/conversion/hybrid_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def __init__(self, configs: PyDoughConfigs, dialect: DatabaseDialect):
# being derived as the subtree of the last element.
self.stack: list[HybridTree] = []
# If True, rewrites MEDIAN calls into an average of the 1-2 median rows
# via window functions, otherwise leaves as-is.
self.rewrite_median: bool = dialect not in {DatabaseDialect.ANSI}
# or rewrites QUANTILE calls to select the first qualifying row,
# both derived from window functions, otherwise leaves as-is.
self.rewrite_median_quantile: bool = dialect not in {DatabaseDialect.ANSI}

@staticmethod
def get_subcollection_join_keys(
Expand Down Expand Up @@ -264,7 +265,7 @@ def eject_aggregate_inputs(self, hybrid: HybridTree) -> None:
rewritten: bool = False
new_args: list[HybridExpr] = []
for arg in agg_call.args:
if isinstance(arg, HybridRefExpr):
if isinstance(arg, (HybridRefExpr, HybridLiteralExpr)):
new_args.append(arg)
else:
rewritten = True
Expand All @@ -282,7 +283,9 @@ def eject_aggregate_inputs(self, hybrid: HybridTree) -> None:
def run_rewrites(self, hybrid: HybridTree):
"""
Run any rewrite procedures that must occur after de-correlation, such
as converting MEDIAN to an average of the 1-2 median rows.
as converting MEDIAN to an average of the 1-2 median rows. Also converting
the QUANTILE calls to the appropriate window function calls.


Args:
`hybrid`: the bottom of the hybrid tree to rewrite.
Expand All @@ -294,15 +297,20 @@ def run_rewrites(self, hybrid: HybridTree):
self.run_rewrites(child.subtree)

create_new_calc: bool = True
# Rewrite any MEDIAN calls
if self.rewrite_median:
# Rewrite any MEDIAN and QUANTILE calls
if self.rewrite_median_quantile:
for child in hybrid.children:
for agg_name, agg_call in child.aggs.items():
if agg_call.operator == pydop.MEDIAN:
child.aggs[agg_name] = self.rewrite_median_call(
child, agg_call, create_new_calc
)
create_new_calc = False
if agg_call.operator == pydop.QUANTILE:
child.aggs[agg_name] = self.rewrite_quantile_call(
child, agg_call, create_new_calc
)
create_new_calc = False

def qdag_expr_contains_window(self, expr: PyDoughExpressionQDAG) -> bool:
"""
Expand Down Expand Up @@ -760,6 +768,101 @@ def rewrite_median_call(

return avg_call

def rewrite_quantile_call(
self,
child_connection: HybridConnection,
expr: HybridFunctionExpr,
create_new_calc: bool,
) -> HybridFunctionExpr:
"""
Rewrites a QUANTILE aggregation call into an equivalent expression using
window functions.
This is typically used for dialects that do not natively support the
PERCENTILE_DISC
aggregate function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The newlines here are a bit off. Also, wrap functions like QUANTILE or PERCENTILE_DISC in backticks (`) for tooltip formatting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address this before merging


The rewritten expression selects the value at the specified quantile by:
- Ranking the rows within each partition.
- Calculating the number of rows (N) in each partition.
- Keeping only those rows where the rank is greater than
INTEGER((1.0 - p) * N), where p is the quantile argument.
- Taking the maximum value among the kept rows.

Args:
child_connection: The HybridConnection containing the aggregate call
to QUANTILE.
expr: The HybridFunctionExpr representing the QUANTILE aggregation.
create_new_calc: If True, injects new expressions into a new CALCULATE
operation.

Returns:
A HybridFunctionExpr representing the rewritten aggregation using
window functions.
"""
assert expr.operator == pydop.QUANTILE

# Valid if the value of p is a number between 0 and 1
if (
not isinstance(expr.args[1], HybridLiteralExpr)
or not isinstance(expr.args[1].typ, NumericType)
or not isinstance(expr.args[1].literal.value, (int, float))
or not (0.0 <= float(expr.args[1].literal.value) <= 1.0)
):
raise ValueError(
f"Expected second argument to QUANTILE to be a numeric literal between 0 and 1, instead received {expr.args[1]!r}"
)

assert len(expr.args) == 2
# The implementation
# MAX(KEEP_IF(args[0], R > INTEGER((1.0-args[1]) * N)))
data_expr: HybridExpr = expr.args[0] # Column

assert child_connection.subtree.agg_keys is not None
partition_args: list[HybridExpr] = child_connection.subtree.agg_keys
order_args: list[HybridCollation] = [HybridCollation(data_expr, False, False)]

# R
rank: HybridExpr = HybridWindowExpr(
pydop.RANKING, [], partition_args, order_args, NumericType(), {}
)
# N
rows: HybridExpr = HybridWindowExpr(
pydop.RELCOUNT, [data_expr], partition_args, [], NumericType(), {}
)

# (1.0-args[1])
sub: HybridExpr = HybridLiteralExpr(
Literal(1.0 - float(expr.args[1].literal.value), NumericType())
)

# (1.0-args[1]) * N
product: HybridExpr = HybridFunctionExpr(pydop.MUL, [sub, rows], NumericType())

# INTEGER((1.0-args[1]) * N)
cast_integer: HybridExpr = HybridFunctionExpr(
pydop.INTEGER, [product], NumericType()
)

# R > INTEGER((1.0-args[1]) * N)
greater: HybridExpr = HybridFunctionExpr(
pydop.GRT, [rank, cast_integer], expr.typ
)

# KEEP_IF(args[0], R > INTEGER((1.0-args[1]) * N)))
keep_largest: HybridExpr = HybridFunctionExpr(
pydop.KEEP_IF, [data_expr, greater], data_expr.typ
)

# MAX
max_input_arg = self.inject_expression(
child_connection.subtree, keep_largest, create_new_calc
)
max_call: HybridFunctionExpr = HybridFunctionExpr(
pydop.MAX, [max_input_arg], expr.typ
)

return max_call

def make_hybrid_correl_expr(
self,
back_expr: BackReferenceExpression,
Expand Down
2 changes: 2 additions & 0 deletions pydough/pydough_operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"PREV",
"PyDoughExpressionOperator",
"PyDoughOperator",
"QUANTILE",
"QUARTER",
"RANKING",
"RELAVG",
Expand Down Expand Up @@ -166,6 +167,7 @@
POWER,
PRESENT,
PREV,
QUANTILE,
QUARTER,
RANKING,
RELAVG,
Expand Down
5 changes: 3 additions & 2 deletions pydough/pydough_operators/expression_operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ These functions can be called on plural data to aggregate it into a singular exp
- `SUM`: returns the result of adding all of the non-null values of a plural expression.
- `AVG`: returns the result of taking the average of the non-null values of a plural expression.
- `MEDIAN`: returns the result of taking the median of the non-null values of a plural expression.
- `MIN`: returns the largest out of the non-null values of a plural expression.
- `MAX`: returns the smallest out of the non-null values of a plural expression.
- `MIN`: returns the smallest out of the non-null values of a plural expression.
- `MAX`: returns the largest out of the non-null values of a plural expression.
- `QUANTILE`: returns the value at a specified quantile from the set of values.
- `ANYTHING`: returns an arbitrary entry from the values of a plural expression.
- `COUNT`: counts how many non-null values exist in a plural expression (special: see collection aggregations).
- `NDISTINCT`: counts how many unique values exist in a plural expression (special: see collection aggregations).
Expand Down
2 changes: 2 additions & 0 deletions pydough/pydough_operators/expression_operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"PRESENT",
"PREV",
"PyDoughExpressionOperator",
"QUANTILE",
"QUARTER",
"RANKING",
"RELAVG",
Expand Down Expand Up @@ -158,6 +159,7 @@
POWER,
PRESENT,
PREV,
QUANTILE,
QUARTER,
RANKING,
RELAVG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"POWER",
"PRESENT",
"PREV",
"QUANTILE",
"QUARTER",
"RANKING",
"RELAVG",
Expand Down Expand Up @@ -160,6 +161,9 @@
MEDIAN = ExpressionFunctionOperator(
"MEDIAN", True, RequireNumArgs(1), ConstantType(NumericType())
)
QUANTILE = ExpressionFunctionOperator(
"QUANTILE", True, RequireNumArgs(2), ConstantType(NumericType())
)
POWER = ExpressionFunctionOperator(
"POWER", False, RequireNumArgs(2), ConstantType(NumericType())
)
Expand Down
52 changes: 52 additions & 0 deletions pydough/sqlglot/transform_bindings/base_transform_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def convert_call_to_sqlglot(
return self.convert_smallest_or_largest(args, types, True)
case pydop.COUNT:
return self.convert_count(args, types)
case pydop.QUANTILE:
return self.convert_quantile(args, types)
case _:
raise NotImplementedError(
f"Operator '{operator.function_name}' is unsupported with this database dialect."
Expand Down Expand Up @@ -1687,3 +1689,53 @@ def convert_count(
return sqlglot_expressions.Count(this=args[0])
else:
raise ValueError(f"COUNT expects 0 or 1 argument, got {len(args)}")

def convert_quantile(
self, args: list[SQLGlotExpression], types: list[PyDoughType]
) -> SQLGlotExpression:
"""
Converts a PyDough QUANTILE(X, p) function call to a SQLGlot expression
representing the SQL standard PERCENTILE_DISC aggregate function.

This produces an expression equivalent to:
PERCENTILE_DISC(p) WITHIN GROUP (ORDER BY X)

Args:
args: A list of two SQLGlot expressions, where args[0] is the column
or expression to order by (X), and args[1] is the quantile value (p)
between 0 and 1.
types: The PyDough types of the arguments.

Returns:
A SQLGlotExpression representing the PERCENTILE_DISC(p) WITHIN GROUP
(ORDER BY X)
aggregate function.
"""

assert len(args) == 2

# Validate that the second argument is a number between 0 and 1 (inclusive)
if (
not isinstance(args[1], sqlglot_expressions.Literal)
or args[1].is_string
or not (0.0 <= float(args[1].this) <= 1.0)
):
raise ValueError(
f"QUANTILE TEST argument to be a numeric literal between 0 and 1, got {args[1]}"
)

percentile_disc_function: SQLGlotExpression = (
sqlglot_expressions.PercentileDisc(this=args[1])
)

ordered_column: SQLGlotExpression = sqlglot_expressions.Ordered(this=args[0])

order: SQLGlotExpression = sqlglot_expressions.Order(
expressions=[ordered_column]
)

within_group_clause: SQLGlotExpression = sqlglot_expressions.WithinGroup(
this=percentile_disc_function, expression=order
)

return within_group_clause
Loading