Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions sagemaker-core/src/sagemaker/core/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ def start(self, input_data, output_data, transform_resources, **kwargs):
self.container.serve(self.primary_container["ModelDataUrl"], environment)

serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080
_wait_for_serving_container(serving_port)
health_check_timeout = (
get_config_value("local.health_check_timeout", self.local_session.config)
or HEALTH_CHECK_TIMEOUT_LIMIT
)
_wait_for_serving_container(serving_port, timeout=health_check_timeout)

# Get capabilities from Container if needed
endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port)
Expand Down Expand Up @@ -623,7 +627,11 @@ def serve(self):
)

serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080
_wait_for_serving_container(serving_port)
health_check_timeout = (
get_config_value("local.health_check_timeout", self.local_session.config)
or HEALTH_CHECK_TIMEOUT_LIMIT
)
_wait_for_serving_container(serving_port, timeout=health_check_timeout)
# the container is running and it passed the healthcheck status is now InService
self.state = _LocalEndpoint._IN_SERVICE

Expand All @@ -646,15 +654,21 @@ def describe(self):
return response


def _wait_for_serving_container(serving_port):
"""Placeholder docstring."""
def _wait_for_serving_container(serving_port, timeout=HEALTH_CHECK_TIMEOUT_LIMIT):
"""Wait for the serving container to become healthy.

Args:
serving_port (int): The port the serving container is listening on.
timeout (int): Maximum number of seconds to wait for the container to become healthy.
Defaults to ``HEALTH_CHECK_TIMEOUT_LIMIT``.
"""
i = 0
http = urllib3.PoolManager()

endpoint_url = "http://%s:%d/ping" % (get_docker_host(), serving_port)
while True:
i += 5
if i >= HEALTH_CHECK_TIMEOUT_LIMIT:
if i >= timeout:
raise RuntimeError("Giving up, endpoint didn't launch correctly")

logger.info("Checking if serving container is up, attempt: %s", i)
Expand Down
68 changes: 68 additions & 0 deletions sagemaker-core/tests/unit/local/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,47 @@ def test_endpoint_serve(self, mock_container_class, mock_wait, mock_session_clas
assert endpoint.state == endpoint._IN_SERVICE
mock_container.serve.assert_called_once()

@patch("sagemaker.core.local.local_session.LocalSession")
@patch("sagemaker.core.local.entities._wait_for_serving_container")
@patch("sagemaker.core.local.entities._SageMakerContainer")
def test_endpoint_serve_custom_health_check_timeout(
self, mock_container_class, mock_wait, mock_session_class
):
"""Test that health_check_timeout from session config is passed to _wait_for_serving_container"""
mock_session = Mock()
mock_client = Mock()

mock_client.describe_endpoint_config.return_value = {
"EndpointConfigName": "test-config",
"ProductionVariants": [
{
"VariantName": "AllTraffic",
"ModelName": "test-model",
"InitialInstanceCount": 1,
"InstanceType": "local",
}
],
}

mock_client.describe_model.return_value = {
"PrimaryContainer": {
"Image": "test-image:latest",
"ModelDataUrl": "s3://bucket/model.tar.gz",
"Environment": {},
}
}

mock_session.sagemaker_client = mock_client
mock_session.config = {"local": {"health_check_timeout": 600}}

mock_container = Mock()
mock_container_class.return_value = mock_container

endpoint = _LocalEndpoint("test-endpoint", "test-config", None, mock_session)
endpoint.serve()

mock_wait.assert_called_once_with(8080, timeout=600)

@patch("sagemaker.core.local.local_session.LocalSession")
def test_endpoint_stop(self, mock_session_class):
"""Test stopping an endpoint"""
Expand Down Expand Up @@ -604,6 +645,33 @@ def test_wait_timeout(self, mock_sleep, mock_get_host, mock_perform_request):
with pytest.raises(RuntimeError, match="Giving up"):
_wait_for_serving_container(8080)

@patch("sagemaker.core.local.entities._perform_request")
@patch("sagemaker.core.local.entities.get_docker_host")
@patch("time.sleep")
def test_wait_custom_timeout(self, mock_sleep, mock_get_host, mock_perform_request):
"""Test that custom timeout is respected"""
mock_get_host.return_value = "localhost"
mock_perform_request.return_value = (None, 500)

with pytest.raises(RuntimeError, match="Giving up"):
_wait_for_serving_container(8080, timeout=10)

# With timeout=10 and step=5, only 2 iterations before i(=10) >= timeout(=10)
assert mock_perform_request.call_count == 1

@patch("sagemaker.core.local.entities._perform_request")
@patch("sagemaker.core.local.entities.get_docker_host")
@patch("time.sleep")
def test_wait_uses_default_timeout_constant(self, mock_sleep, mock_get_host, mock_perform_request):
"""Test that _wait_for_serving_container defaults to HEALTH_CHECK_TIMEOUT_LIMIT"""
mock_get_host.return_value = "localhost"
# Succeed on first attempt so we can verify it was called
mock_perform_request.return_value = (Mock(), 200)

_wait_for_serving_container(8080)

mock_perform_request.assert_called_once()


class TestPerformRequest:
"""Test cases for _perform_request"""
Expand Down
Loading