AST Parser for detecting numerical code (numpy/scipy/torch/tensorflow/jax/numba)#1051
Conversation
|
@claude review this |
|
Claude finished @aseembits93's task in 1m 14s —— View job Code Review Complete
SummaryThis PR adds an AST parser to detect if a function uses numerical computing libraries (numpy, scipy, torch, tensorflow, jax, numba). The implementation is solid with excellent test coverage (819 lines of tests!). However, I've identified several issues that should be addressed: Critical Issues1. Async Functions Not Supported (codeflash/code_utils/code_extractor.py:1264-1299)The Current code: def _find_function_node(tree: ast.Module, name_parts: list[str]) -> ast.FunctionDef | None:
"""Find a function node in the AST given its qualified name parts.
Note: This function only finds regular (sync) functions, not async functions.
...
"""The docstring admits this limitation, but one of the tests (line 531-537) expects async functions to work! This is a critical bug. Impact: Any async function using numerical libraries will be incorrectly classified as non-numerical code. 2. Star Imports Not Properly Handled (codeflash/code_utils/code_extractor.py:1252-1259)Star imports from numerical libraries are collected but not properly tracked. When code uses a bare name from a star import (e.g., Current behavior: from numpy import *
def func(x):
return array(x) # Won't be detected as numerical codeRecommendation: Either fix this or document it clearly in the function's docstring as a known limitation. High Priority Issues3. Inconsistent Return Type Annotations (codeflash/code_utils/code_extractor.py:1362)Line 1362 uses Minor but affects code quality standards. 4. Missing Edge Case: Empty Code in Nested Classes (codeflash/code_utils/code_extractor.py:1288-1296)The function returns Test case missing: code = """
import numpy as np
class Outer:
class Inner:
def other_method(self):
return np.sum([1])
"""
is_numerical_code(code, "Outer.Inner.method") # method doesn't existMedium Priority Issues5. Performance: Short-Circuit Optimization Missing (codeflash/code_utils/code_extractor.py:1190-1198)In Suggestion: Add 6. Test Organization: Class Grouping (tests/test_is_numerical_code.py)The tests are excellently comprehensive, but some related test classes could be merged:
This is a minor organizational issue and doesn't affect functionality. Low Priority / Suggestions7. Type Hint Consistency (codeflash/code_utils/code_extractor.py:1217)Line 1217 uses 8. Docstring Enhancement (codeflash/code_utils/code_extractor.py:1301-1336)The Suggested addition: """
...existing docstring...
Known Limitations:
- Star imports are not fully resolved (e.g., `from numpy import *` followed by `array(x)`)
- [If bug not fixed] Async functions are not currently supported
"""Positive Observations
Verification of PR DescriptionThe PR description states: "In the case of numpy/scipy it will return False when numba is not present in the environment" This is correct and verified by:
RecommendationsMust fix before merge:
Should fix before merge: Nice to have: |
⚡️ Codeflash found optimizations for this PR📄 228% (2.28x) speedup for
|
| if isinstance(node, ast.Name): | ||
| return node.id | ||
| if isinstance(node, ast.Attribute): | ||
| return self._get_root_name(node.value) |
There was a problem hiding this comment.
⚡️Codeflash found 36% (0.36x) speedup for NumericalUsageChecker._get_root_name in codeflash/code_utils/code_extractor.py
⏱️ Runtime : 992 microseconds → 729 microseconds (best of 215 runs)
📝 Explanation and details
The optimization replaces recursion with iteration when traversing AST attribute chains like np.array or np.linalg.inv.
Key Change:
- Original: Recursively calls
_get_root_name(node.value)for eachast.Attributenode, creating new stack frames - Optimized: Uses a
whileloop to iteratively follownode.valuereferences until reaching the baseast.Namenode
Why This Is Faster:
-
Eliminates function call overhead: Each recursive call incurs the cost of creating a new stack frame, passing arguments, and managing return values. The line profiler shows this clearly—the recursive call in the original code takes 54.2% of total time (4.9ms out of 9ms). The iterative version reduces this to a simple variable assignment (23% of 4.3ms total).
-
Reduces isinstance() checks: The recursive version checks
isinstance(node, ast.Name)at every level of the chain before recursing. The iterative version first loops through allast.Attributenodes, then performs theast.Namecheck only once at the end. For a chain of depth N, this cuts isinstance checks roughly in half. -
Better CPU cache locality: Iteration keeps execution in a tight loop within the same stack frame, improving instruction cache utilization compared to jumping between recursive call sites.
Performance Profile:
The speedup is most dramatic on deep attribute chains:
- Deep chains (500 levels): 295% faster (116μs → 29.5μs)
- Medium chains (200 levels): 253% faster (126μs → 36μs)
- Overall workload: 36% faster (992μs → 729μs)
For shallow chains (2-3 levels deep), the overhead difference is smaller but still measurable (1-8% improvements on basic cases). The optimization particularly shines when the function is called repeatedly in hot paths, such as during AST traversal of large codebases, where eliminating per-call overhead compounds across thousands of invocations (5,488 hits in the profiler).
The iterative approach maintains identical semantics—handling edge cases (non-Name/Attribute nodes return None) and preserving exact return values across all test scenarios.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 24 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
import ast
import pytest # used for our unit tests
from codeflash.code_utils.code_extractor import NumericalUsageChecker
# unit tests
# Basic Test Cases
def test_basic_name_node_returns_id():
# Parse a bare name expression and ensure the root name is the identifier itself.
expr = ast.parse("np", mode="eval").body # ast.Name
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr) # 511ns -> 871ns (41.3% slower)
@pytest.mark.parametrize(
"source,expected_root",
[
("np.array", "np"), # simple attribute access
("np.linalg.inv", "np"), # multi-level attribute: expect the top-most name
("pd.DataFrame", "pd"), # different root name
("self.value", "self"), # common pythonic root 'self'
],
)
def test_basic_attribute_chains_return_root_name(source, expected_root):
# For various attribute chains ensure the root name is returned.
expr = ast.parse(source, mode="eval").body # ast.Attribute (possibly nested)
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr) # 3.51μs -> 3.57μs (1.87% slower)
# Edge Test Cases
def test_non_name_nodes_return_none():
# Constants and literals are not name/attribute expressions — should return None.
const_expr = ast.parse("42", mode="eval").body # ast.Constant
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(const_expr) # 571ns -> 592ns (3.55% slower)
# Binary operations are not Name/Attribute -> None
binop_expr = ast.parse("x + y", mode="eval").body # ast.BinOp
codeflash_output = checker._get_root_name(binop_expr) # 311ns -> 321ns (3.12% slower)
def test_attribute_on_non_name_base_returns_none():
# Attribute on top of an expression whose root is not a Name (e.g., (x+1).attr)
expr = ast.parse("(x + 1).attr", mode="eval").body # ast.Attribute with value ast.BinOp
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr) # 892ns -> 752ns (18.6% faster)
# Attribute of a call: np.array().x — the .value is ast.Call, so root cannot be extracted -> None
expr2 = ast.parse("np.array().x", mode="eval").body
codeflash_output = checker._get_root_name(expr2) # 451ns -> 421ns (7.13% faster)
# Attribute of a subscript: arr[0].x -> .value is ast.Subscript -> None
expr3 = ast.parse("arr[0].x", mode="eval").body
codeflash_output = checker._get_root_name(expr3) # 421ns -> 371ns (13.5% faster)
def test_none_input_returns_none_and_is_harmless():
# While the function expects ast.expr, calling with None must not raise and should return None
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(None) # 611ns -> 611ns (0.000% faster)
def test_constructed_ast_nodes_work_as_well():
# Build an Attribute chain programmatically:
# create node representing root.a.b.c where root is Name('root')
root_name = ast.Name(id="root", ctx=ast.Load())
attr1 = ast.Attribute(value=root_name, attr="a", ctx=ast.Load())
attr2 = ast.Attribute(value=attr1, attr="b", ctx=ast.Load())
attr3 = ast.Attribute(value=attr2, attr="c", ctx=ast.Load())
# Ensure our constructed node returns the original root id
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(attr3) # 1.06μs -> 1.13μs (6.18% slower)
def test_idempotence_and_no_side_effects():
# Calling _get_root_name twice on the same node should yield the same result and not mutate it.
expr = ast.parse("np.array", mode="eval").body
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr)
first = codeflash_output # 781ns -> 821ns (4.87% slower)
codeflash_output = checker._get_root_name(expr)
second = codeflash_output # 440ns -> 420ns (4.76% faster)
# Large Scale Test Cases (stress within limits)
def test_deep_attribute_chain_large_scale():
# Construct a deep attribute chain via source string to test recursion depth handling.
# Keep depth under Python recursion limit (default ~1000). We choose 500 here (reasonable stress test).
depth = 500
# Construct source like "root.a.a.a...." with depth elements; root should be 'root'
source = "root" + "".join(f".a{i}" for i in range(depth - 1))
# Parse expression in eval mode to get AST
expr = ast.parse(source, mode="eval").body
# Sanity checks: the parsed top-level node is an Attribute if depth>1 else Name
checker = NumericalUsageChecker(set())
# Ensure the function returns 'root' even for a very deep chain
codeflash_output = checker._get_root_name(expr) # 116μs -> 29.5μs (295% faster)
def test_multiple_distinct_large_chains_quickly():
# Create several moderately-large chains and ensure each is resolved correctly.
# We keep each chain length modest (e.g., 200) and the count small to respect test performance.
chain_len = 200
checker = NumericalUsageChecker(set())
for root in ("alpha", "beta", "gamma"):
# Build source like 'alpha.x0.x1.x2...'
source = root + "".join(f".f{i}" for i in range(chain_len - 1))
expr = ast.parse(source, mode="eval").body
# Each should yield its corresponding root name
codeflash_output = checker._get_root_name(expr) # 126μs -> 36.0μs (253% faster)
# Mutation-sensitive tests
def test_would_detect_returning_attribute_name_instead_of_root():
# This test ensures that mutations that return the last attribute name (e.g., "array")
# instead of the root ("np") will fail.
expr = ast.parse("np.array", mode="eval").body
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr)
result = codeflash_output # 791ns -> 841ns (5.95% slower)
# Ensure that other valid identifier names are preserved (including underscores and numbers)
@pytest.mark.parametrize(
"identifier",
[
"x1", # numeric char in name
"_private", # leading underscore
"CamelCase", # capitalized identifier
"with_underscores_and_numbers123",
],
)
def test_various_identifier_forms(identifier):
# Build an attribute like "<identifier>.attr" and verify the root id is preserved exactly.
source = f"{identifier}.attr"
expr = ast.parse(source, mode="eval").body
checker = NumericalUsageChecker(set())
codeflash_output = checker._get_root_name(expr) # 3.06μs -> 3.35μs (8.67% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.To test or edit this optimization locally git merge codeflash/optimize-pr1051-2026-01-14T02.39.18
| if isinstance(node, ast.Name): | |
| return node.id | |
| if isinstance(node, ast.Attribute): | |
| return self._get_root_name(node.value) | |
| while isinstance(node, ast.Attribute): | |
| node = node.value | |
| if isinstance(node, ast.Name): | |
| return node.id |
| for node in ast.walk(tree): | ||
| if isinstance(node, ast.Import): | ||
| for alias in node.names: | ||
| # import numpy or import numpy as np | ||
| module_root = alias.name.split(".")[0] | ||
| if module_root in NUMERICAL_MODULES: | ||
| # Use the alias if present, otherwise the module name | ||
| name = alias.asname if alias.asname else alias.name.split(".")[0] | ||
| numerical_names.add(name) | ||
| modules_used.add(module_root) | ||
| elif isinstance(node, ast.ImportFrom) and node.module: | ||
| module_root = node.module.split(".")[0] | ||
| if module_root in NUMERICAL_MODULES: | ||
| # from numpy import array, zeros as z | ||
| for alias in node.names: | ||
| if alias.name == "*": | ||
| # Can't track star imports, but mark the module as numerical | ||
| numerical_names.add(module_root) | ||
| else: | ||
| name = alias.asname if alias.asname else alias.name | ||
| numerical_names.add(name) | ||
| modules_used.add(module_root) | ||
|
|
There was a problem hiding this comment.
⚡️Codeflash found 343% (3.43x) speedup for _collect_numerical_imports in codeflash/code_utils/code_extractor.py
⏱️ Runtime : 1.48 milliseconds → 333 microseconds (best of 106 runs)
📝 Explanation and details
The optimized code achieves a 343% speedup by replacing ast.walk() with a targeted manual traversal strategy that only visits AST nodes where imports can appear.
Key Optimization
What changed: Replaced the generic ast.walk(tree) traversal with a manual stack-based traversal that only processes nodes with a body attribute (like functions, classes, control flow structures).
Why it's faster:
ast.walk()visits every single node in the AST tree recursively (1,325 nodes in the profiler), including all expressions, literals, operators, and statements- The manual traversal only visits structural containers that can contain imports (101 while-loop iterations in the profiler), skipping ~92% of unnecessary nodes
- Import statements only appear at the top level or inside function/class definitions, never inside expressions like
x + yor literals
Performance impact from line profiler:
- Original: 76.2% of time (7.6ms) spent in the
ast.walk()loop - Optimized: The while loop and node processing combined take ~41% of time (0.9ms total), saving 6.7ms
How It Works
The optimized version maintains a stack of nodes to check, only descending into:
- Function/class definitions (
ast.FunctionDef,ast.ClassDef, etc.) - Control flow structures that can contain nested scopes (
ast.If,ast.For,ast.Try, etc.) orelsebranches (else clauses)
This selective traversal finds all imports without wasting time visiting arithmetic operations, variable references, or other leaf nodes.
Impact Context
Based on function_references, this function is called by is_numerical_code(), which analyzes whether functions use numerical libraries. The optimization is particularly valuable when:
- Analyzing large files with complex ASTs (as shown in
test_large_ast_with_many_non_import_nodes) - Processing many functions in a codebase during static analysis
- The AST contains extensive nested logic unrelated to imports
The test cases confirm the optimization works well across all scenarios: simple imports, star imports, submodule imports, and complex real-world code with many non-import nodes.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 45 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
import ast
import pytest
from codeflash.code_utils.code_extractor import _collect_numerical_imports
class TestCollectNumericalImportsBasic:
"""Basic test cases for _collect_numerical_imports function."""
def test_single_numpy_import(self):
"""Test basic numpy import without alias."""
code = "import numpy"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_single_numpy_import_with_alias(self):
"""Test numpy import with common alias 'np'."""
code = "import numpy as np"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_single_torch_import(self):
"""Test basic torch import."""
code = "import torch"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_single_torch_import_with_alias(self):
"""Test torch import with alias."""
code = "import torch as t"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_numpy_import_single_function(self):
"""Test importing a single function from numpy."""
code = "from numpy import array"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_numpy_import_function_with_alias(self):
"""Test importing from numpy with alias."""
code = "from numpy import array as arr"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_numpy_import_multiple_functions(self):
"""Test importing multiple functions from numpy."""
code = "from numpy import array, zeros, ones"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_numpy_import_mixed_aliases(self):
"""Test importing from numpy with mixed aliases."""
code = "from numpy import array as arr, zeros, ones as o"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_multiple_imports_same_module(self):
"""Test multiple imports from the same numerical module."""
code = """
import numpy
from numpy import array as arr
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_multiple_different_numerical_modules(self):
"""Test imports from different numerical modules."""
code = """
import numpy as np
import torch
from scipy import stats
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_non_numerical_import(self):
"""Test that non-numerical imports are ignored."""
code = "import os"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_mixed_numerical_and_non_numerical(self):
"""Test mixing numerical and non-numerical imports."""
code = """
import os
import numpy as np
import sys
import torch
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
class TestCollectNumericalImportsEdge:
"""Edge case tests for _collect_numerical_imports function."""
def test_empty_code(self):
"""Test with empty code."""
code = ""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_star_import_from_numpy(self):
"""Test star import from numpy."""
code = "from numpy import *"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_star_import_from_torch(self):
"""Test star import from torch."""
code = "from torch import *"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_submodule_import_numpy(self):
"""Test importing submodule of numpy."""
code = "import numpy.random"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_submodule_with_alias(self):
"""Test importing submodule with alias."""
code = "import numpy.random as nr"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_torch_submodule_import(self):
"""Test importing from torch submodule."""
code = "from torch.nn import Linear"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_torch_submodule_with_alias(self):
"""Test importing from torch submodule with alias."""
code = "from torch.nn import Linear as L"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_math_module_import(self):
"""Test math module import."""
code = "import math"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_math_import(self):
"""Test importing from math module."""
code = "from math import sqrt, pi"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_jax_import(self):
"""Test jax module import."""
code = "import jax"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_tensorflow_import(self):
"""Test tensorflow module import."""
code = "import tensorflow"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_numba_import(self):
"""Test numba module import."""
code = "import numba"
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_all_numerical_modules(self):
"""Test importing all numerical modules."""
code = """
import numpy
import torch
import numba
import jax
import tensorflow
import math
import scipy
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_duplicate_imports_same_name(self):
"""Test duplicate imports with same name."""
code = """
import numpy
import numpy
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_duplicate_imports_different_aliases(self):
"""Test duplicate imports with different aliases."""
code = """
import numpy as np
import numpy as numpy_lib
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_from_import_multiple_times_same_module(self):
"""Test multiple from imports from the same module."""
code = """
from numpy import array
from numpy import zeros
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def my_function():
pass
from jax import jit
from math import sqrt
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_large_ast_with_many_non_import_nodes(self):
"""Test AST with many non-import nodes and several imports."""
code = """
import numpy as np
from torch import tensor
class MyClass:
def __init__(self):
self.x = 5
def method1(self):
return self.x * 2
def method2(self):
for i in range(10):
if i > 5:
pass
def function1():
x = 10
y = 20
return x + y
def function2(a, b, c):
if a > b:
return a
elif b > c:
return b
else:
return c
from scipy.stats import norm
from jax import jit
x = [i for i in range(50)]
y = {k: v for k, v in enumerate(range(50))}
z = {i for i in range(50)}
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_all_numerical_modules_with_aliases(self):
"""Test all numerical modules with various aliases."""
imports = [
"import numpy as np",
"import torch as t",
"import numba as nb",
"import jax as j",
"import tensorflow as tf",
"import math as m",
"import scipy as sp",
]
code = "\n".join(imports)
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_stress_test_many_star_imports(self):
"""Test many star imports from different modules."""
imports = [
"from numpy import *",
"from torch import *",
"from scipy import *",
"from jax import *",
"from tensorflow import *",
"from math import *",
"from numba import *",
]
code = "\n".join(imports)
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
def test_complex_real_world_imports(self):
"""Test realistic complex import structure."""
code = """
import numpy as np
import numpy.random as npr
from numpy import array, zeros, ones as np_ones
import torch
from torch.nn import Linear, Conv2d as Conv
from torch.optim import Adam
import scipy.stats as stats
from scipy.integrate import odeint
from jax import jit, vmap
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D
import math
from math import sqrt, sin as sine_function
import numba
from numba import njit
"""
tree = ast.parse(code)
numerical_names, modules_used = _collect_numerical_imports(tree)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.To test or edit this optimization locally git merge codeflash/optimize-pr1051-2026-01-14T02.52.17
Click to see suggested changes
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| # import numpy or import numpy as np | |
| module_root = alias.name.split(".")[0] | |
| if module_root in NUMERICAL_MODULES: | |
| # Use the alias if present, otherwise the module name | |
| name = alias.asname if alias.asname else alias.name.split(".")[0] | |
| numerical_names.add(name) | |
| modules_used.add(module_root) | |
| elif isinstance(node, ast.ImportFrom) and node.module: | |
| module_root = node.module.split(".")[0] | |
| if module_root in NUMERICAL_MODULES: | |
| # from numpy import array, zeros as z | |
| for alias in node.names: | |
| if alias.name == "*": | |
| # Can't track star imports, but mark the module as numerical | |
| numerical_names.add(module_root) | |
| else: | |
| name = alias.asname if alias.asname else alias.name | |
| numerical_names.add(name) | |
| modules_used.add(module_root) | |
| # Iterate directly through the AST body instead of using ast.walk() for better performance | |
| nodes_to_check = [tree] | |
| while nodes_to_check: | |
| current = nodes_to_check.pop() | |
| body = getattr(current, "body", None) | |
| if body: | |
| for node in body: | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| # import numpy or import numpy as np | |
| module_root = alias.name.split(".")[0] | |
| if module_root in NUMERICAL_MODULES: | |
| # Use the alias if present, otherwise the module name | |
| name = alias.asname if alias.asname else alias.name.split(".")[0] | |
| numerical_names.add(name) | |
| modules_used.add(module_root) | |
| elif isinstance(node, ast.ImportFrom) and node.module: | |
| module_root = node.module.split(".")[0] | |
| if module_root in NUMERICAL_MODULES: | |
| # from numpy import array, zeros as z | |
| for alias in node.names: | |
| if alias.name == "*": | |
| # Can't track star imports, but mark the module as numerical | |
| numerical_names.add(module_root) | |
| else: | |
| name = alias.asname if alias.asname else alias.name | |
| numerical_names.add(name) | |
| modules_used.add(module_root) | |
| elif isinstance( | |
| node, | |
| ( | |
| ast.FunctionDef, | |
| ast.AsyncFunctionDef, | |
| ast.ClassDef, | |
| ast.If, | |
| ast.For, | |
| ast.AsyncFor, | |
| ast.While, | |
| ast.With, | |
| ast.AsyncWith, | |
| ast.Try, | |
| ast.ExceptHandler, | |
| ), | |
| ): | |
| nodes_to_check.append(node) | |
| orelse = getattr(current, "orelse", None) | |
| if orelse: | |
| for node in orelse: | |
| if isinstance( | |
| node, | |
| ( | |
| ast.FunctionDef, | |
| ast.AsyncFunctionDef, | |
| ast.ClassDef, | |
| ast.If, | |
| ast.For, | |
| ast.AsyncFor, | |
| ast.While, | |
| ast.With, | |
| ast.AsyncWith, | |
| ast.Try, | |
| ), | |
| ): | |
| nodes_to_check.append(node) | |
|
As with the other PR you can skip anything async for this, these sort of things apply mainly for CPU bound things, which asyncio doesn't do |
Agreed, I have removed async visitors here. ready to review |
⚡️ Codeflash found optimizations for this PR📄 73% (0.73x) speedup for
|
In the case of numpy/scipy/math it will return False when numba is not present in the environment because this PR is part of a system which will try to use Just-in-Time compilation which is part of numba to compile numpy/scipy/math functions.