Skip to content

⚡️ Speed up function is_numerical_code by 73% in PR #1051 (detect-numerical-code)#1056

Closed
codeflash-ai[bot] wants to merge 1 commit into
mainfrom
codeflash/optimize-pr1051-2026-01-14T21.27.58
Closed

⚡️ Speed up function is_numerical_code by 73% in PR #1051 (detect-numerical-code)#1056
codeflash-ai[bot] wants to merge 1 commit into
mainfrom
codeflash/optimize-pr1051-2026-01-14T21.27.58

Conversation

@codeflash-ai

@codeflash-ai codeflash-ai Bot commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

⚡️ This pull request contains optimizations for PR #1051

If you approve this dependent PR, these changes will be merged into the original PR branch detect-numerical-code.

This PR will be automatically closed if the original PR is merged.


📄 73% (0.73x) speedup for is_numerical_code in codeflash/code_utils/code_extractor.py

⏱️ Runtime : 30.1 milliseconds 17.5 milliseconds (best of 52 runs)

📝 Explanation and details

The optimized code achieves a 72% speedup (from 30.1ms to 17.5ms) through two key optimizations:

What Changed

1. Single-Pass AST Traversal (Major Optimization)

The original code made multiple passes over the AST:

  • ast.walk(tree) to collect imports (88.4% of _collect_numerical_imports time)
  • Separate iteration through tree.body to find the function

The optimized version combines both operations in _collect_imports_and_find_function, traversing tree.body only once to both collect imports and locate the target function.

2. Early-Exit Visitor Pattern

The NumericalUsageChecker class now implements visit_Name and visit_Attribute methods:

  • Short-circuits traversal once numerical usage is detected (if self.found_numerical: return)
  • Avoids visiting remaining AST nodes after finding the first numerical reference
  • The original implementation had no visitor methods, so it traversed the entire function body even after finding numerical usage

Why This Is Faster

AST traversal cost: The line profiler shows ast.walk(tree) consumed 88.4% of _collect_numerical_imports time in the original code. Eliminating this expensive full-tree walk and replacing it with a targeted single pass through tree.body dramatically reduces overhead.

Short-circuit benefits: Test results show the optimization is most effective on large codebases:

  • test_large_scale_many_lines_and_imports_performance_and_correctness: 84.1% faster (5.87ms → 3.19ms)
  • test_large_code_file_with_many_functions: 106-111% faster (1.6ms → 0.76ms)
  • test_many_function_definitions: 96-98% faster (2.3ms → 1.2ms)

For simple cases, gains are more modest (35-45% faster) due to smaller AST sizes and less traversal overhead.

Impact on Workloads

This optimization is particularly valuable for:

  • Code analysis tools scanning large Python files with many imports and functions
  • Static analysis pipelines that need to classify functions rapidly
  • Hot paths where is_numerical_code is called repeatedly on different functions in the same file (the ast.parse cost remains, but subsequent operations are much faster)

