Skip to content

Aws

zenml.integrations.aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.aws.flavors import (
            AWSContainerRegistryFlavor,
            AWSSecretsManagerFlavor,
            SagemakerStepOperatorFlavor,
        )

        return [
            AWSSecretsManagerFlavor,
            AWSContainerRegistryFlavor,
            SagemakerStepOperatorFlavor,
        ]

flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.aws.flavors import (
        AWSContainerRegistryFlavor,
        AWSSecretsManagerFlavor,
        SagemakerStepOperatorFlavor,
    )

    return [
        AWSSecretsManagerFlavor,
        AWSContainerRegistryFlavor,
        SagemakerStepOperatorFlavor,
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry)

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    @property
    def config(self) -> AWSContainerRegistryConfig:
        """Returns the `AWSContainerRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSContainerRegistryConfig, self._config)

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(
            r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.config.uri
        )
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.config.uri}."
            )

        return match.group(1)

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client(
            "ecr", region_name=self._get_region()
        ).describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
        if not repo_exists:
            match = re.search(f"{self.config.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.config.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
config: AWSContainerRegistryConfig property readonly

Returns the AWSContainerRegistryConfig config.

Returns:

Type Description
AWSContainerRegistryConfig

The configuration.

post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client(
        "ecr", region_name=self._get_region()
    ).describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
    if not repo_exists:
        match = re.search(f"{self.config.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )

flavors special

AWS integration flavors.

aws_container_registry_flavor

AWS container registry flavor.

AWSContainerRegistryConfig (BaseContainerRegistryConfig) pydantic-model

Configuration for AWS Container Registry.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryConfig(BaseContainerRegistryConfig):
    """Configuration for AWS Container Registry."""

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri
AWSContainerRegistryFlavor (BaseContainerRegistryFlavor)

AWS Container Registry flavor.

Source code in zenml/integrations/aws/flavors/aws_container_registry_flavor.py
class AWSContainerRegistryFlavor(BaseContainerRegistryFlavor):
    """AWS Container Registry flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_CONTAINER_REGISTRY_FLAVOR

    @property
    def config_class(self) -> Type[AWSContainerRegistryConfig]:
        """Config class for this flavor.

        Returns:
            The config class.
        """
        return AWSContainerRegistryConfig

    @property
    def implementation_class(self) -> Type["AWSContainerRegistry"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.container_registries import (
            AWSContainerRegistry,
        )

        return AWSContainerRegistry
config_class: Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_container_registry_flavor.AWSContainerRegistryConfig]

The config class.

implementation_class: Type[AWSContainerRegistry] property readonly

Implementation class.

Returns:

Type Description
Type[AWSContainerRegistry]

The implementation class.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

aws_secrets_manager_flavor

AWS secrets manager flavor.

AWSSecretsManagerConfig (BaseSecretsManagerConfig) pydantic-model

Configuration for the AWS Secrets Manager.

Attributes:

Name Type Description
region_name str

The region name of the AWS Secrets Manager.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerConfig(BaseSecretsManagerConfig):
    """Configuration for the AWS Secrets Manager.

    Attributes:
        region_name: The region name of the AWS Secrets Manager.
    """

    SUPPORTS_SCOPING: ClassVar[bool] = True

    region_name: str

    @classmethod
    def _validate_scope(
        cls,
        scope: "SecretsManagerScope",
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            validate_aws_secret_name_or_namespace(namespace)
AWSSecretsManagerFlavor (BaseSecretsManagerFlavor)

Class for the AWSSecretsManagerFlavor.

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
class AWSSecretsManagerFlavor(BaseSecretsManagerFlavor):
    """Class for the `AWSSecretsManagerFlavor`."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            Name of the flavor.
        """
        return AWS_SECRET_MANAGER_FLAVOR

    @property
    def config_class(self) -> Type[AWSSecretsManagerConfig]:
        """Config class for this flavor.

        Returns:
            Config class for this flavor.
        """
        return AWSSecretsManagerConfig

    @property
    def implementation_class(self) -> Type["AWSSecretsManager"]:
        """Implementation class.

        Returns:
            Implementation class.
        """
        from zenml.integrations.aws.secrets_managers import AWSSecretsManager

        return AWSSecretsManager
config_class: Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig] property readonly

Config class for this flavor.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.aws_secrets_manager_flavor.AWSSecretsManagerConfig]

