From c27d414d0accd3d693bd110080d6cc9ea35f02af Mon Sep 17 00:00:00 2001 From: Egor Kosaretskii Date: Mon, 20 Apr 2026 16:00:50 +0500 Subject: [PATCH] fix: make local mode health-check timeout configurable via session config (#3362) Allow users to set `local.health_check_timeout` in their session config to override the hard-coded 120s limit, fixing failures when large model archives or slow networks cause container startup to exceed the default. Co-Authored-By: Claude Sonnet 4.6 --- .../src/sagemaker/core/local/entities.py | 24 +++++-- .../tests/unit/local/test_entities.py | 68 +++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/local/entities.py b/sagemaker-core/src/sagemaker/core/local/entities.py index 88dfb7ed26..fa5e6c88fe 100644 --- a/sagemaker-core/src/sagemaker/core/local/entities.py +++ b/sagemaker-core/src/sagemaker/core/local/entities.py @@ -332,7 +332,11 @@ def start(self, input_data, output_data, transform_resources, **kwargs): self.container.serve(self.primary_container["ModelDataUrl"], environment) serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080 - _wait_for_serving_container(serving_port) + health_check_timeout = ( + get_config_value("local.health_check_timeout", self.local_session.config) + or HEALTH_CHECK_TIMEOUT_LIMIT + ) + _wait_for_serving_container(serving_port, timeout=health_check_timeout) # Get capabilities from Container if needed endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port) @@ -623,7 +627,11 @@ def serve(self): ) serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080 - _wait_for_serving_container(serving_port) + health_check_timeout = ( + get_config_value("local.health_check_timeout", self.local_session.config) + or HEALTH_CHECK_TIMEOUT_LIMIT + ) + _wait_for_serving_container(serving_port, timeout=health_check_timeout) # the container is running and it passed the healthcheck status is now InService self.state = _LocalEndpoint._IN_SERVICE @@ -646,15 +654,21 @@ def describe(self): return response -def _wait_for_serving_container(serving_port): - """Placeholder docstring.""" +def _wait_for_serving_container(serving_port, timeout=HEALTH_CHECK_TIMEOUT_LIMIT): + """Wait for the serving container to become healthy. + + Args: + serving_port (int): The port the serving container is listening on. + timeout (int): Maximum number of seconds to wait for the container to become healthy. + Defaults to ``HEALTH_CHECK_TIMEOUT_LIMIT``. + """ i = 0 http = urllib3.PoolManager() endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port) while True: i += 5 - if i >= HEALTH_CHECK_TIMEOUT_LIMIT: + if i >= timeout: raise RuntimeError("Giving up, endpoint didn't launch correctly") logger.info("Checking if serving container is up, attempt: %s", i) diff --git a/sagemaker-core/tests/unit/local/test_entities.py b/sagemaker-core/tests/unit/local/test_entities.py index 4be2969937..05cd330bd5 100644 --- a/sagemaker-core/tests/unit/local/test_entities.py +++ b/sagemaker-core/tests/unit/local/test_entities.py @@ -510,6 +510,47 @@ def test_endpoint_serve(self, mock_container_class, mock_wait, mock_session_clas assert endpoint.state == endpoint._IN_SERVICE mock_container.serve.assert_called_once() + @patch("sagemaker.core.local.local_session.LocalSession") + @patch("sagemaker.core.local.entities._wait_for_serving_container") + @patch("sagemaker.core.local.entities._SageMakerContainer") + def test_endpoint_serve_custom_health_check_timeout( + self, mock_container_class, mock_wait, mock_session_class + ): + """Test that health_check_timeout from session config is passed to _wait_for_serving_container""" + mock_session = Mock() + mock_client = Mock() + + mock_client.describe_endpoint_config.return_value = { + "EndpointConfigName": "test-config", + "ProductionVariants": [ + { + "VariantName": "AllTraffic", + "ModelName": "test-model", + "InitialInstanceCount": 1, + "InstanceType": "local", + } + ], + } + + mock_client.describe_model.return_value = { + "PrimaryContainer": { + "Image": "test-image:latest", + "ModelDataUrl": "s3://bucket/model.tar.gz", + "Environment": {}, + } + } + + mock_session.sagemaker_client = mock_client + mock_session.config = {"local": {"health_check_timeout": 600}} + + mock_container = Mock() + mock_container_class.return_value = mock_container + + endpoint = _LocalEndpoint("test-endpoint", "test-config", None, mock_session) + endpoint.serve() + + mock_wait.assert_called_once_with(8080, timeout=600) + @patch("sagemaker.core.local.local_session.LocalSession") def test_endpoint_stop(self, mock_session_class): """Test stopping an endpoint""" @@ -604,6 +645,33 @@ def test_wait_timeout(self, mock_sleep, mock_get_host, mock_perform_request): with pytest.raises(RuntimeError, match="Giving up"): _wait_for_serving_container(8080) + @patch("sagemaker.core.local.entities._perform_request") + @patch("sagemaker.core.local.entities.get_docker_host") + @patch("time.sleep") + def test_wait_custom_timeout(self, mock_sleep, mock_get_host, mock_perform_request): + """Test that custom timeout is respected""" + mock_get_host.return_value = "localhost" + mock_perform_request.return_value = (None, 500) + + with pytest.raises(RuntimeError, match="Giving up"): + _wait_for_serving_container(8080, timeout=10) + + # With timeout=10 and step=5, only 2 iterations before i(=10) >= timeout(=10) + assert mock_perform_request.call_count == 1 + + @patch("sagemaker.core.local.entities._perform_request") + @patch("sagemaker.core.local.entities.get_docker_host") + @patch("time.sleep") + def test_wait_uses_default_timeout_constant(self, mock_sleep, mock_get_host, mock_perform_request): + """Test that _wait_for_serving_container defaults to HEALTH_CHECK_TIMEOUT_LIMIT""" + mock_get_host.return_value = "localhost" + # Succeed on first attempt so we can verify it was called + mock_perform_request.return_value = (Mock(), 200) + + _wait_for_serving_container(8080) + + mock_perform_request.assert_called_once() + class TestPerformRequest: """Test cases for _perform_request"""