The optimization maintains correctness across all test cases while providing consistent speedups, especially for real-world codebases with hundreds of lines and dozens of imports.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 84 Passed
🌀 Generated Regression Tests 74 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Click to see Existing Unit Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
test_is_numerical_code.py::TestBasicNumpyUsage.test_numpy_custom_alias 59.5μs 40.9μs 45.5%✅
test_is_numerical_code.py::TestBasicNumpyUsage.test_numpy_from_import 64.9μs 44.4μs 46.0%✅
test_is_numerical_code.py::TestBasicNumpyUsage.test_numpy_from_import_with_alias 61.8μs 42.5μs 45.5%✅
test_is_numerical_code.py::TestBasicNumpyUsage.test_numpy_with_standard_alias 59.0μs 40.9μs 44.3%✅
test_is_numerical_code.py::TestBasicNumpyUsage.test_numpy_without_alias 58.4μs 40.4μs 44.3%✅
test_is_numerical_code.py::TestClassMethods.test_classmethod_with_torch 71.6μs 50.0μs 43.3%✅
test_is_numerical_code.py::TestClassMethods.test_multiple_decorators 78.4μs 54.1μs 45.1%✅
test_is_numerical_code.py::TestClassMethods.test_regular_method_with_numpy 69.0μs 48.5μs 42.2%✅
test_is_numerical_code.py::TestClassMethods.test_regular_method_without_numerical 95.1μs 65.8μs 44.7%✅
test_is_numerical_code.py::TestClassMethods.test_staticmethod_with_numpy 71.2μs 49.3μs 44.4%✅
test_is_numerical_code.py::TestEdgeCases.test_async_function_with_numpy 30.9μs 31.9μs -3.11%⚠️
test_is_numerical_code.py::TestEdgeCases.test_default_argument_with_numpy 61.0μs 42.5μs 43.4%✅
test_is_numerical_code.py::TestEdgeCases.test_empty_code_string 7.83μs 8.19μs -4.41%⚠️
test_is_numerical_code.py::TestEdgeCases.test_empty_function 40.4μs 28.6μs 41.6%✅
test_is_numerical_code.py::TestEdgeCases.test_nonexistent_function 30.8μs 31.4μs -1.85%⚠️
test_is_numerical_code.py::TestEdgeCases.test_numpy_in_docstring_only 64.9μs 47.8μs 35.7%✅
test_is_numerical_code.py::TestEdgeCases.test_syntax_error_code 30.4μs 30.3μs 0.529%✅
test_is_numerical_code.py::TestEdgeCases.test_type_annotation_with_numpy 72.5μs 52.2μs 38.8%✅
test_is_numerical_code.py::TestFalsePositivePrevention.test_class_named_math 52.4μs 38.1μs 37.5%✅
test_is_numerical_code.py::TestFalsePositivePrevention.test_function_named_numpy 57.4μs 39.8μs 44.5%✅
test_is_numerical_code.py::TestFalsePositivePrevention.test_function_named_torch 56.8μs 39.4μs 44.1%✅
test_is_numerical_code.py::TestFalsePositivePrevention.test_variable_named_np 63.1μs 46.4μs 35.8%✅
test_is_numerical_code.py::TestJaxUsage.test_from_jax_import_numpy 60.5μs 42.1μs 43.4%✅
test_is_numerical_code.py::TestJaxUsage.test_jax_basic 59.5μs 41.7μs 42.6%✅
test_is_numerical_code.py::TestJaxUsage.test_jax_from_import 58.7μs 41.7μs 40.7%✅
test_is_numerical_code.py::TestJaxUsage.test_jax_numpy_alias 60.1μs 41.5μs 44.9%✅
test_is_numerical_code.py::TestMathUsage.test_math_aliased 60.4μs 41.5μs 45.6%✅
test_is_numerical_code.py::TestMathUsage.test_math_basic 58.5μs 40.3μs 45.1%✅
test_is_numerical_code.py::TestMathUsage.test_math_from_import 82.9μs 52.6μs 57.8%✅
test_is_numerical_code.py::TestMultipleLibraries.test_numpy_and_torch 83.7μs 55.9μs 49.9%✅
test_is_numerical_code.py::TestMultipleLibraries.test_scipy_and_numpy 84.6μs 57.2μs 48.0%✅
test_is_numerical_code.py::TestNestedUsage.test_numpy_in_conditional 91.0μs 64.9μs 40.1%✅
test_is_numerical_code.py::TestNestedUsage.test_numpy_in_lambda 78.5μs 54.3μs 44.4%✅
test_is_numerical_code.py::TestNestedUsage.test_numpy_in_list_comprehension 75.1μs 51.8μs 45.1%✅
test_is_numerical_code.py::TestNestedUsage.test_numpy_in_try_except 85.0μs 61.3μs 38.7%✅
test_is_numerical_code.py::TestNoNumericalUsage.test_class_method_without_numerical 61.8μs 44.7μs 38.3%✅
test_is_numerical_code.py::TestNoNumericalUsage.test_list_operations 71.1μs 51.8μs 37.3%✅
test_is_numerical_code.py::TestNoNumericalUsage.test_simple_function 55.7μs 40.1μs 38.9%✅
test_is_numerical_code.py::TestNoNumericalUsage.test_string_manipulation 62.8μs 45.5μs 38.0%✅
test_is_numerical_code.py::TestNoNumericalUsage.test_with_non_numerical_imports 79.1μs 55.7μs 41.8%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_jax_returns_true_without_numba 58.0μs 40.5μs 43.4%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_math_from_import_returns_false_without_numba 84.1μs 53.2μs 58.0%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_math_returns_false_without_numba 58.9μs 40.5μs 45.5%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_numba_import_returns_true_without_numba 68.7μs 50.1μs 37.0%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_numpy_and_jax_returns_true_without_numba 83.7μs 57.4μs 45.7%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_numpy_and_torch_returns_true_without_numba 84.0μs 56.8μs 47.9%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_numpy_returns_false_without_numba 60.0μs 41.2μs 45.5%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_numpy_submodule_returns_false_without_numba 60.6μs 42.5μs 42.4%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_scipy_and_tensorflow_returns_true_without_numba 85.1μs 57.8μs 47.2%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_scipy_returns_false_without_numba 61.3μs 42.3μs 44.9%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_tensorflow_returns_true_without_numba 54.7μs 38.0μs 44.1%✅
test_is_numerical_code.py::TestNumbaNotAvailable.test_torch_returns_true_without_numba 63.7μs 42.9μs 48.5%✅
test_is_numerical_code.py::TestNumbaUsage.test_numba_basic 69.3μs 49.9μs 39.0%✅
test_is_numerical_code.py::TestNumbaUsage.test_numba_cuda 58.0μs 40.6μs 42.8%✅
test_is_numerical_code.py::TestNumbaUsage.test_numba_jit_decorator 68.1μs 49.5μs 37.6%✅
test_is_numerical_code.py::TestNumpySubmodules.test_from_numpy_import_submodule 60.0μs 41.7μs 43.6%✅
test_is_numerical_code.py::TestNumpySubmodules.test_from_numpy_linalg_import_function 58.0μs 41.2μs 40.8%✅
test_is_numerical_code.py::TestNumpySubmodules.test_numpy_linalg_aliased 59.0μs 41.3μs 42.9%✅
test_is_numerical_code.py::TestNumpySubmodules.test_numpy_linalg_direct 61.0μs 41.9μs 45.4%✅
test_is_numerical_code.py::TestNumpySubmodules.test_numpy_random_aliased 58.8μs 40.8μs 44.0%✅
test_is_numerical_code.py::TestQualifiedNames.test_class_dot_method 71.6μs 50.4μs 42.1%✅
test_is_numerical_code.py::TestQualifiedNames.test_invalid_qualified_name_too_deep 37.6μs 37.8μs -0.344%⚠️
test_is_numerical_code.py::TestQualifiedNames.test_method_in_wrong_class 156μs 102μs 52.7%✅
test_is_numerical_code.py::TestQualifiedNames.test_simple_function_name 59.3μs 40.7μs 45.7%✅
test_is_numerical_code.py::TestScipyUsage.test_scipy_basic 67.2μs 45.8μs 46.7%✅
test_is_numerical_code.py::TestScipyUsage.test_scipy_optimize_alias 68.4μs 47.6μs 43.7%✅
test_is_numerical_code.py::TestScipyUsage.test_scipy_stats 61.4μs 42.8μs 43.6%✅
test_is_numerical_code.py::TestScipyUsage.test_scipy_stats_from_import 61.3μs 43.3μs 41.5%✅
test_is_numerical_code.py::TestStarImports.test_star_import_bare_name_not_detected 59.4μs 42.5μs 39.9%✅
test_is_numerical_code.py::TestStarImports.test_star_import_math_bare_name_not_detected 59.1μs 42.5μs 39.1%✅
test_is_numerical_code.py::TestStarImports.test_star_import_with_module_reference 65.8μs 45.6μs 44.3%✅
test_is_numerical_code.py::TestTensorflowUsage.test_tensorflow_basic 54.7μs 38.4μs 42.4%✅
test_is_numerical_code.py::TestTensorflowUsage.test_tensorflow_from_import 54.1μs 38.0μs 42.5%✅
test_is_numerical_code.py::TestTensorflowUsage.test_tensorflow_keras_alias 53.7μs 38.3μs 40.4%✅
test_is_numerical_code.py::TestTensorflowUsage.test_tensorflow_keras_layers_alias 58.0μs 40.8μs 42.2%✅
test_is_numerical_code.py::TestTensorflowUsage.test_tensorflow_standard_alias 54.9μs 37.8μs 45.3%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_basic 64.8μs 43.7μs 48.4%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_from_import 59.6μs 42.8μs 39.0%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_from_import_aliased 59.8μs 43.8μs 36.5%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_functional_alias 61.4μs 42.7μs 43.7%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_nn_alias 59.2μs 41.8μs 41.6%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_standard_alias 59.1μs 41.5μs 42.5%✅
test_is_numerical_code.py::TestTorchUsage.test_torch_utils_data 59.4μs 41.7μs 42.4%✅
🌀 Click to see Generated Regression Tests
import ast
from importlib.util import find_spec