Config class for this flavor.

implementation_class: Type[AWSSecretsManager] property readonly

Implementation class.

Returns:

Type Description
Type[AWSSecretsManager]

Implementation class.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

Name of the flavor.

validate_aws_secret_name_or_namespace(name)

Validate a secret name or namespace.

AWS secret names must contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/aws/flavors/aws_secrets_manager_flavor.py
def validate_aws_secret_name_or_namespace(name: str) -> None:
    """Validate a secret name or namespace.

    AWS secret names must contain only alphanumeric characters and the
    characters /_+=.@-. The `/` character is only used internally to delimit
    scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

sagemaker_step_operator_flavor

Amazon SageMaker step operator flavor.

SagemakerStepOperatorConfig (BaseStepOperatorConfig, SagemakerStepOperatorSettings) pydantic-model

Config for the Sagemaker step operator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseStepOperatorConfig, SagemakerStepOperatorSettings
):
    """Config for the Sagemaker step operator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
    """

    role: str
    bucket: Optional[str] = None

    @property
    def is_remote(self) -> bool:
        """Checks if this stack component is running remotely.

        This designation is used to determine if the stack component can be
        used with a local ZenML database or if it requires a remote ZenML
        server.

        Returns:
            True if this config is for a remote component, False otherwise.
        """
        return True
is_remote: bool property readonly

Checks if this stack component is running remotely.

This designation is used to determine if the stack component can be used with a local ZenML database or if it requires a remote ZenML server.

Returns:

Type Description
bool

True if this config is for a remote component, False otherwise.

SagemakerStepOperatorFlavor (BaseStepOperatorFlavor)

Flavor for the Sagemaker step operator.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorFlavor(BaseStepOperatorFlavor):
    """Flavor for the Sagemaker step operator."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def config_class(self) -> Type[SagemakerStepOperatorConfig]:
        """Returns SagemakerStepOperatorConfig config class.

        Returns:
            The config class.
        """
        return SagemakerStepOperatorConfig

    @property
    def implementation_class(self) -> Type["SagemakerStepOperator"]:
        """Implementation class.

        Returns:
            The implementation class.
        """
        from zenml.integrations.aws.step_operators import SagemakerStepOperator

        return SagemakerStepOperator
config_class: Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig] property readonly

Returns SagemakerStepOperatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.aws.flavors.sagemaker_step_operator_flavor.SagemakerStepOperatorConfig]

The config class.

implementation_class: Type[SagemakerStepOperator] property readonly

Implementation class.

Returns:

Type Description
Type[SagemakerStepOperator]

The implementation class.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

SagemakerStepOperatorSettings (BaseSettings) pydantic-model

Settings for the Sagemaker step operator.

Attributes:

Name Type Description
instance_type Optional[str]

The type of the compute instance where jobs will run. Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html for a list of available instance types.

experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

