Skip to content

AST Parser for detecting numerical code (numpy/scipy/torch/tensorflow/jax/numba)#1051

Merged
aseembits93 merged 6 commits into
mainfrom
detect-numerical-code
Jan 14, 2026
Merged

AST Parser for detecting numerical code (numpy/scipy/torch/tensorflow/jax/numba)#1051
aseembits93 merged 6 commits into
mainfrom
detect-numerical-code

Conversation

@aseembits93

@aseembits93 aseembits93 commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

In the case of numpy/scipy/math it will return False when numba is not present in the environment because this PR is part of a system which will try to use Just-in-Time compilation which is part of numba to compile numpy/scipy/math functions.

@aseembits93

Copy link
Copy Markdown
Contributor Author

@claude review this

@claude

claude Bot commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

Claude finished @aseembits93's task in 1m 14s —— View job


Code Review Complete

  • Read repository guidelines
  • Review code_extractor.py changes
  • Review test file
  • Analyze implementation quality
  • Provide feedback

Summary

This PR adds an AST parser to detect if a function uses numerical computing libraries (numpy, scipy, torch, tensorflow, jax, numba). The implementation is solid with excellent test coverage (819 lines of tests!). However, I've identified several issues that should be addressed:

Critical Issues

1. Async Functions Not Supported (codeflash/code_utils/code_extractor.py:1264-1299)

The _find_function_node function only searches for ast.FunctionDef nodes, completely ignoring ast.AsyncFunctionDef. This means async functions will always return False, even if they use numerical libraries.

Current code:

def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.FunctionDef | None:
    """Find a function node in the AST given its qualified name parts.

    Note: This function only finds regular (sync) functions, not async functions.
    ...
    """

The docstring admits this limitation, but one of the tests (line 531-537) expects async functions to work! This is a critical bug.

Impact: Any async function using numerical libraries will be incorrectly classified as non-numerical code.

Fix this →


2. Star Imports Not Properly Handled (codeflash/code_utils/code_extractor.py:1252-1259)

Star imports from numerical libraries are collected but not properly tracked. When code uses a bare name from a star import (e.g., array from from numpy import *), it won't be detected. The tests acknowledge this limitation (lines 559-577), but this could lead to false negatives in production.

Current behavior:

from numpy import *
def func(x):
    return array(x)  # Won't be detected as numerical code

Recommendation: Either fix this or document it clearly in the function's docstring as a known limitation.


High Priority Issues

3. Inconsistent Return Type Annotations (codeflash/code_utils/code_extractor.py:1362)

Line 1362 uses # noqa : SIM103 with a space before the colon, which is non-standard. Should be # noqa: SIM103.

Minor but affects code quality standards.


4. Missing Edge Case: Empty Code in Nested Classes (codeflash/code_utils/code_extractor.py:1288-1296)

The function returns None for nested classes (depth > 2), which is correct, but there's no test coverage for what happens when the class hierarchy exists but the method doesn't. This could cause confusion.

Test case missing:

code = """
import numpy as np
class Outer:
    class Inner:
        def other_method(self):
            return np.sum([1])
"""
is_numerical_code(code, "Outer.Inner.method")  # method doesn't exist

Medium Priority Issues

5. Performance: Short-Circuit Optimization Missing (codeflash/code_utils/code_extractor.py:1190-1198)

In visit_Call, after finding a numerical library usage, you return early. However, in visit_Attribute and other visitor methods, you call self.generic_visit(node) even after self.found_numerical = True. While the early returns prevent further processing, you're still traversing child nodes unnecessarily.

Suggestion: Add return after setting self.found_numerical = True in all visitor methods to avoid unnecessary traversal.


6. Test Organization: Class Grouping (tests/test_is_numerical_code.py)

The tests are excellently comprehensive, but some related test classes could be merged:

  • TestBasicNumpyUsage and TestNumpySubmodules could be one class
  • TestFalsePositivePrevention tests could be distributed to relevant sections

This is a minor organizational issue and doesn't affect functionality.


Low Priority / Suggestions

7. Type Hint Consistency (codeflash/code_utils/code_extractor.py:1217)

Line 1217 uses str | None (modern Python 3.10+ union syntax) while line 1264 uses ast.FunctionDef | None. However, earlier in the file (line 9) uses Optional[...]. Consider standardizing on one style throughout the file for consistency.