# imports
from codeflash.code_utils.code_extractor import is_numerical_code

# function to test
# (Module code reproduced here so the tests run against the real implementation)
has_numba = find_spec("numba") is not None

NUMERICAL_MODULES = frozenset({"numpy", "torch", "numba", "jax", "tensorflow", "math", "scipy"})
# Modules that require numba to be installed for optimization
NUMBA_REQUIRED_MODULES = frozenset({"numpy", "math", "scipy"})


def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
    """Collect names that reference numerical computing libraries from imports.

    Returns:
        A tuple of (numerical_names, modules_used) where:
        - numerical_names: set of names/aliases that reference numerical libraries
        - modules_used: set of actual module names (e.g., "numpy", "math") being imported

    """
    numerical_names: set[str] = set()
    modules_used: set[str] = set()

    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                # import numpy or import numpy as np
                module_root = alias.name.split(".")[0]
                if module_root in NUMERICAL_MODULES:
                    # Use the alias if present, otherwise the module name
                    name = alias.asname if alias.asname else alias.name.split(".")[0]
                    numerical_names.add(name)
                    modules_used.add(module_root)
        elif isinstance(node, ast.ImportFrom) and node.module:
            module_root = node.module.split(".")[0]
            if module_root in NUMERICAL_MODULES:
                # from numpy import array, zeros as z
                for alias in node.names:
                    if alias.name == "*":
                        # Can't track star imports, but mark the module as numerical
                        numerical_names.add(module_root)
                    else:
                        name = alias.asname if alias.asname else alias.name
                        numerical_names.add(name)
                modules_used.add(module_root)

    return numerical_names, modules_used


def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.FunctionDef | None:
    """Find a function node in the AST given its qualified name parts.

    Note: This function only finds regular (sync) functions, not async functions.

    Args:
        tree: The parsed AST module
        name_parts: List of name parts, e.g., ["ClassName", "method_name"] or ["function_name"]

    Returns:
        The function node if found, None otherwise

    """
    if not name_parts:
        return None

    if len(name_parts) == 1:
        # Top-level function
        func_name = name_parts[0]
        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and node.name == func_name:
                return node
        return None

    if len(name_parts) == 2:
        # Class method: ClassName.method_name
        class_name, method_name = name_parts
        for node in tree.body:
            if isinstance(node, ast.ClassDef) and node.name == class_name:
                for class_node in node.body:
                    if isinstance(class_node, ast.FunctionDef) and class_node.name == method_name:
                        return class_node
        return None

    return None


# unit tests

# Write your test functions here, e.g.:
# def test_basic_functionality():
#     ...
# The tests below are ordered from basic to more complex and include edge and large-scale scenarios.
# All tests are deterministic and avoid mocking the function under test or its internals.


def test_basic_non_numerical_function_returns_false():
    # Basic: a simple top-level function that does not use any imports or numerical libraries.
    code = """
def add_one(x):
    return x + 1
"""
    # Expect False because there's no numerical usage detected.
    codeflash_output = is_numerical_code(code, "add_one")  # 55.6μs -> 40.8μs (36.3% faster)


def test_syntax_error_in_code_returns_false():
    # Edge: code string with a syntax error should return False and not raise.
    code = "def broken(:\n    pass"
    # The function should gracefully return False for unparsable code.
    codeflash_output = is_numerical_code(code, "broken")  # 26.3μs -> 25.8μs (1.87% faster)