Source code in zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py
class SagemakerStepOperatorSettings(BaseSettings):
    """Settings for the Sagemaker step operator.

    Attributes:
        instance_type: The type of the compute instance where jobs will run.
            Check https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
            for a list of available instance types.
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
    """

    instance_type: Optional[str] = None
    experiment_name: Optional[str] = None

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager)

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    CLIENT: ClassVar[Any] = None

    @property
    def config(self) -> AWSSecretsManagerConfig:
        """Returns the `AWSSecretsManagerConfig` config.

        Returns:
            The configuration.
        """
        return cast(AWSSecretsManagerConfig, self._config)

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    def _get_secret_tags(
        self, secret: BaseSecretSchema
    ) -> List[Dict[str, str]]:
        """Return a list of AWS secret tag values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of AWS secret tag values
        """
        metadata = self._get_secret_metadata(secret)
        return [{"Key": k, "Value": v} for k, v in metadata.items()]

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """Return a list of AWS filters for the entire scope or just a scoped secret.

        These filters can be used when querying the AWS Secrets Manager
        for all secrets or for a single secret available in the configured
        scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html

        Example AWS filters for all secrets in the current (namespace) scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Example AWS filters for a particular secret in the current (namespace)
        scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_secret_name"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_secret"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of AWS filters uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        metadata = self._get_secret_scope_metadata(secret_name)
        filters: List[Dict[str, Any]] = []
        for k, v in metadata.items():
            filters.append(
                {
                    "Key": "tag-key",
                    "Values": [
                        k,
                    ],
                }
            )
            filters.append(
                {
                    "Key": "tag-value",
                    "Values": [
                        str(v),
                    ],
                }
            )

        return filters

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected(self.config.region_name)

        filters: List[Dict[str, Any]] = []
        prefix: Optional[str] = None
        if self.config.scope == SecretsManagerScope.NONE:
            # unscoped (legacy) secrets don't have tags. We want to filter out
            # non-legacy secrets
            filters = [
                {
                    "Key": "tag-key",
                    "Values": [
                        "!zenml_scope",
                    ],
                },
            ]
            if secret_name:
                prefix = secret_name
        else:
            filters = self._get_secret_scope_filters()
            if secret_name:
                prefix = self._get_scoped_secret_name(secret_name)
            else:
                # add the name prefix to the filters to account for the fact
                # that AWS does not do exact matching but prefix-matching on the
                # filters
                prefix = self._get_scoped_secret_name_prefix()

        if prefix:
            filters.append(
                {
                    "Key": "name",
                    "Values": [
                        f"{prefix}",
                    ],
                }
            )

        paginator = self.CLIENT.get_paginator(_BOTO_CLIENT_LIST_SECRETS)
        pages = paginator.paginate(
            Filters=filters,
            PaginationConfig={
                "PageSize": 100,
            },
        )
        results = []
        for page in pages:
            for secret in page[_PAGINATOR_RESPONSE_SECRETS_LIST_KEY]:
                name = self._get_unscoped_secret_name(secret["Name"])
                # keep only the names that are in scope and filter by secret name,
                # if one was given
                if name and (not secret_name or secret_name == name):
                    results.append(name)

        return results

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        secret_value = json.dumps(secret_to_dict(secret, encode=False))
        kwargs: Dict[str, Any] = {
            "Name": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
            "Tags": self._get_secret_tags(secret),
        }

        self.CLIENT.create_secret(**kwargs)

        logger.debug("Created AWS secret: %s", kwargs["Name"])

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret_name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=self._get_scoped_secret_name(secret_name)
        )
        if "SecretString" not in get_secret_value_response:
            get_secret_value_response = None

        return secret_from_dict(
            json.loads(get_secret_value_response["SecretString"]),
            secret_name=secret_name,
            decode=False,
        )

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        validate_aws_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        secret_value = json.dumps(secret_to_dict(secret))

        kwargs = {
            "SecretId": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
        }

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret does not exist
        """
        self._ensure_client_connected(self.config.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.config.region_name)
        for secret_name in self._list_secrets():
            self.CLIENT.delete_secret(
                SecretId=self._get_scoped_secret_name(secret_name),
                ForceDeleteWithoutRecovery=True,
            )
config: AWSSecretsManagerConfig property readonly

Returns the AWSSecretsManagerConfig config.

Returns:

Type Description
AWSSecretsManagerConfig

The configuration.

delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.config.region_name)
    for secret_name in self._list_secrets():
        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret does not exist
    """
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    self.CLIENT.delete_secret(
        SecretId=self._get_scoped_secret_name(secret_name),
        ForceDeleteWithoutRecovery=True,
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret_name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=self._get_scoped_secret_name(secret_name)
    )
    if "SecretString" not in get_secret_value_response:
        get_secret_value_response = None

    return secret_from_dict(
        json.loads(get_secret_value_response["SecretString"]),
        secret_name=secret_name,
        decode=False,
    )
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    secret_value = json.dumps(secret_to_dict(secret, encode=False))
    kwargs: Dict[str, Any] = {
        "Name": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
        "Tags": self._get_secret_tags(secret),
    }

    self.CLIENT.create_secret(**kwargs)

    logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    validate_aws_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.config.region_name)

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    secret_value = json.dumps(secret_to_dict(secret))

    kwargs = {
        "SecretId": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
    }

    self.CLIENT.put_secret_value(**kwargs)

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator)

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.
    """

    @property
    def config(self) -> SagemakerStepOperatorConfig:
        """Returns the `SagemakerStepOperatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(SagemakerStepOperatorConfig, self._config)

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the SageMaker step operator.

        Returns:
            The settings class.
        """
        return SagemakerStepOperatorSettings

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates the stack.

        Returns:
            A validator that checks that the stack contains a remote container
            registry and a remote artifact store.
        """

        def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]:
            if stack.artifact_store.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to write files into the artifact store, but the "
                    f"artifact store `{stack.artifact_store.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote artifact store when using the SageMaker "
                    "step operator."
                )

            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.config.is_local:
                return False, (
                    "The SageMaker step operator runs code remotely and "
                    "needs to push/pull Docker images, but the "
                    f"container registry `{container_registry.name}` of the "
                    "active stack is local. Please ensure that your stack "
                    "contains a remote container registry when using the "
                    "SageMaker step operator."
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_remote_components,
        )

    def prepare_pipeline_deployment(
        self,
        deployment: "PipelineDeployment",
        stack: "Stack",
    ) -> None:
        """Build a Docker image and push it to the container registry.

        Args:
            deployment: The pipeline deployment configuration.
            stack: The stack on which the pipeline will be deployed.
        """
        steps_to_run = [
            step
            for step in deployment.steps.values()
            if step.config.step_operator == self.name
        ]
        if not steps_to_run:
            return

        docker_image_builder = PipelineDockerImageBuilder()
        image_digest = docker_image_builder.build_and_push_docker_image(
            deployment=deployment,
            stack=stack,
            entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
        )
        for step in steps_to_run:
            step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = image_digest

    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
    ) -> None:
        """Launches a step on SageMaker.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
        """
        if not info.config.resource_settings.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
        environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}

        settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

        session = sagemaker.Session(default_bucket=self.config.bucket)
        instance_type = settings.instance_type or "ml.m5.large"
        estimator = sagemaker.estimator.Estimator(
            image_name,
            self.config.role,
            environment=environment,
            instance_count=1,
            instance_type=instance_type,
            sagemaker_session=session,
        )

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_run_name = info.run_name.replace("_", "-")

        experiment_config = {}
        if settings.experiment_name:
            experiment_config = {
                "ExperimentName": settings.experiment_name,
                "TrialName": sanitized_run_name,
            }

        estimator.fit(
            wait=True,
            experiment_config=experiment_config,
            job_name=sanitized_run_name,
        )
config: SagemakerStepOperatorConfig property readonly

Returns the SagemakerStepOperatorConfig config.

Returns:

Type Description
SagemakerStepOperatorConfig

The configuration.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the SageMaker step operator.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates the stack.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a remote container registry and a remote artifact store.

launch(self, info, entrypoint_command)

Launches a step on SageMaker.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
) -> None:
    """Launches a step on SageMaker.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
    """
    if not info.config.resource_settings.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    image_name = info.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY]
    environment = {_ENTRYPOINT_ENV_VARIABLE: " ".join(entrypoint_command)}

    settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

    session = sagemaker.Session(default_bucket=self.config.bucket)
    instance_type = settings.instance_type or "ml.m5.large"
    estimator = sagemaker.estimator.Estimator(
        image_name,
        self.config.role,
        environment=environment,
        instance_count=1,
        instance_type=instance_type,
        sagemaker_session=session,
    )

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_run_name = info.run_name.replace("_", "-")

    experiment_config = {}
    if settings.experiment_name:
        experiment_config = {
            "ExperimentName": settings.experiment_name,
            "TrialName": sanitized_run_name,
        }

    estimator.fit(
        wait=True,
        experiment_config=experiment_config,
        job_name=sanitized_run_name,
    )
prepare_pipeline_deployment(self, deployment, stack)

Build a Docker image and push it to the container registry.

Parameters:

Name Type Description Default
deployment PipelineDeployment

The pipeline deployment configuration.

required
stack Stack

The stack on which the pipeline will be deployed.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def prepare_pipeline_deployment(
    self,
    deployment: "PipelineDeployment",
    stack: "Stack",
) -> None:
    """Build a Docker image and push it to the container registry.

    Args:
        deployment: The pipeline deployment configuration.
        stack: The stack on which the pipeline will be deployed.
    """
    steps_to_run = [
        step
        for step in deployment.steps.values()
        if step.config.step_operator == self.name
    ]
    if not steps_to_run:
        return

    docker_image_builder = PipelineDockerImageBuilder()
    image_digest = docker_image_builder.build_and_push_docker_image(
        deployment=deployment,
        stack=stack,
        entrypoint=f"${_ENTRYPOINT_ENV_VARIABLE}",
    )
    for step in steps_to_run:
        step.config.extra[SAGEMAKER_DOCKER_IMAGE_KEY] = image_digest