Skip to content

Aws

zenml.integrations.aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = [
        "sagemaker==2.117.0",
    ]

    @staticmethod
    def activate() -> None:
        """Activate the AWS integration."""
        from zenml.integrations.aws import service_connectors  # noqa

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.aws.flavors import (
            AWSContainerRegistryFlavor,
            AWSSecretsManagerFlavor,
            SagemakerOrchestratorFlavor,
            SagemakerStepOperatorFlavor,
        )

        return [
            AWSSecretsManagerFlavor,
            AWSContainerRegistryFlavor,
            SagemakerStepOperatorFlavor,
            SagemakerOrchestratorFlavor,
        ]

activate() staticmethod

Activate the AWS integration.

Source code in zenml/integrations/aws/__init__.py
@staticmethod
def activate() -> None:
    """Activate the AWS integration."""
    from zenml.integrations.aws import service_connectors  # noqa

flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.aws.flavors import (
        AWSContainerRegistryFlavor,
        AWSSecretsManagerFlavor,
        SagemakerOrchestratorFlavor,
        SagemakerStepOperatorFlavor,
    )

    return [
        AWSSecretsManagerFlavor,
        AWSContainerRegistryFlavor,
        SagemakerStepOperatorFlavor,
        SagemakerOrchestratorFlavor,
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry)

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    @property
    def config(self) -> AWSContainerRegistryConfig:
        """Returns the `AWSContainerRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSContainerRegistryConfig, self._config)

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(
            r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
        )
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.config.uri}."
            )

        return match.group(1)

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client(
            "ecr", region_name=self._get_region()
        ).describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(
            image_name.startswith(f"{uri}:") for uri in repo_uris
        )
        if not repo_exists:
            match = re.search(f"{self.config.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
config: AWSContainerRegistryConfig property readonly

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client(
        "ecr", region_name=self._get_region()
    ).describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(
        image_name.startswith(f"{uri}:") for uri in repo_uris
    )
    if not repo_exists:
        match = re.search(f"{self.config.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )

flavors special

AWS integration flavors.

aws_container_registry_flavor

AWS container registry flavor.

AWSContainerRegistryConfig (BaseContainerRegistryConfig) pydantic-model

Configuration for AWS Container Registry.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
    """Configuration for AWS Container Registry."""

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)

AWS Container Registry flavor.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
    """AWS Container Registry flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_CONTAINER_REGISTRY_FLAVOR

    @property
    def service_connector_requirements(
        self,
    ) -> Optional[ServiceConnectorRequirements]:
        """Service connector resource requirements for service connectors.

        Specifies resource requirements that are used to filter the available
        service connector types that are compatible with this flavor.

        Returns:
            Requirements for compatible service connectors, if a service
            connector is required for this flavor.
        """
        return ServiceConnectorRequirements(
            connector_type="aws",
            resource_type="docker-registry",
            resource_id_attr="uri",
        )

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png"

    @property
    def config_class(self) -> Type[AWSContainerRegistryConfig]:
        """Config class for this flavor.

        Returns:
            The config class.
        """
        return AWSContainerRegistryConfig

    @property
    def implementation_class(self) -> Type["AWSContainerRegistry"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.container_registries import (
            AWSContainerRegistry,
        )

        return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSContainerRegistry] property readonly

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[zenml.models.service_connector_models.ServiceConnectorRequirements] property readonly

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[zenml.models.service_connector_models.ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service connector is required for this flavor.

aws_secrets_manager_flavor

AWS secrets manager flavor.

AWSSecretsManagerConfig (BaseSecretsManagerConfig) pydantic-model

Configuration for the AWS Secrets Manager.

Attributes:

Name Type Description
region_name str

The region name of the AWS Secrets Manager.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
    """Configuration for the AWS Secrets Manager.

    Attributes:
        region_name: The region name of the AWS Secrets Manager.
    """

    SUPPORTS_SCOPING: ClassVar[bool] = True

    region_name: str

    @classmethod
    def _validate_scope(
        cls,
        scope: "SecretsManagerScope",
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)

Class for the AWSSecretsManagerFlavor.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
    """Class for the `AWSSecretsManagerFlavor`."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            Name of the flavor.
        """
        return AWS_SECRET_MANAGER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/secrets_managers/aws.png"

    @property
    def config_class(self) -> Type[AWSSecretsManagerConfig]:
        """Config class for this flavor.

        Returns:
            Config class for this flavor.
        """
        return AWSSecretsManagerConfig

    @property
    def implementation_class(self) -> Type["AWSSecretsManager"]:
        """Implementation class.

        Returns:
            Implementation class.
        """
        from zenml.integrations.aws.secrets_managers import AWSSecretsManager

        return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]

Config class for this flavor.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[AWSSecretsManager] property readonly

Implementation class.

Returns:

Type Description
Type[AWSSecretsManager]

Implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

Name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

validate_aws_secret_name_or_namespace(name)

Validate a secret name or namespace.