def test_function_name_not_present_returns_false():
    # Edge: function_name not found in the module should return False.
    code = """
def some_other_function():
    pass
"""
    # Request a non-existent function name.
    codeflash_output = is_numerical_code(code, "missing_function")  # 17.8μs -> 17.8μs (0.225% faster)


def test_imports_numpy_but_checking_still_returns_false_due_to_missing_detection():
    # Basic: even if numpy is imported and used in the function, current implementation's checker
    # does not mark found_numerical (NumericalUsageChecker has no logic), so the result is False.
    code = """
import numpy as np

def process(arr):
    return np.sum(arr)
"""
    codeflash_output = is_numerical_code(code, "process")  # 63.2μs -> 43.9μs (43.9% faster)


def test_from_import_star_with_numerical_module_returns_false():
    # Edge: star-import from a numerical module, function uses a symbol coming from the star import.
    code = """
from numpy import *

def f(a):
    return array(a)  # array would come from the star import
"""
    codeflash_output = is_numerical_code(code, "f")  # 61.3μs -> 43.2μs (42.0% faster)


def test_class_method_detection_returns_false():
    # Basic/Edge: class method resolution (ClassName.method_name) is supported by _find_function_node.
    # Here, even though we import torch and use it inside the method, the overall detection still returns False.
    code = """
import torch

class MyModel:
    def forward(self, x):
        return torch.relu(x)
"""
    # The function does exist and is a class method; the function should be found but numerical detection returns False.
    codeflash_output = is_numerical_code(code, "MyModel.forward")  # 70.2μs -> 48.1μs (46.1% faster)


def test_async_function_not_detected_returns_false():
    # Edge: async functions are not handled by _find_function_node (it only checks ast.FunctionDef),
    # so requesting an async function should return False.
    code = """
async def coro(x):
    return x
"""
    codeflash_output = is_numerical_code(code, "coro")  # 22.0μs -> 21.4μs (3.00% faster)


def test_qualified_name_with_too_many_parts_returns_false():
    # Edge: _find_function_node only supports 1 or 2 parts. Providing 3-part qualified name should return False.
    code = """
class A:
    class B:
        def method(self):
            return 1
"""
    codeflash_output = is_numerical_code(code, "A.B.method")  # 26.6μs -> 26.4μs (0.677% faster)


def test_empty_function_name_returns_false():
    # Edge: empty function name results in no match and should return False.
    code = """
def foo():
    return 42
"""
    codeflash_output = is_numerical_code(code, "")  # 19.6μs -> 19.6μs (0.302% faster)


def test_multiple_import_aliases_and_non_numerical_modules_returns_false():
    # Basic: mix of numerical and non-numerical imports with aliases; function uses only non-numerical code.
    code = """
import os as operating_system
import numpy as np
from math import sqrt as msqrt

def helper(x):
    return operating_system.path.exists(str(x))
"""
    # Even though numpy and math are imported, the function does not use them; detection should be False.
    codeflash_output = is_numerical_code(code, "helper")  # 91.9μs -> 65.2μs (41.0% faster)


def test_numba_dependency_branch_unchanged_when_no_numerical_found(monkeypatch):
    # Edge: Demonstrate that changes to has_numba global do not affect result when no numerical usage is found.
    original = globals().get("has_numba", None)
    try:
        # Force has_numba to both True and False and ensure result stays False for a simple non-numerical function.
        globals()["has_numba"] = True
        code = "def a():\n    return 1"
        codeflash_output = is_numerical_code(code, "a")

        globals()["has_numba"] = False
        codeflash_output = is_numerical_code(code, "a")
    finally:
        # Restore original value to avoid side effects for other tests / environment.
        if original is None:
            globals().pop("has_numba", None)
        else:
            globals()["has_numba"] = original


def test_large_scale_many_lines_and_imports_performance_and_correctness():
    # Large scale: generate a large module with many dummy functions and many imports (but < 1000)
    # to ensure the parser and check function run within reasonable resources and remain deterministic.
    num_dummy_funcs = 600  # below 1000 as required
    header = []
    # Add many (non-numerical) imports and some numerical imports to exercise import collection.
    header.append("import os\n")
    header.append("import sys\n")
    header.append("import numpy as np\n")  # numerical import present
    header.append("from math import sqrt\n")  # numerical import present
    # Append many dummy function definitions
    body_lines = []
    for i in range(num_dummy_funcs):
        body_lines.append(f"def dummy_func_{i}():\n    return {i}\n")
    # Add the target function somewhere in the middle that uses a non-numerical operation.
    target_func = """
def target(x):
    # intentionally simple non-numerical function to assert detection remains False
    y = x * 2
    return y
"""
    # Build the full code string
    code = "".join(header) + "".join(body_lines) + target_func
    # This should locate the function quickly and return False (no numerical detection present in the body).
    codeflash_output = is_numerical_code(code, "target")  # 5.87ms -> 3.19ms (84.1% faster)


def test_mangled_source_with_comments_and_strings_does_not_raise_and_returns_false():
    # Edge: code that contains numerical module names inside strings/comments should not be falsely parsed as usage.
    code = """
# This comment mentions numpy and torch but should not count as usage: numpy, torch
def f():
    s = "this string mentions numpy and scipy and math"
    return s
"""
    # The function should be found and the presence of module names in comments/strings should not trigger detection.
    codeflash_output = is_numerical_code(code, "f")  # 59.7μs -> 43.4μs (37.7% faster)


