Skip to content

Commit

Permalink
Merge pull request #1613 from henryiii/henryiii/fix/mainif
Browse files Browse the repository at this point in the history
fix(setup.py): look inside if name == main block
  • Loading branch information
joerick authored Sep 18, 2023
2 parents 80a54b0 + f34ae77 commit f0feaff
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
50 changes: 48 additions & 2 deletions cibuildwheel/projectfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,43 @@
from ._compat import tomllib


def get_parent(node: ast.AST | None, depth: int = 1) -> ast.AST | None:
for _ in range(depth):
node = getattr(node, "parent", None)
return node


def is_main(parent: ast.AST | None) -> bool:
if parent is None:
return False

# This would be much nicer with 3.10's pattern matching!
if not isinstance(parent, ast.If):
return False
if not isinstance(parent.test, ast.Compare):
return False

try:
(op,) = parent.test.ops
(comp,) = parent.test.comparators
except ValueError:
return False

if not isinstance(op, ast.Eq):
return False

values = {comp, parent.test.left}

mains = {x for x in values if isinstance(x, ast.Constant) and x.value == "__main__"}
if len(mains) != 1:
return False
consts = {x for x in values if isinstance(x, ast.Name) and x.id == "__name__"}
if len(consts) != 1:
return False

return True


class Analyzer(ast.NodeVisitor):
def __init__(self) -> None:
self.requires_python: str | None = None
Expand All @@ -19,13 +56,22 @@ def visit(self, node: ast.AST) -> None:
super().visit(node)

def visit_keyword(self, node: ast.keyword) -> None:
# Must not be nested except for if __name__ == "__main__"

self.generic_visit(node)
# Must not be nested in an if or other structure
# This will be Module -> Expr -> Call -> keyword
parent = get_parent(node, 4)
unnested = parent is None

# This will be Module -> If -> Expr -> Call -> keyword
name_main_unnested = (
parent is not None and get_parent(parent) is None and is_main(get_parent(node, 3))
)

if (
node.arg == "python_requires"
and not hasattr(node.parent.parent.parent, "parent") # type: ignore[attr-defined]
and isinstance(node.value, ast.Constant)
and (unnested or name_main_unnested)
):
self.requires_python = node.value.value

Expand Down
66 changes: 66 additions & 0 deletions unit_test/projectfiles_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,72 @@ def test_read_setup_py_simple(tmp_path):
assert get_requires_python_str(tmp_path) == "1.23"


def test_read_setup_py_if_main(tmp_path):
with open(tmp_path / "setup.py", "w") as f:
f.write(
dedent(
"""
from setuptools import setup
if __name__ == "__main__":
setup(
name = "hello",
other = 23,
example = ["item", "other"],
python_requires = "1.23",
)
"""
)
)

assert setup_py_python_requires(tmp_path.joinpath("setup.py").read_text()) == "1.23"
assert get_requires_python_str(tmp_path) == "1.23"


def test_read_setup_py_if_main_reversed(tmp_path):
with open(tmp_path / "setup.py", "w") as f:
f.write(
dedent(
"""
from setuptools import setup
if "__main__" == __name__:
setup(
name = "hello",
other = 23,
example = ["item", "other"],
python_requires = "1.23",
)
"""
)
)

assert setup_py_python_requires(tmp_path.joinpath("setup.py").read_text()) == "1.23"
assert get_requires_python_str(tmp_path) == "1.23"


def test_read_setup_py_if_invalid(tmp_path):
with open(tmp_path / "setup.py", "w") as f:
f.write(
dedent(
"""
from setuptools import setup
if True:
setup(
name = "hello",
other = 23,
example = ["item", "other"],
python_requires = "1.23",
)
"""
)
)

assert not setup_py_python_requires(tmp_path.joinpath("setup.py").read_text())
assert not get_requires_python_str(tmp_path)


def test_read_setup_py_full(tmp_path):
with open(tmp_path / "setup.py", "w", encoding="utf8") as f:
f.write(
Expand Down

0 comments on commit f0feaff

Please sign in to comment.