diff --git a/docs/changelog.rst b/docs/changelog.rst index 5ac72f4..c704870 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,7 @@ Changelog Unreleased ========== - Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ +- :ref:`ASYNC910 ` and :ref:`ASYNC911 ` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) `_ 25.7.1 ====== diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 67f9d51..123fba7 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -465,6 +465,21 @@ def __init__(self, *args: Any, **kwargs: Any): # used to transfer new body between visit_FunctionDef and leave_FunctionDef self.new_body: cst.BaseSuite | None = None + # Tracks whether the current scope is a class body and, if so, which of + # `__aenter__`/`__aexit__` are directly defined on it (values: True if + # that method contains a checkpoint-like construct, False otherwise, + # missing key if not defined). Used to exempt async context manager + # methods from ASYNC910/911 when their partner method provides the + # checkpoint, or when the partner is inherited from a base class. + self.async_cm_class: dict[str, bool] | None = None + # Whether the enclosing class has an explicit base class (other than + # implicit `object`). We only assume a missing partner is inherited if + # the class actually inherits from something. + self.async_cm_class_has_bases = False + # Set on entry to an exempt `__aenter__`/`__aexit__` so that + # `error_91x` skips emitting ASYNC910/911. + self.exempt_async_cm_method = False + def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: if code is None: code = "ASYNC911" if self.has_yield else "ASYNC910" @@ -532,6 +547,60 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: self.suppress_imported_as.append("suppress") return + # Async context manager methods may legitimately skip checkpointing if the + # partner method provides the checkpoint, or if the partner is inherited + # from a base class (which we charitably assume contains a checkpoint). + # See https://github.com/python-trio/flake8-async/issues/441. + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.save_state(node, "async_cm_class", "async_cm_class_has_bases") + defined: dict[str, bool] = {} + checkpointy = ( + m.Await() + | m.With(asynchronous=m.Asynchronous()) + | m.For(asynchronous=m.Asynchronous()) + ) + if isinstance(node.body, cst.IndentedBlock): + for stmt in node.body.body: + if ( + isinstance(stmt, cst.FunctionDef) + and stmt.asynchronous is not None + and stmt.name.value in ("__aenter__", "__aexit__") + ): + defined[stmt.name.value] = bool(m.findall(stmt, checkpointy)) + self.async_cm_class = defined + # Keyword args like `metaclass=` are in `node.keywords`, not `bases`. + self.async_cm_class_has_bases = bool(node.bases) + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + self.restore_state(original_node) + return updated_node + + def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool: + if self.async_cm_class is None: + return False + name = node.name.value + if name not in ("__aenter__", "__aexit__"): + return False + if name not in self.async_cm_class: + return False + # A method that contains any checkpoint must always checkpoint: we + # still check it normally so conditional checkpoints are flagged. + if self.async_cm_class[name]: + return False + partner = "__aexit__" if name == "__aenter__" else "__aenter__" + if partner in self.async_cm_class: + # Partner is defined on the class; if it checkpoints, we're fine. + if self.async_cm_class[partner]: + return True + # Neither method checkpoints -- to avoid double-flagging (and a + # redundant autofix), we report and fix only `__aenter__`. + return name == "__aexit__" + # Partner is not defined on this class; only assume it is inherited + # (and contains a checkpoint) if the class inherits from something. + return self.async_cm_class_has_bases + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # `await` in default values happen in parent scope # we also know we don't ever modify parameters so we can ignore the return value @@ -543,6 +612,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: if func_has_decorator(node, "overload", "fixture") or func_empty_body(node): return False # subnodes can be ignored + is_exempt_cm = self._is_exempt_async_cm_method(node) + self.save_state( node, "has_yield", @@ -557,6 +628,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: "suppress_imported_as", # a copy is saved, but state is not reset "except_depth", "add_checkpoint_at_function_start", + "async_cm_class", + "async_cm_class_has_bases", + "exempt_async_cm_method", copy=True, ) self.uncheckpointed_statements = set() @@ -567,6 +641,10 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.taskgroup_has_start_soon = {} self.except_depth = 0 self.add_checkpoint_at_function_start = False + # Class-level context does not apply to nested scopes. + self.async_cm_class = None + self.async_cm_class_has_bases = False + self.exempt_async_cm_method = is_exempt_cm self.async_function = ( node.asynchronous is not None @@ -747,6 +825,12 @@ def error_91x( ) -> bool: assert not isinstance(statement, ArtificialStatement), statement + # Exempt `__aenter__`/`__aexit__` when the partner method contains a + # checkpoint, or when the partner is missing and charitably assumed + # inherited. + if self.exempt_async_cm_method: + return False + if isinstance(node, cst.FunctionDef): msg = "exit" else: diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 0d67a69..2b5baa4 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -636,3 +636,123 @@ async def foo_nested_empty_async(): async def bar(): ... await foo() + + +# Issue #441: async context manager methods may legitimately skip checkpointing +# if the partner method provides the checkpoint, or if the partner is inherited. +class ACM: # a dummy base to opt into the charitable-inheritance assumption + pass + + +class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + print("fast exit") + + +class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast + async def __aenter__(self): + print("fast setup") + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +class CtxWithBothCheckpoint: # safe: both checkpoint + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +# fmt: off +class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy + print("teardown") +# fmt: on + + +# A method that contains any checkpoint is still required to always checkpoint. +class CtxAenterConditionalAexitFast(ACM): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() + await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): + print("fast exit") + + +# Only one method defined: charitably assume the other is inherited with a +# checkpoint -- but only when the class inherits from something. +class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited + async def __aenter__(self): + print("setup") + + +class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited + async def __aexit__(self, *a): + print("teardown") + + +# fmt: off +class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + +class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") + await trio.lowlevel.checkpoint() +# fmt: on + + +class CtxOnlyAenterWithCheckpoint: # safe + async def __aenter__(self): + await foo() + + +class CtxOnlyAexitWithCheckpoint: # safe + async def __aexit__(self, *a): + await foo() + + +# keyword-only bases (like `metaclass=`) don't count as inheriting. +class Meta(type): + pass + + +class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + +# a nested function named `__aenter__` inside another function is not a method +def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + +# class nested inside a function still gets the exemption when it inherits +def factory(): + class NestedCtx(ACM): # safe + async def __aenter__(self): + print("setup") + + +# nested class; outer class has nothing relevant +class Outer: + class Inner(ACM): # safe: charitable inheritance for __aexit__ + async def __aenter__(self): + print("setup") diff --git a/tests/autofix_files/async910.py.diff b/tests/autofix_files/async910.py.diff index c765401..9c55839 100644 --- a/tests/autofix_files/async910.py.diff +++ b/tests/autofix_files/async910.py.diff @@ -223,3 +223,48 @@ async def foo_nested_empty_async(): +@@ x,6 x,7 @@ + class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy + print("teardown") +@@ x,6 x,7 @@ + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() ++ await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): + print("fast exit") +@@ x,11 x,13 @@ + class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") ++ await trio.lowlevel.checkpoint() + # fmt: on + + +@@ x,12 x,14 @@ + class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + # a nested function named `__aenter__` inside another function is not a method + def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + # class nested inside a function still gets the exemption when it inherits diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index d370155..2f2850d 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -606,3 +606,117 @@ async def foo_nested_empty_async(): async def bar(): ... await foo() + + +# Issue #441: async context manager methods may legitimately skip checkpointing +# if the partner method provides the checkpoint, or if the partner is inherited. +class ACM: # a dummy base to opt into the charitable-inheritance assumption + pass + + +class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + print("fast exit") + + +class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast + async def __aenter__(self): + print("fast setup") + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +class CtxWithBothCheckpoint: # safe: both checkpoint + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +# fmt: off +class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy + print("teardown") +# fmt: on + + +# A method that contains any checkpoint is still required to always checkpoint. +class CtxAenterConditionalAexitFast(ACM): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() + + async def __aexit__(self, *a): + print("fast exit") + + +# Only one method defined: charitably assume the other is inherited with a +# checkpoint -- but only when the class inherits from something. +class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited + async def __aenter__(self): + print("setup") + + +class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited + async def __aexit__(self, *a): + print("teardown") + + +# fmt: off +class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + +class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") +# fmt: on + + +class CtxOnlyAenterWithCheckpoint: # safe + async def __aenter__(self): + await foo() + + +class CtxOnlyAexitWithCheckpoint: # safe + async def __aexit__(self, *a): + await foo() + + +# keyword-only bases (like `metaclass=`) don't count as inheriting. +class Meta(type): + pass + + +class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + +# a nested function named `__aenter__` inside another function is not a method +def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + +# class nested inside a function still gets the exemption when it inherits +def factory(): + class NestedCtx(ACM): # safe + async def __aenter__(self): + print("setup") + + +# nested class; outer class has nothing relevant +class Outer: + class Inner(ACM): # safe: charitable inheritance for __aexit__ + async def __aenter__(self): + print("setup")