# Additional sanity test: ensure that when the target function is missing but there are multiple definitions,
# the function still returns False and does not raise.
def test_multiple_definitions_but_target_missing_returns_false():
    code = """
def a(): pass
def b(): pass
def c(): pass
"""
    codeflash_output = is_numerical_code(code, "z")  # 23.7μs -> 24.2μs (2.08% slower)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
# imports
from codeflash.code_utils.code_extractor import is_numerical_code

# ============================================================================
# BASIC TEST CASES
# ============================================================================


def test_basic_numpy_usage():
    """Test basic numpy usage detection."""
    code = """
import numpy as np
def process_data(x):
    return np.sum(x)
"""
    # With numba installed, should return True for numpy usage
    codeflash_output = is_numerical_code(code, "process_data")
    result = codeflash_output  # 66.2μs -> 46.1μs (43.6% faster)


def test_basic_torch_usage():
    """Test basic torch usage detection."""
    code = """
import torch
def compute_tensor(x):
    return torch.sum(x)
"""
    # torch doesn't require numba, should return True
    codeflash_output = is_numerical_code(code, "compute_tensor")
    result = codeflash_output  # 62.0μs -> 42.4μs (46.1% faster)


def test_no_numerical_usage():
    """Test function with no numerical library usage."""
    code = """
def simple_func(x):
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "simple_func")
    result = codeflash_output  # 56.7μs -> 41.2μs (37.6% faster)


def test_jax_usage():
    """Test basic jax usage detection."""
    code = """
import jax.numpy as jnp
def compute_with_jax(x):
    return jnp.sum(x)
"""
    # jax doesn't require numba, should return True
    codeflash_output = is_numerical_code(code, "compute_with_jax")
    result = codeflash_output  # 62.7μs -> 43.4μs (44.5% faster)


def test_tensorflow_usage():
    """Test basic tensorflow usage detection."""
    code = """
import tensorflow as tf
def model_forward(x):
    return tf.reduce_sum(x)
"""
    # tensorflow doesn't require numba, should return True
    codeflash_output = is_numerical_code(code, "model_forward")
    result = codeflash_output  # 60.3μs -> 41.1μs (46.9% faster)


def test_function_not_found():
    """Test when function name doesn't exist in code."""
    code = """
def existing_func(x):
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "nonexistent_func")
    result = codeflash_output  # 24.3μs -> 23.8μs (2.32% faster)


def test_invalid_syntax():
    """Test with invalid Python syntax."""
    code = """
def broken_func(x
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "broken_func")
    result = codeflash_output  # 37.2μs -> 36.9μs (0.762% faster)


def test_method_detection():
    """Test detecting numerical usage in class methods."""
    code = """
import numpy as np
class DataProcessor:
    def process(self, x):
        return np.mean(x)
"""
    # Should handle method name notation
    codeflash_output = is_numerical_code(code, "DataProcessor.process")
    result = codeflash_output  # 72.5μs -> 51.1μs (42.0% faster)


def test_staticmethod_detection():
    """Test detecting numerical usage in static methods."""
    code = """
import torch
class TorchHelper:
    @staticmethod
    def compute(x):
        return torch.tensor(x)
"""
    codeflash_output = is_numerical_code(code, "TorchHelper.compute")
    result = codeflash_output  # 71.3μs -> 48.6μs (46.5% faster)


def test_classmethod_detection():
    """Test detecting numerical usage in class methods."""
    code = """
import jax.numpy as jnp
class JaxCompute:
    @classmethod
    def compute(cls, x):
        return jnp.array(x)
"""
    codeflash_output = is_numerical_code(code, "JaxCompute.compute")
    result = codeflash_output  # 75.4μs -> 51.9μs (45.2% faster)


# ============================================================================
# EDGE TEST CASES
# ============================================================================


def test_empty_code_string():
    """Test with empty code string."""
    code = ""
    codeflash_output = is_numerical_code(code, "any_func")
    result = codeflash_output  # 7.83μs -> 8.26μs (5.10% slower)


def test_code_with_no_functions():
    """Test with code containing no functions."""
    code = """
x = 10
y = 20
"""
    codeflash_output = is_numerical_code(code, "nonexistent")
    result = codeflash_output  # 20.1μs -> 20.5μs (1.81% slower)


def test_nested_class_method_not_found():
    """Test with nested class notation that doesn't exist."""
    code = """
class Outer:
    class Inner:
        def method(self):
            pass
"""
    # Only supports one level of nesting (ClassName.method_name)
    codeflash_output = is_numerical_code(code, "Outer.Inner.method")
    result = codeflash_output  # 25.2μs -> 24.9μs (1.37% faster)


def test_import_as_alias():
    """Test numpy imported with different alias."""
    code = """
import numpy as numerical
def func(x):
    return numerical.sum(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 62.9μs -> 44.0μs (43.0% faster)


def test_from_import_function():
    """Test importing specific function from numerical library."""
    code = """