AWS secret names must contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
    """Validate a secret name or namespace.

    AWS secret names must contain only alphanumeric characters and the
    characters /_+=.@-. The `/` character is only used internally to delimit
    scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

sagemaker_orchestrator_flavor

Amazon SageMaker orchestrator flavor.

SagemakerOrchestratorConfig (BaseOrchestratorConfig, SagemakerOrchestratorSettings) pydantic-model

Config for the Sagemaker orchestrator.

Attributes:

Name Type Description
synchronous bool

Whether to run the processing job synchronously or asynchronously. Defaults to False.

execution_role str

The IAM role ARN to use for the pipeline.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseOrchestratorConfig, SagemakerOrchestratorSettings
):
    """Config for the Sagemaker orchestrator.

    Attributes:
        synchronous: Whether to run the processing job synchronously or
            asynchronously. Defaults to False.
        execution_role: The IAM role ARN to use for the pipeline.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format:
            "sagemaker-{region}-{aws-account-id}".
    """

    synchronous: bool = False
    execution_role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerOrchestratorFlavor (BaseOrchestratorFlavor)

Flavor for the Sagemaker orchestrator.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor):
    """Flavor for the Sagemaker orchestrator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerOrchestratorConfig]:
        """Returns SagemakerOrchestratorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerOrchestratorConfig

    @property
    def implementation_class(self) -> Type["SagemakerOrchestrator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.orchestrators import SagemakerOrchestrator

        return SagemakerOrchestrator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig] property readonly

Returns SagemakerOrchestratorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerOrchestrator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerOrchestrator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

SagemakerOrchestratorSettings (BaseSettings) pydantic-model

Settings for the Sagemaker orchestrator.

Attributes:

Name Type Description
instance_type str

The instance type to use for the processing job.

processor_role Optional[str]

The IAM role to use for the step execution on a Processor.

volume_size_in_gb int

The size of the EBS volume to use for the processing job.

max_runtime_in_seconds int

The maximum runtime in seconds for the processing job.

processor_tags Dict[str, str]

Tags to apply to the Processor assigned to the step.

processor_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Processor for a specific step, allowing for overriding the default settings provided when configuring the component. See https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor for a full list of arguments. For processor_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

input_data_s3_mode str

How data is made available to the container. Two possible input modes: File, Pipe.

input_data_s3_uri Union[str, Dict[str, str]]

S3 URI where data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

output_data_s3_mode str

How data is uploaded to the S3 bucket. Two possible output modes: EndOfJob, Continuous.

output_data_s3_uri Union[str, Dict[str, str]]

S3 URI where data is uploaded after or during processing run. e.g. s3://my-bucket/my-data/output. How data will be made available to the container is configured with output_data_s3_mode. Two possible input types: - str: S3 location where data will be uploaded from a local folder named /opt/ml/processing/output/data. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. output_one, output_two) where specific parts of the data are stored locally for S3 upload. Data must be available locally in /opt/ml/processing/output/data/.

Source code in zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py
class SagemakerOrchestratorSettings(BaseSettings):
    """Settings for the Sagemaker orchestrator.

    Attributes:
        instance_type: The instance type to use for the processing job.
        processor_role: The IAM role to use for the step execution on a Processor.
        volume_size_in_gb: The size of the EBS volume to use for the processing
            job.
        max_runtime_in_seconds: The maximum runtime in seconds for the
            processing job.
        processor_tags: Tags to apply to the Processor assigned to the step.
        processor_args: Arguments that are directly passed to the SageMaker
            Processor for a specific step, allowing for overriding the default
            settings provided when configuring the component. See
            https://sagemaker.readthedocs.io/en/stable/api/training/processing.html#sagemaker.processing.Processor
            for a full list of arguments.
            For processor_args.instance_type, check
            https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.
        input_data_s3_mode: How data is made available to the container.
            Two possible input modes: File, Pipe.
        input_data_s3_uri: S3 URI where data is located if not locally,
            e.g. s3://my-bucket/my-data/train. How data will be made available
            to the container is configured with input_data_s3_mode. Two possible
            input types:
                - str: S3 location where training data is saved.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. training, validation, testing) where
                    specific parts of the data are saved in S3.
        output_data_s3_mode: How data is uploaded to the S3 bucket.
            Two possible output modes: EndOfJob, Continuous.
        output_data_s3_uri: S3 URI where data is uploaded after or during processing run.
            e.g. s3://my-bucket/my-data/output. How data will be made available
            to the container is configured with output_data_s3_mode. Two possible
            input types:
                - str: S3 location where data will be uploaded from a local folder
                    named /opt/ml/processing/output/data.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. output_one, output_two) where
                    specific parts of the data are stored locally for S3 upload.
                    Data must be available locally in /opt/ml/processing/output/data/<ChannelName>.
    """

    instance_type: str = "ml.t3.medium"
    processor_role: Optional[str] = None
    volume_size_in_gb: int = 30
    max_runtime_in_seconds: int = 86400
    processor_tags: Dict[str, str] = {}

    processor_args: Dict[str, Any] = {}
    input_data_s3_mode: str = "File"
    input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None

    output_data_s3_mode: str = "EndOfJob"
    output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None

sagemaker_step_operator_flavor

Amazon SageMaker step operator flavor.

SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings) pydantic-model

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
    """Config for the Sagemaker step operator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
    """

    role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)

Flavor for the Sagemaker step operator.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
    """Flavor for the Sagemaker step operator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png"

    @property
    def config_class(self) -> Type[SagemakerStepOperatorConfig]:
        """Returns SagemakerStepOperatorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerStepOperatorConfig

    @property
    def implementation_class(self) -> Type["SagemakerStepOperator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.step_operators import SagemakerStepOperator

        return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] property readonly

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[SagemakerStepOperator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

SagemakerStepOperatorSettings (BaseSettings) pydantic-model

Settings for the Sagemaker step operator.

Attributes:

Name Type Description
experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

input_data_s3_uri Union[str, Dict[str, str]]

S3 URI where training data is located if not locally, e.g. s3://my-bucket/my-data/train. How data will be made available to the container is configured with estimator_args.input_mode. Two possible input types: - str: S3 location where training data is saved. - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3.

estimator_args Dict[str, Any]

Arguments that are directly passed to the SageMaker Estimator. See https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for a full list of arguments. For estimator_args.instance_type, check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
    """Settings for the Sagemaker step operator.

    Attributes:
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
        input_data_s3_uri: S3 URI where training data is located if not locally,
            e.g. s3://my-bucket/my-data/train. How data will be made available
            to the container is configured with estimator_args.input_mode. Two possible
            input types:
                - str: S3 location where training data is saved.
                - Dict[str, str]: (ChannelName, S3Location) which represent
                    channels (e.g. training, validation, testing) where
                    specific parts of the data are saved in S3.
        estimator_args: Arguments that are directly passed to the SageMaker
            Estimator. See
            https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator
            for a full list of arguments.
            For estimator_args.instance_type, check
            https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.

    """

    instance_type: Optional[str] = None
    experiment_name: Optional[str] = None
    input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = None
    estimator_args: Dict[str, Any] = {}

    _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
        "instance_type"
    )

orchestrators special

AWS Sagemaker orchestrator.

sagemaker_orchestrator

Implementation of the SageMaker orchestrator.

SagemakerOrchestrator (ContainerizedOrchestrator)

Orchestrator responsible for running pipelines on Sagemaker.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
class SagemakerOrchestrator(ContainerizedOrchestrator):
    """Orchestrator responsible for running pipelines on Sagemaker."""

    @property
    def config(self) -> SagemakerOrchestratorConfig:
        """Returns the `SagemakerOrchestratorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerOrchestratorConfig, self._config)

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        In the remote case, checks that the stack contains a container registry,
        image builder and only remote components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_remote_components(
            stack: "Stack",
        ) -> Tuple[bool, str]:
            for component in stack.components.values():
                if not component.config.is_local:
                    continue

                return False, (
                    f"The Sagemaker orchestrator runs pipelines remotely, "
                    f"but the '{component.name}' {component.type.value} is "
                    "a local stack component and will not be available in "
                    "the Sagemaker step.\nPlease ensure that you always "
                    "use non-local stack components with the Sagemaker "
                    "orchestrator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_orchestrator_run_id(self) -> str:
        """Returns the run id of the active orchestrator run.

        Important: This needs to be a unique ID and return the same value for
        all steps of a pipeline run.

        Returns:
            The orchestrator run id.

        Raises:
            RuntimeError: If the run id cannot be read from the environment.
        """
        try:
            return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
        except KeyError:
            raise RuntimeError(
                "Unable to read run id from environment variable "
                f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
            )

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Sagemaker orchestrator.

        Returns:
            The settings class.
        """
        return SagemakerOrchestratorSettings

    def prepare_or_run_pipeline(
        self,
        deployment: "PipelineDeploymentResponseModel",
        stack: "Stack",
        environment: Dict[str, str],
    ) -> None:
        """Prepares or runs a pipeline on Sagemaker.

        Args:
            deployment: The deployment to prepare or run.
            stack: The stack to run on.
            environment: Environment variables to set in the orchestration
                environment.

        """
        if deployment.schedule:
            logger.warning(
                "The Sagemaker Orchestrator currently does not support the "
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run immediately."
            )

        # sagemaker requires pipelineName to use alphanum and hyphens only
        unsanitized_orchestrator_run_name = get_orchestrator_run_name(
            pipeline_name=deployment.pipeline_configuration.name
        )
        # replace all non-alphanum and non-hyphens with hyphens
        orchestrator_run_name = re.sub(
            r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
        )

        session = sagemaker.Session(default_bucket=self.config.bucket)

        sagemaker_steps = []
        for step_name, step in deployment.step_configurations.items():
            image = self.get_image(deployment=deployment, step_name=step_name)
            command = StepEntrypointConfiguration.get_entrypoint_command()
            arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name, deployment_id=deployment.id
            )
            entrypoint = command + arguments

            step_settings = cast(
                SagemakerOrchestratorSettings, self.get_settings(step)
            )

            environment[
                ENV_ZENML_SAGEMAKER_RUN_ID
            ] = ExecutionVariables.PIPELINE_EXECUTION_ARN

            # Retrieve Processor arguments provided in the Step settings.
            processor_args_for_step = step_settings.processor_args or {}

            # Set default values from configured orchestrator Component to arguments
            # to be used when they are not present in processor_args.
            processor_args_for_step.setdefault(
                "instance_type", step_settings.instance_type
            )
            processor_args_for_step.setdefault(
                "role",
                step_settings.processor_role or self.config.execution_role,
            )
            processor_args_for_step.setdefault(
                "volume_size_in_gb", step_settings.volume_size_in_gb
            )
            processor_args_for_step.setdefault(
                "max_runtime_in_seconds", step_settings.max_runtime_in_seconds
            )
            processor_args_for_step.setdefault(
                "tags",
                [step_settings.processor_tags]
                if step_settings.processor_tags
                else None,
            )

            # Set values that cannot be overwritten
            processor_args_for_step["image_uri"] = image
            processor_args_for_step["instance_count"] = 1
            processor_args_for_step["sagemaker_session"] = session
            processor_args_for_step["entrypoint"] = entrypoint
            processor_args_for_step["base_job_name"] = orchestrator_run_name
            processor_args_for_step["env"] = environment

            # Construct S3 inputs to container for step
            inputs = None

            if step_settings.input_data_s3_uri is None:
                pass
            elif isinstance(step_settings.input_data_s3_uri, str):
                inputs = [
                    ProcessingInput(
                        source=step_settings.input_data_s3_uri,
                        destination="/opt/ml/processing/input/data",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                ]
            elif isinstance(step_settings.input_data_s3_uri, dict):
                inputs = []
                for channel, s3_uri in step_settings.input_data_s3_uri.items():
                    inputs.append(
                        ProcessingInput(
                            source=s3_uri,
                            destination=f"/opt/ml/processing/input/data/{channel}",
                            s3_input_mode=step_settings.input_data_s3_mode,
                        )
                    )

            # Construct S3 outputs from container for step
            outputs = None

            if step_settings.output_data_s3_uri is None:
                pass
            elif isinstance(step_settings.output_data_s3_uri, str):
                outputs = [
                    ProcessingOutput(
                        source="/opt/ml/processing/output/data",
                        destination=step_settings.output_data_s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                ]
            elif isinstance(step_settings.output_data_s3_uri, dict):
                outputs = []
                for (
                    channel,
                    s3_uri,
                ) in step_settings.output_data_s3_uri.items():
                    outputs.append(
                        ProcessingOutput(
                            source=f"/opt/ml/processing/output/data/{channel}",
                            destination=s3_uri,
                            s3_upload_mode=step_settings.output_data_s3_mode,
                        )
                    )

            # Create Processor and ProcessingStep
            processor = sagemaker.processing.Processor(
                **processor_args_for_step
            )
            sagemaker_step = ProcessingStep(
                name=step_name,
                processor=processor,
                depends_on=step.spec.upstream_steps,
                inputs=inputs,
                outputs=outputs,
            )
            sagemaker_steps.append(sagemaker_step)

        # construct the pipeline from the sagemaker_steps
        pipeline = Pipeline(
            name=orchestrator_run_name,
            steps=sagemaker_steps,
            sagemaker_session=session,
        )

        pipeline.create(role_arn=self.config.execution_role)
        pipeline_execution = pipeline.start()

        # mainly for testing purposes, we wait for the pipeline to finish
        if self.config.synchronous:
            logger.info(
                "Executing synchronously. Waiting for pipeline to finish..."
            )
            pipeline_execution.wait(
                delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
            )
            logger.info("Pipeline completed successfully.")

    def _get_region_name(self) -> str:
        """Returns the AWS region name.

        Returns:
            The region name.

        Raises:
            RuntimeError: If the region name cannot be retrieved.
        """
        try:
            return cast(str, sagemaker.Session().boto_region_name)
        except Exception as e:
            raise RuntimeError(
                "Unable to get region name. Please ensure that you have "
                "configured your AWS credentials correctly."
            ) from e

    def get_pipeline_run_metadata(
        self, run_id: UUID
    ) -> Dict[str, "MetadataType"]:
        """Get general component-specific metadata for a pipeline run.

        Args:
            run_id: The ID of the pipeline run.

        Returns:
            A dictionary of metadata.
        """
        run_metadata: Dict[str, "MetadataType"] = {
            "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
        }
        try:
            region_name = self._get_region_name()
        except RuntimeError:
            logger.warning("Unable to get region name from AWS Sagemaker.")
            return run_metadata

        aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
        orchestrator_logs_url = (
            f"https://{region_name}.console.aws.amazon.com/"
            f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
            f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
            f"$3Dpipelines-{aws_run_id}-"
        )
        run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
        return run_metadata
config: SagemakerOrchestratorConfig property readonly

Returns the SagemakerOrchestratorConfig config.

Returns:

Type Description
SagemakerOrchestratorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Sagemaker orchestrator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

In the remote case, checks that the stack contains a container registry, image builder and only remote components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

get_orchestrator_run_id(self)

Returns the run id of the active orchestrator run.

Important: This needs to be a unique ID and return the same value for all steps of a pipeline run.

Returns:

Type Description
str

The orchestrator run id.

Exceptions:

Type Description
RuntimeError

If the run id cannot be read from the environment.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_orchestrator_run_id(self) -> str:
    """Returns the run id of the active orchestrator run.

    Important: This needs to be a unique ID and return the same value for
    all steps of a pipeline run.

    Returns:
        The orchestrator run id.

    Raises:
        RuntimeError: If the run id cannot be read from the environment.
    """
    try:
        return os.environ[ENV_ZENML_SAGEMAKER_RUN_ID]
    except KeyError:
        raise RuntimeError(
            "Unable to read run id from environment variable "
            f"{ENV_ZENML_SAGEMAKER_RUN_ID}."
        )
get_pipeline_run_metadata(self, run_id)

Get general component-specific metadata for a pipeline run.

Parameters:

Name Type Description Default
run_id UUID

The ID of the pipeline run.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def get_pipeline_run_metadata(
    self, run_id: UUID
) -> Dict[str, "MetadataType"]:
    """Get general component-specific metadata for a pipeline run.

    Args:
        run_id: The ID of the pipeline run.

    Returns:
        A dictionary of metadata.
    """
    run_metadata: Dict[str, "MetadataType"] = {
        "pipeline_execution_arn": os.environ[ENV_ZENML_SAGEMAKER_RUN_ID],
    }
    try:
        region_name = self._get_region_name()
    except RuntimeError:
        logger.warning("Unable to get region name from AWS Sagemaker.")
        return run_metadata

    aws_run_id = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID].split("/")[-1]
    orchestrator_logs_url = (
        f"https://{region_name}.console.aws.amazon.com/"
        f"cloudwatch/home?region={region_name}#logsV2:log-groups/log-group"
        f"/$252Faws$252Fsagemaker$252FProcessingJobs$3FlogStreamNameFilter"
        f"$3Dpipelines-{aws_run_id}-"
    )
    run_metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_logs_url)
    return run_metadata
prepare_or_run_pipeline(self, deployment, stack, environment)

Prepares or runs a pipeline on Sagemaker.

Parameters:

Name Type Description Default
deployment PipelineDeploymentResponseModel

The deployment to prepare or run.

required
stack Stack

The stack to run on.

required
environment Dict[str, str]

Environment variables to set in the orchestration environment.

required
Source code in zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
def prepare_or_run_pipeline(
    self,
    deployment: "PipelineDeploymentResponseModel",
    stack: "Stack",
    environment: Dict[str, str],
) -> None:
    """Prepares or runs a pipeline on Sagemaker.

    Args:
        deployment: The deployment to prepare or run.
        stack: The stack to run on.
        environment: Environment variables to set in the orchestration
            environment.

    """
    if deployment.schedule:
        logger.warning(
            "The Sagemaker Orchestrator currently does not support the "
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run immediately."
        )

    # sagemaker requires pipelineName to use alphanum and hyphens only
    unsanitized_orchestrator_run_name = get_orchestrator_run_name(
        pipeline_name=deployment.pipeline_configuration.name
    )
    # replace all non-alphanum and non-hyphens with hyphens
    orchestrator_run_name = re.sub(
        r"[^a-zA-Z0-9\-]", "-", unsanitized_orchestrator_run_name
    )

    session = sagemaker.Session(default_bucket=self.config.bucket)

    sagemaker_steps = []
    for step_name, step in deployment.step_configurations.items():
        image = self.get_image(deployment=deployment, step_name=step_name)
        command = StepEntrypointConfiguration.get_entrypoint_command()
        arguments = StepEntrypointConfiguration.get_entrypoint_arguments(
            step_name=step_name, deployment_id=deployment.id
        )
        entrypoint = command + arguments

        step_settings = cast(
            SagemakerOrchestratorSettings, self.get_settings(step)
        )

        environment[
            ENV_ZENML_SAGEMAKER_RUN_ID
        ] = ExecutionVariables.PIPELINE_EXECUTION_ARN

        # Retrieve Processor arguments provided in the Step settings.
        processor_args_for_step = step_settings.processor_args or {}

        # Set default values from configured orchestrator Component to arguments
        # to be used when they are not present in processor_args.
        processor_args_for_step.setdefault(
            "instance_type", step_settings.instance_type
        )
        processor_args_for_step.setdefault(
            "role",
            step_settings.processor_role or self.config.execution_role,
        )
        processor_args_for_step.setdefault(
            "volume_size_in_gb", step_settings.volume_size_in_gb
        )
        processor_args_for_step.setdefault(
            "max_runtime_in_seconds", step_settings.max_runtime_in_seconds
        )
        processor_args_for_step.setdefault(
            "tags",
            [step_settings.processor_tags]
            if step_settings.processor_tags
            else None,
        )

        # Set values that cannot be overwritten
        processor_args_for_step["image_uri"] = image
        processor_args_for_step["instance_count"] = 1
        processor_args_for_step["sagemaker_session"] = session
        processor_args_for_step["entrypoint"] = entrypoint
        processor_args_for_step["base_job_name"] = orchestrator_run_name
        processor_args_for_step["env"] = environment

        # Construct S3 inputs to container for step
        inputs = None

        if step_settings.input_data_s3_uri is None:
            pass
        elif isinstance(step_settings.input_data_s3_uri, str):
            inputs = [
                ProcessingInput(
                    source=step_settings.input_data_s3_uri,
                    destination="/opt/ml/processing/input/data",
                    s3_input_mode=step_settings.input_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.input_data_s3_uri, dict):
            inputs = []
            for channel, s3_uri in step_settings.input_data_s3_uri.items():
                inputs.append(
                    ProcessingInput(
                        source=s3_uri,
                        destination=f"/opt/ml/processing/input/data/{channel}",
                        s3_input_mode=step_settings.input_data_s3_mode,
                    )
                )

        # Construct S3 outputs from container for step
        outputs = None

        if step_settings.output_data_s3_uri is None:
            pass
        elif isinstance(step_settings.output_data_s3_uri, str):
            outputs = [
                ProcessingOutput(
                    source="/opt/ml/processing/output/data",
                    destination=step_settings.output_data_s3_uri,
                    s3_upload_mode=step_settings.output_data_s3_mode,
                )
            ]
        elif isinstance(step_settings.output_data_s3_uri, dict):
            outputs = []
            for (
                channel,
                s3_uri,
            ) in step_settings.output_data_s3_uri.items():
                outputs.append(
                    ProcessingOutput(
                        source=f"/opt/ml/processing/output/data/{channel}",
                        destination=s3_uri,
                        s3_upload_mode=step_settings.output_data_s3_mode,
                    )
                )

        # Create Processor and ProcessingStep
        processor = sagemaker.processing.Processor(
            **processor_args_for_step
        )
        sagemaker_step = ProcessingStep(
            name=step_name,
            processor=processor,
            depends_on=step.spec.upstream_steps,
            inputs=inputs,
            outputs=outputs,
        )
        sagemaker_steps.append(sagemaker_step)

    # construct the pipeline from the sagemaker_steps
    pipeline = Pipeline(
        name=orchestrator_run_name,
        steps=sagemaker_steps,
        sagemaker_session=session,
    )

    pipeline.create(role_arn=self.config.execution_role)
    pipeline_execution = pipeline.start()

    # mainly for testing purposes, we wait for the pipeline to finish
    if self.config.synchronous:
        logger.info(
            "Executing synchronously. Waiting for pipeline to finish..."
        )
        pipeline_execution.wait(
            delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS
        )
        logger.info("Pipeline completed successfully.")

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager)

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    CLIENT: ClassVar[Any] = None

    @property
    def config(self) -> AWSSecretsManagerConfig:
        """Returns the `AWSSecretsManagerConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSSecretsManagerConfig, self._config)

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    def _get_secret_tags(
        self, secret: BaseSecretSchema
    ) -> List[Dict[str, str]]:
        """Return a list of AWS secret tag values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of AWS secret tag values
        """
        metadata = self._get_secret_metadata(secret)
        return [{"Key": k, "Value": v} for k, v in metadata.items()]

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """Return a list of AWS filters for the entire scope or just a scoped secret.

        These filters can be used when querying the AWS Secrets Manager
        for all secrets or for a single secret available in the configured
        scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html

        Example AWS filters for all secrets in the current (namespace) scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Example AWS filters for a particular secret in the current (namespace)
        scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_secret_name"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_secret"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of AWS filters uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        metadata = self._get_secret_scope_metadata(secret_name)
        filters: List[Dict[str, Any]] = []
        for k, v in metadata.items():
            filters.append(
                {
                    "Key": "tag-key",
                    "Values": [
                        k,
                    ],
                }
            )
            filters.append(
                {
                    "Key": "tag-value",
                    "Values": [
                        str(v),
                    ],
                }
            )

        return filters

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected(self.config.region_name)

        filters: List[Dict[str, Any]] = []
        prefix: Optional[str] = None
        if self.config.scope == SecretsManagerScope.NONE:
            # unscoped (legacy) secrets don't have tags. We want to filter out
            # non-legacy secrets
            filters = [
                {
                    "Key": "tag-key",
                    "Values": [
                        "!zenml_scope",
                    ],
                },
            ]
            if secret_name:
                prefix = secret_name
        else:
            filters = self._get_secret_scope_filters()
            if secret_name:
                prefix = self._get_scoped_secret_name(secret_name)
            else:
                # add the name prefix to the filters to account for the fact
                # that AWS does not do exact matching but prefix-matching on the
                # filters
                prefix = self._get_scoped_secret_name_prefix()

        if prefix:
            filters.append(
                {
                    "Key": "name",
                    "Values": [
                        f"{prefix}",
                    ],
                }
            )

        paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
        pages = paginator.paginate(
            Filters=filters,
            PaginationConfig={
                "PageSize": 100,
            },
        )
        results = []
        for page in pages:
            for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
                name = self._get_unscoped_secret_name(secret["Name"])
                # keep only the names that are in scope and filter by secret name,
                # if one was given
                if name and (not secret_name or secret_name == name):
                    results.append(name)

        return results

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        secret_value = json.dumps(secret_to_dict(secret, encode=False))
        kwargs: Dict[str, Any] = {
            "Name": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
            "Tags": self._get_secret_tags(secret),
        }

        self.CLIENT.create_secret(**kwargs)

        logger.debug("Created AWS secret: %s", kwargs["Name"])

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret_name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=self._get_scoped_secret_name(secret_name)
        )
        if "SecretString" not in get_secret_value_response:
            get_secret_value_response = None

        return secret_from_dict(
            json.loads(get_secret_value_response["SecretString"]),
            secret_name=secret_name,
            decode=False,
        )

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        secret_value = json.dumps(secret_to_dict(secret))

        kwargs = {
            "SecretId": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
        }

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret does not exist
        """
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.config.region_name)
        for secret_name in self._list_secrets():
            self.CLIENT.delete_secret(
                SecretId=self._get_scoped_secret_name(secret_name),
                ForceDeleteWithoutRecovery=True,
            )