8. Docstring Enhancement (codeflash/code_utils/code_extractor.py:1301-1336)

The is_numerical_code docstring is excellent, but it could benefit from mentioning the star import limitation and the async function limitation (if not fixed).

Suggested addition:

"""
...existing docstring...

Known Limitations:
    - Star imports are not fully resolved (e.g., `from numpy import *` followed by `array(x)`)
    - [If bug not fixed] Async functions are not currently supported
"""

Positive Observations

  1. Excellent test coverage - 819 lines of comprehensive tests covering edge cases, false positives, and various import styles
  2. Good separation of concerns - The numba availability check is cleanly separated
  3. Well-documented - The main function has clear docstrings with examples
  4. Performance-conscious - Uses frozenset for constant lookups, early returns in visitors
  5. Proper use of AST visitors - Clean implementation of the visitor pattern

Verification of PR Description

The PR description states: "In the case of numpy/scipy it will return False when numba is not present in the environment"

This is correct and verified by:

  • Line 1362: if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES):
  • Lines 694-819: Comprehensive tests with @patch("codeflash.code_utils.code_extractor.has_numba", False)

Recommendations

Must fix before merge:

  1. Fix async function support (critical bug with failing test)

Should fix before merge:
2. Fix or document star import limitation
3. Fix the noqa comment formatting

Nice to have:
4. Add missing edge case tests
5. Improve performance with better short-circuiting
6. Standardize type hint style
7. Enhance docstring with limitations


@codeflash-ai

codeflash-ai Bot commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

⚡️ Codeflash found optimizations for this PR

📄 228% (2.28x) speedup for NumericalUsageChecker.visit_Attribute in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 1.12 milliseconds 341 microseconds (best of 250 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch detect-numerical-code).

Static Badge

Comment on lines +1219 to +1222
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return self._get_root_name(node.value)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 36% (0.36x) speedup for NumericalUsageChecker._get_root_name in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 992 microseconds 729 microseconds (best of 215 runs)

📝 Explanation and details

The optimization replaces recursion with iteration when traversing AST attribute chains like np.array or np.linalg.inv.

Key Change:

  • Original: Recursively calls _get_root_name(node.value) for each ast.Attribute node, creating new stack frames
  • Optimized: Uses a while loop to iteratively follow node.value references until reaching the base ast.Name node

Why This Is Faster:

  1. Eliminates function call overhead: Each recursive call incurs the cost of creating a new stack frame, passing arguments, and managing return values. The line profiler shows this clearly—the recursive call in the original code takes 54.2% of total time (4.9ms out of 9ms). The iterative version reduces this to a simple variable assignment (23% of 4.3ms total).

  2. Reduces isinstance() checks: The recursive version checks isinstance(node, ast.Name) at every level of the chain before recursing. The iterative version first loops through all ast.Attribute nodes, then performs the ast.Name check only once at the end. For a chain of depth N, this cuts isinstance checks roughly in half.

  3. Better CPU cache locality: Iteration keeps execution in a tight loop within the same stack frame, improving instruction cache utilization compared to jumping between recursive call sites.

Performance Profile:

The speedup is most dramatic on deep attribute chains:

  • Deep chains (500 levels): 295% faster (116μs → 29.5μs)
  • Medium chains (200 levels): 253% faster (126μs → 36μs)
  • Overall workload: 36% faster (992μs → 729μs)

For shallow chains (2-3 levels deep), the overhead difference is smaller but still measurable (1-8% improvements on basic cases). The optimization particularly shines when the function is called repeatedly in hot paths, such as during AST traversal of large codebases, where eliminating per-call overhead compounds across thousands of invocations (5,488 hits in the profiler).

The iterative approach maintains identical semantics—handling edge cases (non-Name/Attribute nodes return None) and preserving exact return values across all test scenarios.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 24 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest  # used for our unit tests

from codeflash.code_utils.code_extractor import NumericalUsageChecker

# unit tests


# Basic Test Cases
def test_basic_name_node_returns_id():
    # Parse a bare name expression and ensure the root name is the identifier itself.
    expr = ast.parse("np", mode="eval").body  # ast.Name
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)  # 511ns -> 871ns (41.3% slower)


@pytest.mark.parametrize(
    "source,expected_root",
    [
        ("np.array", "np"),  # simple attribute access
        ("np.linalg.inv", "np"),  # multi-level attribute: expect the top-most name
        ("pd.DataFrame", "pd"),  # different root name
        ("self.value", "self"),  # common pythonic root 'self'
    ],
)
def test_basic_attribute_chains_return_root_name(source, expected_root):
    # For various attribute chains ensure the root name is returned.
    expr = ast.parse(source, mode="eval").body  # ast.Attribute (possibly nested)
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)  # 3.51μs -> 3.57μs (1.87% slower)


# Edge Test Cases
def test_non_name_nodes_return_none():
    # Constants and literals are not name/attribute expressions — should return None.
    const_expr = ast.parse("42", mode="eval").body  # ast.Constant
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(const_expr)  # 571ns -> 592ns (3.55% slower)

    # Binary operations are not Name/Attribute -> None
    binop_expr = ast.parse("x + y", mode="eval").body  # ast.BinOp
    codeflash_output = checker._get_root_name(binop_expr)  # 311ns -> 321ns (3.12% slower)


def test_attribute_on_non_name_base_returns_none():
    # Attribute on top of an expression whose root is not a Name (e.g., (x+1).attr)
    expr = ast.parse("(x + 1).attr", mode="eval").body  # ast.Attribute with value ast.BinOp
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)  # 892ns -> 752ns (18.6% faster)

    # Attribute of a call: np.array().x — the .value is ast.Call, so root cannot be extracted -> None
    expr2 = ast.parse("np.array().x", mode="eval").body
    codeflash_output = checker._get_root_name(expr2)  # 451ns -> 421ns (7.13% faster)

    # Attribute of a subscript: arr[0].x -> .value is ast.Subscript -> None
    expr3 = ast.parse("arr[0].x", mode="eval").body
    codeflash_output = checker._get_root_name(expr3)  # 421ns -> 371ns (13.5% faster)


def test_none_input_returns_none_and_is_harmless():
    # While the function expects ast.expr, calling with None must not raise and should return None
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(None)  # 611ns -> 611ns (0.000% faster)


def test_constructed_ast_nodes_work_as_well():
    # Build an Attribute chain programmatically:
    # create node representing root.a.b.c where root is Name('root')
    root_name = ast.Name(id="root", ctx=ast.Load())
    attr1 = ast.Attribute(value=root_name, attr="a", ctx=ast.Load())
    attr2 = ast.Attribute(value=attr1, attr="b", ctx=ast.Load())
    attr3 = ast.Attribute(value=attr2, attr="c", ctx=ast.Load())
    # Ensure our constructed node returns the original root id
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(attr3)  # 1.06μs -> 1.13μs (6.18% slower)


def test_idempotence_and_no_side_effects():
    # Calling _get_root_name twice on the same node should yield the same result and not mutate it.
    expr = ast.parse("np.array", mode="eval").body
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)
    first = codeflash_output  # 781ns -> 821ns (4.87% slower)
    codeflash_output = checker._get_root_name(expr)
    second = codeflash_output  # 440ns -> 420ns (4.76% faster)


# Large Scale Test Cases (stress within limits)
def test_deep_attribute_chain_large_scale():
    # Construct a deep attribute chain via source string to test recursion depth handling.
    # Keep depth under Python recursion limit (default ~1000). We choose 500 here (reasonable stress test).
    depth = 500
    # Construct source like "root.a.a.a...." with depth elements; root should be 'root'
    source = "root" + "".join(f".a{i}" for i in range(depth - 1))
    # Parse expression in eval mode to get AST
    expr = ast.parse(source, mode="eval").body
    # Sanity checks: the parsed top-level node is an Attribute if depth>1 else Name
    checker = NumericalUsageChecker(set())
    # Ensure the function returns 'root' even for a very deep chain
    codeflash_output = checker._get_root_name(expr)  # 116μs -> 29.5μs (295% faster)


def test_multiple_distinct_large_chains_quickly():
    # Create several moderately-large chains and ensure each is resolved correctly.
    # We keep each chain length modest (e.g., 200) and the count small to respect test performance.
    chain_len = 200
    checker = NumericalUsageChecker(set())
    for root in ("alpha", "beta", "gamma"):
        # Build source like 'alpha.x0.x1.x2...'
        source = root + "".join(f".f{i}" for i in range(chain_len - 1))
        expr = ast.parse(source, mode="eval").body
        # Each should yield its corresponding root name
        codeflash_output = checker._get_root_name(expr)  # 126μs -> 36.0μs (253% faster)


# Mutation-sensitive tests
def test_would_detect_returning_attribute_name_instead_of_root():
    # This test ensures that mutations that return the last attribute name (e.g., "array")
    # instead of the root ("np") will fail.
    expr = ast.parse("np.array", mode="eval").body
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)
    result = codeflash_output  # 791ns -> 841ns (5.95% slower)


# Ensure that other valid identifier names are preserved (including underscores and numbers)
@pytest.mark.parametrize(
    "identifier",
    [
        "x1",  # numeric char in name
        "_private",  # leading underscore
        "CamelCase",  # capitalized identifier
        "with_underscores_and_numbers123",
    ],
)
def test_various_identifier_forms(identifier):
    # Build an attribute like "<identifier>.attr" and verify the root id is preserved exactly.
    source = f"{identifier}.attr"
    expr = ast.parse(source, mode="eval").body
    checker = NumericalUsageChecker(set())
    codeflash_output = checker._get_root_name(expr)  # 3.06μs -> 3.35μs (8.67% slower)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1051-2026-01-14T02.39.18

Suggested change
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return self._get_root_name(node.value)
while isinstance(node, ast.Attribute):
node = node.value
if isinstance(node, ast.Name):
return node.id

Static Badge

Comment on lines +1238 to +1260
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
# import numpy or import numpy as np
module_root = alias.name.split(".")[0]
if module_root in NUMERICAL_MODULES:
# Use the alias if present, otherwise the module name
name = alias.asname if alias.asname else alias.name.split(".")[0]
numerical_names.add(name)
modules_used.add(module_root)
elif isinstance(node, ast.ImportFrom) and node.module:
module_root = node.module.split(".")[0]
if module_root in NUMERICAL_MODULES:
# from numpy import array, zeros as z
for alias in node.names:
if alias.name == "*":
# Can't track star imports, but mark the module as numerical
numerical_names.add(module_root)
else:
name = alias.asname if alias.asname else alias.name
numerical_names.add(name)
modules_used.add(module_root)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 343% (3.43x) speedup for _collect_numerical_imports in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 1.48 milliseconds 333 microseconds (best of 106 runs)

📝 Explanation and details

The optimized code achieves a 343% speedup by replacing ast.walk() with a targeted manual traversal strategy that only visits AST nodes where imports can appear.

Key Optimization

What changed: Replaced the generic ast.walk(tree) traversal with a manual stack-based traversal that only processes nodes with a body attribute (like functions, classes, control flow structures).

Why it's faster:

  • ast.walk() visits every single node in the AST tree recursively (1,325 nodes in the profiler), including all expressions, literals, operators, and statements
  • The manual traversal only visits structural containers that can contain imports (101 while-loop iterations in the profiler), skipping ~92% of unnecessary nodes
  • Import statements only appear at the top level or inside function/class definitions, never inside expressions like x + y or literals

Performance impact from line profiler:

  • Original: 76.2% of time (7.6ms) spent in the ast.walk() loop
  • Optimized: The while loop and node processing combined take ~41% of time (0.9ms total), saving 6.7ms

How It Works

The optimized version maintains a stack of nodes to check, only descending into:

  • Function/class definitions (ast.FunctionDef, ast.ClassDef, etc.)
  • Control flow structures that can contain nested scopes (ast.If, ast.For, ast.Try, etc.)
  • orelse branches (else clauses)

This selective traversal finds all imports without wasting time visiting arithmetic operations, variable references, or other leaf nodes.

Impact Context

Based on function_references, this function is called by is_numerical_code(), which analyzes whether functions use numerical libraries. The optimization is particularly valuable when:

  • Analyzing large files with complex ASTs (as shown in test_large_ast_with_many_non_import_nodes)
  • Processing many functions in a codebase during static analysis
  • The AST contains extensive nested logic unrelated to imports

The test cases confirm the optimization works well across all scenarios: simple imports, star imports, submodule imports, and complex real-world code with many non-import nodes.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 45 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest
from codeflash.code_utils.code_extractor import _collect_numerical_imports