from numpy import sum as np_sum
def func(x):
    return np_sum(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 60.1μs -> 42.7μs (40.9% faster)


def test_from_import_star():
    """Test star import from numerical library."""
    code = """
from numpy import *
def func(x):
    return sum(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 60.0μs -> 42.7μs (40.6% faster)


def test_multiple_imports():
    """Test function with multiple numerical imports."""
    code = """
import numpy as np
import torch
def func(x):
    a = np.array(x)
    b = torch.tensor(a)
    return b
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 92.8μs -> 64.0μs (45.1% faster)


def test_scipy_usage():
    """Test scipy usage detection."""
    code = """
from scipy import optimize
def fit_data(x):
    return optimize.curve_fit(None, x, x)
"""
    codeflash_output = is_numerical_code(code, "fit_data")
    result = codeflash_output  # 67.8μs -> 45.7μs (48.4% faster)


def test_math_module_usage():
    """Test math module usage detection."""
    code = """
import math
def compute(x):
    return math.sqrt(x)
"""
    codeflash_output = is_numerical_code(code, "compute")
    result = codeflash_output  # 59.3μs -> 41.1μs (44.2% faster)


def test_numba_module_usage():
    """Test numba module usage detection."""
    code = """
from numba import jit
@jit
def compute(x):
    return x * 2
"""
    codeflash_output = is_numerical_code(code, "compute")
    result = codeflash_output  # 68.8μs -> 49.1μs (40.1% faster)


def test_numerical_import_not_used():
    """Test when numerical library is imported but not used."""
    code = """
import numpy as np
def func(x):
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 62.5μs -> 45.1μs (38.7% faster)


def test_non_numerical_function_in_file_with_numerical():
    """Test non-numerical function in file containing numerical imports."""
    code = """
import numpy as np
def numerical_func(x):
    return np.sum(x)

def non_numerical_func(x):
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "non_numerical_func")
    result = codeflash_output  # 84.7μs -> 58.6μs (44.4% faster)


def test_empty_function_body():
    """Test function with empty body (pass)."""
    code = """
import numpy as np
def empty_func():
    pass
"""
    codeflash_output = is_numerical_code(code, "empty_func")
    result = codeflash_output  # 41.1μs -> 29.2μs (40.6% faster)


def test_function_with_numerical_in_comment():
    """Test function that mentions numerical library only in comment."""
    code = """
def func(x):
    # This function uses numpy
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 55.5μs -> 40.4μs (37.5% faster)


def test_numerical_in_string_literal():
    """Test function that mentions numerical library only in string."""
    code = """
def func(x):
    s = "I use numpy"
    return x + 1
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 67.1μs -> 49.6μs (35.4% faster)


def test_subdomain_import():
    """Test importing from subdomain of numerical library."""
    code = """
import numpy.random as npr
def func(x):
    return npr.random()
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 59.9μs -> 43.1μs (39.0% faster)


def test_qualified_attribute_access():
    """Test accessing attributes through qualified names."""
    code = """
