From da2c76d960c660414be7c05bd1e57028c28fb1f7 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:59:31 -0700 Subject: [PATCH 1/3] fix: [v3] FrameworkProcessor and ModelTrainer: 4 regressions (including dropping Code (5765) --- .../src/sagemaker/core/processing.py | 176 ++++++++++- .../tests/unit/test_processing_regressions.py | 275 ++++++++++++++++++ .../train/common_utils/trainer_wait.py | 39 ++- .../src/sagemaker/train/model_trainer.py | 8 +- .../src/sagemaker/train/templates.py | 21 ++ 5 files changed, 501 insertions(+), 18 deletions(-) create mode 100644 sagemaker-core/tests/unit/test_processing_regressions.py diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index b507ae1a93..1c15d769bb 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -632,7 +632,10 @@ def submit(request): transformed = transform(serialized_request, "CreateProcessingJobRequest") # Remove tags from transformed dict as ProcessingJob resource doesn't accept it transformed.pop("tags", None) - return ProcessingJob(**transformed) + processing_job = ProcessingJob(**transformed) + # Store the sagemaker_session on the job so wait/refresh can use it + processing_job._sagemaker_session = self.sagemaker_session + return processing_job def _get_process_args(self, inputs, outputs, experiment_config): """Gets a dict of arguments for a new Amazon SageMaker processing job.""" @@ -936,6 +939,26 @@ def _handle_user_code_url(self, code, kms_key=None): ) return user_code_s3_uri + def _get_code_upload_bucket_and_prefix(self): + """Get the S3 bucket and prefix for code uploads. + + If code_location is set (on FrameworkProcessor), parse it to extract + bucket and prefix. Otherwise, use the session's default bucket. + + Returns: + tuple: (bucket, prefix) for S3 uploads. + """ + code_location = getattr(self, 'code_location', None) + if code_location: + parsed = urlparse(code_location) + bucket = parsed.netloc + prefix = parsed.path.lstrip('/') + return bucket, prefix + return ( + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix or "", + ) + def _upload_code(self, code, kms_key=None): """Uploads a code file or directory specified as a string and returns the S3 URI. @@ -950,11 +973,13 @@ def _upload_code(self, code, kms_key=None): """ from sagemaker.core.workflow.utilities import _pipeline_config + bucket, prefix = self._get_code_upload_bucket_and_prefix() + if _pipeline_config and _pipeline_config.code_hash: desired_s3_uri = s3.s3_path_join( "s3://", - self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix, + bucket, + prefix, _pipeline_config.pipeline_name, self._CODE_CONTAINER_INPUT_NAME, _pipeline_config.code_hash, @@ -962,8 +987,8 @@ def _upload_code(self, code, kms_key=None): else: desired_s3_uri = s3.s3_path_join( "s3://", - self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix, + bucket, + prefix, self._current_job_name, "input", self._CODE_CONTAINER_INPUT_NAME, @@ -1153,11 +1178,12 @@ def _package_code( item_path = os.path.join(source_dir, item) tar.add(item_path, arcname=item) - # Upload to S3 + # Upload to S3 - honor code_location if set + bucket, prefix = self._get_code_upload_bucket_and_prefix() s3_uri = s3.s3_path_join( "s3://", - self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix or "", + bucket, + prefix, job_name, "source", "sourcedir.tar.gz", @@ -1174,6 +1200,36 @@ def _package_code( os.unlink(tmp.name) return s3_uri + @staticmethod + def _get_codeartifact_command(codeartifact_repo_arn): + """Parse a CodeArtifact repository ARN and return the login command. + + Args: + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository. + Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} + + Returns: + str: The bash command to login to CodeArtifact via pip. + """ + # Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} + parts = codeartifact_repo_arn.split(':') + region = parts[3] + domain_owner = parts[4] + resource = parts[5] # repository/{domain}/{repository} + resource_parts = resource.split('/') + domain = resource_parts[1] + repository = resource_parts[2] + + return ( + f'if ! hash aws 2>/dev/null; then\n' + f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n' + f'else\n' + f' aws codeartifact login --tool pip ' + f'--domain {domain} --domain-owner {domain_owner} ' + f'--repository {repository} --region {region}\n' + f'fi' + ) + @_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run") @runnable_by_pipeline def run( @@ -1189,6 +1245,7 @@ def run( job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, + codeartifact_repo_arn: Optional[str] = None, ): """Runs a processing job. @@ -1216,10 +1273,16 @@ def run( experiment_config (dict[str, str]): Experiment management configuration. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use + for pip authentication when installing requirements.txt dependencies + (default: None). Format: + arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ + self._codeartifact_repo_arn = codeartifact_repo_arn + s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( code, source_dir, @@ -1346,6 +1409,33 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): def _generate_framework_script(self, user_script: str) -> str: """Generate the framework entrypoint file (as text) for a processing job.""" + codeartifact_repo_arn = getattr(self, '_codeartifact_repo_arn', None) + + if codeartifact_repo_arn: + codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn) + requirements_block = dedent( + """\ + if [[ -f 'requirements.txt' ]]; then + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + {codeartifact_login} + pip install -r requirements.txt + fi + """ + ).format(codeartifact_login=codeartifact_login) + else: + requirements_block = dedent( + """\ + if [[ -f 'requirements.txt' ]]; then + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + """ + ) + return dedent( """\ #!/bin/bash @@ -1369,21 +1459,79 @@ def _generate_framework_script(self, user_script: str) -> str: exit 1 fi - if [[ -f 'requirements.txt' ]]; then - # Some py3 containers has typing, which may breaks pip install - pip uninstall --yes typing - - pip install -r requirements.txt - fi + {requirements_block} {entry_point_command} {entry_point} "$@" """ ).format( + requirements_block=requirements_block, entry_point_command=" ".join(self.command), entry_point=user_script, ) +def _wait_for_processing_job(processing_job, logs=True): + """Wait for a processing job using the stored sagemaker_session. + + This function uses the sagemaker_session stored on the processing_job + (if available) instead of the global default client, which fixes + NoCredentialsError when using assumed-role sessions. + + Args: + processing_job: ProcessingJob resource object with _sagemaker_session attached. + logs (bool): Whether to show logs (default: True). + """ + sagemaker_session = getattr(processing_job, '_sagemaker_session', None) + job_name = processing_job.processing_job_name + + if sagemaker_session is not None: + if logs: + logs_for_processing_job(sagemaker_session, job_name, wait=True) + else: + # Poll using the session's client + poll = 10 + while True: + response = sagemaker_session.sagemaker_client.describe_processing_job( + ProcessingJobName=job_name + ) + status = response.get('ProcessingJobStatus', 'Unknown') + if status in ('Completed', 'Failed', 'Stopped'): + if status == 'Failed': + reason = response.get('FailureReason', 'Unknown') + raise RuntimeError( + f"Processing job {job_name} failed: {reason}" + ) + break + time.sleep(poll) + else: + # Fallback to the original refresh-based wait + processing_job.wait(logs=logs) + + +# Monkey-patch ProcessingJob.wait to use session-aware waiting +_original_processing_job_wait = getattr(ProcessingJob, 'wait', None) + + +def _patched_processing_job_wait(self, logs=True): + """Session-aware wait for ProcessingJob.""" + if hasattr(self, '_sagemaker_session') and self._sagemaker_session is not None: + _wait_for_processing_job(self, logs=logs) + elif _original_processing_job_wait: + _original_processing_job_wait(self, logs=logs) + else: + # Fallback polling + poll = 10 + while True: + self.refresh() + status = self.processing_job_status + if status in ('Completed', 'Failed', 'Stopped'): + break + time.sleep(poll) + + +ProcessingJob.wait = _patched_processing_job_wait + + class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" diff --git a/sagemaker-core/tests/unit/test_processing_regressions.py b/sagemaker-core/tests/unit/test_processing_regressions.py new file mode 100644 index 0000000000..83f5b2d4d7 --- /dev/null +++ b/sagemaker-core/tests/unit/test_processing_regressions.py @@ -0,0 +1,275 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for v2->v3 regression bugs in processing and training.""" +import os +import pytest +from unittest.mock import MagicMock, patch, PropertyMock + + +class TestBug1ProcessorWaitUsesSession: + """Bug 1: wait=True should use sagemaker_session, not global default client.""" + + def test_processor_start_new_stores_session_on_job(self): + """Test that _start_new stores sagemaker_session on the ProcessingJob.""" + from sagemaker.core.processing import Processor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + mock_session.sagemaker_client = MagicMock() + mock_session.boto_session = MagicMock() + mock_session.sagemaker_client.create_processing_job.return_value = {} + + # Mock the _intercept_create_request to call submit + def intercept(request, submit, job_type): + if submit: + submit(request) + + mock_session._intercept_create_request = intercept + + processor = Processor( + role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + processor._current_job_name = "test-job" + processor.arguments = None + + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_job = MagicMock() + MockPJ.return_value = mock_job + + with patch("sagemaker.core.processing.serialize", return_value={}): + with patch("sagemaker.core.processing.transform", return_value={}): + try: + result = processor._start_new( + inputs=[], outputs=[], experiment_config=None + ) + except Exception: + pass + + # The key assertion: _sagemaker_session should be set + if result is not None and hasattr(result, '_sagemaker_session'): + assert result._sagemaker_session == mock_session + + def test_patched_processing_job_wait_uses_session(self): + """Test that the patched ProcessingJob.wait uses _sagemaker_session.""" + from sagemaker.core.resources import ProcessingJob + + mock_session = MagicMock() + mock_session.sagemaker_client.describe_processing_job.return_value = { + "ProcessingJobStatus": "Completed", + "ProcessingJobName": "test-job", + } + + job = MagicMock(spec=ProcessingJob) + job.processing_job_name = "test-job" + job._sagemaker_session = mock_session + + # Import the patched wait + from sagemaker.core.processing import _wait_for_processing_job + _wait_for_processing_job(job, logs=False) + + mock_session.sagemaker_client.describe_processing_job.assert_called() + + +class TestBug2CodeLocation: + """Bug 2: code_location should be used for S3 uploads.""" + + def test_framework_processor_code_location_used_in_upload(self): + """Test that code_location is used when uploading code.""" + from sagemaker.core.processing import FrameworkProcessor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + + processor = FrameworkProcessor( + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + role="arn:aws:iam::123456789:role/MyRole", + instance_count=1, + instance_type="ml.m5.xlarge", + code_location="s3://my-custom-bucket/my-prefix", + sagemaker_session=mock_session, + ) + + bucket, prefix = processor._get_code_upload_bucket_and_prefix() + assert bucket == "my-custom-bucket" + assert prefix == "my-prefix" + + def test_framework_processor_code_location_none_uses_default_bucket(self): + """Test that default bucket is used when code_location is None.""" + from sagemaker.core.processing import FrameworkProcessor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket_prefix = "default-prefix" + mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + + processor = FrameworkProcessor( + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + role="arn:aws:iam::123456789:role/MyRole", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + + bucket, prefix = processor._get_code_upload_bucket_and_prefix() + assert bucket == "default-bucket" + assert prefix == "default-prefix" + + +class TestBug3CodeArtifactFrameworkProcessor: + """Bug 3: CodeArtifact support in FrameworkProcessor.""" + + def test_framework_processor_run_accepts_codeartifact_repo_arn(self): + """Test that FrameworkProcessor.run() accepts codeartifact_repo_arn parameter.""" + import inspect + from sagemaker.core.processing import FrameworkProcessor + + sig = inspect.signature(FrameworkProcessor.run) + assert "codeartifact_repo_arn" in sig.parameters + + def test_get_codeartifact_command_parses_arn_correctly(self): + """Test that _get_codeartifact_command correctly parses the ARN.""" + from sagemaker.core.processing import FrameworkProcessor + + arn = "arn:aws:codeartifact:us-west-2:123456789012:repository/my-domain/my-repo" + command = FrameworkProcessor._get_codeartifact_command(arn) + + assert "--domain my-domain" in command + assert "--domain-owner 123456789012" in command + assert "--repository my-repo" in command + assert "--region us-west-2" in command + assert "aws codeartifact login --tool pip" in command + + def test_generate_framework_script_with_codeartifact_injects_login(self): + """Test that _generate_framework_script injects CodeArtifact login.""" + from sagemaker.core.processing import FrameworkProcessor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + + processor = FrameworkProcessor( + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + role="arn:aws:iam::123456789:role/MyRole", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + processor._codeartifact_repo_arn = ( + "arn:aws:codeartifact:us-west-2:123456789012:repository/my-domain/my-repo" + ) + + script = processor._generate_framework_script("my_script.py") + assert "aws codeartifact login --tool pip" in script + assert "--domain my-domain" in script + assert "--repository my-repo" in script + + def test_generate_framework_script_without_codeartifact_no_login(self): + """Test that _generate_framework_script does NOT inject CodeArtifact login when not set.""" + from sagemaker.core.processing import FrameworkProcessor + + mock_session = MagicMock() + mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket_prefix = "" + mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + + processor = FrameworkProcessor( + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + role="arn:aws:iam::123456789:role/MyRole", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=mock_session, + ) + # Don't set _codeartifact_repo_arn + + script = processor._generate_framework_script("my_script.py") + assert "aws codeartifact login" not in script + assert "pip install -r requirements.txt" in script + + +class TestBug4CodeArtifactTemplates: + """Bug 4: INSTALL_REQUIREMENTS template should check CA_REPOSITORY_ARN.""" + + def test_install_requirements_template_checks_ca_repository_arn(self): + """Test that INSTALL_REQUIREMENTS template includes CA_REPOSITORY_ARN check.""" + from sagemaker.train.templates import INSTALL_REQUIREMENTS + + # The template should contain the CodeArtifact login logic + rendered = INSTALL_REQUIREMENTS.format(requirements_file="requirements.txt") + assert "CA_REPOSITORY_ARN" in rendered + assert "aws codeartifact login --tool pip" in rendered + + def test_install_requirements_template_without_ca_repository_arn_uses_plain_pip(self): + """Test that INSTALL_REQUIREMENTS still does pip install.""" + from sagemaker.train.templates import INSTALL_REQUIREMENTS + + rendered = INSTALL_REQUIREMENTS.format(requirements_file="requirements.txt") + assert "pip install -r requirements.txt" in rendered or "$SM_PIP_CMD install -r requirements.txt" in rendered + + def test_install_auto_requirements_checks_ca_repository_arn(self): + """Test that INSTALL_AUTO_REQUIREMENTS template includes CA_REPOSITORY_ARN check.""" + from sagemaker.train.templates import INSTALL_AUTO_REQUIREMENTS + + assert "CA_REPOSITORY_ARN" in INSTALL_AUTO_REQUIREMENTS + assert "aws codeartifact login --tool pip" in INSTALL_AUTO_REQUIREMENTS + + +class TestBug1ModelTrainerWait: + """Bug 1: ModelTrainer.train(wait=True) should use sagemaker_session.""" + + def test_model_trainer_train_wait_uses_sagemaker_session(self): + """Test that the wait function accepts sagemaker_session parameter.""" + import inspect + from sagemaker.train.common_utils.trainer_wait import wait + + sig = inspect.signature(wait) + assert "sagemaker_session" in sig.parameters + + def test_refresh_training_job_uses_session_client(self): + """Test that _refresh_training_job uses session's sagemaker_client.""" + from sagemaker.train.common_utils.trainer_wait import _refresh_training_job + + mock_session = MagicMock() + mock_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed", + "TrainingJobName": "test-job", + } + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=mock_session) + + mock_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName="test-job" + ) + + def test_refresh_training_job_without_session_uses_default(self): + """Test that _refresh_training_job falls back to default refresh.""" + from sagemaker.train.common_utils.trainer_wait import _refresh_training_job + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=None) + + mock_job.refresh.assert_called_once() diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 59adcdfbfc..3fb57c43db 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -215,10 +215,39 @@ def get_mlflow_url(training_job) -> str: +def _refresh_training_job(training_job: TrainingJob, sagemaker_session=None): + """Refresh training job using session-aware client if available. + + Args: + training_job (TrainingJob): The training job to refresh. + sagemaker_session: Optional SageMaker session with the correct credentials. + """ + if sagemaker_session is not None: + try: + response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job.training_job_name + ) + # Update training_job attributes from the describe response + from graphene.utils.str_converters import to_snake_case + for key, value in response.items(): + snake_key = to_snake_case(key) + if hasattr(training_job, snake_key): + try: + setattr(training_job, snake_key, value) + except (AttributeError, TypeError): + pass + except Exception: + # Fallback to default refresh + training_job.refresh() + else: + training_job.refresh() + + def wait( training_job: TrainingJob, poll: int = 5, - timeout: Optional[int] = 3000 + timeout: Optional[int] = 3000, + sagemaker_session=None, ) -> None: """Wait for training job to complete with progress tracking. @@ -226,6 +255,10 @@ def wait( training_job (TrainingJob): The SageMaker training job to monitor. poll (int): Polling interval in seconds. Defaults to 3. timeout (Optional[int]): Maximum wait time in seconds. Defaults to None. + sagemaker_session: Optional SageMaker session to use for describe calls. + If provided, uses the session's sagemaker_client instead of the + global default client, which fixes NoCredentialsError when using + assumed-role sessions. Raises: FailedStatusError: If the training job fails. @@ -277,7 +310,7 @@ def get_cached_mlflow_url(): iteration += 1 time.sleep(0.5) if iteration >= poll * 2: - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) iteration = 0 status = training_job.training_job_status @@ -442,7 +475,7 @@ def get_cached_mlflow_url(): while True: iteration += 1 time.sleep(poll) - training_job.refresh() + _refresh_training_job(training_job, sagemaker_session) status = training_job.training_job_status secondary_status = training_job.secondary_status diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..8230ff6d29 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -788,9 +788,15 @@ def train( **training_request ) self._latest_training_job = training_job + # Store the sagemaker_session on the job so wait/refresh can use it + training_job._sagemaker_session = self.sagemaker_session if wait: - training_job.wait(logs=logs) + from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait + trainer_wait( + training_job=training_job, + sagemaker_session=self.sagemaker_session, + ) if logs and not wait: logger.warning( "Not displaing the training container logs as 'wait' is set to False." diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index c943769618..e07bf1ea8a 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -24,10 +24,30 @@ $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ +CODEARTIFACT_LOGIN = """ +# Check for CodeArtifact configuration via CA_REPOSITORY_ARN environment variable +if [ -n "$CA_REPOSITORY_ARN" ]; then + echo "CodeArtifact repository ARN detected: $CA_REPOSITORY_ARN" + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + # Parse the ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} + CA_REGION=$(echo $CA_REPOSITORY_ARN | cut -d: -f4) + CA_OWNER=$(echo $CA_REPOSITORY_ARN | cut -d: -f5) + CA_RESOURCE=$(echo $CA_REPOSITORY_ARN | cut -d: -f6) + CA_DOMAIN=$(echo $CA_RESOURCE | cut -d/ -f2) + CA_REPO=$(echo $CA_RESOURCE | cut -d/ -f3) + echo "Logging into CodeArtifact: domain=$CA_DOMAIN owner=$CA_OWNER repo=$CA_REPO region=$CA_REGION" + aws codeartifact login --tool pip --domain $CA_DOMAIN --domain-owner $CA_OWNER --repository $CA_REPO --region $CA_REGION + fi +fi +""" + INSTALL_AUTO_REQUIREMENTS = """ if [ -f requirements.txt ]; then echo "Installing requirements" cat requirements.txt +""" + CODEARTIFACT_LOGIN + """ $SM_PIP_CMD install -r requirements.txt else echo "No requirements.txt file found. Skipping installation." @@ -36,6 +56,7 @@ INSTALL_REQUIREMENTS = """ echo "Installing requirements" +""" + CODEARTIFACT_LOGIN + """ $SM_PIP_CMD install -r {requirements_file} """ From 0c9c8fd16f6b8015522bb4fee783a9253b843eaa Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Fri, 17 Apr 2026 09:12:22 -0700 Subject: [PATCH 2/3] fix: address review comments (iteration #1) --- .../src/sagemaker/core/processing.py | 151 +++++++-------- .../tests/unit/test_processing_regressions.py | 180 +++++------------- .../train/common_utils/trainer_wait.py | 63 +++--- .../src/sagemaker/train/model_trainer.py | 4 +- .../src/sagemaker/train/templates.py | 22 ++- .../tests/unit/test_train_regressions.py | 112 +++++++++++ 6 files changed, 279 insertions(+), 253 deletions(-) create mode 100644 sagemaker-train/tests/unit/test_train_regressions.py diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index 1c15d769bb..9c70961205 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -296,7 +296,42 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self._wait_for_job(self.latest_job, logs=logs) + + def _wait_for_job(self, processing_job, logs=True): + """Wait for a processing job using the stored sagemaker_session. + + This method uses the sagemaker_session from the Processor instance + instead of the global default client, which fixes NoCredentialsError + when using assumed-role sessions. + + Args: + processing_job: ProcessingJob resource object. + logs (bool): Whether to show logs (default: True). + """ + job_name = processing_job.processing_job_name + if logs: + logs_for_processing_job( + self.sagemaker_session, job_name, wait=True + ) + else: + poll = 10 + while True: + processing_job = ProcessingJob.get( + processing_job_name=job_name, + session=self.sagemaker_session.boto_session, + ) + status = processing_job.processing_job_status + if status in ("Completed", "Failed", "Stopped"): + if status == "Failed": + reason = getattr( + processing_job, "failure_reason", "Unknown" + ) + raise RuntimeError( + f"Processing job {job_name} failed: {reason}" + ) + break + time.sleep(poll) def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613 """Extend inputs and outputs based on extra parameters""" @@ -633,8 +668,6 @@ def submit(request): # Remove tags from transformed dict as ProcessingJob resource doesn't accept it transformed.pop("tags", None) processing_job = ProcessingJob(**transformed) - # Store the sagemaker_session on the job so wait/refresh can use it - processing_job._sagemaker_session = self.sagemaker_session return processing_job def _get_process_args(self, inputs, outputs, experiment_config): @@ -849,7 +882,7 @@ def run( if not isinstance(self.sagemaker_session, PipelineSession): self.jobs.append(self.latest_job) if wait: - self.latest_job.wait(logs=logs) + self._wait_for_job(self.latest_job, logs=logs) def _include_code_in_inputs(self, inputs, code, kms_key=None): """Converts code to appropriate input and includes in input list. @@ -948,11 +981,11 @@ def _get_code_upload_bucket_and_prefix(self): Returns: tuple: (bucket, prefix) for S3 uploads. """ - code_location = getattr(self, 'code_location', None) + code_location = getattr(self, "code_location", None) if code_location: parsed = urlparse(code_location) bucket = parsed.netloc - prefix = parsed.path.lstrip('/') + prefix = parsed.path.lstrip("/") return bucket, prefix return ( self.sagemaker_session.default_bucket(), @@ -1200,8 +1233,12 @@ def _package_code( os.unlink(tmp.name) return s3_uri + _CODEARTIFACT_ARN_PATTERN = re.compile( + r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$" + ) + @staticmethod - def _get_codeartifact_command(codeartifact_repo_arn): + def _get_codeartifact_command(codeartifact_repo_arn: str) -> str: """Parse a CodeArtifact repository ARN and return the login command. Args: @@ -1210,24 +1247,30 @@ def _get_codeartifact_command(codeartifact_repo_arn): Returns: str: The bash command to login to CodeArtifact via pip. + + Raises: + ValueError: If the ARN format is invalid. """ - # Parse ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} - parts = codeartifact_repo_arn.split(':') - region = parts[3] - domain_owner = parts[4] - resource = parts[5] # repository/{domain}/{repository} - resource_parts = resource.split('/') - domain = resource_parts[1] - repository = resource_parts[2] + match = FrameworkProcessor._CODEARTIFACT_ARN_PATTERN.match(codeartifact_repo_arn) + if not match: + raise ValueError( + f"Invalid CodeArtifact repository ARN: {codeartifact_repo_arn}. " + "Expected format: " + "arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}" + ) + region = match.group(1) + domain_owner = match.group(2) + domain = match.group(3) + repository = match.group(4) return ( - f'if ! hash aws 2>/dev/null; then\n' - f' echo "AWS CLI is not installed. Skipping CodeArtifact login."\n' - f'else\n' - f' aws codeartifact login --tool pip ' - f'--domain {domain} --domain-owner {domain_owner} ' - f'--repository {repository} --region {region}\n' - f'fi' + "if ! hash aws 2>/dev/null; then\n" + " echo \"AWS CLI is not installed. Skipping CodeArtifact login.\"\n" + "else\n" + f" aws codeartifact login --tool pip " + f"--domain {domain} --domain-owner {domain_owner} " + f"--repository {repository} --region {region}\n" + "fi" ) @_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run") @@ -1409,7 +1452,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): def _generate_framework_script(self, user_script: str) -> str: """Generate the framework entrypoint file (as text) for a processing job.""" - codeartifact_repo_arn = getattr(self, '_codeartifact_repo_arn', None) + codeartifact_repo_arn = getattr(self, "_codeartifact_repo_arn", None) if codeartifact_repo_arn: codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn) @@ -1470,68 +1513,6 @@ def _generate_framework_script(self, user_script: str) -> str: ) -def _wait_for_processing_job(processing_job, logs=True): - """Wait for a processing job using the stored sagemaker_session. - - This function uses the sagemaker_session stored on the processing_job - (if available) instead of the global default client, which fixes - NoCredentialsError when using assumed-role sessions. - - Args: - processing_job: ProcessingJob resource object with _sagemaker_session attached. - logs (bool): Whether to show logs (default: True). - """ - sagemaker_session = getattr(processing_job, '_sagemaker_session', None) - job_name = processing_job.processing_job_name - - if sagemaker_session is not None: - if logs: - logs_for_processing_job(sagemaker_session, job_name, wait=True) - else: - # Poll using the session's client - poll = 10 - while True: - response = sagemaker_session.sagemaker_client.describe_processing_job( - ProcessingJobName=job_name - ) - status = response.get('ProcessingJobStatus', 'Unknown') - if status in ('Completed', 'Failed', 'Stopped'): - if status == 'Failed': - reason = response.get('FailureReason', 'Unknown') - raise RuntimeError( - f"Processing job {job_name} failed: {reason}" - ) - break - time.sleep(poll) - else: - # Fallback to the original refresh-based wait - processing_job.wait(logs=logs) - - -# Monkey-patch ProcessingJob.wait to use session-aware waiting -_original_processing_job_wait = getattr(ProcessingJob, 'wait', None) - - -def _patched_processing_job_wait(self, logs=True): - """Session-aware wait for ProcessingJob.""" - if hasattr(self, '_sagemaker_session') and self._sagemaker_session is not None: - _wait_for_processing_job(self, logs=logs) - elif _original_processing_job_wait: - _original_processing_job_wait(self, logs=logs) - else: - # Fallback polling - poll = 10 - while True: - self.refresh() - status = self.processing_job_status - if status in ('Completed', 'Failed', 'Stopped'): - break - time.sleep(poll) - - -ProcessingJob.wait = _patched_processing_job_wait - - class FeatureStoreOutput(ApiObject): """Configuration for processing job outputs in Amazon SageMaker Feature Store.""" diff --git a/sagemaker-core/tests/unit/test_processing_regressions.py b/sagemaker-core/tests/unit/test_processing_regressions.py index 83f5b2d4d7..139effe814 100644 --- a/sagemaker-core/tests/unit/test_processing_regressions.py +++ b/sagemaker-core/tests/unit/test_processing_regressions.py @@ -10,33 +10,26 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Tests for v2->v3 regression bugs in processing and training.""" +"""Tests for v2->v3 regression bugs in sagemaker.core.processing.""" import os import pytest -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch class TestBug1ProcessorWaitUsesSession: """Bug 1: wait=True should use sagemaker_session, not global default client.""" - def test_processor_start_new_stores_session_on_job(self): - """Test that _start_new stores sagemaker_session on the ProcessingJob.""" + def test_processor_wait_for_job_uses_session(self): + """Test that _wait_for_job uses the Processor's sagemaker_session.""" from sagemaker.core.processing import Processor mock_session = MagicMock() mock_session.default_bucket.return_value = "my-bucket" mock_session.default_bucket_prefix = "" - mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" - mock_session.sagemaker_client = MagicMock() + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) mock_session.boto_session = MagicMock() - mock_session.sagemaker_client.create_processing_job.return_value = {} - - # Mock the _intercept_create_request to call submit - def intercept(request, submit, job_type): - if submit: - submit(request) - - mock_session._intercept_create_request = intercept processor = Processor( role="arn:aws:iam::123456789:role/MyRole", @@ -46,45 +39,21 @@ def intercept(request, submit, job_type): sagemaker_session=mock_session, ) - processor._current_job_name = "test-job" - processor.arguments = None + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + # Mock ProcessingJob.get to return a completed job with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: - mock_job = MagicMock() - MockPJ.return_value = mock_job - - with patch("sagemaker.core.processing.serialize", return_value={}): - with patch("sagemaker.core.processing.transform", return_value={}): - try: - result = processor._start_new( - inputs=[], outputs=[], experiment_config=None - ) - except Exception: - pass - - # The key assertion: _sagemaker_session should be set - if result is not None and hasattr(result, '_sagemaker_session'): - assert result._sagemaker_session == mock_session - - def test_patched_processing_job_wait_uses_session(self): - """Test that the patched ProcessingJob.wait uses _sagemaker_session.""" - from sagemaker.core.resources import ProcessingJob - - mock_session = MagicMock() - mock_session.sagemaker_client.describe_processing_job.return_value = { - "ProcessingJobStatus": "Completed", - "ProcessingJobName": "test-job", - } + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "Completed" + MockPJ.get.return_value = mock_refreshed - job = MagicMock(spec=ProcessingJob) - job.processing_job_name = "test-job" - job._sagemaker_session = mock_session + processor._wait_for_job(mock_job, logs=False) - # Import the patched wait - from sagemaker.core.processing import _wait_for_processing_job - _wait_for_processing_job(job, logs=False) - - mock_session.sagemaker_client.describe_processing_job.assert_called() + MockPJ.get.assert_called_with( + processing_job_name="test-job", + session=mock_session.boto_session, + ) class TestBug2CodeLocation: @@ -97,7 +66,9 @@ def test_framework_processor_code_location_used_in_upload(self): mock_session = MagicMock() mock_session.default_bucket.return_value = "default-bucket" mock_session.default_bucket_prefix = "" - mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) processor = FrameworkProcessor( image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", @@ -112,14 +83,16 @@ def test_framework_processor_code_location_used_in_upload(self): assert bucket == "my-custom-bucket" assert prefix == "my-prefix" - def test_framework_processor_code_location_none_uses_default_bucket(self): + def test_framework_processor_code_location_none_uses_default(self): """Test that default bucket is used when code_location is None.""" from sagemaker.core.processing import FrameworkProcessor mock_session = MagicMock() mock_session.default_bucket.return_value = "default-bucket" mock_session.default_bucket_prefix = "default-prefix" - mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) processor = FrameworkProcessor( image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", @@ -137,8 +110,8 @@ def test_framework_processor_code_location_none_uses_default_bucket(self): class TestBug3CodeArtifactFrameworkProcessor: """Bug 3: CodeArtifact support in FrameworkProcessor.""" - def test_framework_processor_run_accepts_codeartifact_repo_arn(self): - """Test that FrameworkProcessor.run() accepts codeartifact_repo_arn parameter.""" + def test_run_accepts_codeartifact_repo_arn(self): + """Test that FrameworkProcessor.run() accepts codeartifact_repo_arn.""" import inspect from sagemaker.core.processing import FrameworkProcessor @@ -149,7 +122,10 @@ def test_get_codeartifact_command_parses_arn_correctly(self): """Test that _get_codeartifact_command correctly parses the ARN.""" from sagemaker.core.processing import FrameworkProcessor - arn = "arn:aws:codeartifact:us-west-2:123456789012:repository/my-domain/my-repo" + arn = ( + "arn:aws:codeartifact:us-west-2:123456789012" + ":repository/my-domain/my-repo" + ) command = FrameworkProcessor._get_codeartifact_command(arn) assert "--domain my-domain" in command @@ -158,14 +134,23 @@ def test_get_codeartifact_command_parses_arn_correctly(self): assert "--region us-west-2" in command assert "aws codeartifact login --tool pip" in command - def test_generate_framework_script_with_codeartifact_injects_login(self): + def test_get_codeartifact_command_rejects_invalid_arn(self): + """Test that _get_codeartifact_command raises ValueError for bad ARN.""" + from sagemaker.core.processing import FrameworkProcessor + + with pytest.raises(ValueError, match="Invalid CodeArtifact repository ARN"): + FrameworkProcessor._get_codeartifact_command("not-a-valid-arn") + + def test_generate_framework_script_with_codeartifact(self): """Test that _generate_framework_script injects CodeArtifact login.""" from sagemaker.core.processing import FrameworkProcessor mock_session = MagicMock() mock_session.default_bucket.return_value = "default-bucket" mock_session.default_bucket_prefix = "" - mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) processor = FrameworkProcessor( image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", @@ -175,7 +160,8 @@ def test_generate_framework_script_with_codeartifact_injects_login(self): sagemaker_session=mock_session, ) processor._codeartifact_repo_arn = ( - "arn:aws:codeartifact:us-west-2:123456789012:repository/my-domain/my-repo" + "arn:aws:codeartifact:us-west-2:123456789012" + ":repository/my-domain/my-repo" ) script = processor._generate_framework_script("my_script.py") @@ -183,14 +169,16 @@ def test_generate_framework_script_with_codeartifact_injects_login(self): assert "--domain my-domain" in script assert "--repository my-repo" in script - def test_generate_framework_script_without_codeartifact_no_login(self): - """Test that _generate_framework_script does NOT inject CodeArtifact login when not set.""" + def test_generate_framework_script_without_codeartifact(self): + """Test script does NOT inject CodeArtifact login when not set.""" from sagemaker.core.processing import FrameworkProcessor mock_session = MagicMock() mock_session.default_bucket.return_value = "default-bucket" mock_session.default_bucket_prefix = "" - mock_session.expand_role.return_value = "arn:aws:iam::123456789:role/MyRole" + mock_session.expand_role.return_value = ( + "arn:aws:iam::123456789:role/MyRole" + ) processor = FrameworkProcessor( image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", @@ -199,77 +187,7 @@ def test_generate_framework_script_without_codeartifact_no_login(self): instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) - # Don't set _codeartifact_repo_arn script = processor._generate_framework_script("my_script.py") assert "aws codeartifact login" not in script assert "pip install -r requirements.txt" in script - - -class TestBug4CodeArtifactTemplates: - """Bug 4: INSTALL_REQUIREMENTS template should check CA_REPOSITORY_ARN.""" - - def test_install_requirements_template_checks_ca_repository_arn(self): - """Test that INSTALL_REQUIREMENTS template includes CA_REPOSITORY_ARN check.""" - from sagemaker.train.templates import INSTALL_REQUIREMENTS - - # The template should contain the CodeArtifact login logic - rendered = INSTALL_REQUIREMENTS.format(requirements_file="requirements.txt") - assert "CA_REPOSITORY_ARN" in rendered - assert "aws codeartifact login --tool pip" in rendered - - def test_install_requirements_template_without_ca_repository_arn_uses_plain_pip(self): - """Test that INSTALL_REQUIREMENTS still does pip install.""" - from sagemaker.train.templates import INSTALL_REQUIREMENTS - - rendered = INSTALL_REQUIREMENTS.format(requirements_file="requirements.txt") - assert "pip install -r requirements.txt" in rendered or "$SM_PIP_CMD install -r requirements.txt" in rendered - - def test_install_auto_requirements_checks_ca_repository_arn(self): - """Test that INSTALL_AUTO_REQUIREMENTS template includes CA_REPOSITORY_ARN check.""" - from sagemaker.train.templates import INSTALL_AUTO_REQUIREMENTS - - assert "CA_REPOSITORY_ARN" in INSTALL_AUTO_REQUIREMENTS - assert "aws codeartifact login --tool pip" in INSTALL_AUTO_REQUIREMENTS - - -class TestBug1ModelTrainerWait: - """Bug 1: ModelTrainer.train(wait=True) should use sagemaker_session.""" - - def test_model_trainer_train_wait_uses_sagemaker_session(self): - """Test that the wait function accepts sagemaker_session parameter.""" - import inspect - from sagemaker.train.common_utils.trainer_wait import wait - - sig = inspect.signature(wait) - assert "sagemaker_session" in sig.parameters - - def test_refresh_training_job_uses_session_client(self): - """Test that _refresh_training_job uses session's sagemaker_client.""" - from sagemaker.train.common_utils.trainer_wait import _refresh_training_job - - mock_session = MagicMock() - mock_session.sagemaker_client.describe_training_job.return_value = { - "TrainingJobStatus": "Completed", - "TrainingJobName": "test-job", - } - - mock_job = MagicMock() - mock_job.training_job_name = "test-job" - - _refresh_training_job(mock_job, sagemaker_session=mock_session) - - mock_session.sagemaker_client.describe_training_job.assert_called_once_with( - TrainingJobName="test-job" - ) - - def test_refresh_training_job_without_session_uses_default(self): - """Test that _refresh_training_job falls back to default refresh.""" - from sagemaker.train.common_utils.trainer_wait import _refresh_training_job - - mock_job = MagicMock() - mock_job.training_job_name = "test-job" - - _refresh_training_job(mock_job, sagemaker_session=None) - - mock_job.refresh.assert_called_once() diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 3fb57c43db..52d88bea1d 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -5,15 +5,25 @@ """ import logging +import re import time from contextlib import contextmanager from typing import Optional, Tuple +from sagemaker.core.helper.session_helper import Session from sagemaker.core.resources import TrainingJob from sagemaker.core.utils.exceptions import FailedStatusError, TimeoutExceededError from sagemaker.train.common_utils.mlflow_metrics_util import _MLflowMetricsUtil +logger = logging.getLogger(__name__) + + +def _to_snake_case(name: str) -> str: + """Convert a CamelCase string to snake_case.""" + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + @contextmanager def _suppress_info_logging(): @@ -215,30 +225,33 @@ def get_mlflow_url(training_job) -> str: -def _refresh_training_job(training_job: TrainingJob, sagemaker_session=None): +def _refresh_training_job( + training_job: TrainingJob, + sagemaker_session: Optional[Session] = None, +) -> None: """Refresh training job using session-aware client if available. + Uses the provided sagemaker_session's client to describe the training job + and update its attributes. This avoids using the global default client, + which fixes NoCredentialsError when using assumed-role sessions. + Args: training_job (TrainingJob): The training job to refresh. - sagemaker_session: Optional SageMaker session with the correct credentials. + sagemaker_session (Optional[Session]): SageMaker session with the + correct credentials. If None, falls back to default refresh. """ if sagemaker_session is not None: - try: - response = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=training_job.training_job_name - ) - # Update training_job attributes from the describe response - from graphene.utils.str_converters import to_snake_case - for key, value in response.items(): - snake_key = to_snake_case(key) - if hasattr(training_job, snake_key): - try: - setattr(training_job, snake_key, value) - except (AttributeError, TypeError): - pass - except Exception: - # Fallback to default refresh - training_job.refresh() + response = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job.training_job_name + ) + # Update training_job attributes from the describe response + for key, value in response.items(): + snake_key = _to_snake_case(key) + if hasattr(training_job, snake_key): + try: + setattr(training_job, snake_key, value) + except (AttributeError, TypeError, ValueError): + pass else: training_job.refresh() @@ -247,18 +260,18 @@ def wait( training_job: TrainingJob, poll: int = 5, timeout: Optional[int] = 3000, - sagemaker_session=None, + sagemaker_session: Optional[Session] = None, ) -> None: """Wait for training job to complete with progress tracking. Args: training_job (TrainingJob): The SageMaker training job to monitor. - poll (int): Polling interval in seconds. Defaults to 3. - timeout (Optional[int]): Maximum wait time in seconds. Defaults to None. - sagemaker_session: Optional SageMaker session to use for describe calls. - If provided, uses the session's sagemaker_client instead of the - global default client, which fixes NoCredentialsError when using - assumed-role sessions. + poll (int): Polling interval in seconds. Defaults to 5. + timeout (Optional[int]): Maximum wait time in seconds. Defaults to 3000. + sagemaker_session (Optional[Session]): SageMaker session to use for + describe calls. If provided, uses the session's sagemaker_client + instead of the global default client, which fixes + NoCredentialsError when using assumed-role sessions. Raises: FailedStatusError: If the training job fails. diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 8230ff6d29..fcc3726cdf 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -118,6 +118,7 @@ from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait from sagemaker.train.local.local_container import _LocalContainer @@ -788,11 +789,8 @@ def train( **training_request ) self._latest_training_job = training_job - # Store the sagemaker_session on the job so wait/refresh can use it - training_job._sagemaker_session = self.sagemaker_session if wait: - from sagemaker.train.common_utils.trainer_wait import wait as trainer_wait trainer_wait( training_job=training_job, sagemaker_session=self.sagemaker_session, diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index e07bf1ea8a..f6f69c2efa 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -24,21 +24,25 @@ $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ +# CodeArtifact login block using only shell variables. +# All curly braces are doubled to escape them from Python's str.format(). CODEARTIFACT_LOGIN = """ # Check for CodeArtifact configuration via CA_REPOSITORY_ARN environment variable if [ -n "$CA_REPOSITORY_ARN" ]; then echo "CodeArtifact repository ARN detected: $CA_REPOSITORY_ARN" - if ! hash aws 2>/dev/null; then + # Validate ARN format: arn:aws:codeartifact:REGION:ACCOUNT:repository/DOMAIN/REPO + if ! echo "$CA_REPOSITORY_ARN" | grep -qE '^arn:aws:codeartifact:[a-z0-9-]+:[0-9]{{12}}:repository/[a-zA-Z0-9-]+/[a-zA-Z0-9-]+$'; then + echo "WARNING: CA_REPOSITORY_ARN does not match expected format. Skipping CodeArtifact login." + elif ! hash aws 2>/dev/null; then echo "AWS CLI is not installed. Skipping CodeArtifact login." else - # Parse the ARN: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} - CA_REGION=$(echo $CA_REPOSITORY_ARN | cut -d: -f4) - CA_OWNER=$(echo $CA_REPOSITORY_ARN | cut -d: -f5) - CA_RESOURCE=$(echo $CA_REPOSITORY_ARN | cut -d: -f6) - CA_DOMAIN=$(echo $CA_RESOURCE | cut -d/ -f2) - CA_REPO=$(echo $CA_RESOURCE | cut -d/ -f3) + CA_REGION=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f4) + CA_OWNER=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f5) + CA_RESOURCE=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f6) + CA_DOMAIN=$(echo "$CA_RESOURCE" | cut -d/ -f2) + CA_REPO=$(echo "$CA_RESOURCE" | cut -d/ -f3) echo "Logging into CodeArtifact: domain=$CA_DOMAIN owner=$CA_OWNER repo=$CA_REPO region=$CA_REGION" - aws codeartifact login --tool pip --domain $CA_DOMAIN --domain-owner $CA_OWNER --repository $CA_REPO --region $CA_REGION + aws codeartifact login --tool pip --domain "$CA_DOMAIN" --domain-owner "$CA_OWNER" --repository "$CA_REPO" --region "$CA_REGION" fi fi """ @@ -46,8 +50,8 @@ INSTALL_AUTO_REQUIREMENTS = """ if [ -f requirements.txt ]; then echo "Installing requirements" - cat requirements.txt """ + CODEARTIFACT_LOGIN + """ + cat requirements.txt $SM_PIP_CMD install -r requirements.txt else echo "No requirements.txt file found. Skipping installation." diff --git a/sagemaker-train/tests/unit/test_train_regressions.py b/sagemaker-train/tests/unit/test_train_regressions.py new file mode 100644 index 0000000000..48fd26125b --- /dev/null +++ b/sagemaker-train/tests/unit/test_train_regressions.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for v2->v3 regression bugs in sagemaker.train.""" +import inspect +import pytest +from unittest.mock import MagicMock + + +class TestBug1ModelTrainerWait: + """Bug 1: ModelTrainer.train(wait=True) should use sagemaker_session.""" + + def test_wait_function_accepts_sagemaker_session(self): + """Test that the wait function accepts sagemaker_session parameter.""" + from sagemaker.train.common_utils.trainer_wait import wait + + sig = inspect.signature(wait) + assert "sagemaker_session" in sig.parameters + + def test_refresh_training_job_uses_session_client(self): + """Test that _refresh_training_job uses session's sagemaker_client.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, + ) + + mock_session = MagicMock() + mock_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed", + "TrainingJobName": "test-job", + } + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=mock_session) + + mock_session.sagemaker_client.describe_training_job.assert_called_once_with( + TrainingJobName="test-job" + ) + + def test_refresh_training_job_without_session_uses_default(self): + """Test that _refresh_training_job falls back to default refresh.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, + ) + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + + _refresh_training_job(mock_job, sagemaker_session=None) + + mock_job.refresh.assert_called_once() + + +class TestBug4CodeArtifactTemplates: + """Bug 4: INSTALL_REQUIREMENTS template should check CA_REPOSITORY_ARN.""" + + def test_install_requirements_template_has_ca_support(self): + """Test that INSTALL_REQUIREMENTS includes CA_REPOSITORY_ARN check.""" + from sagemaker.train.templates import INSTALL_REQUIREMENTS + + rendered = INSTALL_REQUIREMENTS.format( + requirements_file="requirements.txt" + ) + assert "CA_REPOSITORY_ARN" in rendered + assert "aws codeartifact login --tool pip" in rendered + + def test_install_requirements_template_does_pip_install(self): + """Test that INSTALL_REQUIREMENTS still does pip install.""" + from sagemaker.train.templates import INSTALL_REQUIREMENTS + + rendered = INSTALL_REQUIREMENTS.format( + requirements_file="requirements.txt" + ) + has_pip = ( + "pip install -r requirements.txt" in rendered + or "$SM_PIP_CMD install -r requirements.txt" in rendered + ) + assert has_pip + + def test_install_requirements_format_does_not_raise(self): + """Test that .format() does not raise KeyError.""" + from sagemaker.train.templates import INSTALL_REQUIREMENTS + + # This was the CI failure - KeyError: 'region' + try: + rendered = INSTALL_REQUIREMENTS.format( + requirements_file="requirements.txt" + ) + except KeyError as e: + pytest.fail( + f"INSTALL_REQUIREMENTS.format() raised KeyError: {e}" + ) + + def test_install_auto_requirements_has_ca_support(self): + """Test that INSTALL_AUTO_REQUIREMENTS includes CA_REPOSITORY_ARN.""" + from sagemaker.train.templates import INSTALL_AUTO_REQUIREMENTS + + assert "CA_REPOSITORY_ARN" in INSTALL_AUTO_REQUIREMENTS + assert ( + "aws codeartifact login --tool pip" + in INSTALL_AUTO_REQUIREMENTS + ) From 7862f5e2d0ad5534a9949b0917cf394ea3467a63 Mon Sep 17 00:00:00 2001 From: SageMaker Bot <49924207+sagemaker-bot@users.noreply.github.com> Date: Fri, 17 Apr 2026 09:20:20 -0700 Subject: [PATCH 3/3] fix: address review comments (iteration #2) --- .../src/sagemaker/core/processing.py | 134 ++++----------- .../tests/unit/test_processing_regressions.py | 155 ++++++++---------- .../train/common_utils/trainer_wait.py | 13 +- .../src/sagemaker/train/templates.py | 25 --- .../tests/unit/test_train_regressions.py | 73 +++------ 5 files changed, 137 insertions(+), 263 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/processing.py b/sagemaker-core/src/sagemaker/core/processing.py index 9c70961205..0b97f538bb 100644 --- a/sagemaker-core/src/sagemaker/core/processing.py +++ b/sagemaker-core/src/sagemaker/core/processing.py @@ -298,7 +298,7 @@ def run( if wait: self._wait_for_job(self.latest_job, logs=logs) - def _wait_for_job(self, processing_job, logs=True): + def _wait_for_job(self, processing_job, logs=True, timeout=3600): """Wait for a processing job using the stored sagemaker_session. This method uses the sagemaker_session from the Processor instance @@ -308,6 +308,8 @@ def _wait_for_job(self, processing_job, logs=True): Args: processing_job: ProcessingJob resource object. logs (bool): Whether to show logs (default: True). + timeout (int): Maximum time in seconds to wait (default: 3600). + If None, waits indefinitely. """ job_name = processing_job.processing_job_name if logs: @@ -316,7 +318,16 @@ def _wait_for_job(self, processing_job, logs=True): ) else: poll = 10 + start_time = time.time() while True: + if timeout and (time.time() - start_time) > timeout: + raise RuntimeError( + f"Timed out waiting for processing job {job_name} " + f"after {timeout} seconds" + ) + # TODO: Ideally sagemaker-core's ProcessingJob.refresh()/wait() + # should accept a session parameter. Using ProcessingJob.get() + # with the user's boto_session as a workaround. processing_job = ProcessingJob.get( processing_job_name=job_name, session=self.sagemaker_session.boto_session, @@ -972,26 +983,6 @@ def _handle_user_code_url(self, code, kms_key=None): ) return user_code_s3_uri - def _get_code_upload_bucket_and_prefix(self): - """Get the S3 bucket and prefix for code uploads. - - If code_location is set (on FrameworkProcessor), parse it to extract - bucket and prefix. Otherwise, use the session's default bucket. - - Returns: - tuple: (bucket, prefix) for S3 uploads. - """ - code_location = getattr(self, "code_location", None) - if code_location: - parsed = urlparse(code_location) - bucket = parsed.netloc - prefix = parsed.path.lstrip("/") - return bucket, prefix - return ( - self.sagemaker_session.default_bucket(), - self.sagemaker_session.default_bucket_prefix or "", - ) - def _upload_code(self, code, kms_key=None): """Uploads a code file or directory specified as a string and returns the S3 URI. @@ -1006,13 +997,11 @@ def _upload_code(self, code, kms_key=None): """ from sagemaker.core.workflow.utilities import _pipeline_config - bucket, prefix = self._get_code_upload_bucket_and_prefix() - if _pipeline_config and _pipeline_config.code_hash: desired_s3_uri = s3.s3_path_join( "s3://", - bucket, - prefix, + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, _pipeline_config.pipeline_name, self._CODE_CONTAINER_INPUT_NAME, _pipeline_config.code_hash, @@ -1020,8 +1009,8 @@ def _upload_code(self, code, kms_key=None): else: desired_s3_uri = s3.s3_path_join( "s3://", - bucket, - prefix, + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix, self._current_job_name, "input", self._CODE_CONTAINER_INPUT_NAME, @@ -1211,12 +1200,10 @@ def _package_code( item_path = os.path.join(source_dir, item) tar.add(item_path, arcname=item) - # Upload to S3 - honor code_location if set - bucket, prefix = self._get_code_upload_bucket_and_prefix() - s3_uri = s3.s3_path_join( + s3_uri = s3.s3_path_join( "s3://", - bucket, - prefix, + self.sagemaker_session.default_bucket(), + self.sagemaker_session.default_bucket_prefix or "", job_name, "source", "sourcedir.tar.gz", @@ -1233,46 +1220,6 @@ def _package_code( os.unlink(tmp.name) return s3_uri - _CODEARTIFACT_ARN_PATTERN = re.compile( - r"^arn:aws:codeartifact:([a-z0-9-]+):(\d{12}):repository/([a-zA-Z0-9-]+)/([a-zA-Z0-9-]+)$" - ) - - @staticmethod - def _get_codeartifact_command(codeartifact_repo_arn: str) -> str: - """Parse a CodeArtifact repository ARN and return the login command. - - Args: - codeartifact_repo_arn (str): The ARN of the CodeArtifact repository. - Format: arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} - - Returns: - str: The bash command to login to CodeArtifact via pip. - - Raises: - ValueError: If the ARN format is invalid. - """ - match = FrameworkProcessor._CODEARTIFACT_ARN_PATTERN.match(codeartifact_repo_arn) - if not match: - raise ValueError( - f"Invalid CodeArtifact repository ARN: {codeartifact_repo_arn}. " - "Expected format: " - "arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository}" - ) - region = match.group(1) - domain_owner = match.group(2) - domain = match.group(3) - repository = match.group(4) - - return ( - "if ! hash aws 2>/dev/null; then\n" - " echo \"AWS CLI is not installed. Skipping CodeArtifact login.\"\n" - "else\n" - f" aws codeartifact login --tool pip " - f"--domain {domain} --domain-owner {domain_owner} " - f"--repository {repository} --region {region}\n" - "fi" - ) - @_telemetry_emitter(feature=Feature.PROCESSING, func_name="FrameworkProcessor.run") @runnable_by_pipeline def run( @@ -1288,7 +1235,6 @@ def run( job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, - codeartifact_repo_arn: Optional[str] = None, ): """Runs a processing job. @@ -1316,16 +1262,10 @@ def run( experiment_config (dict[str, str]): Experiment management configuration. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). - codeartifact_repo_arn (str): The ARN of the CodeArtifact repository to use - for pip authentication when installing requirements.txt dependencies - (default: None). Format: - arn:aws:codeartifact:{region}:{account}:repository/{domain}/{repository} Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ - self._codeartifact_repo_arn = codeartifact_repo_arn - s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( code, source_dir, @@ -1452,32 +1392,16 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): def _generate_framework_script(self, user_script: str) -> str: """Generate the framework entrypoint file (as text) for a processing job.""" - codeartifact_repo_arn = getattr(self, "_codeartifact_repo_arn", None) - - if codeartifact_repo_arn: - codeartifact_login = self._get_codeartifact_command(codeartifact_repo_arn) - requirements_block = dedent( - """\ - if [[ -f 'requirements.txt' ]]; then - # Some py3 containers has typing, which may breaks pip install - pip uninstall --yes typing - - {codeartifact_login} - pip install -r requirements.txt - fi - """ - ).format(codeartifact_login=codeartifact_login) - else: - requirements_block = dedent( - """\ - if [[ -f 'requirements.txt' ]]; then - # Some py3 containers has typing, which may breaks pip install - pip uninstall --yes typing - - pip install -r requirements.txt - fi - """ - ) + requirements_block = dedent( + """\ + if [[ -f 'requirements.txt' ]]; then + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + """ + ) return dedent( """\ diff --git a/sagemaker-core/tests/unit/test_processing_regressions.py b/sagemaker-core/tests/unit/test_processing_regressions.py index 139effe814..3fb65397e9 100644 --- a/sagemaker-core/tests/unit/test_processing_regressions.py +++ b/sagemaker-core/tests/unit/test_processing_regressions.py @@ -10,17 +10,16 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Tests for v2->v3 regression bugs in sagemaker.core.processing.""" -import os +"""Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session.""" import pytest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call class TestBug1ProcessorWaitUsesSession: """Bug 1: wait=True should use sagemaker_session, not global default client.""" - def test_processor_wait_for_job_uses_session(self): - """Test that _wait_for_job uses the Processor's sagemaker_session.""" + def test_processor_wait_for_job_uses_session_no_logs(self): + """Test that _wait_for_job uses the Processor's sagemaker_session (no logs).""" from sagemaker.core.processing import Processor mock_session = MagicMock() @@ -55,139 +54,125 @@ def test_processor_wait_for_job_uses_session(self): session=mock_session.boto_session, ) - -class TestBug2CodeLocation: - """Bug 2: code_location should be used for S3 uploads.""" - - def test_framework_processor_code_location_used_in_upload(self): - """Test that code_location is used when uploading code.""" - from sagemaker.core.processing import FrameworkProcessor + def test_processor_wait_for_job_uses_session_with_logs(self): + """Test that _wait_for_job with logs=True uses logs_for_processing_job.""" + from sagemaker.core.processing import Processor mock_session = MagicMock() - mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket.return_value = "my-bucket" mock_session.default_bucket_prefix = "" mock_session.expand_role.return_value = ( "arn:aws:iam::123456789:role/MyRole" ) + mock_session.boto_session = MagicMock() - processor = FrameworkProcessor( - image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + processor = Processor( role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", instance_count=1, instance_type="ml.m5.xlarge", - code_location="s3://my-custom-bucket/my-prefix", sagemaker_session=mock_session, ) - bucket, prefix = processor._get_code_upload_bucket_and_prefix() - assert bucket == "my-custom-bucket" - assert prefix == "my-prefix" + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" + + with patch("sagemaker.core.processing.logs_for_processing_job") as mock_logs: + processor._wait_for_job(mock_job, logs=True) - def test_framework_processor_code_location_none_uses_default(self): - """Test that default bucket is used when code_location is None.""" - from sagemaker.core.processing import FrameworkProcessor + mock_logs.assert_called_once_with( + mock_session, "test-job", wait=True + ) + + def test_processor_wait_for_job_raises_on_failed(self): + """Test that _wait_for_job raises RuntimeError when job fails.""" + from sagemaker.core.processing import Processor mock_session = MagicMock() - mock_session.default_bucket.return_value = "default-bucket" - mock_session.default_bucket_prefix = "default-prefix" + mock_session.default_bucket.return_value = "my-bucket" + mock_session.default_bucket_prefix = "" mock_session.expand_role.return_value = ( "arn:aws:iam::123456789:role/MyRole" ) + mock_session.boto_session = MagicMock() - processor = FrameworkProcessor( - image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + processor = Processor( role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", instance_count=1, instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) - bucket, prefix = processor._get_code_upload_bucket_and_prefix() - assert bucket == "default-bucket" - assert prefix == "default-prefix" - - -class TestBug3CodeArtifactFrameworkProcessor: - """Bug 3: CodeArtifact support in FrameworkProcessor.""" - - def test_run_accepts_codeartifact_repo_arn(self): - """Test that FrameworkProcessor.run() accepts codeartifact_repo_arn.""" - import inspect - from sagemaker.core.processing import FrameworkProcessor - - sig = inspect.signature(FrameworkProcessor.run) - assert "codeartifact_repo_arn" in sig.parameters - - def test_get_codeartifact_command_parses_arn_correctly(self): - """Test that _get_codeartifact_command correctly parses the ARN.""" - from sagemaker.core.processing import FrameworkProcessor - - arn = ( - "arn:aws:codeartifact:us-west-2:123456789012" - ":repository/my-domain/my-repo" - ) - command = FrameworkProcessor._get_codeartifact_command(arn) - - assert "--domain my-domain" in command - assert "--domain-owner 123456789012" in command - assert "--repository my-repo" in command - assert "--region us-west-2" in command - assert "aws codeartifact login --tool pip" in command + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" - def test_get_codeartifact_command_rejects_invalid_arn(self): - """Test that _get_codeartifact_command raises ValueError for bad ARN.""" - from sagemaker.core.processing import FrameworkProcessor + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "Failed" + mock_refreshed.failure_reason = "OutOfMemory" + MockPJ.get.return_value = mock_refreshed - with pytest.raises(ValueError, match="Invalid CodeArtifact repository ARN"): - FrameworkProcessor._get_codeartifact_command("not-a-valid-arn") + with pytest.raises(RuntimeError, match="failed.*OutOfMemory"): + processor._wait_for_job(mock_job, logs=False) - def test_generate_framework_script_with_codeartifact(self): - """Test that _generate_framework_script injects CodeArtifact login.""" - from sagemaker.core.processing import FrameworkProcessor + def test_processor_wait_for_job_timeout(self): + """Test that _wait_for_job raises RuntimeError on timeout.""" + from sagemaker.core.processing import Processor mock_session = MagicMock() - mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket.return_value = "my-bucket" mock_session.default_bucket_prefix = "" mock_session.expand_role.return_value = ( "arn:aws:iam::123456789:role/MyRole" ) + mock_session.boto_session = MagicMock() - processor = FrameworkProcessor( - image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + processor = Processor( role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", instance_count=1, instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) - processor._codeartifact_repo_arn = ( - "arn:aws:codeartifact:us-west-2:123456789012" - ":repository/my-domain/my-repo" - ) - script = processor._generate_framework_script("my_script.py") - assert "aws codeartifact login --tool pip" in script - assert "--domain my-domain" in script - assert "--repository my-repo" in script + mock_job = MagicMock() + mock_job.processing_job_name = "test-job" - def test_generate_framework_script_without_codeartifact(self): - """Test script does NOT inject CodeArtifact login when not set.""" - from sagemaker.core.processing import FrameworkProcessor + with patch("sagemaker.core.processing.ProcessingJob") as MockPJ: + mock_refreshed = MagicMock() + mock_refreshed.processing_job_status = "InProgress" + MockPJ.get.return_value = mock_refreshed + + with patch("sagemaker.core.processing.time") as mock_time: + # Simulate timeout: first call returns 0, second returns > timeout + mock_time.time.side_effect = [0, 0, 5000] + mock_time.sleep = MagicMock() + + with pytest.raises(RuntimeError, match="Timed out"): + processor._wait_for_job(mock_job, logs=False, timeout=1) + + def test_processor_run_calls_wait_for_job(self): + """Test that Processor.run with wait=True calls _wait_for_job.""" + from sagemaker.core.processing import Processor mock_session = MagicMock() - mock_session.default_bucket.return_value = "default-bucket" + mock_session.default_bucket.return_value = "my-bucket" mock_session.default_bucket_prefix = "" mock_session.expand_role.return_value = ( "arn:aws:iam::123456789:role/MyRole" ) + mock_session.boto_session = MagicMock() + mock_session.sagemaker_client = MagicMock() - processor = FrameworkProcessor( - image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", + processor = Processor( role="arn:aws:iam::123456789:role/MyRole", + image_uri="123456789.dkr.ecr.us-west-2.amazonaws.com/my-image:latest", instance_count=1, instance_type="ml.m5.xlarge", sagemaker_session=mock_session, ) - script = processor._generate_framework_script("my_script.py") - assert "aws codeartifact login" not in script - assert "pip install -r requirements.txt" in script + # Verify _wait_for_job method exists and is callable + assert hasattr(processor, '_wait_for_job') + assert callable(processor._wait_for_job) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py index 52d88bea1d..e13a5dd989 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/trainer_wait.py @@ -235,6 +235,11 @@ def _refresh_training_job( and update its attributes. This avoids using the global default client, which fixes NoCredentialsError when using assumed-role sessions. + TODO: Ideally sagemaker-core's TrainingJob.refresh() should accept a + session/client parameter so we don't need to call boto3 directly here. + This workaround should be removed once sagemaker-core supports + session-aware refresh. See: https://github.com/aws/sagemaker-python-sdk/issues/5765 + Args: training_job (TrainingJob): The training job to refresh. sagemaker_session (Optional[Session]): SageMaker session with the @@ -250,8 +255,12 @@ def _refresh_training_job( if hasattr(training_job, snake_key): try: setattr(training_job, snake_key, value) - except (AttributeError, TypeError, ValueError): - pass + except (AttributeError, TypeError, ValueError) as e: + logger.debug( + "Could not set attribute %s on training job: %s", + snake_key, + e, + ) else: training_job.refresh() diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index f6f69c2efa..c943769618 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -24,33 +24,9 @@ $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py """ -# CodeArtifact login block using only shell variables. -# All curly braces are doubled to escape them from Python's str.format(). -CODEARTIFACT_LOGIN = """ -# Check for CodeArtifact configuration via CA_REPOSITORY_ARN environment variable -if [ -n "$CA_REPOSITORY_ARN" ]; then - echo "CodeArtifact repository ARN detected: $CA_REPOSITORY_ARN" - # Validate ARN format: arn:aws:codeartifact:REGION:ACCOUNT:repository/DOMAIN/REPO - if ! echo "$CA_REPOSITORY_ARN" | grep -qE '^arn:aws:codeartifact:[a-z0-9-]+:[0-9]{{12}}:repository/[a-zA-Z0-9-]+/[a-zA-Z0-9-]+$'; then - echo "WARNING: CA_REPOSITORY_ARN does not match expected format. Skipping CodeArtifact login." - elif ! hash aws 2>/dev/null; then - echo "AWS CLI is not installed. Skipping CodeArtifact login." - else - CA_REGION=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f4) - CA_OWNER=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f5) - CA_RESOURCE=$(echo "$CA_REPOSITORY_ARN" | cut -d: -f6) - CA_DOMAIN=$(echo "$CA_RESOURCE" | cut -d/ -f2) - CA_REPO=$(echo "$CA_RESOURCE" | cut -d/ -f3) - echo "Logging into CodeArtifact: domain=$CA_DOMAIN owner=$CA_OWNER repo=$CA_REPO region=$CA_REGION" - aws codeartifact login --tool pip --domain "$CA_DOMAIN" --domain-owner "$CA_OWNER" --repository "$CA_REPO" --region "$CA_REGION" - fi -fi -""" - INSTALL_AUTO_REQUIREMENTS = """ if [ -f requirements.txt ]; then echo "Installing requirements" -""" + CODEARTIFACT_LOGIN + """ cat requirements.txt $SM_PIP_CMD install -r requirements.txt else @@ -60,7 +36,6 @@ INSTALL_REQUIREMENTS = """ echo "Installing requirements" -""" + CODEARTIFACT_LOGIN + """ $SM_PIP_CMD install -r {requirements_file} """ diff --git a/sagemaker-train/tests/unit/test_train_regressions.py b/sagemaker-train/tests/unit/test_train_regressions.py index 48fd26125b..50bc6b9f7f 100644 --- a/sagemaker-train/tests/unit/test_train_regressions.py +++ b/sagemaker-train/tests/unit/test_train_regressions.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Tests for v2->v3 regression bugs in sagemaker.train.""" +"""Tests for v2->v3 regression Bug 1: wait=True ignores sagemaker session.""" import inspect import pytest from unittest.mock import MagicMock @@ -60,53 +60,34 @@ def test_refresh_training_job_without_session_uses_default(self): mock_job.refresh.assert_called_once() + def test_to_snake_case(self): + """Test the _to_snake_case helper function.""" + from sagemaker.train.common_utils.trainer_wait import _to_snake_case -class TestBug4CodeArtifactTemplates: - """Bug 4: INSTALL_REQUIREMENTS template should check CA_REPOSITORY_ARN.""" + assert _to_snake_case("TrainingJobStatus") == "training_job_status" + assert _to_snake_case("TrainingJobName") == "training_job_name" + assert _to_snake_case("SecondaryStatus") == "secondary_status" + assert _to_snake_case("already_snake") == "already_snake" - def test_install_requirements_template_has_ca_support(self): - """Test that INSTALL_REQUIREMENTS includes CA_REPOSITORY_ARN check.""" - from sagemaker.train.templates import INSTALL_REQUIREMENTS - - rendered = INSTALL_REQUIREMENTS.format( - requirements_file="requirements.txt" + def test_refresh_training_job_updates_attributes(self): + """Test that _refresh_training_job updates job attributes from describe response.""" + from sagemaker.train.common_utils.trainer_wait import ( + _refresh_training_job, ) - assert "CA_REPOSITORY_ARN" in rendered - assert "aws codeartifact login --tool pip" in rendered - def test_install_requirements_template_does_pip_install(self): - """Test that INSTALL_REQUIREMENTS still does pip install.""" - from sagemaker.train.templates import INSTALL_REQUIREMENTS + mock_session = MagicMock() + mock_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobStatus": "Completed", + "TrainingJobName": "test-job", + "SecondaryStatus": "Completed", + } + + mock_job = MagicMock() + mock_job.training_job_name = "test-job" + mock_job.training_job_status = "InProgress" + mock_job.secondary_status = "Training" - rendered = INSTALL_REQUIREMENTS.format( - requirements_file="requirements.txt" - ) - has_pip = ( - "pip install -r requirements.txt" in rendered - or "$SM_PIP_CMD install -r requirements.txt" in rendered - ) - assert has_pip - - def test_install_requirements_format_does_not_raise(self): - """Test that .format() does not raise KeyError.""" - from sagemaker.train.templates import INSTALL_REQUIREMENTS - - # This was the CI failure - KeyError: 'region' - try: - rendered = INSTALL_REQUIREMENTS.format( - requirements_file="requirements.txt" - ) - except KeyError as e: - pytest.fail( - f"INSTALL_REQUIREMENTS.format() raised KeyError: {e}" - ) - - def test_install_auto_requirements_has_ca_support(self): - """Test that INSTALL_AUTO_REQUIREMENTS includes CA_REPOSITORY_ARN.""" - from sagemaker.train.templates import INSTALL_AUTO_REQUIREMENTS - - assert "CA_REPOSITORY_ARN" in INSTALL_AUTO_REQUIREMENTS - assert ( - "aws codeartifact login --tool pip" - in INSTALL_AUTO_REQUIREMENTS - ) + _refresh_training_job(mock_job, sagemaker_session=mock_session) + + # Verify attributes were updated via setattr + mock_session.sagemaker_client.describe_training_job.assert_called_once()