class TestCollectNumericalImportsBasic:
    """Basic test cases for _collect_numerical_imports function."""

    def test_single_numpy_import(self):
        """Test basic numpy import without alias."""
        code = "import numpy"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_single_numpy_import_with_alias(self):
        """Test numpy import with common alias 'np'."""
        code = "import numpy as np"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_single_torch_import(self):
        """Test basic torch import."""
        code = "import torch"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_single_torch_import_with_alias(self):
        """Test torch import with alias."""
        code = "import torch as t"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_numpy_import_single_function(self):
        """Test importing a single function from numpy."""
        code = "from numpy import array"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_numpy_import_function_with_alias(self):
        """Test importing from numpy with alias."""
        code = "from numpy import array as arr"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_numpy_import_multiple_functions(self):
        """Test importing multiple functions from numpy."""
        code = "from numpy import array, zeros, ones"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_numpy_import_mixed_aliases(self):
        """Test importing from numpy with mixed aliases."""
        code = "from numpy import array as arr, zeros, ones as o"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_multiple_imports_same_module(self):
        """Test multiple imports from the same numerical module."""
        code = """
import numpy
from numpy import array as arr
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_multiple_different_numerical_modules(self):
        """Test imports from different numerical modules."""
        code = """
import numpy as np
import torch
from scipy import stats
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_non_numerical_import(self):
        """Test that non-numerical imports are ignored."""
        code = "import os"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_mixed_numerical_and_non_numerical(self):
        """Test mixing numerical and non-numerical imports."""
        code = """
import os
import numpy as np
import sys
import torch
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)


class TestCollectNumericalImportsEdge:
    """Edge case tests for _collect_numerical_imports function."""

    def test_empty_code(self):
        """Test with empty code."""
        code = ""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_star_import_from_numpy(self):
        """Test star import from numpy."""
        code = "from numpy import *"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_star_import_from_torch(self):
        """Test star import from torch."""
        code = "from torch import *"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_submodule_import_numpy(self):
        """Test importing submodule of numpy."""
        code = "import numpy.random"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_submodule_with_alias(self):
        """Test importing submodule with alias."""
        code = "import numpy.random as nr"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_torch_submodule_import(self):
        """Test importing from torch submodule."""
        code = "from torch.nn import Linear"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_torch_submodule_with_alias(self):
        """Test importing from torch submodule with alias."""
        code = "from torch.nn import Linear as L"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_math_module_import(self):
        """Test math module import."""
        code = "import math"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_math_import(self):
        """Test importing from math module."""
        code = "from math import sqrt, pi"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_jax_import(self):
        """Test jax module import."""
        code = "import jax"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_tensorflow_import(self):
        """Test tensorflow module import."""
        code = "import tensorflow"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_numba_import(self):
        """Test numba module import."""
        code = "import numba"
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_all_numerical_modules(self):
        """Test importing all numerical modules."""
        code = """
import numpy
import torch
import numba
import jax
import tensorflow
import math
import scipy
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_duplicate_imports_same_name(self):
        """Test duplicate imports with same name."""
        code = """
import numpy
import numpy
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_duplicate_imports_different_aliases(self):
        """Test duplicate imports with different aliases."""
        code = """
import numpy as np
import numpy as numpy_lib
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_from_import_multiple_times_same_module(self):
        """Test multiple from imports from the same module."""
        code = """
from numpy import array
from numpy import zeros
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    
def my_function():
    pass

from jax import jit
from math import sqrt
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_large_ast_with_many_non_import_nodes(self):
        """Test AST with many non-import nodes and several imports."""
        code = """
import numpy as np
from torch import tensor

class MyClass:
    def __init__(self):
        self.x = 5
    
    def method1(self):
        return self.x * 2
    
    def method2(self):
        for i in range(10):
            if i > 5:
                pass

def function1():
    x = 10
    y = 20
    return x + y

def function2(a, b, c):
    if a > b:
        return a
    elif b > c:
        return b
    else:
        return c

from scipy.stats import norm
from jax import jit