import numpy.linalg as la
def func(x):
    return la.norm(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 61.2μs -> 42.9μs (42.8% faster)


def test_function_with_only_imports():
    """Test function that only has import statements."""
    code = """
def func():
    import numpy as np
    import torch
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 46.3μs -> 32.3μs (43.3% faster)


def test_function_name_case_sensitive():
    """Test that function name matching is case-sensitive."""
    code = """
import numpy as np
def MyFunction(x):
    return np.sum(x)
"""
    codeflash_output = is_numerical_code(code, "myfunction")
    result = codeflash_output  # 30.1μs -> 31.4μs (3.96% slower)


def test_long_function_name():
    """Test with very long function name."""
    code = """
import numpy as np
def very_long_function_name_that_is_still_valid_python(x):
    return np.sum(x)
"""
    codeflash_output = is_numerical_code(code, "very_long_function_name_that_is_still_valid_python")
    result = codeflash_output  # 60.6μs -> 41.6μs (45.7% faster)


def test_special_characters_in_code():
    """Test code with special Unicode characters in strings."""
    code = """
import numpy as np
def func(x):
    print("Hello 世界")
    return np.sum(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 80.8μs -> 57.3μs (41.0% faster)


# ============================================================================
# LARGE SCALE TEST CASES
# ============================================================================


def test_large_code_file_with_many_functions():
    """Test with a large code file containing many functions."""
    # Build code with 100 functions (staying under 1000 steps)
    code_lines = []
    code_lines.append("import numpy as np")
    code_lines.append("import torch")
    code_lines.append("")

    # Add 50 non-numerical functions
    for i in range(50):
        code_lines.append(f"def non_numerical_func_{i}(x):")
        code_lines.append("    return x + 1")
        code_lines.append("")

    # Add 50 numerical functions
    for i in range(50):
        code_lines.append(f"def numerical_func_{i}(x):")
        code_lines.append("    return np.sum(x)")
        code_lines.append("")

    code = "\n".join(code_lines)

    # Test non-numerical function
    codeflash_output = is_numerical_code(code, "non_numerical_func_25")
    result1 = codeflash_output  # 1.64ms -> 792μs (106% faster)

    # Test numerical function
    codeflash_output = is_numerical_code(code, "numerical_func_25")
    result2 = codeflash_output  # 1.60ms -> 758μs (111% faster)


def test_large_code_file_with_many_classes():
    """Test with a large code file containing many classes."""
    code_lines = []
    code_lines.append("import torch")
    code_lines.append("")

    # Add 50 classes with methods (staying under 500 lines)
    for i in range(50):
        code_lines.append(f"class Class_{i}:")
        code_lines.append("    def method(self, x):")
        if i % 2 == 0:
            code_lines.append("        return torch.sum(x)")
        else:
            code_lines.append("        return x + 1")
        code_lines.append("")

    code = "\n".join(code_lines)

    # Test class with numerical method
    codeflash_output = is_numerical_code(code, "Class_10.method")
    result = codeflash_output  # 1.11ms -> 559μs (98.3% faster)

    # Test class without numerical method
    codeflash_output = is_numerical_code(code, "Class_11.method")
    result2 = codeflash_output  # 1.08ms -> 539μs (100% faster)


def test_large_function_body():
    """Test function with large body (many statements)."""
    code_lines = []
    code_lines.append("import numpy as np")
    code_lines.append("def large_func(x):")

    # Add 200 statements (staying under 1000)
    for i in range(200):
        if i == 150:
            code_lines.append("    result = np.sum(x)")
        else:
            code_lines.append(f"    var_{i} = {i}")

    code = "\n".join(code_lines)
    codeflash_output = is_numerical_code(code, "large_func")
    result = codeflash_output  # 1.78ms -> 1.19ms (49.0% faster)


def test_deeply_nested_code():
    """Test function with deeply nested control flow."""
    code = """
import torch
def nested_func(x):
    for i in range(10):
        if i > 5:
            for j in range(10):
                if j > 3:
                    while j > 0:
                        try:
                            result = torch.sum(x)
                        except:
                            pass
                        j -= 1
    return result
"""
    codeflash_output = is_numerical_code(code, "nested_func")
    result = codeflash_output  # 189μs -> 134μs (41.0% faster)


def test_many_imports():
    """Test with many import statements."""
    code_lines = []

    # Add 100 imports (50 regular, 50 numerical)
    for i in range(50):
        code_lines.append(f"import module_{i}")

    code_lines.append("import numpy as np")
    code_lines.append("import torch")

    for i in range(50, 100):
        code_lines.append(f"import module_{i}")

    code_lines.append("def func(x):")
    code_lines.append("    return np.sum(x)")

    code = "\n".join(code_lines)
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 405μs -> 248μs (63.2% faster)


def test_many_function_definitions():
    """Test file with 200 function definitions."""
    code_lines = []
    code_lines.append("import numpy as np")
    code_lines.append("")

    # Add 200 function definitions
    for i in range(200):
        code_lines.append(f"def func_{i}(x):")
        if i == 100:
            code_lines.append("    return np.sum(x)")
        else:
            code_lines.append("    return x")
        code_lines.append("")

    code = "\n".join(code_lines)

    # Test the numerical function among many
    codeflash_output = is_numerical_code(code, "func_100")
    result = codeflash_output  # 2.31ms -> 1.18ms (96.0% faster)

    # Test a non-numerical function
    codeflash_output = is_numerical_code(code, "func_50")
    result2 = codeflash_output  # 2.26ms -> 1.14ms (97.9% faster)


def test_performance_with_large_string_code():
    """Test that function handles large code string efficiently."""
    code_lines = []
    code_lines.append("import scipy")
    code_lines.append("")

    # Build a large code string with 500 lines
    for i in range(500):
        code_lines.append(f"# Comment line {i}")

    code_lines.append("def target_func(x):")
    code_lines.append("    return scipy.stats.norm.pdf(x)")

    code = "\n".join(code_lines)

    # Should complete without timeout/hanging
    codeflash_output = is_numerical_code(code, "target_func")
    result = codeflash_output  # 86.5μs -> 64.5μs (34.1% faster)


def test_mixed_imports_and_functions_scaled():
    """Test scaled version with mixed imports and functions."""
    code_lines = []

    # Add multiple import styles
    import_styles = [
        "import numpy as np",
        "import torch",
        "from scipy import optimize",
        "from jax import numpy as jnp",
        "import tensorflow as tf",
        "from numba import jit",
        "import math",
    ]

    for style in import_styles:
        code_lines.append(style)

    code_lines.append("")

    # Add 50 functions that use different libraries
    libraries = ["np", "torch", "optimize", "jnp", "tf", "jit", "math"]
    for i in range(50):
        code_lines.append(f"def func_{i}(x):")
        lib = libraries[i % len(libraries)]
        if lib == "np":
            code_lines.append("    return np.sum(x)")
        elif lib == "torch":
            code_lines.append("    return torch.sum(x)")
        elif lib == "optimize":
            code_lines.append("    return optimize.minimize(lambda x: x**2, x)")
        elif lib == "jnp":
            code_lines.append("    return jnp.sum(x)")
        elif lib == "tf":
            code_lines.append("    return tf.reduce_sum(x)")
        elif lib == "jit":
            code_lines.append("    return jit(lambda x: x)(x)")
        elif lib == "math":
            code_lines.append("    return math.sqrt(x)")
        code_lines.append("")

    code = "\n".join(code_lines)

    # Test several functions across the file
    codeflash_output = is_numerical_code(code, "func_0")
    result1 = codeflash_output  # 1.10ms -> 530μs (107% faster)

    codeflash_output = is_numerical_code(code, "func_25")
    result2 = codeflash_output  # 1.06ms -> 504μs (111% faster)

    codeflash_output = is_numerical_code(code, "func_49")
    result3 = codeflash_output  # 1.06ms -> 495μs (113% faster)


# ============================================================================
# ADDITIONAL COMPREHENSIVE TESTS
# ============================================================================


def test_return_type_is_boolean():
    """Ensure function always returns a boolean type."""
    test_cases = [
        ("import numpy as np\ndef f(x): return np.sum(x)", "f"),
        ("def f(x): return x + 1", "f"),
        ("", "nonexistent"),
        ("def f(x)\n return x", "f"),
    ]

    for code, func_name in test_cases:
        codeflash_output = is_numerical_code(code, func_name)
        result = codeflash_output  # 124μs -> 91.1μs (36.6% faster)


def test_consistent_results():
    """Test that repeated calls with same input give same result."""
    code = """
import numpy as np
def func(x):
    return np.sum(x)
"""

    # Call multiple times
    results = [is_numerical_code(code, "func") for _ in range(5)]  # 61.3μs -> 42.5μs (44.2% faster)


def test_different_indentation_styles():
    """Test with different indentation levels and styles."""
    code = """
import torch
def func(x):
  if True:
    return torch.sum(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 72.5μs -> 51.3μs (41.2% faster)


def test_multiple_statements_on_one_line():
    """Test with multiple statements on single line (though not recommended)."""
    code = """
import numpy as np
def func(x): return np.sum(x); pass
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 62.2μs -> 42.8μs (45.3% faster)


def test_backslash_continuation():
    """Test with line continuation using backslash."""
    code = """
import numpy as np
def func(x):
    result = np.\\
             sum(x)
    return result
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 72.5μs -> 49.8μs (45.5% faster)


def test_parenthesis_continuation():
    """Test with line continuation using parenthesis."""
    code = """
import numpy as np
def func(x):
    result = (
        np.sum(x)
    )
    return result
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 73.6μs -> 51.4μs (43.1% faster)


def test_numeric_function_name():
    """Test with numeric characters in function name."""
    code = """
import torch
def func_v2_0(x):
    return torch.sum(x)
"""
    codeflash_output = is_numerical_code(code, "func_v2_0")
    result = codeflash_output  # 59.5μs -> 41.0μs (45.3% faster)


def test_underscore_function_name():
    """Test with underscores in function name."""
    code = """
import jax.numpy as jnp
def __special_func__(x):
    return jnp.sum(x)
"""
    codeflash_output = is_numerical_code(code, "__special_func__")
    result = codeflash_output  # 61.1μs -> 42.3μs (44.5% faster)


def test_lambda_in_function():
    """Test function containing lambda expressions."""
    code = """
import numpy as np
def func(x):
    f = lambda y: np.sum(y)
    return f(x)
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 86.9μs -> 59.2μs (46.9% faster)


def test_list_comprehension_with_numerical():
    """Test function with list comprehension using numerical library."""
    code = """
import numpy as np
def func(x):
    return [np.sum(xi) for xi in x]
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 74.3μs -> 52.5μs (41.5% faster)


def test_dict_comprehension_with_numerical():
    """Test function with dict comprehension using numerical library."""
    code = """
import torch
def func(x):
    return {i: torch.sum(xi) for i, xi in enumerate(x)}
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 90.8μs -> 62.1μs (46.2% faster)


def test_set_comprehension_with_numerical():
    """Test function with set comprehension using numerical library."""
    code = """
import scipy
def func(x):
    return {scipy.stats.norm.pdf(xi) for xi in x}
"""
    codeflash_output = is_numerical_code(code, "func")
    result = codeflash_output  # 79.7μs -> 54.6μs (45.9% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1051-2026-01-14T21.27.58 and push.

Codeflash Static Badge

The optimized code achieves a **72% speedup** (from 30.1ms to 17.5ms) through two key optimizations:

## What Changed

### 1. Single-Pass AST Traversal (Major Optimization)
The original code made **multiple passes** over the AST:
- `ast.walk(tree)` to collect imports (88.4% of `_collect_numerical_imports` time)
- Separate iteration through `tree.body` to find the function

The optimized version **combines both operations** in `_collect_imports_and_find_function`, traversing `tree.body` only once to both collect imports and locate the target function.

### 2. Early-Exit Visitor Pattern
The `NumericalUsageChecker` class now implements `visit_Name` and `visit_Attribute` methods:
- **Short-circuits traversal** once numerical usage is detected (`if self.found_numerical: return`)
- Avoids visiting remaining AST nodes after finding the first numerical reference
- The original implementation had no visitor methods, so it traversed the entire function body even after finding numerical usage

## Why This Is Faster

**AST traversal cost**: The line profiler shows `ast.walk(tree)` consumed 88.4% of `_collect_numerical_imports` time in the original code. Eliminating this expensive full-tree walk and replacing it with a targeted single pass through `tree.body` dramatically reduces overhead.

**Short-circuit benefits**: Test results show the optimization is most effective on large codebases:
- `test_large_scale_many_lines_and_imports_performance_and_correctness`: **84.1% faster** (5.87ms → 3.19ms)
- `test_large_code_file_with_many_functions`: **106-111% faster** (1.6ms → 0.76ms)
- `test_many_function_definitions`: **96-98% faster** (2.3ms → 1.2ms)

For simple cases, gains are more modest (35-45% faster) due to smaller AST sizes and less traversal overhead.

## Impact on Workloads

This optimization is particularly valuable for:
- **Code analysis tools** scanning large Python files with many imports and functions
- **Static analysis pipelines** that need to classify functions rapidly
- **Hot paths** where `is_numerical_code` is called repeatedly on different functions in the same file (the `ast.parse` cost remains, but subsequent operations are much faster)

The optimization maintains correctness across all test cases while providing consistent speedups, especially for real-world codebases with hundreds of lines and dozens of imports.
@codeflash-ai codeflash-ai Bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Jan 14, 2026
@codeflash-ai codeflash-ai Bot closed this Jan 14, 2026
@codeflash-ai

codeflash-ai Bot commented Jan 14, 2026

Copy link
Copy Markdown
Contributor Author

This PR has been automatically closed because the original PR #1051 by aseembits93 was closed.

Base automatically changed from detect-numerical-code to main January 14, 2026 21:35
@codeflash-ai codeflash-ai Bot deleted the codeflash/optimize-pr1051-2026-01-14T21.27.58 branch January 14, 2026 21:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants