Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,17 @@ def get_project_from_pypi(


class BaseProvider(ExtrasProvider):
"""Base class for Fromager's dependency resolver (resolvelib + extras).

Subclasses implement ``find_candidates``, ``cache_key``, and
``provider_description`` to list versions from PyPI, a version map, etc.

Candidate lists are cached per package in one global dict.

``find_matches`` keeps only versions that fit the requirements and
constraints, then picks newest first.
"""

resolver_cache: typing.ClassVar[ResolverCache] = {}
provider_description: typing.ClassVar[str]
_cooldown_unsupported_warned: typing.ClassVar[set[str]] = set()
Expand Down Expand Up @@ -516,7 +527,7 @@ def identify(self, requirement_or_candidate: Requirement | Candidate) -> str:

@classmethod
def clear_cache(cls, identifier: str | None = None) -> None:
"""Clear global resolver cache
"""Clear global resolver cache.

``None`` clears all caches, an ``identifier`` string clears the
cache for an identifier. Raises :exc:`KeyError` for unknown
Expand Down Expand Up @@ -657,45 +668,55 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]:
# return candidate.dependencies
return []

def _get_cached_candidates(self, identifier: str) -> list[Candidate]:
"""Get list of cached candidates for identifier and provider
def _get_cached_candidates(self, identifier: str) -> list[Candidate] | None:
"""Get a copy of cached candidates for identifier and provider.

The method always returns a list. If the cache did not have an entry
before, a new empty list is stored in the cache and returned to the
caller. The caller can mutate the list in place to update the cache.
Returns ``None`` if no entry exists in the cache, or a copy of the
cached list (which may be empty). A copy is returned so callers
cannot accidentally corrupt the cache.
"""
cls = type(self)
provider_cache = cls.resolver_cache.get(identifier, {})
candidate_cache = provider_cache.get((cls, self.cache_key))
if candidate_cache is None:
return None
return list(candidate_cache)

def _set_cached_candidates(
self, identifier: str, candidates: list[Candidate]
) -> None:
"""Store candidates in the cache for identifier and provider."""
cls = type(self)
provider_cache = cls.resolver_cache.setdefault(identifier, {})
candidate_cache = provider_cache.setdefault((cls, self.cache_key), [])
return candidate_cache
provider_cache[(cls, self.cache_key)] = list(candidates)

def _find_cached_candidates(self, identifier: str) -> Candidates:
"""Find candidates with caching"""
cached_candidates: list[Candidate] = []
if self.use_cache_candidates:
cached_candidates = self._get_cached_candidates(identifier)
if cached_candidates:
logger.debug(
"%s: use %i cached candidates",
identifier,
len(cached_candidates),
)
return cached_candidates
candidates = list(self.find_candidates(identifier))
if self.use_cache_candidates:
# mutate list object in-place
cached_candidates[:] = candidates
"""Find candidates with caching."""
if not self.use_cache_candidates:
candidates = list(self.find_candidates(identifier))
logger.debug(
"%s: cache %i unfiltered candidates",
"%s: got %i unfiltered candidates, ignoring cache",
identifier,
len(candidates),
)
Comment thread
LalatenduMohanty marked this conversation as resolved.
else:
return candidates

cached_candidates = self._get_cached_candidates(identifier)
if cached_candidates is not None:
logger.debug(
"%s: got %i unfiltered candidates, ignoring cache",
"%s: use %i cached candidates",
identifier,
len(candidates),
len(cached_candidates),
)
return cached_candidates

candidates = list(self.find_candidates(identifier))
self._set_cached_candidates(identifier, candidates)
logger.debug(
"%s: cache %i unfiltered candidates",
identifier,
len(candidates),
)
return candidates

def _get_no_match_error_message(
Expand Down
102 changes: 97 additions & 5 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from fromager import constraints, resolver
from fromager.__main__ import main as fromager
from fromager.candidate import Candidate

_hydra_core_simple_response = """
<!DOCTYPE html>
Expand Down Expand Up @@ -58,7 +59,9 @@