config: AWSSecretsManagerConfig property readonly

Returns the AWSSecretsManagerConfig config.

Returns:

Type Description
AWSSecretsManagerConfig

The configuration.

delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.config.region_name)
    for secret_name in self._list_secrets():
        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret does not exist
    """
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    self.CLIENT.delete_secret(
        SecretId=self._get_scoped_secret_name(secret_name),
        ForceDeleteWithoutRecovery=True,
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret_name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=self._get_scoped_secret_name(secret_name)
    )
    if "SecretString" not in get_secret_value_response:
        get_secret_value_response = None

    return secret_from_dict(
        json.loads(get_secret_value_response["SecretString"]),
        secret_name=secret_name,
        decode=False,
    )
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    secret_value = json.dumps(secret_to_dict(secret, encode=False))
    kwargs: Dict[str, Any] = {
        "Name": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
        "Tags": self._get_secret_tags(secret),
    }

    self.CLIENT.create_secret(**kwargs)

    logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    secret_value = json.dumps(secret_to_dict(secret))

    kwargs = {
        "SecretId": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
    }

    self.CLIENT.put_secret_value(**kwargs)

service_connectors special

AWS Service Connector.

aws_service_connector

AWS Service Connector.

The AWS Service Connector implements various authentication methods for AWS services:

  • Explicit AWS secret key (access key, secret key)
  • Explicit AWS STS tokens (access key, secret key, session token)
  • IAM roles (i.e. generating temporary STS tokens on the fly by assuming an IAM role)
  • IAM user federation tokens
  • STS Session tokens
AWSAuthenticationMethods (StrEnum)

AWS Authentication methods.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSAuthenticationMethods(StrEnum):
    """AWS Authentication methods."""

    IMPLICIT = "implicit"
    SECRET_KEY = "secret-key"
    STS_TOKEN = "sts-token"
    IAM_ROLE = "iam-role"
    SESSION_TOKEN = "session-token"
    FEDERATION_TOKEN = "federation-token"
AWSBaseConfig (AuthenticationConfig) pydantic-model

AWS base configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSBaseConfig(AuthenticationConfig):
    """AWS base configuration."""

    region: str = Field(
        title="AWS Region",
    )
    endpoint_url: Optional[str] = Field(
        default=None,
        title="AWS Endpoint URL",
    )
AWSImplicitConfig (AWSBaseConfig) pydantic-model

AWS implicit configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSImplicitConfig(AWSBaseConfig):
    """AWS implicit configuration."""

    profile_name: Optional[str] = Field(
        default=None,
        title="AWS Profile Name",
    )
AWSSecretKey (AuthenticationConfig) pydantic-model

AWS secret key credentials.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKey(AuthenticationConfig):
    """AWS secret key credentials."""

    aws_access_key_id: SecretStr = Field(
        title="AWS Access Key ID",
    )
    aws_secret_access_key: SecretStr = Field(
        title="AWS Secret Access Key",
    )
AWSSecretKeyConfig (AWSBaseConfig, AWSSecretKey) pydantic-model

AWS secret key authentication configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSecretKeyConfig(AWSBaseConfig, AWSSecretKey):
    """AWS secret key authentication configuration."""
AWSServiceConnector (ServiceConnector) pydantic-model

AWS service connector.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSServiceConnector(ServiceConnector):
    """AWS service connector."""

    config: AWSBaseConfig

    _account_id: Optional[str] = None
    _session_cache: Dict[
        Tuple[str, Optional[str], Optional[str]],
        Tuple[boto3.Session, Optional[datetime.datetime]],
    ] = {}

    @classmethod
    def _get_connector_type(cls) -> ServiceConnectorTypeModel:
        """Get the service connector type specification.

        Returns:
            The service connector type specification.
        """
        return AWS_SERVICE_CONNECTOR_TYPE_SPEC

    @property
    def account_id(self) -> str:
        """Get the AWS account ID.

        Returns:
            The AWS account ID.

        Raises:
            AuthorizationException: If the AWS account ID could not be
                determined.
        """
        if self._account_id is None:
            logger.debug("Getting account ID from AWS...")
            try:
                session, _ = self.get_boto3_session(self.auth_method)
                sts_client = session.client("sts")
                response = sts_client.get_caller_identity()
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to fetch the AWS account ID: {e}"
                ) from e

            self._account_id = response["Account"]

        return self._account_id

    def get_boto3_session(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
        """Get a boto3 session for the specified resource.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to get a boto3 session for.
            resource_id: The resource ID to get a boto3 session for.

        Returns:
            A boto3 session for the specified resource and its expiration
            timestamp, if applicable.
        """
        # We maintain a cache of all sessions to avoid re-authenticating
        # multiple times for the same resource
        key = (auth_method, resource_type, resource_id)
        if key in self._session_cache:
            session, expires_at = self._session_cache[key]
            if expires_at is None:
                return session, None

            # Refresh expired sessions
            now = datetime.datetime.now(datetime.timezone.utc)
            expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
            if expires_at > now:
                return session, expires_at

        logger.debug(
            f"Creating boto3 session for auth method '{auth_method}', "
            f"resource type '{resource_type}' and resource ID "
            f"'{resource_id}'..."
        )
        session, expires_at = self._authenticate(
            auth_method, resource_type, resource_id
        )
        self._session_cache[key] = (session, expires_at)
        return session, expires_at

    def _get_iam_policy(
        self,
        region_id: str,
        resource_type: Optional[str],
        resource_id: Optional[str] = None,
    ) -> Optional[str]:
        """Get the IAM inline policy to use for the specified resource.

        Args:
            region_id: The AWS region ID to get the IAM inline policy for.
            resource_type: The resource type to get the IAM inline policy for.
            resource_id: The resource ID to get the IAM inline policy for.

        Returns:
            The IAM inline policy to use for the specified resource.
        """
        if resource_type == S3_RESOURCE_TYPE:
            if resource_id:
                bucket = self._parse_s3_resource_id(resource_id)
                resource = [
                    f"arn:aws:s3:::{bucket}",
                    f"arn:aws:s3:::{bucket}/*",
                ]
            else:
                resource = ["arn:aws:s3:::*", "arn:aws:s3:::*/*"]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowS3BucketAccess",
                        "Effect": "Allow",
                        "Action": [
                            "s3:ListBucket",
                            "s3:GetObject",
                            "s3:PutObject",
                            "s3:DeleteObject",
                            "s3:ListAllMyBuckets",
                        ],
                        "Resource": resource,
                    },
                ],
            }
            return json.dumps(policy)
        elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            if resource_id:
                cluster_name = self._parse_eks_resource_id(resource_id)
                resource = [
                    f"arn:aws:eks:{region_id}:*:cluster/{cluster_name}",
                ]
            else:
                resource = [f"arn:aws:eks:{region_id}:*:cluster/*"]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowEKSClusterAccess",
                        "Effect": "Allow",
                        "Action": [
                            "eks:ListClusters",
                            "eks:DescribeCluster",
                        ],
                        "Resource": resource,
                    },
                ],
            }
            return json.dumps(policy)
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            resource = [
                f"arn:aws:ecr:{region_id}:*:repository/*",
                f"arn:aws:ecr:{region_id}:*:repository",
            ]
            policy = {
                "Version": "2012-10-17",
                "Statement": [
                    {
                        "Sid": "AllowECRRepositoryAccess",
                        "Effect": "Allow",
                        "Action": [
                            "ecr:DescribeRegistry",
                            "ecr:DescribeRepositories",
                            "ecr:ListRepositories",
                            "ecr:BatchGetImage",
                            "ecr:DescribeImages",
                            "ecr:BatchCheckLayerAvailability",
                            "ecr:GetDownloadUrlForLayer",
                            "ecr:InitiateLayerUpload",
                            "ecr:UploadLayerPart",
                            "ecr:CompleteLayerUpload",
                            "ecr:PutImage",
                        ],
                        "Resource": resource,
                    },
                    {
                        "Sid": "AllowECRRepositoryGetToken",
                        "Effect": "Allow",
                        "Action": [
                            "ecr:GetAuthorizationToken",
                        ],
                        "Resource": ["*"],
                    },
                ],
            }
            return json.dumps(policy)

        return None

    def _authenticate(
        self,
        auth_method: str,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
        """Authenticate to AWS and return a boto3 session.

        Args:
            auth_method: The authentication method to use.
            resource_type: The resource type to authenticate for.
            resource_id: The resource ID to authenticate for.

        Returns:
            An authenticated boto3 session and the expiration time of the
            temporary credentials if applicable.

        Raises:
            AuthorizationException: If the IAM role authentication method is
                used and the role cannot be assumed.
            NotImplementedError: If the authentication method is not supported.
        """
        cfg = self.config
        if auth_method == AWSAuthenticationMethods.IMPLICIT:
            assert isinstance(cfg, AWSImplicitConfig)
            # Create a boto3 session and use the default credentials provider
            session = boto3.Session(
                profile_name=cfg.profile_name, region_name=cfg.region
            )

            credentials = session.get_credentials()
            if not credentials:
                raise AuthorizationException(
                    "Failed to get AWS credentials from the default provider. "
                    "Please check your AWS configuration or attached IAM role."
                )
            if credentials.token:
                # Temporary credentials were generated. It's not possible to
                # determine the expiration time of the temporary credentials
                # from the boto3 session, so we assume the default IAM role
                # expiration date is used
                expiration_time = datetime.datetime.now(
                    tz=datetime.timezone.utc
                ) + datetime.timedelta(
                    seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
                )
                return session, expiration_time

            return session, None
        elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
            assert isinstance(cfg, AWSSecretKeyConfig)
            # Create a boto3 session using long-term AWS credentials
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                region_name=cfg.region,
            )
            return session, None
        elif auth_method == AWSAuthenticationMethods.STS_TOKEN:
            assert isinstance(cfg, STSTokenConfig)
            # Create a boto3 session using a temporary AWS STS token
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                aws_session_token=cfg.aws_session_token.get_secret_value(),
                region_name=cfg.region,
            )
            return session, cfg.expires_at
        elif auth_method in [
            AWSAuthenticationMethods.IAM_ROLE,
            AWSAuthenticationMethods.SESSION_TOKEN,
            AWSAuthenticationMethods.FEDERATION_TOKEN,
        ]:
            assert isinstance(cfg, AWSSecretKey)

            # Create a boto3 session
            session = boto3.Session(
                aws_access_key_id=cfg.aws_access_key_id.get_secret_value(),
                aws_secret_access_key=cfg.aws_secret_access_key.get_secret_value(),
                region_name=cfg.region,
            )

            sts = session.client("sts", region_name=cfg.region)
            session_name = "zenml-connector"
            if self.id:
                session_name += f"-{self.id}"

            # Next steps are different for each authentication method

            # The IAM role and federation token authentication methods
            # accept a managed IAM policy that restricts/grants permissions.
            # If one isn't explicitly configured, we generate one based on the
            # resource specified by the resource type and ID (if present).
            if auth_method in [
                AWSAuthenticationMethods.IAM_ROLE,
                AWSAuthenticationMethods.FEDERATION_TOKEN,
            ]:
                assert isinstance(cfg, AWSSessionPolicy)
                policy_kwargs = {}
                policy = cfg.policy
                if not cfg.policy and not cfg.policy_arns:
                    policy = self._get_iam_policy(
                        region_id=cfg.region,
                        resource_type=resource_type,
                        resource_id=resource_id,
                    )
                if policy:
                    policy_kwargs["Policy"] = policy
                elif cfg.policy_arns:
                    policy_kwargs["PolicyArns"] = cfg.policy_arns

                if auth_method == AWSAuthenticationMethods.IAM_ROLE:
                    assert isinstance(cfg, IAMRoleAuthenticationConfig)

                    try:
                        response = sts.assume_role(
                            RoleArn=cfg.role_arn,
                            RoleSessionName=session_name,
                            DurationSeconds=self.expiration_seconds,
                            **policy_kwargs,
                        )
                    except (ClientError, BotoCoreError) as e:
                        raise AuthorizationException(
                            f"Failed to assume IAM role {cfg.role_arn} "
                            f"using the AWS credentials configured in the "
                            f"connector: {e}"
                        ) from e

                else:
                    assert isinstance(cfg, FederationTokenAuthenticationConfig)

                    try:
                        response = sts.get_federation_token(
                            Name=session_name[:32],
                            DurationSeconds=self.expiration_seconds,
                            **policy_kwargs,
                        )
                    except (ClientError, BotoCoreError) as e:
                        raise AuthorizationException(
                            "Failed to get federation token "
                            "using the AWS credentials configured in the "
                            f"connector: {e}"
                        ) from e

            else:
                assert isinstance(cfg, SessionTokenAuthenticationConfig)
                try:
                    response = sts.get_session_token(
                        DurationSeconds=self.expiration_seconds,
                    )
                except (ClientError, BotoCoreError) as e:
                    raise AuthorizationException(
                        "Failed to get session token "
                        "using the AWS credentials configured in the "
                        f"connector: {e}"
                    ) from e

            session = boto3.Session(
                aws_access_key_id=response["Credentials"]["AccessKeyId"],
                aws_secret_access_key=response["Credentials"][
                    "SecretAccessKey"
                ],
                aws_session_token=response["Credentials"]["SessionToken"],
            )
            expiration = response["Credentials"]["Expiration"]
            # Add the UTC timezone to the expiration time
            expiration = expiration.replace(tzinfo=datetime.timezone.utc)
            return session, expiration

        raise NotImplementedError(
            f"Authentication method '{auth_method}' is not supported by "
            "the AWS connector."
        )

    @classmethod
    def _get_eks_bearer_token(
        cls,
        session: boto3.Session,
        cluster_id: str,
        region: str,
    ) -> str:
        """Generate a bearer token for authenticating to the EKS API server.

        Based on: https://github.com/kubernetes-sigs/aws-iam-authenticator/blob/master/README.md#api-authorization-from-outside-a-cluster

        Args:
            session: An authenticated boto3 session to use for generating the
                token.
            cluster_id: The name of the EKS cluster.
            region: The AWS region the EKS cluster is in.

        Returns:
            A bearer token for authenticating to the EKS API server.
        """
        client = session.client("sts", region_name=region)
        service_id = client.meta.service_model.service_id

        signer = RequestSigner(
            service_id,
            region,
            "sts",
            "v4",
            session.get_credentials(),
            session.events,
        )

        params = {
            "method": "GET",
            "url": f"https://sts.{region}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
            "body": {},
            "headers": {"x-k8s-aws-id": cluster_id},
            "context": {},
        }

        signed_url = signer.generate_presigned_url(
            params,
            region_name=region,
            expires_in=EKS_KUBE_API_TOKEN_EXPIRATION,
            operation_name="",
        )

        base64_url = base64.urlsafe_b64encode(
            signed_url.encode("utf-8")
        ).decode("utf-8")

        # remove any base64 encoding padding:
        return "k8s-aws-v1." + re.sub(r"=*", "", base64_url)

    def _parse_s3_resource_id(self, resource_id: str) -> str:
        """Validate and convert an S3 resource ID to an S3 bucket name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The S3 bucket name.

        Raises:
            ValueError: If the provided resource ID is not a valid S3 bucket
                name, ARN or URI.
        """
        # The resource ID could mean different things:
        #
        # - an S3 bucket ARN
        # - an S3 bucket URI
        # - the S3 bucket name
        #
        # We need to extract the bucket name from the provided resource ID
        bucket_name: Optional[str] = None
        if re.match(
            r"^arn:aws:s3:::[a-z0-9-]+(/.*)*$",
            resource_id,
        ):
            # The resource ID is an S3 bucket ARN
            bucket_name = resource_id.split(":")[-1].split("/")[0]
        elif re.match(
            r"^s3://[a-z0-9-]+(/.*)*$",
            resource_id,
        ):
            # The resource ID is an S3 bucket URI
            bucket_name = resource_id.split("/")[2]
        elif re.match(
            r"^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$",
            resource_id,
        ):
            # The resource ID is the S3 bucket name
            bucket_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for an S3 bucket: {resource_id}. "
                f"Supported formats are:\n"
                f"S3 bucket ARN: arn:aws:s3:::<bucket-name>\n"
                f"S3 bucket URI: s3://<bucket-name>\n"
                f"S3 bucket name: <bucket-name>"
            )

        return bucket_name

    def _parse_ecr_resource_id(
        self,
        resource_id: str,
    ) -> str:
        """Validate and convert an ECR resource ID to an ECR registry ID.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The ECR registry ID (AWS account ID).

        Raises:
            ValueError: If the provided resource ID is not a valid ECR
                repository ARN or URI.
        """
        # The resource ID could mean different things:
        #
        # - an ECR repository ARN
        # - an ECR repository URI
        #
        # We need to extract the region ID and registry ID from
        # the provided resource ID
        config_region_id = self.config.region
        region_id: Optional[str] = None
        if re.match(
            r"^arn:aws:ecr:[a-z0-9-]+:\d{12}:repository(/.+)*$",
            resource_id,
        ):
            # The resource ID is an ECR repository ARN
            registry_id = resource_id.split(":")[4]
            region_id = resource_id.split(":")[3]
        elif re.match(
            r"^(http[s]?://)?\d{12}\.dkr\.ecr\.[a-z0-9-]+\.amazonaws\.com(/.+)*$",
            resource_id,
        ):
            # The resource ID is an ECR repository URI
            registry_id = resource_id.split(".")[0].split("/")[-1]
            region_id = resource_id.split(".")[3]
        else:
            raise ValueError(
                f"Invalid resource ID for a ECR registry: {resource_id}. "
                f"Supported formats are:\n"
                f"ECR repository ARN: arn:aws:ecr:<region>:<account-id>:repository[/<repository-name>]\n"
                f"ECR repository URI: [https://]<account-id>.dkr.ecr.<region>.amazonaws.com[/<repository-name>]"
            )

        # If the connector is configured with a region and the resource ID
        # is an ECR repository ARN or URI that specifies a different region
        # we raise an error
        if region_id and region_id != config_region_id:
            raise ValueError(
                f"The AWS region for the {resource_id} ECR repository region "
                f"'{region_id}' does not match the region configured in "
                f"the connector: '{config_region_id}'."
            )

        return registry_id

    def _parse_eks_resource_id(self, resource_id: str) -> str:
        """Validate and convert an EKS resource ID to an AWS region and EKS cluster name.

        Args:
            resource_id: The resource ID to convert.

        Returns:
            The EKS cluster name.

        Raises:
            ValueError: If the provided resource ID is not a valid EKS cluster
                name or ARN.
        """
        # The resource ID could mean different things:
        #
        # - an EKS cluster ARN
        # - an EKS cluster ID
        #
        # We need to extract the cluster name and region ID from the
        # provided resource ID
        config_region_id = self.config.region
        cluster_name: Optional[str] = None
        region_id: Optional[str] = None
        if re.match(
            r"^arn:aws:eks:[a-z0-9-]+:\d{12}:cluster/.+$",
            resource_id,
        ):
            # The resource ID is an EKS cluster ARN
            cluster_name = resource_id.split("/")[-1]
            region_id = resource_id.split(":")[3]
        elif re.match(
            r"^[a-z0-9]+[a-z0-9_-]*$",
            resource_id,
        ):
            # Assume the resource ID is an EKS cluster name
            cluster_name = resource_id
        else:
            raise ValueError(
                f"Invalid resource ID for a EKS cluster: {resource_id}. "
                f"Supported formats are:\n"
                f"EKS cluster ARN: arn:aws:eks:<region>:<account-id>:cluster/<cluster-name>\n"
                f"ECR cluster name: <cluster-name>"
            )

        # If the connector is configured with a region and the resource ID
        # is an EKS registry ARN or URI that specifies a different region
        # we raise an error
        if region_id and region_id != config_region_id:
            raise ValueError(
                f"The AWS region for the {resource_id} EKS cluster "
                f"({region_id}) does not match the region configured in "
                f"the connector ({config_region_id})."
            )

        return cluster_name

    def _canonical_resource_id(
        self, resource_type: str, resource_id: str
    ) -> str:
        """Convert a resource ID to its canonical form.

        Args:
            resource_type: The resource type to canonicalize.
            resource_id: The resource ID to canonicalize.

        Returns:
            The canonical resource ID.
        """
        if resource_type == S3_RESOURCE_TYPE:
            bucket = self._parse_s3_resource_id(resource_id)
            return f"s3://{bucket}"
        elif resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            cluster_name = self._parse_eks_resource_id(resource_id)
            return cluster_name
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            registry_id = self._parse_ecr_resource_id(
                resource_id,
            )
            return f"{registry_id}.dkr.ecr.{self.config.region}.amazonaws.com"
        else:
            return resource_id

    def _get_default_resource_id(self, resource_type: str) -> str:
        """Get the default resource ID for a resource type.

        Args:
            resource_type: The type of the resource to get a default resource ID
                for. Only called with resource types that do not support
                multiple instances.

        Returns:
            The default resource ID for the resource type.

        Raises:
            RuntimeError: If the ECR registry ID (AWS account ID)
                cannot be retrieved from AWS because the connector is not
                authorized.
        """
        if resource_type == AWS_RESOURCE_TYPE:
            return self.config.region
        elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            # we need to get the account ID (same as registry ID) from the
            # caller identity AWS service
            account_id = self.account_id

            return f"{account_id}.dkr.ecr.{self.config.region}.amazonaws.com"

        raise RuntimeError(
            f"Default resource ID for '{resource_type}' not available."
        )

    def _connect_to_resource(
        self,
        **kwargs: Any,
    ) -> Any:
        """Authenticate and connect to an AWS resource.

        Initialize and return a session or client object depending on the
        connector configuration:

        - initialize and return a boto3 session if the resource type
        is a generic AWS resource
        - initialize and return a boto3 client for an S3 resource type

        For the Docker and Kubernetes resource types, the connector does not
        support connecting to the resource directly. Instead, the connector
        supports generating a connector client object for the resource type
        in question.

        Args:
            kwargs: Additional implementation specific keyword arguments to pass
                to the session or client constructor.

        Returns:
            A boto3 session for AWS generic resources and a boto3 S3 client for
            S3 resources.

        Raises:
            NotImplementedError: If the connector instance does not support
                directly connecting to the indicated resource type.
        """
        resource_type = self.resource_type
        resource_id = self.resource_id

        assert resource_type is not None
        assert resource_id is not None

        # Regardless of the resource type, we must authenticate to AWS first
        # before we can connect to any AWS resource
        session, _ = self.get_boto3_session(
            self.auth_method,
            resource_type=resource_type,
            resource_id=resource_id,
        )

        if resource_type == S3_RESOURCE_TYPE:
            # Validate that the resource ID is a valid S3 bucket name
            self._parse_s3_resource_id(resource_id)

            # Create an S3 client for the bucket
            client = session.client(
                "s3",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            # There is no way to retrieve the credentials from the S3 client
            # but some consumers need them to configure 3rd party services.
            # We therefore store the credentials in the client object so that
            # they can be retrieved later.
            client.credentials = session.get_credentials()
            return client

        if resource_type == AWS_RESOURCE_TYPE:
            return session

        raise NotImplementedError(
            f"Connecting to {resource_type} resources is not directly "
            "supported by the AWS connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    def _configure_local_client(
        self,
        **kwargs: Any,
    ) -> None:
        """Configure a local client to authenticate and connect to a resource.

        This method uses the connector's configuration to configure a local
        client or SDK installed on the localhost for the indicated resource.

        Args:
            kwargs: Additional implementation specific keyword arguments to use
                to configure the client.

        Raises:
            NotImplementedError: If the connector instance does not support
                local configuration for the configured resource type or
                authentication method.registry
        """
        resource_type = self.resource_type

        if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
            raise NotImplementedError(
                f"Local client configuration for resource type "
                f"{resource_type} is not supported"
            )

        raise NotImplementedError(
            f"Configuring the local client for {resource_type} resources is "
            "not directly supported by the AWS connector. Please call the "
            f"`get_connector_client` method to get a {resource_type} connector "
            "instance for the resource."
        )

    @classmethod
    def _auto_configure(
        cls,
        auth_method: Optional[str] = None,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
        region_name: Optional[str] = None,
        profile_name: Optional[str] = None,
        role_arn: Optional[str] = None,
        **kwargs: Any,
    ) -> "AWSServiceConnector":
        """Auto-configure the connector.

        Instantiate an AWS connector with a configuration extracted from the
        authentication configuration available in the environment (e.g.
        environment variables or local AWS client/SDK configuration files).

        Args:
            auth_method: The particular authentication method to use. If not
                specified, the connector implementation must decide which
                authentication method to use or raise an exception.
            resource_type: The type of resource to configure.
            resource_id: The ID of the resource to configure. The
                implementation may choose to either require or ignore this
                parameter if it does not support or detect an resource type that
                supports multiple instances.
            region_name: The name of the AWS region to use. If not specified,
                the implicit region is used.
            profile_name: The name of the AWS profile to use. If not specified,
                the implicit profile is used.
            role_arn: The ARN of the AWS role to assume. Applicable only if the
                IAM role authentication method is specified or long-term
                credentials are discovered.
            kwargs: Additional implementation specific keyword arguments to use.

        Returns:
            An AWS connector instance configured with authentication credentials
            automatically extracted from the environment.

        Raises:
            NotImplementedError: If the connector implementation does not
                support auto-configuration for the specified authentication
                method.
            ValueError: If the supplied arguments are not valid.
            AuthorizationException: If no AWS credentials can be loaded from
                the environment.
        """
        auth_config: AWSBaseConfig
        expiration_seconds: Optional[int] = None
        expires_at: Optional[datetime.datetime] = None
        if auth_method == AWSAuthenticationMethods.IMPLICIT:
            if region_name is None:
                raise ValueError(
                    "The AWS region name must be specified when using the "
                    "implicit authentication method"
                )
            auth_config = AWSImplicitConfig(
                profile_name=profile_name,
                region=region_name,
            )
        else:
            # Initialize an AWS session with the default configuration loaded
            # from the environment.
            session = boto3.Session(
                profile_name=profile_name, region_name=region_name
            )

            region_name = region_name or session.region_name
            if not region_name:
                raise ValueError(
                    "The AWS region name was not specified and could not "
                    "be determined from the AWS session"
                )
            endpoint_url = session._session.get_config_variable("endpoint_url")

            # Extract the AWS credentials from the session and store them in
            # the connector secrets.
            credentials = session.get_credentials()
            if not credentials:
                raise AuthorizationException(
                    "Could not determine the AWS credentials from the "
                    "environment"
                )
            if credentials.token:
                # The session picked up temporary STS credentials
                if auth_method and auth_method not in [
                    None,
                    AWSAuthenticationMethods.STS_TOKEN,
                    AWSAuthenticationMethods.IAM_ROLE,
                ]:
                    raise NotImplementedError(
                        f"The specified authentication method '{auth_method}' "
                        "could not be used to auto-configure the connector. "
                    )

                if (
                    credentials.method == "assume-role"
                    and auth_method != AWSAuthenticationMethods.STS_TOKEN
                ):
                    # In the special case of IAM role authentication, the
                    # credentials in the boto3 session are the temporary STS
                    # credentials instead of the long-lived credentials, and the
                    # role ARN is not known. We have to dig deeper into the
                    # botocore session internals to retrieve the role ARN and
                    # the original long-lived credentials.

                    botocore_session = session._session
                    profile_config = botocore_session.get_scoped_config()
                    source_profile = profile_config.get("source_profile")
                    role_arn = profile_config.get("role_arn")
                    profile_map = botocore_session._build_profile_map()
                    if not (
                        role_arn
                        and source_profile
                        and source_profile in profile_map
                    ):
                        raise AuthorizationException(
                            "Could not determine the IAM role ARN and source "
                            "profile credentials from the environment"
                        )

                    auth_method = AWSAuthenticationMethods.IAM_ROLE
                    source_profile_config = profile_map[source_profile]
                    auth_config = IAMRoleAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=source_profile_config.get(
                            "aws_access_key_id"
                        ),
                        aws_secret_access_key=source_profile_config.get(
                            "aws_secret_access_key",
                        ),
                        role_arn=role_arn,
                    )
                    expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION

                else:
                    if auth_method == AWSAuthenticationMethods.IAM_ROLE:
                        raise NotImplementedError(
                            f"The specified authentication method "
                            f"'{auth_method}' could not be used to "
                            "auto-configure the connector."
                        )

                    # Temporary credentials were picked up from the local
                    # configuration. It's not possible to determine the
                    # expiration time of the temporary credentials from the
                    # boto3 session, so we assume the default IAM role
                    # expiration period is used
                    expires_at = datetime.datetime.now(
                        tz=datetime.timezone.utc
                    ) + datetime.timedelta(
                        seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
                    )

                    auth_method = AWSAuthenticationMethods.STS_TOKEN
                    auth_config = STSTokenConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        aws_session_token=credentials.token,
                    )
            else:
                # The session picked up long-lived credentials
                if not auth_method:
                    if role_arn:
                        auth_method = AWSAuthenticationMethods.IAM_ROLE
                    else:
                        # If no authentication method was specified, use the
                        # session token as a default recommended authentication
                        # method to be used with long-lived credentials.
                        auth_method = AWSAuthenticationMethods.SESSION_TOKEN

                region_name = region_name or session.region_name
                if not region_name:
                    raise ValueError(
                        "The AWS region name was not specified and could not "
                        "be determined from the AWS session"
                    )

                if auth_method == AWSAuthenticationMethods.STS_TOKEN:
                    # Generate a session token from the long-lived credentials
                    # and store it in the connector secrets.
                    sts_client = session.client("sts")
                    response = sts_client.get_session_token(
                        DurationSeconds=DEFAULT_STS_TOKEN_EXPIRATION
                    )
                    credentials = response["Credentials"]
                    auth_config = STSTokenConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials["AccessKeyId"],
                        aws_secret_access_key=credentials["SecretAccessKey"],
                        aws_session_token=credentials["SessionToken"],
                    )
                    expires_at = datetime.datetime.now(
                        tz=datetime.timezone.utc
                    ) + datetime.timedelta(
                        seconds=DEFAULT_STS_TOKEN_EXPIRATION
                    )

                elif auth_method == AWSAuthenticationMethods.IAM_ROLE:
                    if not role_arn:
                        raise ValueError(
                            "The ARN of the AWS role to assume must be "
                            "specified when using the IAM role authentication "
                            "method."
                        )
                    auth_config = IAMRoleAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        role_arn=role_arn,
                    )
                    expiration_seconds = DEFAULT_IAM_ROLE_TOKEN_EXPIRATION

                elif auth_method == AWSAuthenticationMethods.SECRET_KEY:
                    auth_config = AWSSecretKeyConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )

                elif auth_method == AWSAuthenticationMethods.SESSION_TOKEN:
                    auth_config = SessionTokenAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                    expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION

                else:  # auth_method is AWSAuthenticationMethods.FEDERATION_TOKEN
                    auth_config = FederationTokenAuthenticationConfig(
                        region=region_name,
                        endpoint_url=endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                    expiration_seconds = DEFAULT_STS_TOKEN_EXPIRATION

        return cls(
            auth_method=auth_method,
            resource_type=resource_type,
            resource_id=resource_id
            if resource_type not in [AWS_RESOURCE_TYPE, None]
            else None,
            expiration_seconds=expiration_seconds,
            expires_at=expires_at,
            config=auth_config,
        )

    def _verify(
        self,
        resource_type: Optional[str] = None,
        resource_id: Optional[str] = None,
    ) -> List[str]:
        """Verify and list all the resources that the connector can access.

        Args:
            resource_type: The type of the resource to verify. If omitted and
                if the connector supports multiple resource types, the
                implementation must verify that it can authenticate and connect
                to any and all of the supported resource types.
            resource_id: The ID of the resource to connect to. Omitted if a
                resource type is not specified. It has the same value as the
                default resource ID if the supplied resource type doesn't
                support multiple instances. If the supplied resource type does
                allows multiple instances, this parameter may still be omitted
                to fetch a list of resource IDs identifying all the resources
                of the indicated type that the connector can access.

        Returns:
            The list of resources IDs in canonical format identifying the
            resources that the connector can access. This list is empty only
            if the resource type is not specified (i.e. for multi-type
            connectors).

        Raises:
            AuthorizationException: If the connector cannot authenticate or
                access the specified resource.
        """
        # If the resource type is not specified, treat this the
        # same as a generic AWS connector.
        session, _ = self.get_boto3_session(
            self.auth_method,
            resource_type=resource_type or AWS_RESOURCE_TYPE,
            resource_id=resource_id,
        )

        # Verify that the AWS account is accessible
        assert isinstance(session, boto3.Session)
        sts_client = session.client("sts")
        try:
            sts_client.get_caller_identity()
        except (ClientError, BotoCoreError) as err:
            msg = f"failed to verify AWS account access: {err}"
            logger.debug(msg)
            raise AuthorizationException(msg) from err

        if not resource_type:
            return []

        if resource_type == AWS_RESOURCE_TYPE:
            assert resource_id is not None
            return [resource_id]

        if resource_type == S3_RESOURCE_TYPE:
            s3_client = session.client(
                "s3",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            if not resource_id:
                # List all S3 buckets
                try:
                    response = s3_client.list_buckets()
                except (ClientError, BotoCoreError) as e:
                    msg = f"failed to list S3 buckets: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return [
                    f"s3://{bucket['Name']}" for bucket in response["Buckets"]
                ]
            else:
                # Check if the specified S3 bucket exists
                bucket_name = self._parse_s3_resource_id(resource_id)
                try:
                    s3_client.head_bucket(Bucket=bucket_name)
                    return [resource_id]
                except (ClientError, BotoCoreError) as e:
                    msg = f"failed to fetch S3 bucket {bucket_name}: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            assert resource_id is not None

            ecr_client = session.client(
                "ecr",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            # List all ECR repositories
            try:
                repositories = ecr_client.describe_repositories()
            except (ClientError, BotoCoreError) as e:
                msg = f"failed to list ECR repositories: {e}"
                logger.error(msg)
                raise AuthorizationException(msg) from e

            if len(repositories["repositories"]) == 0:
                raise AuthorizationException(
                    "the AWS connector does not have access to any ECR "
                    "repositories. Please adjust the AWS permissions "
                    "associated with the authentication credentials to "
                    "include access to at least one ECR repository."
                )

            return [resource_id]

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            eks_client = session.client(
                "eks",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )
            if not resource_id:
                # List all EKS clusters
                try:
                    clusters = eks_client.list_clusters()
                except (ClientError, BotoCoreError) as e:
                    msg = f"Failed to list EKS clusters: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return cast(List[str], clusters["clusters"])
            else:
                # Check if the specified EKS cluster exists
                cluster_name = self._parse_eks_resource_id(resource_id)
                try:
                    clusters = eks_client.describe_cluster(name=cluster_name)
                except (ClientError, BotoCoreError) as e:
                    msg = f"Failed to fetch EKS cluster {cluster_name}: {e}"
                    logger.error(msg)
                    raise AuthorizationException(msg) from e

                return [resource_id]

        return []

    def _get_connector_client(
        self,
        resource_type: str,
        resource_id: str,
    ) -> "ServiceConnector":
        """Get a connector instance that can be used to connect to a resource.

        This method generates a client-side connector instance that can be used
        to connect to a resource of the given type. The client-side connector
        is configured with temporary AWS credentials extracted from the
        current connector and, depending on resource type, it may also be
        of a different connector type:

        - a Kubernetes connector for Kubernetes clusters
        - a Docker connector for Docker registries

        Args:
            resource_type: The type of the resources to connect to.
            resource_id: The ID of a particular resource to connect to.

        Returns:
            An AWS, Kubernetes or Docker connector instance that can be used to
            connect to the specified resource.

        Raises:
            AuthorizationException: If authentication failed.
            ValueError: If the resource type is not supported.
        """
        connector_name = ""
        if self.name:
            connector_name = self.name
        if resource_id:
            connector_name += f" ({resource_type} | {resource_id} client)"
        else:
            connector_name += f" ({resource_type} client)"

        logger.debug(f"Getting connector client for {connector_name}")

        if resource_type in [AWS_RESOURCE_TYPE, S3_RESOURCE_TYPE]:
            auth_method = self.auth_method
            if self.auth_method in [
                AWSAuthenticationMethods.SECRET_KEY,
                AWSAuthenticationMethods.STS_TOKEN,
            ]:
                if (
                    self.resource_type == resource_type
                    and self.resource_id == resource_id
                ):
                    # If the requested type and resource ID are the same as
                    # those configured, we can return the current connector
                    # instance because it's fully formed and ready to use
                    # to connect to the specified resource
                    return self

                # The secret key and STS token authentication methods do not
                # involve generating temporary credentials, so we can just
                # use the current connector configuration
                config = self.config
                expires_at = self.expires_at
            else:
                # Get an authenticated boto3 session
                session, expires_at = self.get_boto3_session(
                    self.auth_method,
                    resource_type=resource_type,
                    resource_id=resource_id,
                )
                assert isinstance(session, boto3.Session)
                credentials = session.get_credentials()

                if (
                    self.auth_method == AWSAuthenticationMethods.IMPLICIT
                    and credentials.token is None
                ):
                    # The implicit authentication method may involve picking up
                    # long-lived credentials from the environment
                    auth_method = AWSAuthenticationMethods.SECRET_KEY
                    config = AWSSecretKeyConfig(
                        region=self.config.region,
                        endpoint_url=self.config.endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                    )
                else:
                    assert credentials.token is not None

                    # Use the temporary credentials extracted from the boto3
                    # session
                    auth_method = AWSAuthenticationMethods.STS_TOKEN
                    config = STSTokenConfig(
                        region=self.config.region,
                        endpoint_url=self.config.endpoint_url,
                        aws_access_key_id=credentials.access_key,
                        aws_secret_access_key=credentials.secret_key,
                        aws_session_token=credentials.token,
                    )

            # Create a client-side AWS connector instance that is fully formed
            # and ready to use to connect to the specified resource (i.e. has
            # all the necessary configuration and credentials, a resource type
            # and a resource ID where applicable)
            return AWSServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
                config=config,
                expires_at=expires_at,
            )

        if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE:
            assert resource_id is not None

            # Get an authenticated boto3 session
            session, expires_at = self.get_boto3_session(
                self.auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
            )
            assert isinstance(session, boto3.Session)

            registry_id = self._parse_ecr_resource_id(resource_id)

            ecr_client = session.client(
                "ecr",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            assert isinstance(ecr_client, BaseClient)
            assert registry_id is not None

            try:
                auth_token = ecr_client.get_authorization_token(
                    registryIds=[
                        registry_id,
                    ]
                )
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get authorization token from ECR: {e}"
                ) from e

            token = auth_token["authorizationData"][0]["authorizationToken"]
            endpoint = auth_token["authorizationData"][0]["proxyEndpoint"]
            # The token is base64 encoded and has the format
            # "username:password"
            username, token = (
                base64.b64decode(token).decode("utf-8").split(":")
            )

            # Create a client-side Docker connector instance with the temporary
            # Docker credentials
            return DockerServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=DockerAuthenticationMethods.PASSWORD,
                resource_type=resource_type,
                config=DockerConfiguration(
                    username=username,
                    password=token,
                    registry=endpoint,
                ),
                expires_at=expires_at,
            )

        if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE:
            assert resource_id is not None

            # Get an authenticated boto3 session
            session, expires_at = self.get_boto3_session(
                self.auth_method,
                resource_type=resource_type,
                resource_id=resource_id,
            )
            assert isinstance(session, boto3.Session)

            cluster_name = self._parse_eks_resource_id(resource_id)

            # Get a boto3 EKS client
            eks_client = session.client(
                "eks",
                region_name=self.config.region,
                endpoint_url=self.config.endpoint_url,
            )

            assert isinstance(eks_client, BaseClient)

            try:
                cluster = eks_client.describe_cluster(name=cluster_name)
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get EKS cluster {cluster_name}: {e}"
                ) from e

            try:
                user_token = self._get_eks_bearer_token(
                    session=session,
                    cluster_id=cluster_name,
                    region=self.config.region,
                )
            except (ClientError, BotoCoreError) as e:
                raise AuthorizationException(
                    f"Failed to get EKS bearer token: {e}"
                ) from e

            # get cluster details
            cluster_arn = cluster["cluster"]["arn"]
            cluster_ca_cert = cluster["cluster"]["certificateAuthority"][
                "data"
            ]
            cluster_server = cluster["cluster"]["endpoint"]

            # Create a client-side Kubernetes connector instance with the
            # temporary Kubernetes credentials
            return KubernetesServiceConnector(
                id=self.id,
                name=connector_name,
                auth_method=KubernetesAuthenticationMethods.TOKEN,
                resource_type=resource_type,
                config=KubernetesTokenConfig(
                    cluster_name=cluster_arn,
                    certificate_authority=cluster_ca_cert,
                    server=cluster_server,
                    token=user_token,
                ),
                expires_at=expires_at,
            )

        raise ValueError(f"Unsupported resource type: {resource_type}")
account_id: str property readonly

Get the AWS account ID.

Returns:

Type Description
str

The AWS account ID.

Exceptions:

Type Description
AuthorizationException

If the AWS account ID could not be determined.

get_boto3_session(self, auth_method, resource_type=None, resource_id=None)

Get a boto3 session for the specified resource.

Parameters:

Name Type Description Default
auth_method str

The authentication method to use.

required
resource_type Optional[str]

The resource type to get a boto3 session for.

None
resource_id Optional[str]

The resource ID to get a boto3 session for.

None

Returns:

Type Description
Tuple[boto3.session.Session, Optional[datetime.datetime]]

A boto3 session for the specified resource and its expiration timestamp, if applicable.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
def get_boto3_session(
    self,
    auth_method: str,
    resource_type: Optional[str] = None,
    resource_id: Optional[str] = None,
) -> Tuple[boto3.Session, Optional[datetime.datetime]]:
    """Get a boto3 session for the specified resource.

    Args:
        auth_method: The authentication method to use.
        resource_type: The resource type to get a boto3 session for.
        resource_id: The resource ID to get a boto3 session for.

    Returns:
        A boto3 session for the specified resource and its expiration
        timestamp, if applicable.
    """
    # We maintain a cache of all sessions to avoid re-authenticating
    # multiple times for the same resource
    key = (auth_method, resource_type, resource_id)
    if key in self._session_cache:
        session, expires_at = self._session_cache[key]
        if expires_at is None:
            return session, None

        # Refresh expired sessions
        now = datetime.datetime.now(datetime.timezone.utc)
        expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
        if expires_at > now:
            return session, expires_at

    logger.debug(
        f"Creating boto3 session for auth method '{auth_method}', "
        f"resource type '{resource_type}' and resource ID "
        f"'{resource_id}'..."
    )
    session, expires_at = self._authenticate(
        auth_method, resource_type, resource_id
    )
    self._session_cache[key] = (session, expires_at)
    return session, expires_at
AWSSessionPolicy (AuthenticationConfig) pydantic-model

AWS session IAM policy configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class AWSSessionPolicy(AuthenticationConfig):
    """AWS session IAM policy configuration."""

    policy_arns: Optional[List[str]] = Field(
        default=None,
        title="ARNs of the IAM managed policies that you want to use as a "
        "managed session policy. The policies must exist in the same account "
        "as the IAM user that is requesting temporary credentials.",
    )
    policy: Optional[str] = Field(
        default=None,
        title="An IAM policy in JSON format that you want to use as an inline "
        "session policy",
    )
FederationTokenAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy) pydantic-model

AWS federation token authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class FederationTokenAuthenticationConfig(
    AWSSecretKeyConfig, AWSSessionPolicy
):
    """AWS federation token authentication config."""
IAMRoleAuthenticationConfig (AWSSecretKeyConfig, AWSSessionPolicy) pydantic-model

AWS IAM authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class IAMRoleAuthenticationConfig(AWSSecretKeyConfig, AWSSessionPolicy):
    """AWS IAM authentication config."""

    role_arn: str = Field(
        title="AWS IAM Role ARN",
    )
STSToken (AWSSecretKey) pydantic-model

AWS STS token.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSToken(AWSSecretKey):
    """AWS STS token."""

    aws_session_token: SecretStr = Field(
        title="AWS Session Token",
    )
STSTokenConfig (AWSBaseConfig, STSToken) pydantic-model

AWS STS token authentication configuration.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class STSTokenConfig(AWSBaseConfig, STSToken):
    """AWS STS token authentication configuration."""

    expires_at: Optional[datetime.datetime] = Field(
        default=None,
        title="AWS STS Token Expiration",
    )
SessionTokenAuthenticationConfig (AWSSecretKeyConfig) pydantic-model

AWS session token authentication config.

Source code in zenml/integrations/aws/service_connectors/aws_service_connector.py
class SessionTokenAuthenticationConfig(AWSSecretKeyConfig):
    """AWS session token authentication config."""

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator)

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.
    """

    @property
    def config(self) -> SagemakerStepOperatorConfig:
        """Returns the `SagemakerStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the SageMaker step operator.

        Returns:
            The settings class.
        """
        return SagemakerStepOperatorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the SageMaker "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "SageMaker step operator."
                )

            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.CONTAINER_REGISTRY,
                StackComponentType.IMAGE_BUILDER,
            },
            custom_validation_function=_validate_remote_components,
        )

    def get_docker_builds(
        self, deployment: "PipelineDeploymentBaseModel"
    ) -> List["BuildConfiguration"]:
        """Gets the Docker builds required for the component.

        Args:
            deployment: The pipeline deployment for which to get the builds.

        Returns:
            The required Docker builds.
        """
        builds = []
        for step_name, step in deployment.step_configurations.items():
            if step.config.step_operator == self.name:
                build = BuildConfiguration(
                    key=SAGEMAKER_DOCKER_IMAGE_KEY,
                    settings=step.config.docker_settings,
                    step_name=step_name,
                    entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
                )
                builds.append(build)

        return builds

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
        environment: Dict[str, str],
    ) -> None:
        """Launches a step on SageMaker.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
            environment: Environment variables to set in the step operator
                environment.
        """
        if not info.config.resource_settings.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
        environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

        settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

        # Get and default fill SageMaker estimator arguments for full ZenML support
        estimator_args = settings.estimator_args
        session = sagemaker.Session(default_bucket=self.config.bucket)

        estimator_args.setdefault(
            "instance_type", settings.instance_type or "ml.m5.large"
        )

        estimator_args["environment"] = environment
        estimator_args["instance_count"] = 1
        estimator_args["sagemaker_session"] = session

        # Create Estimator
        estimator = sagemaker.estimator.Estimator(
            image_name, self.config.role, **estimator_args
        )

        # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
        step_name = Client().get_run_step(info.step_run_id).name
        training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
        suffix = random_str(4)
        unique_training_job_name = f"{training_job_name}-{suffix}"

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_training_job_name = unique_training_job_name.replace(
            "_", "-"
        )

        # Construct training input object, if necessary
        inputs = None

        if isinstance(settings.input_data_s3_uri, str):
            inputs = sagemaker.inputs.TrainingInput(
                s3_data=settings.input_data_s3_uri
            )
        elif isinstance(settings.input_data_s3_uri, dict):
            inputs = {}
            for channel, s3_uri in settings.input_data_s3_uri.items():
                inputs[channel] = sagemaker.inputs.TrainingInput(
                    s3_data=s3_uri
                )

        experiment_config = {}
        if settings.experiment_name:
            experiment_config = {
                "ExperimentName": settings.experiment_name,
                "TrialName": sanitized_training_job_name,
            }

        estimator.fit(
            wait=True,
            inputs=inputs,
            experiment_config=experiment_config,
            job_name=sanitized_training_job_name,
        )
config: SagemakerStepOperatorConfig property readonly

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a remote container registry and a remote artifact store.

get_docker_builds(self, deployment)

Gets the Docker builds required for the component.

Parameters:

Name Type Description Default
deployment PipelineDeploymentBaseModel

The pipeline deployment for which to get the builds.

required

Returns:

Type Description
List[BuildConfiguration]

The required Docker builds.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def get_docker_builds(
    self, deployment: "PipelineDeploymentBaseModel"
) -> List["BuildConfiguration"]:
    """Gets the Docker builds required for the component.

    Args:
        deployment: The pipeline deployment for which to get the builds.

    Returns:
        The required Docker builds.
    """
    builds = []
    for step_name, step in deployment.step_configurations.items():
        if step.config.step_operator == self.name:
            build = BuildConfiguration(
                key=SAGEMAKER_DOCKER_IMAGE_KEY,
                settings=step.config.docker_settings,
                step_name=step_name,
                entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
            )
            builds.append(build)

    return builds
launch(self, info, entrypoint_command, environment)

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
environment Dict[str, str]

Environment variables to set in the step operator environment.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
    environment: Dict[str, str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
        environment: Environment variables to set in the step operator
            environment.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
    environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    # Get and default fill SageMaker estimator arguments for full ZenML support
    estimator_args = settings.estimator_args
    session = sagemaker.Session(default_bucket=self.config.bucket)

    estimator_args.setdefault(
        "instance_type", settings.instance_type or "ml.m5.large"
    )

    estimator_args["environment"] = environment
    estimator_args["instance_count"] = 1
    estimator_args["sagemaker_session"] = session

    # Create Estimator
    estimator = sagemaker.estimator.Estimator(
        image_name, self.config.role, **estimator_args
    )

    # SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
    step_name = Client().get_run_step(info.step_run_id).name
    training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
    suffix = random_str(4)
    unique_training_job_name = f"{training_job_name}-{suffix}"

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_training_job_name = unique_training_job_name.replace(
        "_", "-"
    )

    # Construct training input object, if necessary
    inputs = None

    if isinstance(settings.input_data_s3_uri, str):
        inputs = sagemaker.inputs.TrainingInput(
            s3_data=settings.input_data_s3_uri
        )
    elif isinstance(settings.input_data_s3_uri, dict):
        inputs = {}
        for channel, s3_uri in settings.input_data_s3_uri.items():
            inputs[channel] = sagemaker.inputs.TrainingInput(
                s3_data=s3_uri
            )

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_training_job_name,
        }

    estimator.fit(
        wait=True,
        inputs=inputs,
        experiment_config=experiment_config,
        job_name=sanitized_training_job_name,
    )