Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions playwright/_impl/_impl_to_api_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import contextvars
import inspect
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import greenlet

from playwright._impl._errors import Error
from playwright._impl._map import Map

Expand Down Expand Up @@ -118,11 +122,47 @@ def to_impl(
raise Error("Maximum argument depth exceeded")

def wrap_handler(self, handler: Callable[..., Any]) -> Callable[..., None]:
# Capture the caller's context at registration time so contextvars
# set in user code are available when the event handler runs, even
# though events are dispatched from a different greenlet/task.
# See: https://github.com/microsoft/playwright-python/issues/1816
context = contextvars.copy_context()
is_coroutine = inspect.iscoroutinefunction(handler)

def wrapper_func(*args: Any) -> Any:
arg_count = len(inspect.signature(handler).parameters)
return handler(
*list(map(lambda a: self.from_maybe_impl(a), args))[:arg_count]
)
mapped_args = list(map(lambda a: self.from_maybe_impl(a), args))[:arg_count]
if is_coroutine:
# Async-mode coroutine handler: propagate context to the
# handler's awaits by spawning an inner Task inside our
# captured context (Tasks copy the active context at
# construction).
async def _coro_wrapper() -> Any:
loop = asyncio.get_running_loop()
inner = context.run(lambda: loop.create_task(handler(*mapped_args)))
return await inner

return _coro_wrapper()
# Sync handler. Two cases:
# * Async mode: no greenlet is involved in event dispatch
# (asyncio Task), so we use Context.run to run the handler
# in the captured context.
# * Sync mode: events are dispatched inside a fresh
# EventGreenlet whose default gr_context is empty. We can't
# use Context.run here because handlers like route.fulfill
# internally use greenlet.switch, and Context.run does not
# compose with greenlet switches. Instead we set the
# greenlet's gr_context to our captured context for the
# duration of the handler, then restore it.
current = greenlet.getcurrent()
if current.parent is None:
return context.run(handler, *mapped_args)
saved_context = current.gr_context
current.gr_context = context
try:
return handler(*mapped_args)
finally:
current.gr_context = saved_context

if inspect.ismethod(handler):
wrapper = getattr(handler.__self__, IMPL_ATTR + handler.__name__, None)
Expand Down
38 changes: 38 additions & 0 deletions tests/async/test_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,3 +1605,41 @@ async def test_page_should_ignore_deprecated_is_hidden_and_visible_timeout(
await page.set_content("<div>foo</div>")
assert await page.is_hidden("div", timeout=10) is False
assert await page.is_visible("div", timeout=10) is True


async def test_should_propagate_contextvars_to_event_handlers(
page: Page, server: Server
) -> None:
import contextvars

shared_var: "contextvars.ContextVar[str]" = contextvars.ContextVar("shared_var")
shared_var.set("expected value")

sync_seen: List[Optional[str]] = []
async_seen: List[Optional[str]] = []

def on_request_sync(request: Any) -> None:
try:
sync_seen.append(shared_var.get())
except LookupError:
sync_seen.append(None)

async def on_request_async(request: Any) -> None:
try:
async_seen.append(shared_var.get())
except LookupError:
async_seen.append(None)
await asyncio.sleep(0)
try:
async_seen.append(shared_var.get())
except LookupError:
async_seen.append(None)

page.on("request", on_request_sync)
page.on("request", on_request_async)
await page.goto(server.EMPTY_PAGE)
await asyncio.sleep(0.1)
assert sync_seen
assert all(v == "expected value" for v in sync_seen)
assert async_seen
assert all(v == "expected value" for v in async_seen)
22 changes: 22 additions & 0 deletions tests/sync/test_page_event_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,25 @@ def gather_response(request: Request) -> dict:
url = server.PREFIX + f"/fetch?{i}"
expected.append({"url": url, "text": f"url:{url}"})
assert received == expected


def test_should_propagate_contextvars_to_event_handlers(
page: Page, server: Server
) -> None:
import contextvars

shared_var: "contextvars.ContextVar[str]" = contextvars.ContextVar("shared_var")
shared_var.set("expected value")

seen: list = []

def on_request(request: Request) -> None:
try:
seen.append(shared_var.get())
except LookupError:
seen.append(None)

page.on("request", on_request)
page.goto(server.EMPTY_PAGE)
assert seen
assert all(v == "expected value" for v in seen)
Loading