diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index a192fde3..bdf5424c 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -462,6 +462,17 @@ def get_project_from_pypi( class BaseProvider(ExtrasProvider): + """Base class for Fromager's dependency resolver (resolvelib + extras). + + Subclasses implement ``find_candidates``, ``cache_key``, and + ``provider_description`` to list versions from PyPI, a version map, etc. + + Candidate lists are cached per package in one global dict. + + ``find_matches`` keeps only versions that fit the requirements and + constraints, then picks newest first. + """ + resolver_cache: typing.ClassVar[ResolverCache] = {} provider_description: typing.ClassVar[str] _cooldown_unsupported_warned: typing.ClassVar[set[str]] = set() @@ -516,7 +527,7 @@ def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: @classmethod def clear_cache(cls, identifier: str | None = None) -> None: - """Clear global resolver cache + """Clear global resolver cache. ``None`` clears all caches, an ``identifier`` string clears the cache for an identifier. Raises :exc:`KeyError` for unknown @@ -657,45 +668,55 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]: # return candidate.dependencies return [] - def _get_cached_candidates(self, identifier: str) -> list[Candidate]: - """Get list of cached candidates for identifier and provider + def _get_cached_candidates(self, identifier: str) -> list[Candidate] | None: + """Get a copy of cached candidates for identifier and provider. - The method always returns a list. If the cache did not have an entry - before, a new empty list is stored in the cache and returned to the - caller. The caller can mutate the list in place to update the cache. + Returns ``None`` if no entry exists in the cache, or a copy of the + cached list (which may be empty). A copy is returned so callers + cannot accidentally corrupt the cache. """ cls = type(self) + provider_cache = cls.resolver_cache.get(identifier, {}) + candidate_cache = provider_cache.get((cls, self.cache_key)) + if candidate_cache is None: + return None + return list(candidate_cache) + + def _set_cached_candidates( + self, identifier: str, candidates: list[Candidate] + ) -> None: + """Store candidates in the cache for identifier and provider.""" + cls = type(self) provider_cache = cls.resolver_cache.setdefault(identifier, {}) - candidate_cache = provider_cache.setdefault((cls, self.cache_key), []) - return candidate_cache + provider_cache[(cls, self.cache_key)] = list(candidates) def _find_cached_candidates(self, identifier: str) -> Candidates: - """Find candidates with caching""" - cached_candidates: list[Candidate] = [] - if self.use_cache_candidates: - cached_candidates = self._get_cached_candidates(identifier) - if cached_candidates: - logger.debug( - "%s: use %i cached candidates", - identifier, - len(cached_candidates), - ) - return cached_candidates - candidates = list(self.find_candidates(identifier)) - if self.use_cache_candidates: - # mutate list object in-place - cached_candidates[:] = candidates + """Find candidates with caching.""" + if not self.use_cache_candidates: + candidates = list(self.find_candidates(identifier)) logger.debug( - "%s: cache %i unfiltered candidates", + "%s: got %i unfiltered candidates, ignoring cache", identifier, len(candidates), ) - else: + return candidates + + cached_candidates = self._get_cached_candidates(identifier) + if cached_candidates is not None: logger.debug( - "%s: got %i unfiltered candidates, ignoring cache", + "%s: use %i cached candidates", identifier, - len(candidates), + len(cached_candidates), ) + return cached_candidates + + candidates = list(self.find_candidates(identifier)) + self._set_cached_candidates(identifier, candidates) + logger.debug( + "%s: cache %i unfiltered candidates", + identifier, + len(candidates), + ) return candidates def _get_no_match_error_message( diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 690a1288..4f0c4db7 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -11,6 +11,7 @@ from fromager import constraints, resolver from fromager.__main__ import main as fromager +from fromager.candidate import Candidate _hydra_core_simple_response = """ @@ -58,7 +59,9 @@ @pytest.fixture(autouse=True) -def reset_cache() -> None: +def reset_cache() -> typing.Generator[None, None, None]: + resolver.BaseProvider.clear_cache() + yield resolver.BaseProvider.clear_cache() @@ -144,7 +147,7 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None: provider = pypi_hydra_resolver.provider assert provider.cache_key == "https://pypi.org/simple/" req_cache = provider._get_cached_candidates(req.name) - assert req_cache == [] + assert req_cache is None result = pypi_hydra_resolver.resolve([req]) candidate = result.mapping[req.name] @@ -153,10 +156,8 @@ def test_provider_cache_key_pypi(pypi_hydra_resolver: typing.Any) -> None: resolver_cache = resolver.BaseProvider.resolver_cache assert req.name in resolver_cache assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name] - # mutated in place - assert provider._get_cached_candidates(req.name) is req_cache + # _get_cached_candidates returns a defensive copy, not the same object assert len(provider._get_cached_candidates(req.name)) == 7 - assert len(req_cache) == 7 def test_provider_cache_key_gitlab(gitlab_decile_resolver: typing.Any) -> None: @@ -1278,3 +1279,94 @@ def test_cli_package_resolver( assert "- PyPI versions: 1.2.2, 1.3.1+local, 1.3.2, 2.0.0a1" in result.stdout assert "- only wheels on PyPI: 1.3.1+local, 2.0.0a1" in result.stdout assert "- missing from Fromager: 1.3.1+local, 2.0.0a1" in result.stdout + + +def _make_candidate(name: str, version: str) -> Candidate: + """Create a minimal Candidate for testing.""" + return Candidate( + name=name, version=Version(version), url="https://example.com", is_sdist=False + ) + + +class _StubProvider(resolver.BaseProvider): + """Minimal BaseProvider subclass for cache tests.""" + + provider_description = "stub" + + @property + def cache_key(self) -> str: + return "stub-key" + + def find_candidates(self, identifier: str) -> list[Candidate]: + return [] + + +class _CallbackProvider(resolver.BaseProvider): + """BaseProvider subclass whose find_candidates delegates to a callback.""" + + provider_description = "callback" + + def __init__( + self, + callback: typing.Callable[[str], list[Candidate]], + **kwargs: typing.Any, + ) -> None: + super().__init__(**kwargs) + self._callback = callback + + @property + def cache_key(self) -> str: + return "callback-key" + + def find_candidates(self, identifier: str) -> list[Candidate]: + return self._callback(identifier) + + +def test_get_cached_candidates_returns_defensive_copy() -> None: + """Mutating the list returned by _get_cached_candidates must not corrupt the cache.""" + provider = _StubProvider() + identifier = "test-pkg" + + # Seed the cache directly + resolver.BaseProvider.resolver_cache[identifier] = { + (type(provider), provider.cache_key): [_make_candidate("test-pkg", "1.0.0")] + } + + first = provider._get_cached_candidates(identifier) + assert first is not None + first.append(_make_candidate("test-pkg", "2.0.0")) + + # The cache should not reflect the caller's mutation + second = provider._get_cached_candidates(identifier) + assert second is not None + assert len(second) == 1, ( + "_get_cached_candidates should return a defensive copy, " + "not a direct reference to the internal cache" + ) + assert second[0].version == Version("1.0.0") + + +def test_empty_candidate_list_is_cached() -> None: + """An empty find_candidates result must be cached, not re-fetched.""" + call_count = 0 + + def counting_find(identifier: str) -> list[Candidate]: + nonlocal call_count + call_count += 1 + return [] + + provider = _CallbackProvider(callback=counting_find) + provider._find_cached_candidates("empty-pkg") + provider._find_cached_candidates("empty-pkg") + assert call_count == 1, ( + f"find_candidates() was called {call_count} times; expected 1. " + "Empty candidate lists must be treated as valid cache entries." + ) + + +def test_find_cached_candidates_cache_disabled() -> None: + """With use_resolver_cache=False, results must bypass the cache entirely.""" + provider = _StubProvider(use_resolver_cache=False) + result = list(provider._find_cached_candidates("uncached-pkg")) + assert result == [] + assert "uncached-pkg" not in resolver.BaseProvider.resolver_cache