@pytest.fixture(autouse=True)
def reset_cache() -> None:
def reset_cache() -> typing.Generator[None, None, None]:
resolver.BaseProvider.clear_cache()
yield
resolver.BaseProvider.clear_cache()


Expand Down Expand Up @@ -144,7 +147,7 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None:
provider = pypi_hydra_resolver.provider
assert provider.cache_key == "https://pypi.org/simple/"
req_cache = provider._get_cached_candidates(req.name)
assert req_cache == []
assert req_cache is None

result = pypi_hydra_resolver.resolve([req])
candidate = result.mapping[req.name]
Expand All @@ -153,10 +156,8 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None:
resolver_cache = resolver.BaseProvider.resolver_cache
assert req.name in resolver_cache
assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name]
# mutated in place
assert provider._get_cached_candidates(req.name) is req_cache
# _get_cached_candidates returns a defensive copy, not the same object
assert len(provider._get_cached_candidates(req.name)) == 7
assert len(req_cache) == 7


def test_provider_cache_key_gitlab(gitlab_decile_resolver: typing.Any) -> None:
Expand Down Expand Up @@ -1278,3 +1279,94 @@ def test_cli_package_resolver(
assert "- PyPI versions: 1.2.2, 1.3.1+local, 1.3.2, 2.0.0a1" in result.stdout
assert "- only wheels on PyPI: 1.3.1+local, 2.0.0a1" in result.stdout
assert "- missing from Fromager: 1.3.1+local, 2.0.0a1" in result.stdout


def _make_candidate(name: str, version: str) -> Candidate:
"""Create a minimal Candidate for testing."""
return Candidate(
name=name, version=Version(version), url="https://example.com", is_sdist=False
)


class _StubProvider(resolver.BaseProvider):
"""Minimal BaseProvider subclass for cache tests."""

provider_description = "stub"

@property
def cache_key(self) -> str:
return "stub-key"

def find_candidates(self, identifier: str) -> list[Candidate]:
return []


class _CallbackProvider(resolver.BaseProvider):
"""BaseProvider subclass whose find_candidates delegates to a callback."""

provider_description = "callback"

def __init__(
self,
callback: typing.Callable[[str], list[Candidate]],
**kwargs: typing.Any,
) -> None:
super().__init__(**kwargs)
self._callback = callback

@property
def cache_key(self) -> str:
return "callback-key"

def find_candidates(self, identifier: str) -> list[Candidate]:
return self._callback(identifier)


def test_get_cached_candidates_returns_defensive_copy() -> None:
"""Mutating the list returned by _get_cached_candidates must not corrupt the cache."""
provider = _StubProvider()
identifier = "test-pkg"

# Seed the cache directly
resolver.BaseProvider.resolver_cache[identifier] = {
(type(provider), provider.cache_key): [_make_candidate("test-pkg", "1.0.0")]
}

first = provider._get_cached_candidates(identifier)
assert first is not None
first.append(_make_candidate("test-pkg", "2.0.0"))

# The cache should not reflect the caller's mutation
second = provider._get_cached_candidates(identifier)
assert second is not None
assert len(second) == 1, (
"_get_cached_candidates should return a defensive copy, "
"not a direct reference to the internal cache"
)
assert second[0].version == Version("1.0.0")


def test_empty_candidate_list_is_cached() -> None:
"""An empty find_candidates result must be cached, not re-fetched."""
call_count = 0

def counting_find(identifier: str) -> list[Candidate]:
nonlocal call_count
call_count += 1
return []

provider = _CallbackProvider(callback=counting_find)
provider._find_cached_candidates("empty-pkg")
provider._find_cached_candidates("empty-pkg")
assert call_count == 1, (
f"find_candidates() was called {call_count} times; expected 1. "
"Empty candidate lists must be treated as valid cache entries."
)


def test_find_cached_candidates_cache_disabled() -> None:
"""With use_resolver_cache=False, results must bypass the cache entirely."""
provider = _StubProvider(use_resolver_cache=False)
result = list(provider._find_cached_candidates("uncached-pkg"))
assert result == []
assert "uncached-pkg" not in resolver.BaseProvider.resolver_cache
Loading