x = [i for i in range(50)]
y = {k: v for k, v in enumerate(range(50))}
z = {i for i in range(50)}
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_all_numerical_modules_with_aliases(self):
        """Test all numerical modules with various aliases."""
        imports = [
            "import numpy as np",
            "import torch as t",
            "import numba as nb",
            "import jax as j",
            "import tensorflow as tf",
            "import math as m",
            "import scipy as sp",
        ]
        code = "\n".join(imports)
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_stress_test_many_star_imports(self):
        """Test many star imports from different modules."""
        imports = [
            "from numpy import *",
            "from torch import *",
            "from scipy import *",
            "from jax import *",
            "from tensorflow import *",
            "from math import *",
            "from numba import *",
        ]
        code = "\n".join(imports)
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)

    def test_complex_real_world_imports(self):
        """Test realistic complex import structure."""
        code = """
import numpy as np
import numpy.random as npr
from numpy import array, zeros, ones as np_ones
import torch
from torch.nn import Linear, Conv2d as Conv
from torch.optim import Adam
import scipy.stats as stats
from scipy.integrate import odeint
from jax import jit, vmap
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D
import math
from math import sqrt, sin as sine_function
import numba
from numba import njit
"""
        tree = ast.parse(code)
        numerical_names, modules_used = _collect_numerical_imports(tree)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1051-2026-01-14T02.52.17

Click to see suggested changes
Suggested change
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
# import numpy or import numpy as np
module_root = alias.name.split(".")[0]
if module_root in NUMERICAL_MODULES:
# Use the alias if present, otherwise the module name
name = alias.asname if alias.asname else alias.name.split(".")[0]
numerical_names.add(name)
modules_used.add(module_root)
elif isinstance(node, ast.ImportFrom) and node.module:
module_root = node.module.split(".")[0]
if module_root in NUMERICAL_MODULES:
# from numpy import array, zeros as z
for alias in node.names:
if alias.name == "*":
# Can't track star imports, but mark the module as numerical
numerical_names.add(module_root)
else:
name = alias.asname if alias.asname else alias.name
numerical_names.add(name)
modules_used.add(module_root)
# Iterate directly through the AST body instead of using ast.walk() for better performance
nodes_to_check = [tree]
while nodes_to_check:
current = nodes_to_check.pop()
body = getattr(current, "body", None)
if body:
for node in body:
if isinstance(node, ast.Import):
for alias in node.names:
# import numpy or import numpy as np
module_root = alias.name.split(".")[0]
if module_root in NUMERICAL_MODULES:
# Use the alias if present, otherwise the module name
name = alias.asname if alias.asname else alias.name.split(".")[0]
numerical_names.add(name)
modules_used.add(module_root)
elif isinstance(node, ast.ImportFrom) and node.module:
module_root = node.module.split(".")[0]
if module_root in NUMERICAL_MODULES:
# from numpy import array, zeros as z
for alias in node.names:
if alias.name == "*":
# Can't track star imports, but mark the module as numerical
numerical_names.add(module_root)
else:
name = alias.asname if alias.asname else alias.name
numerical_names.add(name)
modules_used.add(module_root)
elif isinstance(
node,
(
ast.FunctionDef,
ast.AsyncFunctionDef,
ast.ClassDef,
ast.If,
ast.For,
ast.AsyncFor,
ast.While,
ast.With,
ast.AsyncWith,
ast.Try,
ast.ExceptHandler,
),
):
nodes_to_check.append(node)
orelse = getattr(current, "orelse", None)
if orelse:
for node in orelse:
if isinstance(
node,
(
ast.FunctionDef,
ast.AsyncFunctionDef,
ast.ClassDef,
ast.If,
ast.For,
ast.AsyncFor,
ast.While,
ast.With,
ast.AsyncWith,
ast.Try,
),
):
nodes_to_check.append(node)

Static Badge

@KRRT7

KRRT7 commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

As with the other PR you can skip anything async for this, these sort of things apply mainly for CPU bound things, which asyncio doesn't do

@aseembits93

Copy link
Copy Markdown
Contributor Author

As with the other PR you can skip anything async for this, these sort of things apply mainly for CPU bound things, which asyncio doesn't do

Agreed, I have removed async visitors here. ready to review

@codeflash-ai

codeflash-ai Bot commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

⚡️ Codeflash found optimizations for this PR

📄 73% (0.73x) speedup for is_numerical_code in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 30.1 milliseconds 17.5 milliseconds (best of 52 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch detect-numerical-code).

Static Badge

@aseembits93 aseembits93 merged commit e27afda into main Jan 14, 2026
22 of 23 checks passed
@aseembits93 aseembits93 deleted the detect-numerical-code branch January 14, 2026 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants