Skip to content

Integrations

zenml.integrations special

ZenML integrations module.

The ZenML integrations module contains sub-modules for each integration that we support. This includes orchestrators like Apache Airflow, visualization tools like the facets library, as well as deep learning libraries like PyTorch.

airflow special

Airflow integration for ZenML.

The Airflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the zenml orchestrator up command.

AirflowIntegration (Integration)

Definition of Airflow Integration for ZenML.

Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
    """Definition of Airflow Integration for ZenML."""

    NAME = AIRFLOW
    REQUIREMENTS = ["apache-airflow==2.2.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Airflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AIRFLOW_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Airflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Airflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AIRFLOW_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        )
    ]

orchestrators special

The Airflow integration enables the use of Airflow as a pipeline orchestrator.

airflow_orchestrator

Implementation of Airflow orchestrator integration.

AirflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Airflow."""

    airflow_home: str = ""

    # Class Configuration
    FLAVOR: ClassVar[str] = AIRFLOW_ORCHESTRATOR_FLAVOR

    def __init__(self, **values: Any):
        """Sets environment variables to configure airflow.

        Args:
            **values: Values to set in the orchestrator.
        """
        super().__init__(**values)
        self._set_env()

    @staticmethod
    def _translate_schedule(
        schedule: Optional[Schedule] = None,
    ) -> Dict[str, Any]:
        """Convert ZenML schedule into Airflow schedule.

        The Airflow schedule uses slightly different naming and needs some
        default entries for execution without a schedule.

        Args:
            schedule: Containing the interval, start and end date and
                a boolean flag that defines if past runs should be caught up
                on

        Returns:
            Airflow configuration dict.
        """
        if schedule:
            if schedule.cron_expression:
                return {
                    "schedule_interval": schedule.cron_expression,
                }
            else:
                return {
                    "schedule_interval": schedule.interval_second,
                    "start_date": schedule.start_time,
                    "end_date": schedule.end_time,
                    "catchup": schedule.catchup,
                }

        return {
            "schedule_interval": "@once",
            # set the a start time in the past and disable catchup so airflow runs the dag immediately
            "start_date": datetime.datetime.now() - datetime.timedelta(7),
            "catchup": False,
        }

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List[BaseStep],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates an Airflow DAG as the intermediate representation for the pipeline.

        This DAG will be loaded by airflow in the target environment
        and used for orchestration of the pipeline.

        How it works:
        -------------
        A new airflow_dag is instantiated with the pipeline name and among
        others things the run schedule.

        For each step of the pipeline a callable is created. This callable
        uses the run_step() method to execute the step. The parameters of
        this callable are pre-filled and an airflow step_operator is created
        within the dag. The dependencies to upstream steps are then
        configured.

        Finally, the dag is fully complete and can be returned.

        Args:
            sorted_steps: List of steps in the pipeline.
            pipeline: The pipeline to be executed.
            pb2_pipeline: The pipeline as a protobuf message.
            stack: The stack on which the pipeline will be deployed.
            runtime_configuration: The runtime configuration.

        Returns:
            The Airflow DAG.
        """
        import airflow
        from airflow.operators import python as airflow_python

        # Instantiate and configure airflow Dag with name and schedule
        airflow_dag = airflow.DAG(
            dag_id=pipeline.name,
            is_paused_upon_creation=False,
            **self._translate_schedule(runtime_configuration.schedule),
        )

        # Dictionary mapping step names to airflow_operators. This will be needed
        # to configure airflow operator dependencies
        step_name_to_airflow_operator = {}

        for step in sorted_steps:
            # Create callable that will be used by airflow to execute the step
            # within the orchestrated environment
            def _step_callable(step_instance: "BaseStep", **kwargs):
                # Extract run name for the kwargs that will be passed to the
                # callable
                run_name = kwargs["ti"].get_dagrun().run_id
                self.run_step(
                    step=step_instance,
                    run_name=run_name,
                    pb2_pipeline=pb2_pipeline,
                )

            # Create airflow python operator that contains the step callable
            airflow_operator = airflow_python.PythonOperator(
                dag=airflow_dag,
                task_id=step.name,
                provide_context=True,
                python_callable=functools.partial(
                    _step_callable, step_instance=step
                ),
            )

            # Configure the current airflow operator to run after all upstream
            # operators finished executing
            step_name_to_airflow_operator[step.name] = airflow_operator
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                airflow_operator.set_upstream(
                    step_name_to_airflow_operator[upstream_step_name]
                )

        # Return the finished airflow dag
        return airflow_dag

    @root_validator(skip_on_failure=True)
    def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets Airflow home according to orchestrator UUID.

        Args:
            values: Dictionary containing all orchestrator attributes values.

        Returns:
            Dictionary containing all orchestrator attributes values and the airflow home.

        Raises:
            ValueError: If the orchestrator UUID is not set.
        """
        if "uuid" not in values:
            raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
        values["airflow_home"] = os.path.join(
            io_utils.get_global_config_directory(),
            AIRFLOW_ROOT_DIR,
            str(values["uuid"]),
        )
        return values

    @property
    def dags_directory(self) -> str:
        """Returns path to the airflow dags directory.

        Returns:
            Path to the airflow dags directory.
        """
        return os.path.join(self.airflow_home, "dags")

    @property
    def pid_file(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.airflow_home, "airflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Returns path to the airflow log file.

        Returns:
            str: Path to the airflow log file.
        """
        return os.path.join(self.airflow_home, "airflow_orchestrator.log")

    @property
    def password_file(self) -> str:
        """Returns path to the webserver password file.

        Returns:
            Path to the webserver password file.
        """
        return os.path.join(self.airflow_home, "standalone_admin_password.txt")

    def _set_env(self) -> None:
        """Sets environment variables to configure airflow."""
        os.environ["AIRFLOW_HOME"] = self.airflow_home
        os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
        os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
        os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
        # check the DAG folder every 10 seconds for new files
        os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"

    def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
        """Copies DAG module to the Airflow DAGs directory if not already present.

        Args:
            dag_filepath: Path to the file in which the DAG is defined.
        """
        dags_directory = io_utils.resolve_relative_path(self.dags_directory)

        if dags_directory == os.path.dirname(dag_filepath):
            logger.debug("File is already in airflow DAGs directory.")
        else:
            logger.debug(
                "Copying dag file '%s' to DAGs directory.", dag_filepath
            )
            destination_path = os.path.join(
                dags_directory, os.path.basename(dag_filepath)
            )
            if fileio.exists(destination_path):
                logger.info(
                    "File '%s' already exists, overwriting with new DAG file",
                    destination_path,
                )
            fileio.copy(dag_filepath, destination_path, overwrite=True)

    def _log_webserver_credentials(self) -> None:
        """Logs URL and credentials to log in to the airflow webserver.

        Raises:
            FileNotFoundError: If the password file does not exist.
        """
        if fileio.exists(self.password_file):
            with open(self.password_file) as file:
                password = file.read().strip()
        else:
            raise FileNotFoundError(
                f"Can't find password file '{self.password_file}'"
            )
        logger.info(
            "To inspect your DAGs, login to http://0.0.0.0:8080 "
            "with username: admin password: %s",
            password,
        )

    def runtime_options(self) -> Dict[str, Any]:
        """Runtime options for the airflow orchestrator.

        Returns:
            Runtime options dictionary.
        """
        return {DAG_FILEPATH_OPTION_KEY: None}

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

        Args:
            pipeline: Pipeline to be deployed.
            stack: Stack to be deployed.
            runtime_configuration: Runtime configuration for the pipeline.

        Raises:
            RuntimeError: If Airflow is not running or no DAG filepath runtime
                          option is provided.
        """
        if not self.is_running:
            raise RuntimeError(
                "Airflow orchestrator is currently not running. Run `zenml "
                "stack up` to provision resources for the active stack."
            )

        try:
            dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
        except KeyError:
            raise RuntimeError(
                f"No DAG filepath found in runtime configuration. Make sure "
                f"to add the filepath to your airflow DAG file as a runtime "
                f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
            )

        self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)

    @property
    def is_running(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the daemon is running, False otherwise.

        Raises:
            RuntimeError: If port 8080 is occupied.
        """
        from airflow.cli.commands.standalone_command import StandaloneCommand
        from airflow.jobs.triggerer_job import TriggererJob

        daemon_running = daemon.check_if_daemon_is_running(self.pid_file)

        command = StandaloneCommand()
        webserver_port_open = command.port_open(8080)

        if not daemon_running:
            if webserver_port_open:
                raise RuntimeError(
                    "The airflow daemon does not seem to be running but "
                    "local port 8080 is occupied. Make sure the port is "
                    "available and try again."
                )

            # exit early so we don't check non-existing airflow databases
            return False

        # we can't use StandaloneCommand().is_ready() here as the
        # Airflow SequentialExecutor apparently does not send a heartbeat
        # while running a task which would result in this returning `False`
        # even if Airflow is running.
        airflow_running = webserver_port_open and command.job_running(
            TriggererJob
        )
        return airflow_running

    @property
    def is_provisioned(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the airflow daemon is running, False otherwise.
        """
        return self.is_running

    def provision(self) -> None:
        """Ensures that Airflow is running."""
        if self.is_running:
            logger.info("Airflow is already running.")
            self._log_webserver_credentials()
            return

        if not fileio.exists(self.dags_directory):
            io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

        from airflow.cli.commands.standalone_command import StandaloneCommand

        try:
            command = StandaloneCommand()
            # Run the daemon with a working directory inside the current
            # zenml repo so the same repo will be used to run the DAGs
            daemon.run_as_daemon(
                command.run,
                pid_file=self.pid_file,
                log_file=self.log_file,
                working_directory=get_source_root_path(),
            )
            while not self.is_running:
                # Wait until the daemon started all the relevant airflow
                # processes
                time.sleep(0.1)
            self._log_webserver_credentials()
        except Exception as e:
            logger.error(e)
            logger.error(
                "An error occurred while starting the Airflow daemon. If you "
                "want to start it manually, use the commands described in the "
                "official Airflow quickstart guide for running Airflow locally."
            )
            self.deprovision()

    def deprovision(self) -> None:
        """Stops the airflow daemon if necessary and tears down resources."""
        if self.is_running:
            daemon.stop_daemon(self.pid_file)

        fileio.rmtree(self.airflow_home)
        logger.info("Airflow spun down.")
dags_directory: str property readonly

Returns path to the airflow dags directory.

Returns:

Type Description
str

Path to the airflow dags directory.

is_provisioned: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the airflow daemon is running, False otherwise.

is_running: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the daemon is running, False otherwise.

Exceptions:

Type Description
RuntimeError

If port 8080 is occupied.

log_file: str property readonly

Returns path to the airflow log file.

Returns:

Type Description
str

Path to the airflow log file.

password_file: str property readonly

Returns path to the webserver password file.

Returns:

Type Description
str

Path to the webserver password file.

pid_file: str property readonly

Returns path to the daemon PID file.

Returns:

Type Description
str

Path to the daemon PID file.

__init__(self, **values) special

Sets environment variables to configure airflow.

Parameters:

Name Type Description Default
**values Any

Values to set in the orchestrator.

{}
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
    """Sets environment variables to configure airflow.

    Args:
        **values: Values to set in the orchestrator.
    """
    super().__init__(**values)
    self._set_env()
deprovision(self)

Stops the airflow daemon if necessary and tears down resources.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
    """Stops the airflow daemon if necessary and tears down resources."""
    if self.is_running:
        daemon.stop_daemon(self.pid_file)

    fileio.rmtree(self.airflow_home)
    logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates an Airflow DAG as the intermediate representation for the pipeline.

This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.

How it works:

A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.

For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.

Finally, the dag is fully complete and can be returned.

Parameters:

Name Type Description Default
sorted_steps List[zenml.steps.base_step.BaseStep]

List of steps in the pipeline.

required
pipeline BasePipeline

The pipeline to be executed.

required
pb2_pipeline Pipeline

The pipeline as a protobuf message.

required
stack Stack

The stack on which the pipeline will be deployed.

required
runtime_configuration RuntimeConfiguration

The runtime configuration.

required

Returns:

Type Description
Any

The Airflow DAG.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List[BaseStep],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates an Airflow DAG as the intermediate representation for the pipeline.

    This DAG will be loaded by airflow in the target environment
    and used for orchestration of the pipeline.

    How it works:
    -------------
    A new airflow_dag is instantiated with the pipeline name and among
    others things the run schedule.

    For each step of the pipeline a callable is created. This callable
    uses the run_step() method to execute the step. The parameters of
    this callable are pre-filled and an airflow step_operator is created
    within the dag. The dependencies to upstream steps are then
    configured.

    Finally, the dag is fully complete and can be returned.

    Args:
        sorted_steps: List of steps in the pipeline.
        pipeline: The pipeline to be executed.
        pb2_pipeline: The pipeline as a protobuf message.
        stack: The stack on which the pipeline will be deployed.
        runtime_configuration: The runtime configuration.

    Returns:
        The Airflow DAG.
    """
    import airflow
    from airflow.operators import python as airflow_python

    # Instantiate and configure airflow Dag with name and schedule
    airflow_dag = airflow.DAG(
        dag_id=pipeline.name,
        is_paused_upon_creation=False,
        **self._translate_schedule(runtime_configuration.schedule),
    )

    # Dictionary mapping step names to airflow_operators. This will be needed
    # to configure airflow operator dependencies
    step_name_to_airflow_operator = {}

    for step in sorted_steps:
        # Create callable that will be used by airflow to execute the step
        # within the orchestrated environment
        def _step_callable(step_instance: "BaseStep", **kwargs):
            # Extract run name for the kwargs that will be passed to the
            # callable
            run_name = kwargs["ti"].get_dagrun().run_id
            self.run_step(
                step=step_instance,
                run_name=run_name,
                pb2_pipeline=pb2_pipeline,
            )

        # Create airflow python operator that contains the step callable
        airflow_operator = airflow_python.PythonOperator(
            dag=airflow_dag,
            task_id=step.name,
            provide_context=True,
            python_callable=functools.partial(
                _step_callable, step_instance=step
            ),
        )

        # Configure the current airflow operator to run after all upstream
        # operators finished executing
        step_name_to_airflow_operator[step.name] = airflow_operator
        upstream_step_names = self.get_upstream_step_names(
            step=step, pb2_pipeline=pb2_pipeline
        )
        for upstream_step_name in upstream_step_names:
            airflow_operator.set_upstream(
                step_name_to_airflow_operator[upstream_step_name]
            )

    # Return the finished airflow dag
    return airflow_dag
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

Parameters:

Name Type Description Default
pipeline BasePipeline

Pipeline to be deployed.

required
stack Stack

Stack to be deployed.

required
runtime_configuration RuntimeConfiguration

Runtime configuration for the pipeline.

required

Exceptions:

Type Description
RuntimeError

If Airflow is not running or no DAG filepath runtime option is provided.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

    Args:
        pipeline: Pipeline to be deployed.
        stack: Stack to be deployed.
        runtime_configuration: Runtime configuration for the pipeline.

    Raises:
        RuntimeError: If Airflow is not running or no DAG filepath runtime
                      option is provided.
    """
    if not self.is_running:
        raise RuntimeError(
            "Airflow orchestrator is currently not running. Run `zenml "
            "stack up` to provision resources for the active stack."
        )

    try:
        dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
    except KeyError:
        raise RuntimeError(
            f"No DAG filepath found in runtime configuration. Make sure "
            f"to add the filepath to your airflow DAG file as a runtime "
            f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
        )

    self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)

Ensures that Airflow is running.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
    """Ensures that Airflow is running."""
    if self.is_running:
        logger.info("Airflow is already running.")
        self._log_webserver_credentials()
        return

    if not fileio.exists(self.dags_directory):
        io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

    from airflow.cli.commands.standalone_command import StandaloneCommand

    try:
        command = StandaloneCommand()
        # Run the daemon with a working directory inside the current
        # zenml repo so the same repo will be used to run the DAGs
        daemon.run_as_daemon(
            command.run,
            pid_file=self.pid_file,
            log_file=self.log_file,
            working_directory=get_source_root_path(),
        )
        while not self.is_running:
            # Wait until the daemon started all the relevant airflow
            # processes
            time.sleep(0.1)
        self._log_webserver_credentials()
    except Exception as e:
        logger.error(e)
        logger.error(
            "An error occurred while starting the Airflow daemon. If you "
            "want to start it manually, use the commands described in the "
            "official Airflow quickstart guide for running Airflow locally."
        )
        self.deprovision()
runtime_options(self)

Runtime options for the airflow orchestrator.

Returns:

Type Description
Dict[str, Any]

Runtime options dictionary.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
    """Runtime options for the airflow orchestrator.

    Returns:
        Runtime options dictionary.
    """
    return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values) classmethod

Sets Airflow home according to orchestrator UUID.

Parameters:

Name Type Description Default
values Dict[str, Any]

Dictionary containing all orchestrator attributes values.

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing all orchestrator attributes values and the airflow home.

Exceptions:

Type Description
ValueError

If the orchestrator UUID is not set.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets Airflow home according to orchestrator UUID.

    Args:
        values: Dictionary containing all orchestrator attributes values.

    Returns:
        Dictionary containing all orchestrator attributes values and the airflow home.

    Raises:
        ValueError: If the orchestrator UUID is not set.
    """
    if "uuid" not in values:
        raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
    values["airflow_home"] = os.path.join(
        io_utils.get_global_config_directory(),
        AIRFLOW_ROOT_DIR,
        str(values["uuid"]),
    )
    return values

aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = ["boto3==1.21.21", "sagemaker==2.82.2"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AWS_SECRET_MANAGER_FLAVOR,
                source="zenml.integrations.aws.secrets_managers"
                ".AWSSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_CONTAINER_REGISTRY_FLAVOR,
                source="zenml.integrations.aws.container_registries"
                ".AWSContainerRegistry",
                type=StackComponentType.CONTAINER_REGISTRY,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.aws.step_operators"
                ".SagemakerStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AWS_SECRET_MANAGER_FLAVOR,
            source="zenml.integrations.aws.secrets_managers"
            ".AWSSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_CONTAINER_REGISTRY_FLAVOR,
            source="zenml.integrations.aws.container_registries"
            ".AWSContainerRegistry",
            type=StackComponentType.CONTAINER_REGISTRY,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.aws.step_operators"
            ".SagemakerStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry) pydantic-model

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_CONTAINER_REGISTRY_FLAVOR

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client("ecr").describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
        if not repo_exists:
            match = re.search(f"{self.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client("ecr").describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
    if not repo_exists:
        match = re.search(f"{self.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    region_name: str = DEFAULT_AWS_REGION

    # Class configuration
    FLAVOR: ClassVar[str] = AWS_SECRET_MANAGER_FLAVOR
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self._ensure_client_connected(self.region_name)
        secret_value = jsonify_secret_contents(secret)

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        kwargs = {"Name": secret.name, "SecretString": secret_value}

        self.CLIENT.create_secret(**kwargs)

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            RuntimeError: if the secret does not exist
        """
        self._ensure_client_connected(self.region_name)
        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=secret_name
        )
        if "SecretString" not in get_secret_value_response:
            raise RuntimeError(f"No secrets found within the {secret_name}")
        secret_contents: Dict[str, str] = json.loads(
            get_secret_value_response["SecretString"]
        )

        zenml_schema_name = secret_contents.pop(ZENML_SCHEMA_NAME)
        secret_contents["name"] = secret_name

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        return secret_schema(**secret_contents)

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        self._ensure_client_connected(self.region_name)

        # TODO [ENG-720]: Deal with pagination in the aws secret manager when
        #  listing all secrets
        # TODO [ENG-721]: take out this magic maxresults number
        response = self.CLIENT.list_secrets(MaxResults=100)
        return [secret["Name"] for secret in response["SecretList"]]

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update
        """
        self._ensure_client_connected(self.region_name)

        secret_value = jsonify_secret_contents(secret)

        kwargs = {"SecretId": secret.name, "SecretString": secret_value}

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete
        """
        self._ensure_client_connected(self.region_name)
        self.CLIENT.delete_secret(
            SecretId=secret_name, ForceDeleteWithoutRecovery=False
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.region_name)
        for secret_name in self.get_all_secret_keys():
            self.CLIENT.delete_secret(
                SecretId=secret_name, ForceDeleteWithoutRecovery=True
            )
delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.region_name)
    for secret_name in self.get_all_secret_keys():
        self.CLIENT.delete_secret(
            SecretId=secret_name, ForceDeleteWithoutRecovery=True
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete
    """
    self._ensure_client_connected(self.region_name)
    self.CLIENT.delete_secret(
        SecretId=secret_name, ForceDeleteWithoutRecovery=False
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    self._ensure_client_connected(self.region_name)

    # TODO [ENG-720]: Deal with pagination in the aws secret manager when
    #  listing all secrets
    # TODO [ENG-721]: take out this magic maxresults number
    response = self.CLIENT.list_secrets(MaxResults=100)
    return [secret["Name"] for secret in response["SecretList"]]
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
RuntimeError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        RuntimeError: if the secret does not exist
    """
    self._ensure_client_connected(self.region_name)
    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=secret_name
    )
    if "SecretString" not in get_secret_value_response:
        raise RuntimeError(f"No secrets found within the {secret_name}")
    secret_contents: Dict[str, str] = json.loads(
        get_secret_value_response["SecretString"]
    )

    zenml_schema_name = secret_contents.pop(ZENML_SCHEMA_NAME)
    secret_contents["name"] = secret_name

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=zenml_schema_name
    )
    return secret_schema(**secret_contents)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self._ensure_client_connected(self.region_name)
    secret_value = jsonify_secret_contents(secret)

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    kwargs = {"Name": secret.name, "SecretString": secret_value}

    self.CLIENT.create_secret(**kwargs)
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update
    """
    self._ensure_client_connected(self.region_name)

    secret_value = jsonify_secret_contents(secret)

    kwargs = {"SecretId": secret.name, "SecretString": secret_value}

    self.CLIENT.put_secret_value(**kwargs)
jsonify_secret_contents(secret)

Adds the secret type to the secret contents.

This persists the schema type in the secrets backend, so that the correct SecretSchema can be retrieved when the secret is queried from the backend.

Parameters:

Name Type Description Default
secret BaseSecretSchema

should be a subclass of the BaseSecretSchema class

required

Returns:

Type Description
str

jsonified dictionary containing all key-value pairs and the ZenML schema type

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def jsonify_secret_contents(secret: BaseSecretSchema) -> str:
    """Adds the secret type to the secret contents.

    This persists the schema type in the secrets backend, so that the correct
    SecretSchema can be retrieved when the secret is queried from the backend.

    Args:
        secret: should be a subclass of the BaseSecretSchema class

    Returns:
        jsonified dictionary containing all key-value pairs and the ZenML schema
        type
    """
    secret_contents = secret.content
    secret_contents[ZENML_SCHEMA_NAME] = secret.TYPE
    return json.dumps(secret_contents)

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator) pydantic-model

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

instance_type str

The type of the compute instance where jobs will run.

base_image Optional[str]

[Optional] The base image to use for building the docker image that will be executed.

bucket Optional[str]

[Optional] Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

experiment_name Optional[str]

[Optional] The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        instance_type: The type of the compute instance where jobs will run.
        base_image: [Optional] The base image to use for building the docker
            image that will be executed.
        bucket: [Optional] Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
        experiment_name: [Optional] The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
    """

    role: str
    instance_type: str

    base_image: Optional[str] = None
    bucket: Optional[str] = None
    experiment_name: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Returns:
            A validator that checks that the stack contains a container registry.
        """

        def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
            return (
                stack.orchestrator.FLAVOR == "local",
                "Local orchestrator is required",
            )

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_ensure_local_orchestrator,
        )

    def _build_docker_image(
        self,
        pipeline_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> str:
        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        registry_uri = container_registry.uri.rstrip("/")
        image_name = f"{registry_uri}/zenml-sagemaker:{pipeline_name}"

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            entrypoint=" ".join(entrypoint_command),
            requirements=set(requirements),
            base_image=self.base_image,
        )
        container_registry.push_image(image_name)
        return docker_utils.get_image_digest(image_name) or image_name

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> None:
        """Launches a step on Sagemaker.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
        """
        image_name = self._build_docker_image(
            pipeline_name=pipeline_name,
            requirements=requirements,
            entrypoint_command=entrypoint_command,
        )

        session = sagemaker.Session(default_bucket=self.bucket)
        estimator = sagemaker.estimator.Estimator(
            image_name,
            self.role,
            instance_count=1,
            instance_type=self.instance_type,
            sagemaker_session=session,
        )

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_run_name = run_name.replace("_", "-")

        experiment_config = {}
        if self.experiment_name:
            experiment_config = {
                "ExperimentName": self.experiment_name,
                "TrialName": sanitized_run_name,
            }

        estimator.fit(
            wait=True,
            experiment_config=experiment_config,
            job_name=sanitized_run_name,
        )
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a container registry.

launch(self, pipeline_name, run_name, requirements, entrypoint_command)

Launches a step on Sagemaker.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
) -> None:
    """Launches a step on Sagemaker.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
    """
    image_name = self._build_docker_image(
        pipeline_name=pipeline_name,
        requirements=requirements,
        entrypoint_command=entrypoint_command,
    )

    session = sagemaker.Session(default_bucket=self.bucket)
    estimator = sagemaker.estimator.Estimator(
        image_name,
        self.role,
        instance_count=1,
        instance_type=self.instance_type,
        sagemaker_session=session,
    )

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_run_name = run_name.replace("_", "-")

    experiment_config = {}
    if self.experiment_name:
        experiment_config = {
            "ExperimentName": self.experiment_name,
            "TrialName": sanitized_run_name,
        }

    estimator.fit(
        wait=True,
        experiment_config=experiment_config,
        job_name=sanitized_run_name,
    )

azure special

Initialization of the ZenML Azure integration.

The Azure integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, and an io module to handle file operations on Azure Blob Storage. The Azure Step Operator integration submodule provides a way to run ZenML steps in AzureML.

AzureIntegration (Integration)

Definition of Azure integration for ZenML.

Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
    """Definition of Azure integration for ZenML."""

    NAME = AZURE
    REQUIREMENTS = [
        "adlfs==2021.10.0",
        "azure-keyvault-keys",
        "azure-keyvault-secrets",
        "azure-identity",
        "azureml-core==1.39.0.post1",
    ]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declares the flavors for the integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AZURE_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.azure.artifact_stores"
                ".AzureArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZURE_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.azure.secrets_managers"
                ".AzureSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZUREML_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.azure.step_operators"
                ".AzureMLStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declares the flavors for the integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declares the flavors for the integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AZURE_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.azure.artifact_stores"
            ".AzureArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZURE_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.azure.secrets_managers"
            ".AzureSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZUREML_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.azure.step_operators"
            ".AzureMLStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

artifact_stores special

Initialization of the Azure Artifact Store integration.

azure_artifact_store

Implementation of the Azure Artifact Store integration.

AzureArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for Microsoft Azure based artifacts.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Microsoft Azure based artifacts."""

    _filesystem: Optional[adlfs.AzureBlobFileSystem] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AZURE_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}

    @property
    def filesystem(self) -> adlfs.AzureBlobFileSystem:
        """The adlfs filesystem to access this artifact store.

        Returns:
            The adlfs filesystem to access this artifact store.
        """
        if not self._filesystem:
            secret = self.get_authentication_secret(
                expected_schema_type=AzureSecretSchema
            )
            credentials = secret.content if secret else {}

            self._filesystem = adlfs.AzureBlobFileSystem(
                **credentials,
                anon=False,
                use_listings_cache=False,
            )
        return self._filesystem

    @classmethod
    def _split_path(cls, path: PathType) -> Tuple[str, str]:
        """Splits a path into the filesystem prefix and remainder.

        Example:
        ```python
        prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
        print(prefix, remainder)  # "az://" "my_container/test.txt"
        ```

        Args:
            path: The path to split.

        Returns:
            A tuple of the filesystem prefix and the remainder.
        """
        path = convert_to_str(path)
        prefix = ""
        for potential_prefix in cls.SUPPORTED_SCHEMES:
            if path.startswith(potential_prefix):
                prefix = potential_prefix
                path = path[len(potential_prefix) :]
                break

        return prefix, path

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
            in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        prefix, _ = self._split_path(pattern)
        return [
            f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path to list.

        Returns:
            A list of files in the given directory.
        """
        _, path = self._split_path(path)

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a dictionary returned by the Azure filesystem.

            Args:
                file_dict: A dictionary returned by the Azure filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: The path to get stat info for.

        Returns:
            Stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        prefix, _ = self._split_path(top)
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{prefix}{directory}", subdirectories, files
filesystem: AzureBlobFileSystem property readonly

The adlfs filesystem to access this artifact store.

Returns:

Type Description
AzureBlobFileSystem

The adlfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
        in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    prefix, _ = self._split_path(pattern)
    return [
        f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of files in the given directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path to list.

    Returns:
        A list of files in the given directory.
    """
    _, path = self._split_path(path)

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a dictionary returned by the Azure filesystem.

        Args:
            file_dict: A dictionary returned by the Azure filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

Stat info.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: The path to get stat info for.

    Returns:
        Stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    prefix, _ = self._split_path(top)
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{prefix}{directory}", subdirectories, files

secrets_managers special

Initialization of the Azure Secrets Manager integration.

azure_secrets_manager

Implementation of the Azure Secrets Manager integration.

AzureSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the Azure secrets manager.

Attributes:

Name Type Description
project_id

This is necessary to access the correct Azure project. The project_id of your Azure project space that contains the Secret Manager.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
    """Class to interact with the Azure secrets manager.

    Attributes:
        project_id:  This is necessary to access the correct Azure project.
                     The project_id of your Azure project space that contains
                     the Secret Manager.
    """

    key_vault_name: str

    # Class configuration
    FLAVOR: ClassVar[str] = AZURE_SECRETS_MANAGER_FLAVOR
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls, vault_name: str) -> None:
        if cls.CLIENT is None:
            KVUri = f"https://{vault_name}.vault.azure.net"

            credential = DefaultAzureCredential()
            cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
            ValueError: if the secret name contains an underscore.
        """
        self._ensure_client_connected(self.key_vault_name)

        if "_" in secret.name:
            raise ValueError(
                f"The secret name `{secret.name}` contains an underscore. "
                f"This will cause issues with Azure. Please try again."
            )

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name '{secret.name}' already exists."
            )

        adjusted_content = prepend_group_name_to_keys(secret)

        for k, v in adjusted_content.items():
            # Create the secret, this only creates an empty secret with the
            #  supplied name.
            azure_secret = self.CLIENT.set_secret(k, v)
            self.CLIENT.update_secret_properties(
                azure_secret.name,
                tags={
                    ZENML_GROUP_KEY: secret.name,
                    ZENML_SCHEMA_NAME: secret.TYPE,
                },
            )

            logger.debug("Created created secret: %s", azure_secret.name)
            logger.debug("Added value to secret.")

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Get a secret by its name.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            RuntimeError: if the secret does not exist
            ValueError: if the secret is named 'name'
        """
        self._ensure_client_connected(self.key_vault_name)

        secret_contents = {}
        zenml_schema_name = ""

        for secret_property in self.CLIENT.list_properties_of_secrets():
            response = self.CLIENT.get_secret(secret_property.name)
            tags = response.properties.tags
            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                secret_key = remove_group_name_from_key(
                    combined_key_name=response.name, group_name=secret_name
                )
                if secret_key == "name":
                    raise ValueError("The secret's key cannot be 'name'.")

                secret_contents[secret_key] = response.value

                zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

        if not secret_contents:
            raise RuntimeError(f"No secrets found within the {secret_name}")

        secret_contents["name"] = secret_name

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        return secret_schema(**secret_contents)

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        self._ensure_client_connected(self.key_vault_name)

        set_of_secrets = set()

        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags
            if tags:
                set_of_secrets.add(tags.get(ZENML_GROUP_KEY))

        return list(set_of_secrets)

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret by creating new versions of the existing secrets.

        Args:
            secret: the secret to update
        """
        self._ensure_client_connected(self.key_vault_name)

        adjusted_content = prepend_group_name_to_keys(secret)

        for k, v in adjusted_content.items():
            self.CLIENT.set_secret(k, v)
            self.CLIENT.update_secret_properties(
                k,
                tags={
                    ZENML_GROUP_KEY: secret.name,
                    ZENML_SCHEMA_NAME: secret.TYPE,
                },
            )

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret. by name.

        In Azure a secret is a single k-v pair. Within ZenML a secret is a
        collection of k-v pairs. As such, deleting a secret will iterate through
        all secrets and delete the ones with the secret_name as label.

        Args:
            secret_name: the name of the secret to delete
        """
        self._ensure_client_connected(self.key_vault_name)

        # Go through all Azure secrets and delete the ones with the secret_name
        #  as label.
        for secret_property in self.CLIENT.list_properties_of_secrets():
            response = self.CLIENT.get_secret(secret_property.name)
            tags = response.properties.tags
            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                self.CLIENT.begin_delete_secret(secret_property.name).result()

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_connected(self.key_vault_name)

        # List all secrets.
        for secret_property in self.CLIENT.list_properties_of_secrets():
            response = self.CLIENT.get_secret(secret_property.name)
            tags = response.properties.tags
            if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
                logger.info(
                    "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                    secret_property.name,
                    tags.get(ZENML_GROUP_KEY),
                )
                self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_connected(self.key_vault_name)

    # List all secrets.
    for secret_property in self.CLIENT.list_properties_of_secrets():
        response = self.CLIENT.get_secret(secret_property.name)
        tags = response.properties.tags
        if tags and (ZENML_GROUP_KEY in tags or ZENML_SCHEMA_NAME in tags):
            logger.info(
                "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                secret_property.name,
                tags.get(ZENML_GROUP_KEY),
            )
            self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)

Delete an existing secret. by name.

In Azure a secret is a single k-v pair. Within ZenML a secret is a collection of k-v pairs. As such, deleting a secret will iterate through all secrets and delete the ones with the secret_name as label.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret. by name.

    In Azure a secret is a single k-v pair. Within ZenML a secret is a
    collection of k-v pairs. As such, deleting a secret will iterate through
    all secrets and delete the ones with the secret_name as label.

    Args:
        secret_name: the name of the secret to delete
    """
    self._ensure_client_connected(self.key_vault_name)

    # Go through all Azure secrets and delete the ones with the secret_name
    #  as label.
    for secret_property in self.CLIENT.list_properties_of_secrets():
        response = self.CLIENT.get_secret(secret_property.name)
        tags = response.properties.tags
        if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
            self.CLIENT.begin_delete_secret(secret_property.name).result()
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    self._ensure_client_connected(self.key_vault_name)

    set_of_secrets = set()

    for secret_property in self.CLIENT.list_properties_of_secrets():
        tags = secret_property.tags
        if tags:
            set_of_secrets.add(tags.get(ZENML_GROUP_KEY))

    return list(set_of_secrets)
get_secret(self, secret_name)

Get a secret by its name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
RuntimeError

if the secret does not exist

ValueError

if the secret is named 'name'

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Get a secret by its name.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        RuntimeError: if the secret does not exist
        ValueError: if the secret is named 'name'
    """
    self._ensure_client_connected(self.key_vault_name)

    secret_contents = {}
    zenml_schema_name = ""

    for secret_property in self.CLIENT.list_properties_of_secrets():
        response = self.CLIENT.get_secret(secret_property.name)
        tags = response.properties.tags
        if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
            secret_key = remove_group_name_from_key(
                combined_key_name=response.name, group_name=secret_name
            )
            if secret_key == "name":
                raise ValueError("The secret's key cannot be 'name'.")

            secret_contents[secret_key] = response.value

            zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

    if not secret_contents:
        raise RuntimeError(f"No secrets found within the {secret_name}")

    secret_contents["name"] = secret_name

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=zenml_schema_name
    )
    return secret_schema(**secret_contents)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

ValueError

if the secret name contains an underscore.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
        ValueError: if the secret name contains an underscore.
    """
    self._ensure_client_connected(self.key_vault_name)

    if "_" in secret.name:
        raise ValueError(
            f"The secret name `{secret.name}` contains an underscore. "
            f"This will cause issues with Azure. Please try again."
        )

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name '{secret.name}' already exists."
        )

    adjusted_content = prepend_group_name_to_keys(secret)

    for k, v in adjusted_content.items():
        # Create the secret, this only creates an empty secret with the
        #  supplied name.
        azure_secret = self.CLIENT.set_secret(k, v)
        self.CLIENT.update_secret_properties(
            azure_secret.name,
            tags={
                ZENML_GROUP_KEY: secret.name,
                ZENML_SCHEMA_NAME: secret.TYPE,
            },
        )

        logger.debug("Created created secret: %s", azure_secret.name)
        logger.debug("Added value to secret.")
update_secret(self, secret)

Update an existing secret by creating new versions of the existing secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret by creating new versions of the existing secrets.

    Args:
        secret: the secret to update
    """
    self._ensure_client_connected(self.key_vault_name)

    adjusted_content = prepend_group_name_to_keys(secret)

    for k, v in adjusted_content.items():
        self.CLIENT.set_secret(k, v)
        self.CLIENT.update_secret_properties(
            k,
            tags={
                ZENML_GROUP_KEY: secret.name,
                ZENML_SCHEMA_NAME: secret.TYPE,
            },
        )
prepend_group_name_to_keys(secret)

Adds the secret group name to the keys of each secret key-value pair.

This allows using the same key across multiple secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The ZenML Secret schema

required

Returns:

Type Description
Dict[str, str]

A dictionary with the secret keys prepended with the group name

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def prepend_group_name_to_keys(secret: BaseSecretSchema) -> Dict[str, str]:
    """Adds the secret group name to the keys of each secret key-value pair.

    This allows using the same key across multiple
    secrets.

    Args:
        secret: The ZenML Secret schema

    Returns:
        A dictionary with the secret keys prepended with the group name
    """
    return {f"{secret.name}-{k}": v for k, v in secret.content.items()}
remove_group_name_from_key(combined_key_name, group_name)

Removes the secret group name from the secret key.

Parameters:

Name Type Description Default
combined_key_name str

Full name as it is within the Azure secrets manager

required
group_name str

Group name (the ZenML Secret name)

required

Returns:

Type Description
str

The cleaned key

Exceptions:

Type Description
RuntimeError

If the group name is not found in the key name

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
    """Removes the secret group name from the secret key.

    Args:
        combined_key_name: Full name as it is within the Azure secrets manager
        group_name: Group name (the ZenML Secret name)

    Returns:
        The cleaned key

    Raises:
        RuntimeError: If the group name is not found in the key name
    """
    if combined_key_name.startswith(f"{group_name}-"):
        return combined_key_name[len(f"{group_name}-") :]
    else:
        raise RuntimeError(
            f"Key-name `{combined_key_name}` does not have the "
            f"prefix `{group_name}`. Key could not be "
            f"extracted."
        )

step_operators special

Initialization of AzureML Step Operator integration.

azureml_step_operator

Implementation of the ZenML AzureML Step Operator.

AzureMLStepOperator (BaseStepOperator) pydantic-model

Step operator to run a step on AzureML.

This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.

Attributes:

Name Type Description
subscription_id str

The Azure account's subscription ID

resource_group str

The resource group to which the AzureML workspace is deployed.

workspace_name str

The name of the AzureML Workspace.

compute_target_name str

The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already.

environment_name Optional[str]

[Optional] The name of the environment if there already exists one.

docker_base_image Optional[str]

[Optional] The custom docker base image that the environment should use.

tenant_id Optional[str]

The Azure Tenant ID.

service_principal_id Optional[str]

The ID for the service principal that is created to allow apps to access secure resources.

service_principal_password Optional[str]

Password for the service principal.

Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator):
    """Step operator to run a step on AzureML.

    This class defines code that can set up an AzureML environment and run the
    ZenML entrypoint command in it.

    Attributes:
        subscription_id: The Azure account's subscription ID
        resource_group: The resource group to which the AzureML workspace
            is deployed.
        workspace_name: The name of the AzureML Workspace.
        compute_target_name: The name of the configured ComputeTarget.
            An instance of it has to be created on the portal if it doesn't
            exist already.
        environment_name: [Optional] The name of the environment if there
            already exists one.
        docker_base_image: [Optional] The custom docker base image that the
            environment should use.
        tenant_id: The Azure Tenant ID.
        service_principal_id: The ID for the service principal that is created
            to allow apps to access secure resources.
        service_principal_password: Password for the service principal.
    """

    subscription_id: str
    resource_group: str
    workspace_name: str
    compute_target_name: str

    # Environment
    environment_name: Optional[str] = None
    docker_base_image: Optional[str] = None

    # Service principal authentication
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
    tenant_id: Optional[str] = None
    service_principal_id: Optional[str] = None
    service_principal_password: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AZUREML_STEP_OPERATOR_FLAVOR

    def _get_authentication(self) -> Optional[AbstractAuthentication]:
        """Returns the authentication object for the AzureML environment.

        Returns:
            The authentication object for the AzureML environment.
        """
        if (
            self.tenant_id
            and self.service_principal_id
            and self.service_principal_password
        ):
            return ServicePrincipalAuthentication(
                tenant_id=self.tenant_id,
                service_principal_id=self.service_principal_id,
                service_principal_password=self.service_principal_password,
            )
        return None

    def _prepare_environment(
        self, workspace: Workspace, requirements: List[str], run_name: str
    ) -> Environment:
        """Prepares the environment in which Azure will run all jobs.

        Args:
            workspace: The AzureML Workspace that has configuration
                for a storage account, container registry among other
                things.
            requirements: The list of requirements to be installed
                in the environment.
            run_name: The name of the pipeline run that can be used
                for naming environments and runs.

        Returns:
            The AzureML Environment object.
        """
        if self.environment_name:
            environment = Environment.get(
                workspace=workspace, name=self.environment_name
            )
            if not environment.python.conda_dependencies:
                environment.python.conda_dependencies = (
                    CondaDependencies.create(
                        python_version=ZenMLEnvironment.python_version()
                    )
                )

            for requirement in requirements:
                environment.python.conda_dependencies.add_pip_package(
                    requirement
                )
        else:
            environment = Environment(name=f"zenml-{run_name}")
            environment.python.conda_dependencies = CondaDependencies.create(
                pip_packages=requirements,
                python_version=ZenMLEnvironment.python_version(),
            )

            if self.docker_base_image:
                # replace the default azure base image
                environment.docker.base_image = self.docker_base_image

        environment_variables = {
            "ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
        }
        # set credentials to access azure storage
        for key in [
            "AZURE_STORAGE_ACCOUNT_KEY",
            "AZURE_STORAGE_ACCOUNT_NAME",
            "AZURE_STORAGE_CONNECTION_STRING",
            "AZURE_STORAGE_SAS_TOKEN",
        ]:
            value = os.getenv(key)
            if value:
                environment_variables[key] = value

        environment_variables[
            ENV_ZENML_CONFIG_PATH
        ] = f"./{CONTAINER_ZENML_CONFIG_DIR}"

        environment.environment_variables = environment_variables
        return environment

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> None:
        """Launches a step on AzureML.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
        """
        workspace = Workspace.get(
            subscription_id=self.subscription_id,
            resource_group=self.resource_group,
            name=self.workspace_name,
            auth=self._get_authentication(),
        )

        source_directory = get_source_root_path()
        config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
        try:

            # Save a copy of the current global configuration with the
            # active profile contents into the build context, to have
            # the configured stacks accessible from within the Azure ML
            # environment.
            GlobalConfiguration().copy_active_configuration(
                config_path,
                load_config_path=f"./{CONTAINER_ZENML_CONFIG_DIR}",
            )

            environment = self._prepare_environment(
                workspace=workspace,
                requirements=requirements,
                run_name=run_name,
            )
            compute_target = ComputeTarget(
                workspace=workspace, name=self.compute_target_name
            )

            run_config = ScriptRunConfig(
                source_directory=source_directory,
                environment=environment,
                compute_target=compute_target,
                command=entrypoint_command,
            )

            experiment = Experiment(workspace=workspace, name=pipeline_name)
            run = experiment.submit(config=run_config)

        finally:
            # Clean up the temporary build files
            fileio.rmtree(config_path)

        run.display_name = run_name
        run.wait_for_completion(show_output=True)
launch(self, pipeline_name, run_name, requirements, entrypoint_command)

Launches a step on AzureML.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
) -> None:
    """Launches a step on AzureML.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
    """
    workspace = Workspace.get(
        subscription_id=self.subscription_id,
        resource_group=self.resource_group,
        name=self.workspace_name,
        auth=self._get_authentication(),
    )

    source_directory = get_source_root_path()
    config_path = os.path.join(source_directory, CONTAINER_ZENML_CONFIG_DIR)
    try:

        # Save a copy of the current global configuration with the
        # active profile contents into the build context, to have
        # the configured stacks accessible from within the Azure ML
        # environment.
        GlobalConfiguration().copy_active_configuration(
            config_path,
            load_config_path=f"./{CONTAINER_ZENML_CONFIG_DIR}",
        )

        environment = self._prepare_environment(
            workspace=workspace,
            requirements=requirements,
            run_name=run_name,
        )
        compute_target = ComputeTarget(
            workspace=workspace, name=self.compute_target_name
        )

        run_config = ScriptRunConfig(
            source_directory=source_directory,
            environment=environment,
            compute_target=compute_target,
            command=entrypoint_command,
        )

        experiment = Experiment(workspace=workspace, name=pipeline_name)
        run = experiment.submit(config=run_config)

    finally:
        # Clean up the temporary build files
        fileio.rmtree(config_path)

    run.display_name = run_name
    run.wait_for_completion(show_output=True)

constants

Constants for ZenML integrations.

dash special

Initialization of the Dash integration.

DashIntegration (Integration)

Definition of Dash integration for ZenML.

Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
    """Definition of Dash integration for ZenML."""

    NAME = DASH
    REQUIREMENTS = [
        "dash>=2.0.0",
        "dash-cytoscape>=0.3.0",
        "dash-bootstrap-components>=1.0.1",
        "jupyter-dash>=0.4.2",
    ]

visualizers special

Initialization of the Pipeline Run Visualizer.

pipeline_run_lineage_visualizer

Implementation of the pipeline run lineage visualizer.

PipelineRunLineageVisualizer (BasePipelineRunVisualizer)

Implementation of a lineage diagram via the dash and dash-cytoscape libraries.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BasePipelineRunVisualizer):
    """Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""

    ARTIFACT_PREFIX = "artifact_"
    STEP_PREFIX = "step_"
    STATUS_CLASS_MAPPING = {
        ExecutionStatus.CACHED: "green",
        ExecutionStatus.FAILED: "red",
        ExecutionStatus.RUNNING: "yellow",
        ExecutionStatus.COMPLETED: "blue",
    }

    def visualize(
        self,
        object: PipelineRunView,
        magic: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> dash.Dash:
        """Method to visualize pipeline runs via the Dash library.

        The layout puts every layer of the dag in a column.

        Args:
            object: The pipeline run to visualize.
            magic: If True, the visualization is rendered in a magic mode.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The Dash application.
        """
        external_stylesheets = [
            dbc.themes.BOOTSTRAP,
            dbc.icons.BOOTSTRAP,
        ]
        if magic:
            if Environment.in_notebook:
                # Only import jupyter_dash in this case
                from jupyter_dash import JupyterDash  # noqa

                JupyterDash.infer_jupyter_proxy_config()

                app = JupyterDash(
                    __name__,
                    external_stylesheets=external_stylesheets,
                )
                mode = "inline"
            else:
                cli_utils.warning(
                    "Cannot set magic flag in non-notebook environments."
                )
        else:
            app = dash.Dash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = None
        nodes, edges, first_step_id = [], [], None
        first_step_id = None
        for step in object.steps:
            step_output_artifacts = list(step.outputs.values())
            execution_id = (
                step_output_artifacts[0].producer_step.id
                if step_output_artifacts
                else step.id
            )
            step_id = self.STEP_PREFIX + str(step.id)
            if first_step_id is None:
                first_step_id = step_id
            nodes.append(
                {
                    "data": {
                        "id": step_id,
                        "execution_id": execution_id,
                        "label": f"{execution_id} / {step.entrypoint_name}",
                        "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                        "name": step.name,  # redundant for consistency
                        "type": "step",
                        "parameters": step.parameters,
                        "inputs": {k: v.uri for k, v in step.inputs.items()},
                        "outputs": {k: v.uri for k, v in step.outputs.items()},
                    },
                    "classes": self.STATUS_CLASS_MAPPING[step.status],
                }
            )

            for artifact_name, artifact in step.outputs.items():
                nodes.append(
                    {
                        "data": {
                            "id": self.ARTIFACT_PREFIX + str(artifact.id),
                            "execution_id": artifact.id,
                            "label": f"{artifact.id} / {artifact_name} ("
                            f"{artifact.data_type})",
                            "type": "artifact",
                            "name": artifact_name,
                            "is_cached": artifact.is_cached,
                            "artifact_type": artifact.type,
                            "artifact_data_type": artifact.data_type,
                            "parent_step_id": artifact.parent_step_id,
                            "producer_step_id": artifact.producer_step.id,
                            "uri": artifact.uri,
                        },
                        "classes": f"rectangle "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}",
                    }
                )
                edges.append(
                    {
                        "data": {
                            "source": self.STEP_PREFIX + str(step.id),
                            "target": self.ARTIFACT_PREFIX + str(artifact.id),
                        },
                        "classes": f"edge-arrow "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}"
                        + (" dashed" if artifact.is_cached else " solid"),
                    }
                )

            for artifact_name, artifact in step.inputs.items():
                edges.append(
                    {
                        "data": {
                            "source": self.ARTIFACT_PREFIX + str(artifact.id),
                            "target": self.STEP_PREFIX + str(step.id),
                        },
                        "classes": "edge-arrow "
                        + (
                            f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                            if artifact.is_cached
                            else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                        ),
                    }
                )

        app.layout = dbc.Row(
            [
                dbc.Container(f"Run: {object.name}", class_name="h1"),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        html.Span(
                                            [
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-circle-fill me-1"
                                                        ),
                                                        "Step",
                                                    ],
                                                    className="me-2",
                                                ),
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-square-fill me-1"
                                                        ),
                                                        "Artifact",
                                                    ],
                                                    className="me-4",
                                                ),
                                                dbc.Badge(
                                                    "Completed",
                                                    color=COLOR_BLUE,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Cached",
                                                    color=COLOR_GREEN,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Running",
                                                    color=COLOR_YELLOW,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Failed",
                                                    color=COLOR_RED,
                                                    className="me-1",
                                                ),
                                            ]
                                        ),
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        cyto.Cytoscape(
                                            id="cytoscape",
                                            layout={
                                                "name": "breadthfirst",
                                                "roots": f'[id = "{first_step_id}"]',
                                            },
                                            elements=edges + nodes,
                                            stylesheet=STYLESHEET,
                                            style={
                                                "width": "100%",
                                                "height": "800px",
                                            },
                                            zoom=1,
                                        )
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        dbc.Button(
                                            "Reset",
                                            id="bt-reset",
                                            color="primary",
                                            className="me-1",
                                        )
                                    ]
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                dcc.Markdown(id="markdown-selected-node-data"),
                            ]
                        ),
                    ]
                ),
            ],
            className="p-5",
        )

        @app.callback(  # type: ignore[misc]
            Output("markdown-selected-node-data", "children"),
            Input("cytoscape", "selectedNodeData"),
        )
        def display_data(data_list: List[Dict[str, Any]]) -> str:
            """Callback for the text area below the graph.

            Args:
                data_list: The selected node data.

            Returns:
                str: The selected node data.
            """
            if data_list is None:
                return "Click on a node in the diagram."

            text = ""
            for data in data_list:
                text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
                if data["type"] == "artifact":
                    for item in [
                        "artifact_data_type",
                        "is_cached",
                        "producer_step_id",
                        "parent_step_id",
                        "uri",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                elif data["type"] == "step":
                    text += "### Inputs:" + "\n\n"
                    for k, v in data["inputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Outputs:" + "\n\n"
                    for k, v in data["outputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Params:"
                    for k, v in data["parameters"].items():
                        text += f"**{k}**: {v}" + "\n\n"
            return text

        @app.callback(  # type: ignore[misc]
            [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
            [Input("bt-reset", "n_clicks")],
        )
        def reset_layout(
            n_clicks: int,
        ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
            """Resets the layout.

            Args:
                n_clicks: The number of clicks on the reset button.

            Returns:
                The zoom and the elements.
            """
            logger.debug(n_clicks, "clicked in reset button.")
            return [1, edges + nodes]

        if mode is not None:
            app.run_server(mode=mode)
        app.run_server()
        return app
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize pipeline runs via the Dash library.

The layout puts every layer of the dag in a column.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run to visualize.

required
magic bool

If True, the visualization is rendered in a magic mode.

False
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Dash

The Dash application.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
    self,
    object: PipelineRunView,
    magic: bool = False,
    *args: Any,
    **kwargs: Any,
) -> dash.Dash:
    """Method to visualize pipeline runs via the Dash library.

    The layout puts every layer of the dag in a column.

    Args:
        object: The pipeline run to visualize.
        magic: If True, the visualization is rendered in a magic mode.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The Dash application.
    """
    external_stylesheets = [
        dbc.themes.BOOTSTRAP,
        dbc.icons.BOOTSTRAP,
    ]
    if magic:
        if Environment.in_notebook:
            # Only import jupyter_dash in this case
            from jupyter_dash import JupyterDash  # noqa

            JupyterDash.infer_jupyter_proxy_config()

            app = JupyterDash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = "inline"
        else:
            cli_utils.warning(
                "Cannot set magic flag in non-notebook environments."
            )
    else:
        app = dash.Dash(
            __name__,
            external_stylesheets=external_stylesheets,
        )
        mode = None
    nodes, edges, first_step_id = [], [], None
    first_step_id = None
    for step in object.steps:
        step_output_artifacts = list(step.outputs.values())
        execution_id = (
            step_output_artifacts[0].producer_step.id
            if step_output_artifacts
            else step.id
        )
        step_id = self.STEP_PREFIX + str(step.id)
        if first_step_id is None:
            first_step_id = step_id
        nodes.append(
            {
                "data": {
                    "id": step_id,
                    "execution_id": execution_id,
                    "label": f"{execution_id} / {step.entrypoint_name}",
                    "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                    "name": step.name,  # redundant for consistency
                    "type": "step",
                    "parameters": step.parameters,
                    "inputs": {k: v.uri for k, v in step.inputs.items()},
                    "outputs": {k: v.uri for k, v in step.outputs.items()},
                },
                "classes": self.STATUS_CLASS_MAPPING[step.status],
            }
        )

        for artifact_name, artifact in step.outputs.items():
            nodes.append(
                {
                    "data": {
                        "id": self.ARTIFACT_PREFIX + str(artifact.id),
                        "execution_id": artifact.id,
                        "label": f"{artifact.id} / {artifact_name} ("
                        f"{artifact.data_type})",
                        "type": "artifact",
                        "name": artifact_name,
                        "is_cached": artifact.is_cached,
                        "artifact_type": artifact.type,
                        "artifact_data_type": artifact.data_type,
                        "parent_step_id": artifact.parent_step_id,
                        "producer_step_id": artifact.producer_step.id,
                        "uri": artifact.uri,
                    },
                    "classes": f"rectangle "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}",
                }
            )
            edges.append(
                {
                    "data": {
                        "source": self.STEP_PREFIX + str(step.id),
                        "target": self.ARTIFACT_PREFIX + str(artifact.id),
                    },
                    "classes": f"edge-arrow "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}"
                    + (" dashed" if artifact.is_cached else " solid"),
                }
            )

        for artifact_name, artifact in step.inputs.items():
            edges.append(
                {
                    "data": {
                        "source": self.ARTIFACT_PREFIX + str(artifact.id),
                        "target": self.STEP_PREFIX + str(step.id),
                    },
                    "classes": "edge-arrow "
                    + (
                        f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                        if artifact.is_cached
                        else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                    ),
                }
            )

    app.layout = dbc.Row(
        [
            dbc.Container(f"Run: {object.name}", class_name="h1"),
            dbc.Row(
                [
                    dbc.Col(
                        [
                            dbc.Row(
                                [
                                    html.Span(
                                        [
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-circle-fill me-1"
                                                    ),
                                                    "Step",
                                                ],
                                                className="me-2",
                                            ),
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-square-fill me-1"
                                                    ),
                                                    "Artifact",
                                                ],
                                                className="me-4",
                                            ),
                                            dbc.Badge(
                                                "Completed",
                                                color=COLOR_BLUE,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Cached",
                                                color=COLOR_GREEN,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Running",
                                                color=COLOR_YELLOW,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Failed",
                                                color=COLOR_RED,
                                                className="me-1",
                                            ),
                                        ]
                                    ),
                                ]
                            ),
                            dbc.Row(
                                [
                                    cyto.Cytoscape(
                                        id="cytoscape",
                                        layout={
                                            "name": "breadthfirst",
                                            "roots": f'[id = "{first_step_id}"]',
                                        },
                                        elements=edges + nodes,
                                        stylesheet=STYLESHEET,
                                        style={
                                            "width": "100%",
                                            "height": "800px",
                                        },
                                        zoom=1,
                                    )
                                ]
                            ),
                            dbc.Row(
                                [
                                    dbc.Button(
                                        "Reset",
                                        id="bt-reset",
                                        color="primary",
                                        className="me-1",
                                    )
                                ]
                            ),
                        ]
                    ),
                    dbc.Col(
                        [
                            dcc.Markdown(id="markdown-selected-node-data"),
                        ]
                    ),
                ]
            ),
        ],
        className="p-5",
    )

    @app.callback(  # type: ignore[misc]
        Output("markdown-selected-node-data", "children"),
        Input("cytoscape", "selectedNodeData"),
    )
    def display_data(data_list: List[Dict[str, Any]]) -> str:
        """Callback for the text area below the graph.

        Args:
            data_list: The selected node data.

        Returns:
            str: The selected node data.
        """
        if data_list is None:
            return "Click on a node in the diagram."

        text = ""
        for data in data_list:
            text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
            if data["type"] == "artifact":
                for item in [
                    "artifact_data_type",
                    "is_cached",
                    "producer_step_id",
                    "parent_step_id",
                    "uri",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
            elif data["type"] == "step":
                text += "### Inputs:" + "\n\n"
                for k, v in data["inputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Outputs:" + "\n\n"
                for k, v in data["outputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Params:"
                for k, v in data["parameters"].items():
                    text += f"**{k}**: {v}" + "\n\n"
        return text

    @app.callback(  # type: ignore[misc]
        [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
        [Input("bt-reset", "n_clicks")],
    )
    def reset_layout(
        n_clicks: int,
    ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
        """Resets the layout.

        Args:
            n_clicks: The number of clicks on the reset button.

        Returns:
            The zoom and the elements.
        """
        logger.debug(n_clicks, "clicked in reset button.")
        return [1, edges + nodes]

    if mode is not None:
        app.run_server(mode=mode)
    app.run_server()
    return app

evidently special

Initialization of the Evidently integration.

The Evidently integration provides a way to monitor your models in production. It includes a way to detect data drift and different kinds of model performance issues.

The results of Evidently calculations can either be exported as an interactive dashboard (visualized as an html file or in your Jupyter notebook), or as a JSON file.

EvidentlyIntegration (Integration)

Evidently integration for ZenML.

Source code in zenml/integrations/evidently/__init__.py
class EvidentlyIntegration(Integration):
    """[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML."""

    NAME = EVIDENTLY
    REQUIREMENTS = ["evidently==v0.1.41.dev0"]

steps special

Initialization of the Evidently Standard Steps.

evidently_profile

Implementation of the Evidently Profile Step.

EvidentlyProfileConfig (BaseDriftDetectionConfig) pydantic-model

Config class for Evidently profile steps.

column_mapping: properties of the DataFrame's columns used !!! profile_section "a string that identifies the profile section to be used." The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
    """Config class for Evidently profile steps.

    column_mapping: properties of the DataFrame's columns used
    profile_section: a string that identifies the profile section to be used.
        The following are valid options supported by Evidently:
        - "datadrift"
        - "categoricaltargetdrift"
        - "numericaltargetdrift"
        - "classificationmodelperformance"
        - "regressionmodelperformance"
        - "probabilisticmodelperformance"
    """

    def get_profile_sections_and_tabs(
        self,
    ) -> Tuple[List[ProfileSection], List[Tab]]:
        """Get the profile sections and tabs to be used in the dashboard.

        Returns:
            A tuple of two lists of profile sections and tabs.

        Raises:
            ValueError: if the profile_section is not supported.
        """
        try:
            return (
                [
                    profile_mapper[profile]()
                    for profile in self.profile_sections
                ],
                [
                    dashboard_mapper[profile]()
                    for profile in self.profile_sections
                ],
            )
        except KeyError:
            nl = "\n"
            raise ValueError(
                f"Invalid profile section: {self.profile_sections} \n\n"
                f"Valid and supported options are: {nl}- "
                f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
            )

    column_mapping: Optional[ColumnMapping]
    profile_sections: Sequence[str]
get_profile_sections_and_tabs(self)

Get the profile sections and tabs to be used in the dashboard.

Returns:

Type Description
Tuple[List[evidently.model_profile.sections.base_profile_section.ProfileSection], List[evidently.dashboard.tabs.base_tab.Tab]]

A tuple of two lists of profile sections and tabs.

Exceptions:

Type Description
ValueError

if the profile_section is not supported.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def get_profile_sections_and_tabs(
    self,
) -> Tuple[List[ProfileSection], List[Tab]]:
    """Get the profile sections and tabs to be used in the dashboard.

    Returns:
        A tuple of two lists of profile sections and tabs.

    Raises:
        ValueError: if the profile_section is not supported.
    """
    try:
        return (
            [
                profile_mapper[profile]()
                for profile in self.profile_sections
            ],
            [
                dashboard_mapper[profile]()
                for profile in self.profile_sections
            ],
        )
    except KeyError:
        nl = "\n"
        raise ValueError(
            f"Invalid profile section: {self.profile_sections} \n\n"
            f"Valid and supported options are: {nl}- "
            f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
        )
EvidentlyProfileStep (BaseDriftDetectionStep)

Step implementation implementing an Evidently Profile Step.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileStep(BaseDriftDetectionStep):
    """Step implementation implementing an Evidently Profile Step."""

    OUTPUT_SPEC = {
        "profile": DataAnalysisArtifact,
        "dashboard": DataAnalysisArtifact,
    }

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        comparison_dataset: pd.DataFrame,
        config: EvidentlyProfileConfig,
    ) -> Output(  # type:ignore[valid-type]
        profile=dict, dashboard=str
    ):
        """Main entrypoint for the Evidently categorical target drift detection step.

        Args:
            reference_dataset: a Pandas DataFrame
            comparison_dataset: a Pandas DataFrame of new data you wish to
                compare against the reference data
            config: the configuration for the step

        Returns:
            profile: dictionary report extracted from an Evidently Profile
              generated for the data drift
            dashboard: HTML report extracted from an Evidently Dashboard
              generated for the data drift
        """
        sections, tabs = config.get_profile_sections_and_tabs()
        data_drift_dashboard = Dashboard(tabs=tabs)
        data_drift_dashboard.calculate(
            reference_dataset,
            comparison_dataset,
            column_mapping=config.column_mapping or None,
        )
        data_drift_profile = Profile(sections=sections)
        data_drift_profile.calculate(
            reference_dataset,
            comparison_dataset,
            column_mapping=config.column_mapping or None,
        )
        return [data_drift_profile.object(), data_drift_dashboard.html()]
CONFIG_CLASS (BaseDriftDetectionConfig) pydantic-model

Config class for Evidently profile steps.

column_mapping: properties of the DataFrame's columns used !!! profile_section "a string that identifies the profile section to be used." The following are valid options supported by Evidently: - "datadrift" - "categoricaltargetdrift" - "numericaltargetdrift" - "classificationmodelperformance" - "regressionmodelperformance" - "probabilisticmodelperformance"

Source code in zenml/integrations/evidently/steps/evidently_profile.py
class EvidentlyProfileConfig(BaseDriftDetectionConfig):
    """Config class for Evidently profile steps.

    column_mapping: properties of the DataFrame's columns used
    profile_section: a string that identifies the profile section to be used.
        The following are valid options supported by Evidently:
        - "datadrift"
        - "categoricaltargetdrift"
        - "numericaltargetdrift"
        - "classificationmodelperformance"
        - "regressionmodelperformance"
        - "probabilisticmodelperformance"
    """

    def get_profile_sections_and_tabs(
        self,
    ) -> Tuple[List[ProfileSection], List[Tab]]:
        """Get the profile sections and tabs to be used in the dashboard.

        Returns:
            A tuple of two lists of profile sections and tabs.

        Raises:
            ValueError: if the profile_section is not supported.
        """
        try:
            return (
                [
                    profile_mapper[profile]()
                    for profile in self.profile_sections
                ],
                [
                    dashboard_mapper[profile]()
                    for profile in self.profile_sections
                ],
            )
        except KeyError:
            nl = "\n"
            raise ValueError(
                f"Invalid profile section: {self.profile_sections} \n\n"
                f"Valid and supported options are: {nl}- "
                f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
            )

    column_mapping: Optional[ColumnMapping]
    profile_sections: Sequence[str]
get_profile_sections_and_tabs(self)

Get the profile sections and tabs to be used in the dashboard.

Returns:

Type Description
Tuple[List[evidently.model_profile.sections.base_profile_section.ProfileSection], List[evidently.dashboard.tabs.base_tab.Tab]]

A tuple of two lists of profile sections and tabs.

Exceptions:

Type Description
ValueError

if the profile_section is not supported.

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def get_profile_sections_and_tabs(
    self,
) -> Tuple[List[ProfileSection], List[Tab]]:
    """Get the profile sections and tabs to be used in the dashboard.

    Returns:
        A tuple of two lists of profile sections and tabs.

    Raises:
        ValueError: if the profile_section is not supported.
    """
    try:
        return (
            [
                profile_mapper[profile]()
                for profile in self.profile_sections
            ],
            [
                dashboard_mapper[profile]()
                for profile in self.profile_sections
            ],
        )
    except KeyError:
        nl = "\n"
        raise ValueError(
            f"Invalid profile section: {self.profile_sections} \n\n"
            f"Valid and supported options are: {nl}- "
            f'{f"{nl}- ".join(list(profile_mapper.keys()))}'
        )
entrypoint(self, reference_dataset, comparison_dataset, config)

Main entrypoint for the Evidently categorical target drift detection step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

a Pandas DataFrame

required
comparison_dataset DataFrame

a Pandas DataFrame of new data you wish to compare against the reference data

required
config EvidentlyProfileConfig

the configuration for the step

required

Returns:

Type Description
profile

dictionary report extracted from an Evidently Profile generated for the data drift dashboard: HTML report extracted from an Evidently Dashboard generated for the data drift

Source code in zenml/integrations/evidently/steps/evidently_profile.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    comparison_dataset: pd.DataFrame,
    config: EvidentlyProfileConfig,
) -> Output(  # type:ignore[valid-type]
    profile=dict, dashboard=str
):
    """Main entrypoint for the Evidently categorical target drift detection step.

    Args:
        reference_dataset: a Pandas DataFrame
        comparison_dataset: a Pandas DataFrame of new data you wish to
            compare against the reference data
        config: the configuration for the step

    Returns:
        profile: dictionary report extracted from an Evidently Profile
          generated for the data drift
        dashboard: HTML report extracted from an Evidently Dashboard
          generated for the data drift
    """
    sections, tabs = config.get_profile_sections_and_tabs()
    data_drift_dashboard = Dashboard(tabs=tabs)
    data_drift_dashboard.calculate(
        reference_dataset,
        comparison_dataset,
        column_mapping=config.column_mapping or None,
    )
    data_drift_profile = Profile(sections=sections)
    data_drift_profile.calculate(
        reference_dataset,
        comparison_dataset,
        column_mapping=config.column_mapping or None,
    )
    return [data_drift_profile.object(), data_drift_dashboard.html()]

visualizers special

Initialization for Evidently visualizer.

evidently_visualizer

Implementation of the Evidently visualizer.

EvidentlyVisualizer (BaseStepVisualizer)

The implementation of an Evidently Visualizer.

Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
class EvidentlyVisualizer(BaseStepVisualizer):
    """The implementation of an Evidently Visualizer."""

    @abstractmethod
    def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
        """Method to visualize components.

        Args:
            object: StepView fetched from run.get_step().
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for artifact_view in object.outputs.values():
            # filter out anything but data analysis artifacts
            if (
                artifact_view.type == DataAnalysisArtifact.__name__
                and artifact_view.data_type == "builtins.str"
            ):
                artifact = artifact_view.read()
                self.generate_facet(artifact)

    def generate_facet(self, html_: str) -> None:
        """Generate a Facet Overview.

        Args:
            html_: HTML represented as a string.
        """
        if Environment.in_notebook() or Environment.in_google_colab():
            from IPython.core.display import HTML, display

            display(HTML(html_))
        else:
            logger.warning(
                "The magic functions are only usable in a Jupyter notebook."
            )
            with tempfile.NamedTemporaryFile(
                mode="w", delete=False, suffix=".html", encoding="utf-8"
            ) as f:
                f.write(html_)
                url = f"file:///{f.name}"
                logger.info("Opening %s in a new browser.." % f.name)
                webbrowser.open(url, new=2)
generate_facet(self, html_)

Generate a Facet Overview.

Parameters:

Name Type Description Default
html_ str

HTML represented as a string.

required
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
def generate_facet(self, html_: str) -> None:
    """Generate a Facet Overview.

    Args:
        html_: HTML represented as a string.
    """
    if Environment.in_notebook() or Environment.in_google_colab():
        from IPython.core.display import HTML, display

        display(HTML(html_))
    else:
        logger.warning(
            "The magic functions are only usable in a Jupyter notebook."
        )
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, suffix=".html", encoding="utf-8"
        ) as f:
            f.write(html_)
            url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
visualize(self, object, *args, **kwargs)

Method to visualize components.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/evidently/visualizers/evidently_visualizer.py
@abstractmethod
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
    """Method to visualize components.

    Args:
        object: StepView fetched from run.get_step().
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for artifact_view in object.outputs.values():
        # filter out anything but data analysis artifacts
        if (
            artifact_view.type == DataAnalysisArtifact.__name__
            and artifact_view.data_type == "builtins.str"
        ):
            artifact = artifact_view.read()
            self.generate_facet(artifact)

facets special

Facets integration for ZenML.

The Facets integration provides a simple way to visualize post-execution objects like PipelineView, PipelineRunView and StepView. These objects can be extended using the BaseVisualization class. This integration requires facets-overview be installed in your Python environment.

FacetsIntegration (Integration)

Definition of Facet integration for ZenML.

Source code in zenml/integrations/facets/__init__.py
class FacetsIntegration(Integration):
    """Definition of [Facet](https://pair-code.github.io/facets/) integration for ZenML."""

    NAME = FACETS
    REQUIREMENTS = ["facets-overview>=1.0.0", "IPython"]

visualizers special

Intitialization of the Facet Visualizer.

facet_statistics_visualizer

Implementation of the Facet Statistics Visualizer.

FacetStatisticsVisualizer (BaseStepVisualizer)

The base implementation of a ZenML Visualizer.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
class FacetStatisticsVisualizer(BaseStepVisualizer):
    """The base implementation of a ZenML Visualizer."""

    @abstractmethod
    def visualize(
        self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
    ) -> None:
        """Method to visualize components.

        Args:
            object: StepView fetched from run.get_step().
            magic: Whether to render in a Jupyter notebook or not.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        datasets = []
        for output_name, artifact_view in object.outputs.items():
            df = artifact_view.read()
            if type(df) is not pd.DataFrame:
                logger.warning(
                    "`%s` is not a pd.DataFrame. You can only visualize "
                    "statistics of steps that output pandas DataFrames. "
                    "Skipping this output.." % output_name
                )
            else:
                datasets.append({"name": output_name, "table": df})
        h = self.generate_html(datasets)
        self.generate_facet(h, magic)

    def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
        """Generates html for facet.

        Args:
            datasets: List of dicts of DataFrames to be visualized as stats.

        Returns:
            HTML template with proto string embedded.
        """
        proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
            datasets
        )
        protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

        template = os.path.join(
            os.path.abspath(os.path.dirname(__file__)),
            "stats.html",
        )
        html_template = io_utils.read_file_contents_as_string(template)

        html_ = html_template.replace("protostr", protostr)
        return html_

    def generate_facet(self, html_: str, magic: bool = False) -> None:
        """Generate a Facet Overview.

        Args:
            html_: HTML represented as a string.
            magic: Whether to magically materialize facet in a notebook.

        Raises:
            EnvironmentError: If magic is True and not in a notebook.
        """
        if magic:
            if not Environment.in_notebook() or Environment.in_google_colab():
                raise EnvironmentError(
                    "The magic functions are only usable in a Jupyter notebook."
                )
            display(HTML(html_))
        else:
            with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
                io_utils.write_file_contents_as_string(f.name, html_)
                url = f"file:///{f.name}"
                logger.info("Opening %s in a new browser.." % f.name)
                webbrowser.open(url, new=2)
generate_facet(self, html_, magic=False)

Generate a Facet Overview.

Parameters:

Name Type Description Default
html_ str

HTML represented as a string.

required
magic bool

Whether to magically materialize facet in a notebook.

False

Exceptions:

Type Description
EnvironmentError

If magic is True and not in a notebook.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_facet(self, html_: str, magic: bool = False) -> None:
    """Generate a Facet Overview.

    Args:
        html_: HTML represented as a string.
        magic: Whether to magically materialize facet in a notebook.

    Raises:
        EnvironmentError: If magic is True and not in a notebook.
    """
    if magic:
        if not Environment.in_notebook() or Environment.in_google_colab():
            raise EnvironmentError(
                "The magic functions are only usable in a Jupyter notebook."
            )
        display(HTML(html_))
    else:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            io_utils.write_file_contents_as_string(f.name, html_)
            url = f"file:///{f.name}"
            logger.info("Opening %s in a new browser.." % f.name)
            webbrowser.open(url, new=2)
generate_html(self, datasets)

Generates html for facet.

Parameters:

Name Type Description Default
datasets List[Dict[str, pandas.core.frame.DataFrame]]

List of dicts of DataFrames to be visualized as stats.

required

Returns:

Type Description
str

HTML template with proto string embedded.

Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
def generate_html(self, datasets: List[Dict[Text, pd.DataFrame]]) -> str:
    """Generates html for facet.

    Args:
        datasets: List of dicts of DataFrames to be visualized as stats.

    Returns:
        HTML template with proto string embedded.
    """
    proto = GenericFeatureStatisticsGenerator().ProtoFromDataFrames(
        datasets
    )
    protostr = base64.b64encode(proto.SerializeToString()).decode("utf-8")

    template = os.path.join(
        os.path.abspath(os.path.dirname(__file__)),
        "stats.html",
    )
    html_template = io_utils.read_file_contents_as_string(template)

    html_ = html_template.replace("protostr", protostr)
    return html_
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize components.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
magic bool

Whether to render in a Jupyter notebook or not.

False
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/facets/visualizers/facet_statistics_visualizer.py
@abstractmethod
def visualize(
    self, object: StepView, magic: bool = False, *args: Any, **kwargs: Any
) -> None:
    """Method to visualize components.

    Args:
        object: StepView fetched from run.get_step().
        magic: Whether to render in a Jupyter notebook or not.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    datasets = []
    for output_name, artifact_view in object.outputs.items():
        df = artifact_view.read()
        if type(df) is not pd.DataFrame:
            logger.warning(
                "`%s` is not a pd.DataFrame. You can only visualize "
                "statistics of steps that output pandas DataFrames. "
                "Skipping this output.." % output_name
            )
        else:
            datasets.append({"name": output_name, "table": df})
    h = self.generate_html(datasets)
    self.generate_facet(h, magic)

feast special

Initialization for Feast integration.

The Feast integration offers a way to connect to a Feast Feature Store. ZenML implements a dedicated stack component that you can access as part of your ZenML steps in the usual ways.

FeastIntegration (Integration)

Definition of Feast integration for ZenML.

Source code in zenml/integrations/feast/__init__.py
class FeastIntegration(Integration):
    """Definition of Feast integration for ZenML."""

    NAME = FEAST
    REQUIREMENTS = ["feast[redis]>=0.19.4", "redis-server"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Feast integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=FEAST_FEATURE_STORE_FLAVOR,
                source="zenml.integrations.feast.feature_store.FeastFeatureStore",
                type=StackComponentType.FEATURE_STORE,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Feast integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/feast/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Feast integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=FEAST_FEATURE_STORE_FLAVOR,
            source="zenml.integrations.feast.feature_store.FeastFeatureStore",
            type=StackComponentType.FEATURE_STORE,
            integration=cls.NAME,
        )
    ]

feature_stores special

Feast Feature Store integration for ZenML.

Feature stores allow data teams to serve data via an offline store and an online low-latency store where data is kept in sync between the two. It also offers a centralized registry where features (and feature schemas) are stored for use within a team or wider organization. Feature stores are a relatively recent addition to commonly-used machine learning stacks. Feast is a leading open-source feature store, first developed by Gojek in collaboration with Google.

feast_feature_store

Implementation of the Feast Feature Store for ZenML.

FeastFeatureStore (BaseFeatureStore) pydantic-model

Class to interact with the Feast feature store.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
class FeastFeatureStore(BaseFeatureStore):
    """Class to interact with the Feast feature store."""

    FLAVOR: ClassVar[str] = FEAST_FEATURE_STORE_FLAVOR

    online_host: str = "localhost"
    online_port: int = 6379
    feast_repo: str

    def _validate_connection(self) -> None:
        """Validates the connection to the feature store.

        Raises:
            ConnectionError: If the online component (Redis) is not available.
        """
        client = redis.Redis(host=self.online_host, port=self.online_port)
        try:
            client.ping()
        except redis.exceptions.ConnectionError as e:
            raise redis.exceptions.ConnectionError(
                "Could not connect to feature store's online component. "
                "Please make sure that Redis is running."
            ) from e

    def get_historical_features(
        self,
        entity_df: Union[pd.DataFrame, str],
        features: List[str],
        full_feature_names: bool = False,
    ) -> pd.DataFrame:
        """Returns the historical features for training or batch scoring.

        Args:
            entity_df: The entity DataFrame or entity name.
            features: The features to retrieve.
            full_feature_names: Whether to return the full feature names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The historical features as a Pandas DataFrame.
        """
        fs = FeatureStore(repo_path=self.feast_repo)

        return fs.get_historical_features(
            entity_df=entity_df,
            features=features,
            full_feature_names=full_feature_names,
        ).to_df()

    def get_online_features(
        self,
        entity_rows: List[Dict[str, Any]],
        features: List[str],
        full_feature_names: bool = False,
    ) -> Dict[str, Any]:
        """Returns the latest online feature data.

        Args:
            entity_rows: The entity rows to retrieve.
            features: The features to retrieve.
            full_feature_names: Whether to return the full feature names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The latest online feature data as a dictionary.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)

        return fs.get_online_features(  # type: ignore[no-any-return]
            entity_rows=entity_rows,
            features=features,
            full_feature_names=full_feature_names,
        ).to_dict()

    def get_data_sources(self) -> List[str]:
        """Returns the data sources' names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The data sources' names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_data_sources()]

    def get_entities(self) -> List[str]:
        """Returns the entity names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The entity names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_entities()]

    def get_feature_services(self) -> List[str]:
        """Returns the feature service names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The feature service names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_feature_services()]

    def get_feature_views(self) -> List[str]:
        """Returns the feature view names.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The feature view names.
        """
        self._validate_connection()
        fs = FeatureStore(repo_path=self.feast_repo)
        return [ds.name for ds in fs.list_feature_views()]

    def get_project(self) -> str:
        """Returns the project name.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The project name.
        """
        fs = FeatureStore(repo_path=self.feast_repo)
        return str(fs.project)

    def get_registry(self) -> Registry:
        """Returns the feature store registry.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The registry.
        """
        fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
        return fs.registry

    def get_feast_version(self) -> str:
        """Returns the version of Feast used.

        Raise:
            ConnectionError: If the online component (Redis) is not available.

        Returns:
            The version of Feast currently being used.
        """
        fs = FeatureStore(repo_path=self.feast_repo)
        return str(fs.version())
get_data_sources(self)

Returns the data sources' names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The data sources' names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_data_sources(self) -> List[str]:
    """Returns the data sources' names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The data sources' names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_data_sources()]
get_entities(self)

Returns the entity names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The entity names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_entities(self) -> List[str]:
    """Returns the entity names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The entity names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_entities()]
get_feast_version(self)

Returns the version of Feast used.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
str

The version of Feast currently being used.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feast_version(self) -> str:
    """Returns the version of Feast used.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The version of Feast currently being used.
    """
    fs = FeatureStore(repo_path=self.feast_repo)
    return str(fs.version())
get_feature_services(self)

Returns the feature service names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The feature service names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_services(self) -> List[str]:
    """Returns the feature service names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The feature service names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_feature_services()]
get_feature_views(self)

Returns the feature view names.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
List[str]

The feature view names.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_feature_views(self) -> List[str]:
    """Returns the feature view names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The feature view names.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)
    return [ds.name for ds in fs.list_feature_views()]
get_historical_features(self, entity_df, features, full_feature_names=False)

Returns the historical features for training or batch scoring.

Parameters:

Name Type Description Default
entity_df Union[pandas.core.frame.DataFrame, str]

The entity DataFrame or entity name.

required
features List[str]

The features to retrieve.

required
full_feature_names bool

Whether to return the full feature names.

False

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
DataFrame

The historical features as a Pandas DataFrame.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_historical_features(
    self,
    entity_df: Union[pd.DataFrame, str],
    features: List[str],
    full_feature_names: bool = False,
) -> pd.DataFrame:
    """Returns the historical features for training or batch scoring.

    Args:
        entity_df: The entity DataFrame or entity name.
        features: The features to retrieve.
        full_feature_names: Whether to return the full feature names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The historical features as a Pandas DataFrame.
    """
    fs = FeatureStore(repo_path=self.feast_repo)

    return fs.get_historical_features(
        entity_df=entity_df,
        features=features,
        full_feature_names=full_feature_names,
    ).to_df()
get_online_features(self, entity_rows, features, full_feature_names=False)

Returns the latest online feature data.

Parameters:

Name Type Description Default
entity_rows List[Dict[str, Any]]

The entity rows to retrieve.

required
features List[str]

The features to retrieve.

required
full_feature_names bool

Whether to return the full feature names.

False

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
Dict[str, Any]

The latest online feature data as a dictionary.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_online_features(
    self,
    entity_rows: List[Dict[str, Any]],
    features: List[str],
    full_feature_names: bool = False,
) -> Dict[str, Any]:
    """Returns the latest online feature data.

    Args:
        entity_rows: The entity rows to retrieve.
        features: The features to retrieve.
        full_feature_names: Whether to return the full feature names.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The latest online feature data as a dictionary.
    """
    self._validate_connection()
    fs = FeatureStore(repo_path=self.feast_repo)

    return fs.get_online_features(  # type: ignore[no-any-return]
        entity_rows=entity_rows,
        features=features,
        full_feature_names=full_feature_names,
    ).to_dict()
get_project(self)

Returns the project name.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
str

The project name.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_project(self) -> str:
    """Returns the project name.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The project name.
    """
    fs = FeatureStore(repo_path=self.feast_repo)
    return str(fs.project)
get_registry(self)

Returns the feature store registry.

Exceptions:

Type Description
ConnectionError

If the online component (Redis) is not available.

Returns:

Type Description
Registry

The registry.

Source code in zenml/integrations/feast/feature_stores/feast_feature_store.py
def get_registry(self) -> Registry:
    """Returns the feature store registry.

    Raise:
        ConnectionError: If the online component (Redis) is not available.

    Returns:
        The registry.
    """
    fs: FeatureStore = FeatureStore(repo_path=self.feast_repo)
    return fs.registry

gcp special

Initialization of the GCP ZenML integration.

The GCP integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, metadata stores, and an io module to handle file operations on Google Cloud Storage (GCS).

Additionally, the GCP secrets manager integration submodule provides a way to access the GCP secrets manager from within your ZenML Pipeline runs.

The Vertex AI integration submodule provides a way to run ZenML pipelines in a Vertex AI environment.

GcpIntegration (Integration)

Definition of Google Cloud Platform integration for ZenML.

Source code in zenml/integrations/gcp/__init__.py
class GcpIntegration(Integration):
    """Definition of Google Cloud Platform integration for ZenML."""

    NAME = GCP
    REQUIREMENTS = [
        "kfp",
        "gcsfs",
        "google-cloud-secret-manager",
        "google-cloud-aiplatform>=1.11.0",
    ]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the GCP integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GCP_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.gcp.artifact_stores"
                ".GCPArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.gcp.secrets_manager."
                "GCPSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.gcp.orchestrators"
                ".VertexOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.gcp.step_operators"
                ".VertexStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the GCP integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/gcp/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the GCP integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GCP_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.gcp.artifact_stores"
            ".GCPArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.gcp.secrets_manager."
            "GCPSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_VERTEX_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.gcp.orchestrators"
            ".VertexOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GCP_VERTEX_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.gcp.step_operators"
            ".VertexStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

artifact_stores special

Initialization of the GCP Artifact Store.

gcp_artifact_store

Implementation of the GCP Artifact Store.

GCPArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for Google Cloud Storage based artifacts.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Google Cloud Storage based artifacts."""

    _filesystem: Optional[gcsfs.GCSFileSystem] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = GCP_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX}

    @property
    def filesystem(self) -> gcsfs.GCSFileSystem:
        """The gcsfs filesystem to access this artifact store.

        Returns:
            The gcsfs filesystem to access this artifact store.
        """
        if not self._filesystem:
            secret = self.get_authentication_secret(
                expected_schema_type=GCPSecretSchema
            )
            token = secret.get_credential_dict() if secret else None
            self._filesystem = gcsfs.GCSFileSystem(token=token)

        return self._filesystem

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object that can be used to read or write to the file.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )
        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
          in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        return [
            f"{GCP_PATH_PREFIX}{path}"
            for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path of the directory to list.

        Returns:
            A list of paths of files in the directory.
        """
        path_without_prefix = convert_to_str(path)
        if path_without_prefix.startswith(GCP_PATH_PREFIX):
            path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a file info dict returned by GCP.

            Args:
                file_dict: A file info dict returned by the GCP filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path_without_prefix) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path of the directory to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path of the file to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: the path to get stat info for.

        Returns:
            A dictionary with the stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files
filesystem: GCSFileSystem property readonly

The gcsfs filesystem to access this artifact store.

Returns:

Type Description
GCSFileSystem

The gcsfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )
    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
      in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    return [
        f"{GCP_PATH_PREFIX}{path}"
        for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths of files in the directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path of the directory to list.

    Returns:
        A list of paths of files in the directory.
    """
    path_without_prefix = convert_to_str(path)
    if path_without_prefix.startswith(GCP_PATH_PREFIX):
        path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :]

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a file info dict returned by GCP.

        Args:
            file_dict: A file info dict returned by the GCP filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path_without_prefix) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to create.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path of the directory to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object that can be used to read or write to the file.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object that can be used to read or write to the file.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the file to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path of the file to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

the path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

A dictionary with the stat info.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: the path to get stat info for.

    Returns:
        A dictionary with the stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{GCP_PATH_PREFIX}{directory}", subdirectories, files

constants

Constants for the VertexAI integration.

google_credentials_mixin

Implementation of the Google credentials mixin.

GoogleCredentialsMixin (BaseModel) pydantic-model

Mixin for Google Cloud Platform credentials.

Attributes:

Name Type Description
service_account_path Optional[str]

path to the service account credentials file to be used for authentication. If not provided, the default credentials will be used.

Source code in zenml/integrations/gcp/google_credentials_mixin.py
class GoogleCredentialsMixin(BaseModel):
    """Mixin for Google Cloud Platform credentials.

    Attributes:
        service_account_path: path to the service account credentials file to be
            used for authentication. If not provided, the default credentials
            will be used.
    """

    service_account_path: Optional[str] = None

    def _get_authentication(self) -> Tuple["Credentials", str]:
        """Get GCP credentials and the project ID associated with the credentials.

        If `service_account_path` is provided, then the credentials will be
        loaded from the file at that path. Otherwise, the default credentials
        will be used.

        Returns:
            A tuple containing the credentials and the project ID associated to
            the credentials.
        """
        if self.service_account_path:
            credentials, project_id = load_credentials_from_file(
                self.service_account_path
            )
        else:
            credentials, project_id = default()
        return credentials, project_id

orchestrators special

Initialization for the VertexAI orchestrator.

vertex_entrypoint_configuration

Implementation of the VertexAI entrypoint configuration.

VertexEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on Vertex AI Pipelines.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
class VertexEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on Vertex AI Pipelines."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Vertex AI Pipelines specific entrypoint options.

        The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
        Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
        method.

        Returns:
            The set of custom entrypoint options.
        """
        return {VERTEX_JOB_ID_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: "BaseStep", *args: Any, **kwargs: Any
    ) -> List[str]:
        """Sets the value for the `VERTEX_JOB_ID_OPTION` argument.

        Args:
            step: The step to be executed.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            A list of arguments for the entrypoint.
        """
        return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the Vertex AI Pipeline job id.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The Vertex AI Pipeline job id.
        """
        job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
        return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Sets the value for the VERTEX_JOB_ID_OPTION argument.

Parameters:

Name Type Description Default
step BaseStep

The step to be executed.

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[str]

A list of arguments for the entrypoint.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
    """Sets the value for the `VERTEX_JOB_ID_OPTION` argument.

    Args:
        step: The step to be executed.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        A list of arguments for the entrypoint.
    """
    return [f"--{VERTEX_JOB_ID_OPTION}", kwargs[VERTEX_JOB_ID_OPTION]]
get_custom_entrypoint_options() classmethod

Vertex AI Pipelines specific entrypoint options.

The argument VERTEX_JOB_ID_OPTION allows to specify the job id of the Vertex AI Pipeline and get it in the execution of the step, via the get_run_name method.

Returns:

Type Description
Set[str]

The set of custom entrypoint options.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Vertex AI Pipelines specific entrypoint options.

    The argument `VERTEX_JOB_ID_OPTION` allows to specify the job id of the
    Vertex AI Pipeline and get it in the execution of the step, via the `get_run_name`
    method.

    Returns:
        The set of custom entrypoint options.
    """
    return {VERTEX_JOB_ID_OPTION}
get_run_name(self, pipeline_name)

Returns the Vertex AI Pipeline job id.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The Vertex AI Pipeline job id.

Source code in zenml/integrations/gcp/orchestrators/vertex_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the Vertex AI Pipeline job id.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The Vertex AI Pipeline job id.
    """
    job_id: str = self.entrypoint_args[VERTEX_JOB_ID_OPTION]
    return job_id
vertex_orchestrator

Implementation of the VertexAI orchestrator.

VertexOrchestrator (BaseOrchestrator, GoogleCredentialsMixin) pydantic-model

Orchestrator responsible for running pipelines on Vertex AI.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of the Docker image that should be used as the base for the image that will be used to execute each of the steps. If no custom base image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

project Optional[str]

GCP project name. If None, the project will be inferred from the environment.

location str

Name of GCP region where the pipeline job will be executed. Vertex AI Pipelines is available in the following regions: https://cloud.google.com/vertex-ai/docs/general/locations#feature -availability

pipeline_root Optional[str]

a Cloud Storage URI that will be used by the Vertex AI

encryption_spec_key_name Optional[str]

The Cloud KMS resource identifier of the

customer managed encryption key used to protect the job. Has the form

projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY> . The key needs to be in the same region as where the compute resource is created.

workload_service_account Optional[str]

the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If not provided, the default service account will be used.

network Optional[str]

the full name of the Compute Engine Network to which the job

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on Vertex AI Pipelines service.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
class VertexOrchestrator(BaseOrchestrator, GoogleCredentialsMixin):
    """Orchestrator responsible for running pipelines on Vertex AI.

    Attributes:
        custom_docker_base_image_name: Name of the Docker image that should be
            used as the base for the image that will be used to execute each of
            the steps. If no custom base image is given, a basic image of the
            active ZenML version will be used. **Note**: This image needs to
            have ZenML installed, otherwise the pipeline execution will fail.
            For that reason, you might want to extend the ZenML Docker images found
            here: https://hub.docker.com/r/zenmldocker/zenml/
        project: GCP project name. If `None`, the project will be inferred from
            the environment.
        location: Name of GCP region where the pipeline job will be executed.
            Vertex AI Pipelines is available in the following regions:
            https://cloud.google.com/vertex-ai/docs/general/locations#feature
            -availability
        pipeline_root: a Cloud Storage URI that will be used by the Vertex AI
        Pipelines.
            If not provided but the artifact store in the stack used to execute
            the pipeline is a
            `zenml.integrations.gcp.artifact_stores.GCPArtifactStore`,
            then a subdirectory of the artifact store will be used.
        encryption_spec_key_name: The Cloud KMS resource identifier of the
        customer
            managed encryption key used to protect the job. Has the form:
            `projects/<PRJCT>/locations/<REGION>/keyRings/<KR>/cryptoKeys/<KEY>`
            . The key needs to be in the same region as where the compute
            resource is created.
        workload_service_account: the service account for workload run-as
            account. Users submitting jobs must have act-as permission on this
            run-as account.
            If not provided, the default service account will be used.
        network: the full name of the Compute Engine Network to which the job
        should
            be peered. For example, `projects/12345/global/networks/myVPC`
            If not provided, the job will not be peered with any network.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on Vertex AI Pipelines
            service.
    """

    custom_docker_base_image_name: Optional[str] = None
    project: Optional[str] = None
    location: str
    pipeline_root: Optional[str] = None
    labels: Dict[str, str] = {}
    encryption_spec_key_name: Optional[str] = None
    workload_service_account: Optional[str] = None
    network: Optional[str] = None
    synchronous: bool = False

    _pipeline_root: str

    FLAVOR: ClassVar[str] = GCP_VERTEX_ORCHESTRATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Also validates that the artifact store and metadata store used are not
        local.

        Returns:
            A StackValidator instance.
        """

        def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that all the stack components are not local.

            Args:
                stack: The stack to validate.

            Returns:
                A tuple of (is_valid, error_message).
            """
            # Validate that the container registry is not local.
            container_registry = stack.container_registry
            if container_registry and container_registry.is_local:
                return False, (
                    f"The Vertex orchestrator does not support local "
                    f"container registries. You should replace the component '"
                    f"{container_registry.name}' "
                    f"{container_registry.TYPE.value} to a remote one."
                )

            # Validate that the rest of the components are not local.
            for stack_comp in stack.components.values():
                local_path = stack_comp.local_path
                if not local_path:
                    continue
                return False, (
                    f"The '{stack_comp.name}' {stack_comp.TYPE.value} is a "
                    f"local stack component. The Vertex AI Pipelines "
                    f"orchestrator requires that all the components in the "
                    f"stack used to execute the pipeline have to be not local, "
                    f"because there is no way for Vertex to connect to your "
                    f"local machine. You should use a flavor of "
                    f"{stack_comp.TYPE.value} other than '"
                    f"{stack_comp.FLAVOR}'."
                )

            # If the `pipeline_root` has not been defined in the orchestrator
            # configuration, and the artifact store is not a GCP artifact store,
            # then raise an error.
            if (
                not self.pipeline_root
                and stack.artifact_store.FLAVOR != GCP_ARTIFACT_STORE_FLAVOR
            ):
                return False, (
                    f"The attribute `pipeline_root` has not been set and it "
                    f"cannot be generated using the path of the artifact store "
                    f"because it is not a "
                    f"`zenml.integrations.gcp.artifact_store.GCPArtifactStore`."
                    f" To solve this issue, set the `pipeline_root` attribute "
                    f"manually executing the following command: "
                    f"`zenml orchestrator update {stack.orchestrator.name} "
                    f'--pipeline_root="<Cloud Storage URI>"`.'
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_stack_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The full docker image name including registry and tag.
        """
        base_image_name = f"zenml-vertex:{pipeline_name}"
        container_registry = Repository().active_stack.container_registry

        if container_registry:
            registry_uri = container_registry.uri.rstrip("/")
            return f"{registry_uri}/{base_image_name}"

        return base_image_name

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for files for this orchestrator.

        Returns:
            The path to the root directory for all files concerning this
            orchestrator.
        """
        return os.path.join(
            get_global_config_directory(), "vertex", str(self.uuid)
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to directory where kubeflow pipelines files are stored.

        Returns:
            Path to the pipeline directory.
        """
        return os.path.join(self.root_directory, "pipelines")

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Build a Docker image for the current environment.

        This uploads it to a container registry if configured.

        Args:
            pipeline: The pipeline to be deployed.
            stack: The stack that will be used to deploy the pipeline.
            runtime_configuration: The runtime configuration for the pipeline.

        Raises:
            RuntimeError: If the container registry is missing.
        """
        from zenml.utils import docker_utils

        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug(
            "Vertex AI Pipelines service docker container requirements %s",
            requirements,
        )

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )
        container_registry.push_image(image_name)

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: "Pb2Pipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates a KFP JSON pipeline.

        # noqa: DAR402

        This is an intermediary representation of the pipeline which is then
        deployed to Vertex AI Pipelines service.

        How it works:
        -------------
        Before this method is called the `prepare_pipeline_deployment()` method
        builds a Docker image that contains the code for the pipeline, all steps
        the context around these files.

        Based on this Docker image a callable is created which builds
        container_ops for each step (`_construct_kfp_pipeline`). The function
        `kfp.components.load_component_from_text` is used to create the
        `ContainerOp`, because using the `dsl.ContainerOp` class directly is
        deprecated when using the Kubeflow SDK v2. The step entrypoint command
        with the entrypoint arguments is the command that will be executed by
        the container created using the previously created Docker image.

        This callable is then compiled into a JSON file that is used as the
        intermediary representation of the Kubeflow pipeline.

        This file then is submitted to the Vertex AI Pipelines service for
        execution.

        Args:
            sorted_steps: List of sorted steps.
            pipeline: Zenml Pipeline instance.
            pb2_pipeline: Protobuf Pipeline instance.
            stack: The stack the pipeline was run on.
            runtime_configuration: The Runtime configuration of the current run.

        Raises:
            ValueError: If the attribute `pipeline_root` is not set and it
                can be not generated using the path of the artifact store in the
                stack because it is not a
                `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
        """
        # If the `pipeline_root` has not been defined in the orchestrator
        # configuration,
        # try to create it from the artifact store if it is a
        # `GCPArtifactStore`.
        if not self.pipeline_root:
            artifact_store = stack.artifact_store
            self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
            logger.info(
                "The attribute `pipeline_root` has not been set in the "
                "orchestrator configuration. One has been generated "
                "automatically based on the path of the `GCPArtifactStore` "
                "artifact store in the stack used to execute the pipeline. "
                "The generated `pipeline_root` is `%s`.",
                self._pipeline_root,
            )
        else:
            self._pipeline_root = self.pipeline_root

        # Build the Docker image that will be used to run the steps of the
        # pipeline.
        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        def _construct_kfp_pipeline() -> None:
            """Create a `ContainerOp` for each step.

            This should contain the name of the Docker image and configures the
            entrypoint of the Docker image to run the step.

            Additionally, this gives each `ContainerOp` information about its
            direct downstream steps.

            If this callable is passed to the `compile()` method of
            `KFPV2Compiler` all `dsl.ContainerOp` instances will be
            automatically added to a singular `dsl.Pipeline` instance.
            """
            step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

            for step in sorted_steps:
                # The command will be needed to eventually call the python step
                # within the docker container
                command = VertexEntrypointConfiguration.get_entrypoint_command()

                # The arguments are passed to configure the entrypoint of the
                # docker container when the step is called.
                arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
                )

                # Create the `ContainerOp` for the step. Using the
                # `dsl.ContainerOp`
                # class directly is deprecated when using the Kubeflow SDK v2.
                container_op = kfp.components.load_component_from_text(
                    f"""
                    name: {step.name}
                    implementation:
                        container:
                            image: {image_name}
                            command: {command + arguments}"""
                )()

                # Set upstream tasks as a dependency of the current step
                upstream_step_names = self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                )
                for upstream_step_name in upstream_step_names:
                    upstream_container_op = step_name_to_container_op[
                        upstream_step_name
                    ]
                    container_op.after(upstream_container_op)

                step_name_to_container_op[step.name] = container_op

        # Save the generated pipeline to a file.
        assert runtime_configuration.run_name
        fileio.makedirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory,
            f"{runtime_configuration.run_name}.json",
        )

        # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
        # to generate a JSON representation of the pipeline that can be later
        # upload to Vertex AI Pipelines service.
        logger.debug(
            "Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
            "to `%s`",
            pipeline_file_path,
        )
        KFPV2Compiler().compile(
            pipeline_func=_construct_kfp_pipeline,
            package_path=pipeline_file_path,
            pipeline_name=_clean_pipeline_name(pipeline.name),
        )

        # Using the Google Cloud AIPlatform client, upload and execute the
        # pipeline
        # on the Vertex AI Pipelines service.
        self._upload_and_run_pipeline(
            pipeline_name=pipeline.name,
            pipeline_file_path=pipeline_file_path,
            runtime_configuration=runtime_configuration,
            enable_cache=pipeline.enable_cache,
        )

    def _upload_and_run_pipeline(
        self,
        pipeline_name: str,
        pipeline_file_path: str,
        runtime_configuration: "RuntimeConfiguration",
        enable_cache: bool,
    ) -> None:
        """Uploads and run the pipeline on the Vertex AI Pipelines service.

        Args:
            pipeline_name: Name of the pipeline.
            pipeline_file_path: Path of the JSON file containing the compiled
                Kubeflow pipeline (compiled with Kubeflow SDK v2).
            runtime_configuration: Runtime configuration of the pipeline run.
            enable_cache: Whether caching is enabled for this pipeline run.
        """
        # We have to replace the hyphens in the pipeline name with underscores
        # and lower case the string, because the Vertex AI Pipelines service
        # requires this format.
        assert runtime_configuration.run_name
        job_id = _clean_pipeline_name(runtime_configuration.run_name)

        # Warn the user that the scheduling is not available using the Vertex
        # Orchestrator
        if runtime_configuration.schedule:
            logger.warning(
                "Pipeline scheduling configuration was provided, but Vertex "
                "AI Pipelines "
                "do not have capabilities for scheduling yet."
            )

        # Get the credentials that would be used to create the Vertex AI
        # Pipelines
        # job.
        credentials, project_id = self._get_authentication()
        if self.project and self.project != project_id:
            logger.warning(
                "Authenticated with project `%s`, but this orchestrator is "
                "configured to use the project `%s`.",
                project_id,
                self.project,
            )

        # If the project was set in the configuration, use it. Otherwise, use
        # the project that was used to authenticate.
        project_id = self.project if self.project else project_id

        # Instantiate the Vertex AI Pipelines job
        run = aiplatform.PipelineJob(
            display_name=pipeline_name,
            template_path=pipeline_file_path,
            job_id=job_id,
            pipeline_root=self._pipeline_root,
            parameter_values=None,
            enable_caching=enable_cache,
            encryption_spec_key_name=self.encryption_spec_key_name,
            labels=self.labels,
            credentials=credentials,
            project=self.project,
            location=self.location,
        )

        logger.info(
            "Submitting pipeline job with job_id `%s` to Vertex AI Pipelines "
            "service.",
            job_id,
        )

        # Submit the job to Vertex AI Pipelines service.
        try:
            if self.workload_service_account:
                logger.info(
                    "The Vertex AI Pipelines job workload will be executed "
                    "using `%s` "
                    "service account.",
                    self.workload_service_account,
                )

            if self.network:
                logger.info(
                    "The Vertex AI Pipelines job will be peered with `%s` "
                    "network.",
                    self.network,
                )

            run.submit(
                service_account=self.workload_service_account,
                network=self.network,
            )
            logger.info(
                "View the Vertex AI Pipelines job at %s", run._dashboard_uri()
            )

            if self.synchronous:
                logger.info(
                    "Waiting for the Vertex AI Pipelines job to finish..."
                )
                run.wait()

        except google_exceptions.ClientError as e:
            logger.warning(
                "Failed to create the Vertex AI Pipelines job: %s", e
            )

        except RuntimeError as e:
            logger.error(
                "The Vertex AI Pipelines job execution has failed: %s", e
            )
pipeline_directory: str property readonly

Returns path to directory where kubeflow pipelines files are stored.

Returns:

Type Description
str

Path to the pipeline directory.

root_directory: str property readonly

Returns path to the root directory for files for this orchestrator.

Returns:

Type Description
str

The path to the root directory for all files concerning this orchestrator.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Also validates that the artifact store and metadata store used are not local.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The full docker image name including registry and tag.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The full docker image name including registry and tag.
    """
    base_image_name = f"zenml-vertex:{pipeline_name}"
    container_registry = Repository().active_stack.container_registry

    if container_registry:
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/{base_image_name}"

    return base_image_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates a KFP JSON pipeline.

noqa: DAR402

This is an intermediary representation of the pipeline which is then deployed to Vertex AI Pipelines service.

How it works:

Before this method is called the prepare_pipeline_deployment() method builds a Docker image that contains the code for the pipeline, all steps the context around these files.

Based on this Docker image a callable is created which builds container_ops for each step (_construct_kfp_pipeline). The function kfp.components.load_component_from_text is used to create the ContainerOp, because using the dsl.ContainerOp class directly is deprecated when using the Kubeflow SDK v2. The step entrypoint command with the entrypoint arguments is the command that will be executed by the container created using the previously created Docker image.

This callable is then compiled into a JSON file that is used as the intermediary representation of the Kubeflow pipeline.

This file then is submitted to the Vertex AI Pipelines service for execution.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

List of sorted steps.

required
pipeline BasePipeline

Zenml Pipeline instance.

required
pb2_pipeline Pb2Pipeline

Protobuf Pipeline instance.

required
stack Stack

The stack the pipeline was run on.

required
runtime_configuration RuntimeConfiguration

The Runtime configuration of the current run.

required

Exceptions:

Type Description
ValueError

If the attribute pipeline_root is not set and it can be not generated using the path of the artifact store in the stack because it is not a zenml.integrations.gcp.artifact_store.GCPArtifactStore.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: "Pb2Pipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates a KFP JSON pipeline.

    # noqa: DAR402

    This is an intermediary representation of the pipeline which is then
    deployed to Vertex AI Pipelines service.

    How it works:
    -------------
    Before this method is called the `prepare_pipeline_deployment()` method
    builds a Docker image that contains the code for the pipeline, all steps
    the context around these files.

    Based on this Docker image a callable is created which builds
    container_ops for each step (`_construct_kfp_pipeline`). The function
    `kfp.components.load_component_from_text` is used to create the
    `ContainerOp`, because using the `dsl.ContainerOp` class directly is
    deprecated when using the Kubeflow SDK v2. The step entrypoint command
    with the entrypoint arguments is the command that will be executed by
    the container created using the previously created Docker image.

    This callable is then compiled into a JSON file that is used as the
    intermediary representation of the Kubeflow pipeline.

    This file then is submitted to the Vertex AI Pipelines service for
    execution.

    Args:
        sorted_steps: List of sorted steps.
        pipeline: Zenml Pipeline instance.
        pb2_pipeline: Protobuf Pipeline instance.
        stack: The stack the pipeline was run on.
        runtime_configuration: The Runtime configuration of the current run.

    Raises:
        ValueError: If the attribute `pipeline_root` is not set and it
            can be not generated using the path of the artifact store in the
            stack because it is not a
            `zenml.integrations.gcp.artifact_store.GCPArtifactStore`.
    """
    # If the `pipeline_root` has not been defined in the orchestrator
    # configuration,
    # try to create it from the artifact store if it is a
    # `GCPArtifactStore`.
    if not self.pipeline_root:
        artifact_store = stack.artifact_store
        self._pipeline_root = f"{artifact_store.path.rstrip('/')}/vertex_pipeline_root/{pipeline.name}/{runtime_configuration.run_name}"
        logger.info(
            "The attribute `pipeline_root` has not been set in the "
            "orchestrator configuration. One has been generated "
            "automatically based on the path of the `GCPArtifactStore` "
            "artifact store in the stack used to execute the pipeline. "
            "The generated `pipeline_root` is `%s`.",
            self._pipeline_root,
        )
    else:
        self._pipeline_root = self.pipeline_root

    # Build the Docker image that will be used to run the steps of the
    # pipeline.
    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    def _construct_kfp_pipeline() -> None:
        """Create a `ContainerOp` for each step.

        This should contain the name of the Docker image and configures the
        entrypoint of the Docker image to run the step.

        Additionally, this gives each `ContainerOp` information about its
        direct downstream steps.

        If this callable is passed to the `compile()` method of
        `KFPV2Compiler` all `dsl.ContainerOp` instances will be
        automatically added to a singular `dsl.Pipeline` instance.
        """
        step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

        for step in sorted_steps:
            # The command will be needed to eventually call the python step
            # within the docker container
            command = VertexEntrypointConfiguration.get_entrypoint_command()

            # The arguments are passed to configure the entrypoint of the
            # docker container when the step is called.
            arguments = VertexEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
                **{VERTEX_JOB_ID_OPTION: dslv2.PIPELINE_JOB_ID_PLACEHOLDER},
            )

            # Create the `ContainerOp` for the step. Using the
            # `dsl.ContainerOp`
            # class directly is deprecated when using the Kubeflow SDK v2.
            container_op = kfp.components.load_component_from_text(
                f"""
                name: {step.name}
                implementation:
                    container:
                        image: {image_name}
                        command: {command + arguments}"""
            )()

            # Set upstream tasks as a dependency of the current step
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                upstream_container_op = step_name_to_container_op[
                    upstream_step_name
                ]
                container_op.after(upstream_container_op)

            step_name_to_container_op[step.name] = container_op

    # Save the generated pipeline to a file.
    assert runtime_configuration.run_name
    fileio.makedirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory,
        f"{runtime_configuration.run_name}.json",
    )

    # Compile the pipeline using the Kubeflow SDK V2 compiler that allows
    # to generate a JSON representation of the pipeline that can be later
    # upload to Vertex AI Pipelines service.
    logger.debug(
        "Compiling pipeline using Kubeflow SDK V2 compiler and saving it "
        "to `%s`",
        pipeline_file_path,
    )
    KFPV2Compiler().compile(
        pipeline_func=_construct_kfp_pipeline,
        package_path=pipeline_file_path,
        pipeline_name=_clean_pipeline_name(pipeline.name),
    )

    # Using the Google Cloud AIPlatform client, upload and execute the
    # pipeline
    # on the Vertex AI Pipelines service.
    self._upload_and_run_pipeline(
        pipeline_name=pipeline.name,
        pipeline_file_path=pipeline_file_path,
        runtime_configuration=runtime_configuration,
        enable_cache=pipeline.enable_cache,
    )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Build a Docker image for the current environment.

This uploads it to a container registry if configured.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to be deployed.

required
stack Stack

The stack that will be used to deploy the pipeline.

required
runtime_configuration RuntimeConfiguration

The runtime configuration for the pipeline.

required

Exceptions:

Type Description
RuntimeError

If the container registry is missing.

Source code in zenml/integrations/gcp/orchestrators/vertex_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Build a Docker image for the current environment.

    This uploads it to a container registry if configured.

    Args:
        pipeline: The pipeline to be deployed.
        stack: The stack that will be used to deploy the pipeline.
        runtime_configuration: The runtime configuration for the pipeline.

    Raises:
        RuntimeError: If the container registry is missing.
    """
    from zenml.utils import docker_utils

    repo = Repository()
    container_registry = repo.active_stack.container_registry

    if not container_registry:
        raise RuntimeError("Missing container registry")

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug(
        "Vertex AI Pipelines service docker container requirements %s",
        requirements,
    )

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )
    container_registry.push_image(image_name)

secrets_manager special

ZenML integration for GCP Secrets Manager.

The GCP Secrets Manager allows your pipeline to directly access the GCP secrets manager and use the secrets within during runtime.

gcp_secrets_manager

Implementation of the GCP Secrets Manager.

GCPSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the GCP secrets manager.

Attributes:

Name Type Description
project_id str

This is necessary to access the correct GCP project. The project_id of your GCP project space that contains the Secret Manager.

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
class GCPSecretsManager(BaseSecretsManager):
    """Class to interact with the GCP secrets manager.

    Attributes:
        project_id:  This is necessary to access the correct GCP project.
                     The project_id of your GCP project space that contains
                     the Secret Manager.
    """

    project_id: str

    # Class configuration
    FLAVOR: ClassVar[str] = GCP_SECRETS_MANAGER_FLAVOR
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls) -> None:
        if cls.CLIENT is None:
            cls.CLIENT = secretmanager.SecretManagerServiceClient()

    @property
    def parent_name(self) -> str:
        """Construct the GCP parent path to the secret manager.

        Returns:
            The parent path to the secret manager
        """
        return f"projects/{self.project_id}"

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self._ensure_client_connected()

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists."
            )

        adjusted_content = prepend_group_name_to_keys(secret)
        for k, v in adjusted_content.items():
            # Create the secret, this only creates an empty secret with the
            #  supplied name.
            gcp_secret = self.CLIENT.create_secret(
                request={
                    "parent": self.parent_name,
                    "secret_id": k,
                    "secret": {
                        "replication": {"automatic": {}},
                        "labels": [
                            (ZENML_GROUP_KEY, secret.name),
                            (ZENML_SCHEMA_NAME, secret.TYPE),
                        ],
                    },
                }
            )

            logger.debug("Created empty secret: %s", gcp_secret.name)

            self.CLIENT.add_secret_version(
                request={
                    "parent": gcp_secret.name,
                    "payload": {"data": str(v).encode()},
                }
            )

            logger.debug("Added value to secret.")

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Get a secret by its name.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            RuntimeError: if the secret does not exist
        """
        self._ensure_client_connected()

        secret_contents = {}
        zenml_schema_name = ""

        # List all secrets.
        for secret in self.CLIENT.list_secrets(
            request={"parent": self.parent_name}
        ):
            if (
                ZENML_GROUP_KEY in secret.labels
                and secret_name == secret.labels[ZENML_GROUP_KEY]
            ):

                secret_version_name = secret.name + "/versions/latest"

                response = self.CLIENT.access_secret_version(
                    request={"name": secret_version_name}
                )

                secret_value = response.payload.data.decode("UTF-8")

                secret_key = remove_group_name_from_key(
                    secret.name.split("/")[-1], secret_name
                )

                secret_contents[secret_key] = secret_value

                zenml_schema_name = secret.labels[ZENML_SCHEMA_NAME]

        if not secret_contents:
            raise RuntimeError(f"No secrets found within the {secret_name}")

        secret_contents["name"] = secret_name

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=zenml_schema_name
        )
        return secret_schema(**secret_contents)

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        self._ensure_client_connected()

        set_of_secrets = set()

        # List all secrets.
        for secret in self.CLIENT.list_secrets(
            request={"parent": self.parent_name}
        ):
            if ZENML_GROUP_KEY in secret.labels:
                group_key = secret.labels[ZENML_GROUP_KEY]
                set_of_secrets.add(group_key)

        return list(set_of_secrets)

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret by creating new versions of the existing secrets.

        Args:
            secret: the secret to update
        """
        self._ensure_client_connected()

        adjusted_content = prepend_group_name_to_keys(secret)

        for k, v in adjusted_content.items():
            # Create the secret, this only creates an empty secret with the
            #  supplied name.
            version_parent = self.CLIENT.secret_path(self.project_id, k)
            payload = {"data": str(v).encode()}

            self.CLIENT.add_secret_version(
                request={"parent": version_parent, "payload": payload}
            )

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret by name.

        In GCP a secret is a single k-v
        pair. Within ZenML a secret is a collection of k-v pairs. As such,
        deleting a secret will iterate through all secrets and delete the ones
        with the secret_name as label.

        Args:
            secret_name: the name of the secret to delete
        """
        self._ensure_client_connected()

        # Go through all gcp secrets and delete the ones with the secret_name
        #  as label.
        for secret in self.CLIENT.list_secrets(
            request={"parent": self.parent_name}
        ):
            if (
                ZENML_GROUP_KEY in secret.labels
                and secret_name == secret.labels[ZENML_GROUP_KEY]
            ):
                self.CLIENT.delete_secret(request={"name": secret.name})

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_connected()

        # List all secrets.
        for secret in self.CLIENT.list_secrets(
            request={"parent": self.parent_name}
        ):
            if (
                ZENML_GROUP_KEY in secret.labels
                or ZENML_SCHEMA_NAME in secret.labels
            ):
                logger.info(
                    "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                    secret.name.split("/")[-1],
                    secret.labels[ZENML_GROUP_KEY],
                )
                self.CLIENT.delete_secret(request={"name": secret.name})
parent_name: str property readonly

Construct the GCP parent path to the secret manager.

Returns:

Type Description
str

The parent path to the secret manager

delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_connected()

    # List all secrets.
    for secret in self.CLIENT.list_secrets(
        request={"parent": self.parent_name}
    ):
        if (
            ZENML_GROUP_KEY in secret.labels
            or ZENML_SCHEMA_NAME in secret.labels
        ):
            logger.info(
                "Deleted key-value pair {`%s`, `***`} from secret " "`%s`",
                secret.name.split("/")[-1],
                secret.labels[ZENML_GROUP_KEY],
            )
            self.CLIENT.delete_secret(request={"name": secret.name})
delete_secret(self, secret_name)

Delete an existing secret by name.

In GCP a secret is a single k-v pair. Within ZenML a secret is a collection of k-v pairs. As such, deleting a secret will iterate through all secrets and delete the ones with the secret_name as label.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret by name.

    In GCP a secret is a single k-v
    pair. Within ZenML a secret is a collection of k-v pairs. As such,
    deleting a secret will iterate through all secrets and delete the ones
    with the secret_name as label.

    Args:
        secret_name: the name of the secret to delete
    """
    self._ensure_client_connected()

    # Go through all gcp secrets and delete the ones with the secret_name
    #  as label.
    for secret in self.CLIENT.list_secrets(
        request={"parent": self.parent_name}
    ):
        if (
            ZENML_GROUP_KEY in secret.labels
            and secret_name == secret.labels[ZENML_GROUP_KEY]
        ):
            self.CLIENT.delete_secret(request={"name": secret.name})
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    self._ensure_client_connected()

    set_of_secrets = set()

    # List all secrets.
    for secret in self.CLIENT.list_secrets(
        request={"parent": self.parent_name}
    ):
        if ZENML_GROUP_KEY in secret.labels:
            group_key = secret.labels[ZENML_GROUP_KEY]
            set_of_secrets.add(group_key)

    return list(set_of_secrets)
get_secret(self, secret_name)

Get a secret by its name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
RuntimeError

if the secret does not exist

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Get a secret by its name.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        RuntimeError: if the secret does not exist
    """
    self._ensure_client_connected()

    secret_contents = {}
    zenml_schema_name = ""

    # List all secrets.
    for secret in self.CLIENT.list_secrets(
        request={"parent": self.parent_name}
    ):
        if (
            ZENML_GROUP_KEY in secret.labels
            and secret_name == secret.labels[ZENML_GROUP_KEY]
        ):

            secret_version_name = secret.name + "/versions/latest"

            response = self.CLIENT.access_secret_version(
                request={"name": secret_version_name}
            )

            secret_value = response.payload.data.decode("UTF-8")

            secret_key = remove_group_name_from_key(
                secret.name.split("/")[-1], secret_name
            )

            secret_contents[secret_key] = secret_value

            zenml_schema_name = secret.labels[ZENML_SCHEMA_NAME]

    if not secret_contents:
        raise RuntimeError(f"No secrets found within the {secret_name}")

    secret_contents["name"] = secret_name

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=zenml_schema_name
    )
    return secret_schema(**secret_contents)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self._ensure_client_connected()

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists."
        )

    adjusted_content = prepend_group_name_to_keys(secret)
    for k, v in adjusted_content.items():
        # Create the secret, this only creates an empty secret with the
        #  supplied name.
        gcp_secret = self.CLIENT.create_secret(
            request={
                "parent": self.parent_name,
                "secret_id": k,
                "secret": {
                    "replication": {"automatic": {}},
                    "labels": [
                        (ZENML_GROUP_KEY, secret.name),
                        (ZENML_SCHEMA_NAME, secret.TYPE),
                    ],
                },
            }
        )

        logger.debug("Created empty secret: %s", gcp_secret.name)

        self.CLIENT.add_secret_version(
            request={
                "parent": gcp_secret.name,
                "payload": {"data": str(v).encode()},
            }
        )

        logger.debug("Added value to secret.")
update_secret(self, secret)

Update an existing secret by creating new versions of the existing secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required
Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret by creating new versions of the existing secrets.

    Args:
        secret: the secret to update
    """
    self._ensure_client_connected()

    adjusted_content = prepend_group_name_to_keys(secret)

    for k, v in adjusted_content.items():
        # Create the secret, this only creates an empty secret with the
        #  supplied name.
        version_parent = self.CLIENT.secret_path(self.project_id, k)
        payload = {"data": str(v).encode()}

        self.CLIENT.add_secret_version(
            request={"parent": version_parent, "payload": payload}
        )
prepend_group_name_to_keys(secret)

Adds the secret group name to the keys of each secret key-value pair.

This allows using the same key across multiple secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The ZenML Secret schema

required

Returns:

Type Description
Dict[str, str]

A dictionary with the keys prepended with the group name

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def prepend_group_name_to_keys(secret: BaseSecretSchema) -> Dict[str, str]:
    """Adds the secret group name to the keys of each secret key-value pair.

    This allows using the same key across multiple secrets.

    Args:
        secret: The ZenML Secret schema

    Returns:
        A dictionary with the keys prepended with the group name
    """
    return {f"{secret.name}_{k}": v for k, v in secret.content.items()}
remove_group_name_from_key(combined_key_name, group_name)

Removes the secret group name from the secret key.

Parameters:

Name Type Description Default
combined_key_name str

Full name as it is within the gcp secrets manager

required
group_name str

Group name (the ZenML Secret name)

required

Returns:

Type Description
str

The cleaned key

Exceptions:

Type Description
RuntimeError

If the group name is not found in the key

Source code in zenml/integrations/gcp/secrets_manager/gcp_secrets_manager.py
def remove_group_name_from_key(combined_key_name: str, group_name: str) -> str:
    """Removes the secret group name from the secret key.

    Args:
        combined_key_name: Full name as it is within the gcp secrets manager
        group_name: Group name (the ZenML Secret name)

    Returns:
        The cleaned key

    Raises:
        RuntimeError: If the group name is not found in the key
    """
    if combined_key_name.startswith(group_name + "_"):
        return combined_key_name[len(group_name + "_") :]
    else:
        raise RuntimeError(
            f"Key-name `{combined_key_name}` does not have the "
            f"prefix `{group_name}`. Key could not be "
            f"extracted."
        )

step_operators special

Initialization for the VertexAI Step Operator.

vertex_step_operator

Implementation of a VertexAI step operator.

Code heavily inspired by TFX Implementation: https://github.com/tensorflow/tfx/blob/master/tfx/extensions/ google_cloud_ai_platform/training_clients.py

VertexStepOperator (BaseStepOperator, GoogleCredentialsMixin) pydantic-model

Step operator to run a step on Vertex AI.

This class defines code that can set up a Vertex AI environment and run the ZenML entrypoint command in it.

Attributes:

Name Type Description
region str

Region name, e.g., europe-west1.

project Optional[str]

[Optional] GCP project name. If left None, inferred from the environment.

accelerator_type Optional[str]

[Optional] Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType

accelerator_count int

[Optional] Defines number of accelerators to be used for the job.

machine_type str

[Optional] Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types

base_image Optional[str]

[Optional] Base image for building the custom job container.

encryption_spec_key_name Optional[str]

[Optional]: Encryption spec key name.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
class VertexStepOperator(BaseStepOperator, GoogleCredentialsMixin):
    """Step operator to run a step on Vertex AI.

    This class defines code that can set up a Vertex AI environment and run the
    ZenML entrypoint command in it.

    Attributes:
        region: Region name, e.g., `europe-west1`.
        project: [Optional] GCP project name. If left None, inferred from the
            environment.
        accelerator_type: [Optional] Accelerator type from list: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#AcceleratorType
        accelerator_count: [Optional] Defines number of accelerators to be
            used for the job.
        machine_type: [Optional] Machine type specified here: https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types
        base_image: [Optional] Base image for building the custom job container.
        encryption_spec_key_name: [Optional]: Encryption spec key name.
    """

    region: str
    project: Optional[str] = None
    accelerator_type: Optional[str] = None
    accelerator_count: int = 0
    machine_type: str = "n1-standard-4"
    base_image: Optional[str] = None

    # customer managed encryption key resource name
    # will be applied to all Vertex AI resources if set
    encryption_spec_key_name: Optional[str] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GCP_VERTEX_STEP_OPERATOR_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Returns:
            StackValidator: Validator for the stack.
        """

        def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
            # For now this only works on local orchestrator and GCP artifact
            #  store
            return (
                (
                    stack.orchestrator.FLAVOR == "local"
                    and stack.artifact_store.FLAVOR == "gcp"
                ),
                "Only local orchestrator and GCP artifact store are currently "
                "supported",
            )

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_ensure_local_orchestrator,
        )

    @property_validator("accelerator_type")
    def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
        """Validates that the accelerator type is valid.

        Args:
            accelerator_type: Accelerator type

        Raises:
            ValueError: If the accelerator type is not valid.
        """
        accepted_vals = list(
            aiplatform.gapic.AcceleratorType.__members__.keys()
        )
        if accelerator_type and accelerator_type.upper() not in accepted_vals:
            raise ValueError(
                f"Accelerator must be one of the following: {accepted_vals}"
            )

    def _build_and_push_docker_image(
        self,
        pipeline_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> str:
        """Builds and pushes a docker image.

        Args:
            pipeline_name: Pipeline name
            requirements: Requirements
            entrypoint_command: Entrypoint command

        Returns:
            Docker image name

        Raises:
            RuntimeError: If no container registry is found in the stack.
        """
        repo = Repository()
        container_registry = repo.active_stack.container_registry

        if not container_registry:
            raise RuntimeError("Missing container registry")

        registry_uri = container_registry.uri.rstrip("/")
        image_name = f"{registry_uri}/zenml-vertex:{pipeline_name}"

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            entrypoint=" ".join(entrypoint_command),
            requirements=set(requirements),
            base_image=self.base_image,
        )
        container_registry.push_image(image_name)
        return docker_utils.get_image_digest(image_name) or image_name

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> None:
        """Launches a step on Vertex AI.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.

        Raises:
            RuntimeError: If the run fails.
            ConnectionError: If the run fails due to a connection error.
        """
        job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

        # Step 1: Authenticate with Google
        credentials, project_id = self._get_authentication()
        if self.project:
            if self.project != project_id:
                logger.warning(
                    "Authenticated with project `%s`, but this orchestrator is "
                    "configured to use the project `%s`.",
                    project_id,
                    self.project,
                )
        else:
            self.project = project_id

        # Step 2: Build and push image
        image_name = self._build_and_push_docker_image(
            pipeline_name=pipeline_name,
            requirements=requirements,
            entrypoint_command=entrypoint_command,
        )

        # Step 3: Launch the job
        # The AI Platform services require regional API endpoints.
        client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        client = aiplatform.gapic.JobServiceClient(
            credentials=credentials, client_options=client_options
        )
        custom_job = {
            "display_name": run_name,
            "job_spec": {
                "worker_pool_specs": [
                    {
                        "machine_spec": {
                            "machine_type": self.machine_type,
                            "accelerator_type": self.accelerator_type,
                            "accelerator_count": self.accelerator_count
                            if self.accelerator_type
                            else 0,
                        },
                        "replica_count": 1,
                        "container_spec": {
                            "image_uri": image_name,
                            "command": [],
                            "args": [],
                        },
                    }
                ]
            },
            "labels": job_labels,
            "encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
            if self.encryption_spec_key_name
            else {},
        }
        logger.debug("Vertex AI Job=%s", custom_job)

        parent = f"projects/{self.project}/locations/{self.region}"
        logger.info(
            "Submitting custom job='%s', path='%s' to Vertex AI Training.",
            custom_job["display_name"],
            parent,
        )
        response = client.create_custom_job(
            parent=parent, custom_job=custom_job
        )
        logger.debug("Vertex AI response:", response)

        # Step 4: Monitor the job

        # Monitors the long-running operation by polling the job state
        # periodically, and retries the polling when a transient connectivity
        # issue is encountered.
        #
        # Long-running operation monitoring:
        #   The possible states of "get job" response can be found at
        #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
        #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
        #   The following logic will keep polling the state of the job until
        #   the job enters a final state.
        #
        # During the polling, if a connection error was encountered, the GET
        # request will be retried by recreating the Python API client to
        # refresh the lifecycle of the connection being used. See
        # https://github.com/googleapis/google-api-python-client/issues/218
        # for a detailed description of the problem. If the error persists for
        # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
        # will raise ConnectionError.
        retry_count = 0
        job_id = response.name

        while response.state not in VERTEX_JOB_STATES_COMPLETED:
            time.sleep(POLLING_INTERVAL_IN_SECONDS)
            try:
                response = client.get_custom_job(name=job_id)
                retry_count = 0
            # Handle transient connection error.
            except ConnectionError as err:
                if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                    retry_count += 1
                    logger.warning(
                        "ConnectionError (%s) encountered when polling job: "
                        "%s. Trying to recreate the API client.",
                        err,
                        job_id,
                    )
                    # Recreate the Python API client.
                    client = aiplatform.gapic.JobServiceClient(
                        client_options=client_options
                    )
                else:
                    logger.error(
                        "Request failed after %s retries.",
                        CONNECTION_ERROR_RETRY_LIMIT,
                    )
                    raise

            if response.state in VERTEX_JOB_STATES_FAILED:
                err_msg = (
                    "Job '{}' did not succeed.  Detailed response {}.".format(
                        job_id, response
                    )
                )
                logger.error(err_msg)
                raise RuntimeError(err_msg)

        # Cloud training complete
        logger.info("Job '%s' successful.", job_id)
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Returns:

Type Description
StackValidator

Validator for the stack.

launch(self, pipeline_name, run_name, requirements, entrypoint_command)

Launches a step on Vertex AI.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required

Exceptions:

Type Description
RuntimeError

If the run fails.

ConnectionError

If the run fails due to a connection error.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
) -> None:
    """Launches a step on Vertex AI.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.

    Raises:
        RuntimeError: If the run fails.
        ConnectionError: If the run fails due to a connection error.
    """
    job_labels = {"source": f"zenml-{__version__.replace('.', '_')}"}

    # Step 1: Authenticate with Google
    credentials, project_id = self._get_authentication()
    if self.project:
        if self.project != project_id:
            logger.warning(
                "Authenticated with project `%s`, but this orchestrator is "
                "configured to use the project `%s`.",
                project_id,
                self.project,
            )
    else:
        self.project = project_id

    # Step 2: Build and push image
    image_name = self._build_and_push_docker_image(
        pipeline_name=pipeline_name,
        requirements=requirements,
        entrypoint_command=entrypoint_command,
    )

    # Step 3: Launch the job
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": self.region + VERTEX_ENDPOINT_SUFFIX}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.JobServiceClient(
        credentials=credentials, client_options=client_options
    )
    custom_job = {
        "display_name": run_name,
        "job_spec": {
            "worker_pool_specs": [
                {
                    "machine_spec": {
                        "machine_type": self.machine_type,
                        "accelerator_type": self.accelerator_type,
                        "accelerator_count": self.accelerator_count
                        if self.accelerator_type
                        else 0,
                    },
                    "replica_count": 1,
                    "container_spec": {
                        "image_uri": image_name,
                        "command": [],
                        "args": [],
                    },
                }
            ]
        },
        "labels": job_labels,
        "encryption_spec": {"kmsKeyName": self.encryption_spec_key_name}
        if self.encryption_spec_key_name
        else {},
    }
    logger.debug("Vertex AI Job=%s", custom_job)

    parent = f"projects/{self.project}/locations/{self.region}"
    logger.info(
        "Submitting custom job='%s', path='%s' to Vertex AI Training.",
        custom_job["display_name"],
        parent,
    )
    response = client.create_custom_job(
        parent=parent, custom_job=custom_job
    )
    logger.debug("Vertex AI response:", response)

    # Step 4: Monitor the job

    # Monitors the long-running operation by polling the job state
    # periodically, and retries the polling when a transient connectivity
    # issue is encountered.
    #
    # Long-running operation monitoring:
    #   The possible states of "get job" response can be found at
    #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
    #   where SUCCEEDED/FAILED/CANCELED are considered to be final states.
    #   The following logic will keep polling the state of the job until
    #   the job enters a final state.
    #
    # During the polling, if a connection error was encountered, the GET
    # request will be retried by recreating the Python API client to
    # refresh the lifecycle of the connection being used. See
    # https://github.com/googleapis/google-api-python-client/issues/218
    # for a detailed description of the problem. If the error persists for
    # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function
    # will raise ConnectionError.
    retry_count = 0
    job_id = response.name

    while response.state not in VERTEX_JOB_STATES_COMPLETED:
        time.sleep(POLLING_INTERVAL_IN_SECONDS)
        try:
            response = client.get_custom_job(name=job_id)
            retry_count = 0
        # Handle transient connection error.
        except ConnectionError as err:
            if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
                retry_count += 1
                logger.warning(
                    "ConnectionError (%s) encountered when polling job: "
                    "%s. Trying to recreate the API client.",
                    err,
                    job_id,
                )
                # Recreate the Python API client.
                client = aiplatform.gapic.JobServiceClient(
                    client_options=client_options
                )
            else:
                logger.error(
                    "Request failed after %s retries.",
                    CONNECTION_ERROR_RETRY_LIMIT,
                )
                raise

        if response.state in VERTEX_JOB_STATES_FAILED:
            err_msg = (
                "Job '{}' did not succeed.  Detailed response {}.".format(
                    job_id, response
                )
            )
            logger.error(err_msg)
            raise RuntimeError(err_msg)

    # Cloud training complete
    logger.info("Job '%s' successful.", job_id)
validate_accelerator_enum(accelerator_type) classmethod

Validates that the accelerator type is valid.

Parameters:

Name Type Description Default
accelerator_type Optional[str]

Accelerator type

required

Exceptions:

Type Description
ValueError

If the accelerator type is not valid.

Source code in zenml/integrations/gcp/step_operators/vertex_step_operator.py
@property_validator("accelerator_type")
def validate_accelerator_enum(cls, accelerator_type: Optional[str]) -> None:
    """Validates that the accelerator type is valid.

    Args:
        accelerator_type: Accelerator type

    Raises:
        ValueError: If the accelerator type is not valid.
    """
    accepted_vals = list(
        aiplatform.gapic.AcceleratorType.__members__.keys()
    )
    if accelerator_type and accelerator_type.upper() not in accepted_vals:
        raise ValueError(
            f"Accelerator must be one of the following: {accepted_vals}"
        )

github special

Initialization of the GitHub ZenML integration.

The GitHub integration provides a way to orchestrate pipelines using GitHub Actions.

GitHubIntegration (Integration)

Definition of GitHub integration for ZenML.

Source code in zenml/integrations/github/__init__.py
class GitHubIntegration(Integration):
    """Definition of GitHub integration for ZenML."""

    NAME = GITHUB
    REQUIREMENTS: List[str] = ["PyNaCl~=1.5.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the GitHub integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GITHUB_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=GITHUB_SECRET_MANAGER_FLAVOR,
                source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the GitHub integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/github/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the GitHub integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GITHUB_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.github.orchestrators.GitHubActionsOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=GITHUB_SECRET_MANAGER_FLAVOR,
            source="zenml.integrations.github.secrets_managers.GitHubSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
    ]

orchestrators special

Initialization of the GitHub Actions Orchestrator.

github_actions_entrypoint_configuration

Implementation of the GitHub Actions Orchestrator entrypoint.

GitHubActionsEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on GitHub Actions runners.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
class GitHubActionsEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on GitHub Actions runners."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """GitHub Actions specific entrypoint options.

        Returns:
            Set with the custom run id option.
        """
        return {RUN_ID_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: BaseStep, **kwargs: Any
    ) -> List[str]:
        """Adds a run id argument for the entrypoint.

        Args:
            step: Step for which the arguments are passed.
            **kwargs: Additional keyword arguments.

        Returns:
            GitHub Actions placeholder for the run id option.
        """
        # These placeholders in the workflow file will be replaced with
        # concrete values by the GitHub Actions runner
        run_id = (
            "${{ github.run_id }}_${{ github.run_number }}_"
            "${{ github.run_attempt }}"
        )
        return [f"--{RUN_ID_OPTION}", run_id]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the pipeline run name.

        Args:
            pipeline_name: Name of the pipeline which will run.

        Returns:
            The run name.
        """
        run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
        return f"{pipeline_name}-{run_id}"
get_custom_entrypoint_arguments(step, **kwargs) classmethod

Adds a run id argument for the entrypoint.

Parameters:

Name Type Description Default
step BaseStep

Step for which the arguments are passed.

required
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[str]

GitHub Actions placeholder for the run id option.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: BaseStep, **kwargs: Any
) -> List[str]:
    """Adds a run id argument for the entrypoint.

    Args:
        step: Step for which the arguments are passed.
        **kwargs: Additional keyword arguments.

    Returns:
        GitHub Actions placeholder for the run id option.
    """
    # These placeholders in the workflow file will be replaced with
    # concrete values by the GitHub Actions runner
    run_id = (
        "${{ github.run_id }}_${{ github.run_number }}_"
        "${{ github.run_attempt }}"
    )
    return [f"--{RUN_ID_OPTION}", run_id]
get_custom_entrypoint_options() classmethod

GitHub Actions specific entrypoint options.

Returns:

Type Description
Set[str]

Set with the custom run id option.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """GitHub Actions specific entrypoint options.

    Returns:
        Set with the custom run id option.
    """
    return {RUN_ID_OPTION}
get_run_name(self, pipeline_name)

Returns the pipeline run name.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which will run.

required

Returns:

Type Description
str

The run name.

Source code in zenml/integrations/github/orchestrators/github_actions_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the pipeline run name.

    Args:
        pipeline_name: Name of the pipeline which will run.

    Returns:
        The run name.
    """
    run_id = cast(str, self.entrypoint_args[RUN_ID_OPTION])
    return f"{pipeline_name}-{run_id}"
github_actions_orchestrator

Implementation of the GitHub Actions Orchestrator.

GitHubActionsOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using GitHub Actions.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a docker image that should be used as the base for the image that will be run on GitHub Action runners. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

skip_dirty_repository_check bool

If True, this orchestrator will not raise an exception when trying to run a pipeline while there are still untracked/uncommitted files in the git repository.

skip_github_repository_check bool

If True, the orchestrator will not check if your git repository is pointing to a GitHub remote.

push bool

If True, this orchestrator will automatically commit and push the GitHub workflow file when running a pipeline. If False, the workflow file will be written to the correct location but needs to be committed and pushed manually.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
class GitHubActionsOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using GitHub Actions.

    Attributes:
        custom_docker_base_image_name: Name of a docker image that should be
            used as the base for the image that will be run on GitHub Action
            runners. If no custom image is given, a basic image of the active
            ZenML version will be used. **Note**: This image needs to have
            ZenML installed, otherwise the pipeline execution will fail. For
            that reason, you might want to extend the ZenML docker images
            found here: https://hub.docker.com/r/zenmldocker/zenml/
        skip_dirty_repository_check: If `True`, this orchestrator will not
            raise an exception when trying to run a pipeline while there are
            still untracked/uncommitted files in the git repository.
        skip_github_repository_check: If `True`, the orchestrator will not check
            if your git repository is pointing to a GitHub remote.
        push: If `True`, this orchestrator will automatically commit and push
            the GitHub workflow file when running a pipeline. If `False`, the
            workflow file will be written to the correct location but needs to
            be committed and pushed manually.
    """

    custom_docker_base_image_name: Optional[str] = None
    skip_dirty_repository_check: bool = False
    skip_github_repository_check: bool = False
    push: bool = False

    _git_repo: Optional[Repo] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GITHUB_ORCHESTRATOR_FLAVOR

    @property
    def git_repo(self) -> Repo:
        """Returns the git repository for the current working directory.

        Returns:
            Git repository for the current working directory.

        Raises:
            RuntimeError: If there is no git repository for the current working
                directory or the repository remote is not pointing to GitHub.
        """
        if not self._git_repo:
            try:
                self._git_repo = Repo(search_parent_directories=True)
            except InvalidGitRepositoryError:
                raise RuntimeError(
                    "Unable to find git repository in current working "
                    f"directory {os.getcwd()} or its parent directories."
                )

            remote_url = self.git_repo.remote().url
            is_github_repo = any(
                remote_url.startswith(prefix)
                for prefix in GITHUB_REMOTE_URL_PREFIXES
            )
            if not (is_github_repo or self.skip_github_repository_check):
                raise RuntimeError(
                    f"The remote URL '{remote_url}' of your git repo "
                    f"({self._git_repo.git_dir}) is not pointing to a GitHub "
                    "repository. The GitHub Actions orchestrator runs "
                    "pipelines using GitHub Actions and therefore only works "
                    "with GitHub repositories. If you want to skip this check "
                    "and run this orchestrator anyway, run: \n"
                    f"`zenml orchestrator update {self.name} "
                    "--skip_github_repository_check=true`"
                )

        return self._git_repo

    @property
    def workflow_directory(self) -> str:
        """Returns path to the GitHub workflows directory.

        Returns:
            The GitHub workflows directory.
        """
        assert self.git_repo.working_dir
        return os.path.join(self.git_repo.working_dir, ".github", "workflows")

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validator that ensures that the stack is compatible.

        Makes sure that the stack contains a container registry and only
        remote components.

        Returns:
            The stack validator.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
            container_registry = stack.container_registry
            assert container_registry is not None

            if container_registry.is_local:
                return False, (
                    "The GitHub Actions orchestrator requires a remote "
                    f"container registry, but the '{container_registry.name}' "
                    "container registry of your active stack points to a local "
                    f"URI '{container_registry.uri}'. Please make sure stacks "
                    "with a GitHub Actions orchestrator always contain remote "
                    "container registries."
                )

            if container_registry.requires_authentication:
                return False, (
                    "The GitHub Actions orchestrator currently only works with "
                    "GitHub container registries or public container "
                    f"registries, but your {container_registry.FLAVOR} "
                    f"container registry '{container_registry.name}' requires "
                    "authentication."
                )

            for component in stack.components.values():
                if component.local_path:
                    return False, (
                        "The GitHub Actions orchestrator runs pipelines on "
                        "remote GitHub Actions runners, but the "
                        f"'{component.name}' {component.TYPE.value} of your "
                        "active stack is a local component. Please make sure "
                        "to only use remote stack components in combination "
                        "with the GitHub Actions orchestrator. "
                    )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: Name of the pipeline for which to generate a docker
                image name.

        Returns:
            The docker image name.
        """
        container_registry = Repository().active_stack.container_registry
        assert container_registry  # should never happen due to validation
        return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"

    def _docker_login_step(
        self,
        container_registry: BaseContainerRegistry,
    ) -> Optional[Dict[str, Any]]:
        """GitHub Actions step for authenticating with the container registry.

        Args:
            container_registry: The container registry which (potentially)
                requires a step to authenticate.

        Returns:
            Dictionary specifying the GitHub Actions step for authenticating
            with the container registry if that is required, `None` otherwise.
        """
        if (
            isinstance(container_registry, GitHubContainerRegistry)
            and container_registry.automatic_token_authentication
        ):
            # Use GitHub Actions specific placeholder if the container registry
            # specifies automatic token authentication
            username = "${{ github.actor }}"
            password = "${{ secrets.GITHUB_TOKEN }}"
        # TODO: Uncomment these lines once we support different private
        #  container registries in GitHub Actions
        # elif container_registry.requires_authentication:
        #     username = cast(str, container_registry.username)
        #     password = cast(str, container_registry.password)
        else:
            return None

        return {
            "name": "Authenticate with the container registry",
            "uses": DOCKER_LOGIN_ACTION,
            "with": {
                "registry": container_registry.uri,
                "username": username,
                "password": password,
            },
        }

    def _write_environment_file_step(
        self,
        file_name: str,
        secrets_manager: Optional[BaseSecretsManager] = None,
    ) -> Optional[Dict[str, Any]]:
        """GitHub Actions step for writing secrets to an environment file.

        Args:
            file_name: Name of the environment file that should be written.
            secrets_manager: Secrets manager that will be used to read secrets
                during pipeline execution.

        Returns:
            Dictionary specifying the GitHub Actions step for writing the
            environment file.
        """
        if not isinstance(secrets_manager, GitHubSecretsManager):
            return None

        # Always include the environment variable that specifies whether
        # we're running in a GitHub Action workflow so the secret manager knows
        # how to query secret values
        command = (
            f'echo {ENV_IN_GITHUB_ACTIONS}="${ENV_IN_GITHUB_ACTIONS}" '
            f"> {file_name}; "
        )

        # Write all ZenML secrets into the environment file. Explicitly writing
        # these `${{ secrets.<SECRET_NAME> }}` placeholders into the workflow
        # yaml is the only way for us to access the GitHub secrets in a GitHub
        # Actions workflow.
        append_secret_placeholder = (
            "echo {secret_name}=${{{{ secrets.{secret_name} }}}} >> {file}; "
        )
        for secret_name in secrets_manager.get_all_secret_keys(
            include_prefix=True
        ):
            command += append_secret_placeholder.format(
                secret_name=secret_name, file=file_name
            )

        return {
            "name": "Write environment file",
            "run": command,
        }

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Builds and uploads a docker image.

        Args:
            pipeline: The pipeline for which the image is built.
            stack: The stack on which the pipeline will be executed.
            runtime_configuration: Runtime configuration for the pipeline run.

        Raises:
            RuntimeError: If the orchestrator should only run in a clean git
                repository and the repository is dirty.
        """
        if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
            untracked_files=True
        ):
            raise RuntimeError(
                "Trying to run a pipeline from within a dirty (=containing "
                "untracked/uncommitted files) git repository."
                "If you want this orchestrator to skip the dirty repo check in "
                f"the future, run\n `zenml orchestrator update {self.name} "
                "--skip_dirty_repository_check=true`"
            )

        image_name = self.get_docker_image_name(pipeline.name)
        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug(
            "Github actions docker image requirements: %s", requirements
        )

        docker_utils.build_docker_image(
            build_context_path=source_utils.get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List[BaseStep],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Writes a GitHub Action workflow yaml and optionally pushes it.

        Args:
             sorted_steps: List of sorted steps
             pipeline: Zenml Pipeline instance
             pb2_pipeline: Protobuf Pipeline instance
             stack: The stack the pipeline was run on
             runtime_configuration: The Runtime configuration of the current run

        Raises:
            ValueError: If a schedule without a cron expression or with an
                invalid cron expression is passed.
        """
        schedule = runtime_configuration.schedule

        workflow_name = pipeline.name
        if schedule:
            # Add a suffix to the workflow filename so we don't overwrite
            # scheduled pipeline by future schedules or single pipeline runs.
            datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
            workflow_name += f"-scheduled-{datetime_string}"

        workflow_path = os.path.join(
            self.workflow_directory,
            f"{workflow_name}.yaml",
        )

        # Store the encoded pb2 pipeline once as an environment variable.
        # We will replace the entrypoint argument later to reduce the size
        # of the workflow file.
        encoded_pb2_pipeline = string_utils.b64_encode(
            json_format.MessageToJson(pb2_pipeline)
        )
        workflow_dict: Dict[str, Any] = {
            "name": workflow_name,
            "env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
        }

        if schedule:
            if not schedule.cron_expression:
                raise ValueError(
                    "GitHub Action workflows can only be scheduled using cron "
                    "expressions and not using a periodic schedule. If you "
                    "want to schedule pipelines using this GitHub Action "
                    "orchestrator, please include a cron expression in your "
                    "schedule object. For more information on GitHub workflow "
                    "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
                )

            # GitHub workflows requires a schedule interval of at least 5
            # minutes. Invalid cron expressions would be something like
            # `*/3 * * * *` (all stars except the first part of the expression,
            # which will have the format `*/minute_interval`)
            if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
                raise ValueError(
                    "GitHub workflows requires a schedule interval of at "
                    "least 5 minutes which is incompatible with your cron "
                    f"expression '{schedule.cron_expression}'. An example of a "
                    "valid cron expression would be '* 1 * * *' to run "
                    "every hour. For more information on GitHub workflow "
                    "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
                )

            logger.warning(
                "GitHub only runs scheduled workflows once the "
                "workflow file is merged to the default branch of the "
                "repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
                "Please make sure to merge your current branch into the "
                "default branch for this scheduled pipeline to run."
            )
            workflow_dict["on"] = {
                "schedule": [{"cron": schedule.cron_expression}]
            }
        else:
            # The pipeline should only run once. The only fool-proof way to
            # only execute a workflow once seems to be running on specific tags.
            # We don't want to create tags for each pipeline run though, so
            # instead we only run this workflow if the workflow file is
            # modified. As long as users don't manually modify these files this
            # should be sufficient.
            workflow_path_in_repo = os.path.relpath(
                workflow_path, self.git_repo.working_dir
            )
            workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}

        image_name = self.get_docker_image_name(pipeline.name)
        image_name = docker_utils.get_image_digest(image_name) or image_name

        # Prepare the step that writes an environment file which will get
        # passed to the docker image
        env_file_name = ".zenml_docker_env"
        write_env_file_step = self._write_environment_file_step(
            file_name=env_file_name, secrets_manager=stack.secrets_manager
        )
        docker_run_args = (
            ["--env-file", env_file_name] if write_env_file_step else []
        )

        # Prepare the docker login step if necessary
        container_registry = stack.container_registry
        assert container_registry
        docker_login_step = self._docker_login_step(container_registry)

        # The base command that each job will execute with specific arguments
        base_command = [
            "docker",
            "run",
            *docker_run_args,
            image_name,
        ] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()

        jobs = {}
        for step in sorted_steps:
            job_steps = []

            # Copy the shared dicts here to avoid creating yaml anchors (which
            # are currently not supported in GitHub workflow yaml files)
            if write_env_file_step:
                job_steps.append(copy.deepcopy(write_env_file_step))

            if docker_login_step:
                job_steps.append(copy.deepcopy(docker_login_step))

            entrypoint_args = (
                GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                )
            )

            # Replace the encoded string by a global environment variable to
            # keep the workflow file small
            index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
            entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"

            command = base_command + entrypoint_args
            docker_run_step = {
                "name": "Run the docker image",
                "run": " ".join(command),
            }

            job_steps.append(docker_run_step)
            job_dict = {
                "runs-on": "ubuntu-latest",
                "needs": self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                ),
                "steps": job_steps,
            }
            jobs[step.name] = job_dict

        workflow_dict["jobs"] = jobs

        fileio.makedirs(self.workflow_directory)
        yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
        logger.info("Wrote GitHub workflow file to %s", workflow_path)

        if self.push:
            # Add, commit and push the pipeline workflow yaml
            self.git_repo.index.add(workflow_path)
            self.git_repo.index.commit(
                "[ZenML GitHub Actions Orchestrator] Add github workflow for "
                f"pipeline {pipeline.name}."
            )
            self.git_repo.remote().push()
            logger.info("Pushed workflow file '%s'", workflow_path)
        else:
            logger.info(
                "Automatically committing and pushing is disabled for this "
                "orchestrator. To run the pipeline, you'll have to commit and "
                "push the workflow file '%s' manually.\n"
                "If you want to update this orchestrator to automatically "
                "commit and push in the future, run "
                "`zenml orchestrator update %s --push=true`",
                workflow_path,
                self.name,
            )
git_repo: Repo property readonly

Returns the git repository for the current working directory.

Returns:

Type Description
Repo

Git repository for the current working directory.

Exceptions:

Type Description
RuntimeError

If there is no git repository for the current working directory or the repository remote is not pointing to GitHub.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validator that ensures that the stack is compatible.

Makes sure that the stack contains a container registry and only remote components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

The stack validator.

workflow_directory: str property readonly

Returns path to the GitHub workflows directory.

Returns:

Type Description
str

The GitHub workflows directory.

get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline for which to generate a docker image name.

required

Returns:

Type Description
str

The docker image name.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: Name of the pipeline for which to generate a docker
            image name.

    Returns:
        The docker image name.
    """
    container_registry = Repository().active_stack.container_registry
    assert container_registry  # should never happen due to validation
    return f"{container_registry.uri}/zenml-github-actions:{pipeline_name}"
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Writes a GitHub Action workflow yaml and optionally pushes it.

Parameters:

Name Type Description Default
sorted_steps List[zenml.steps.base_step.BaseStep]

List of sorted steps

required
pipeline BasePipeline

Zenml Pipeline instance

required
pb2_pipeline Pipeline

Protobuf Pipeline instance

required
stack Stack

The stack the pipeline was run on

required
runtime_configuration RuntimeConfiguration

The Runtime configuration of the current run

required

Exceptions:

Type Description
ValueError

If a schedule without a cron expression or with an invalid cron expression is passed.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List[BaseStep],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Writes a GitHub Action workflow yaml and optionally pushes it.

    Args:
         sorted_steps: List of sorted steps
         pipeline: Zenml Pipeline instance
         pb2_pipeline: Protobuf Pipeline instance
         stack: The stack the pipeline was run on
         runtime_configuration: The Runtime configuration of the current run

    Raises:
        ValueError: If a schedule without a cron expression or with an
            invalid cron expression is passed.
    """
    schedule = runtime_configuration.schedule

    workflow_name = pipeline.name
    if schedule:
        # Add a suffix to the workflow filename so we don't overwrite
        # scheduled pipeline by future schedules or single pipeline runs.
        datetime_string = datetime.now().strftime("%y_%m_%d_%H_%M_%S")
        workflow_name += f"-scheduled-{datetime_string}"

    workflow_path = os.path.join(
        self.workflow_directory,
        f"{workflow_name}.yaml",
    )

    # Store the encoded pb2 pipeline once as an environment variable.
    # We will replace the entrypoint argument later to reduce the size
    # of the workflow file.
    encoded_pb2_pipeline = string_utils.b64_encode(
        json_format.MessageToJson(pb2_pipeline)
    )
    workflow_dict: Dict[str, Any] = {
        "name": workflow_name,
        "env": {ENV_ENCODED_ZENML_PIPELINE: encoded_pb2_pipeline},
    }

    if schedule:
        if not schedule.cron_expression:
            raise ValueError(
                "GitHub Action workflows can only be scheduled using cron "
                "expressions and not using a periodic schedule. If you "
                "want to schedule pipelines using this GitHub Action "
                "orchestrator, please include a cron expression in your "
                "schedule object. For more information on GitHub workflow "
                "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
            )

        # GitHub workflows requires a schedule interval of at least 5
        # minutes. Invalid cron expressions would be something like
        # `*/3 * * * *` (all stars except the first part of the expression,
        # which will have the format `*/minute_interval`)
        if re.fullmatch(r"\*/[1-4]( \*){4,}", schedule.cron_expression):
            raise ValueError(
                "GitHub workflows requires a schedule interval of at "
                "least 5 minutes which is incompatible with your cron "
                f"expression '{schedule.cron_expression}'. An example of a "
                "valid cron expression would be '* 1 * * *' to run "
                "every hour. For more information on GitHub workflow "
                "schedules check out https://docs.github.com/en/actions/using-workflows/events-that-trigger-workflows#schedule."
            )

        logger.warning(
            "GitHub only runs scheduled workflows once the "
            "workflow file is merged to the default branch of the "
            "repository (https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-branches#about-the-default-branch). "
            "Please make sure to merge your current branch into the "
            "default branch for this scheduled pipeline to run."
        )
        workflow_dict["on"] = {
            "schedule": [{"cron": schedule.cron_expression}]
        }
    else:
        # The pipeline should only run once. The only fool-proof way to
        # only execute a workflow once seems to be running on specific tags.
        # We don't want to create tags for each pipeline run though, so
        # instead we only run this workflow if the workflow file is
        # modified. As long as users don't manually modify these files this
        # should be sufficient.
        workflow_path_in_repo = os.path.relpath(
            workflow_path, self.git_repo.working_dir
        )
        workflow_dict["on"] = {"push": {"paths": [workflow_path_in_repo]}}

    image_name = self.get_docker_image_name(pipeline.name)
    image_name = docker_utils.get_image_digest(image_name) or image_name

    # Prepare the step that writes an environment file which will get
    # passed to the docker image
    env_file_name = ".zenml_docker_env"
    write_env_file_step = self._write_environment_file_step(
        file_name=env_file_name, secrets_manager=stack.secrets_manager
    )
    docker_run_args = (
        ["--env-file", env_file_name] if write_env_file_step else []
    )

    # Prepare the docker login step if necessary
    container_registry = stack.container_registry
    assert container_registry
    docker_login_step = self._docker_login_step(container_registry)

    # The base command that each job will execute with specific arguments
    base_command = [
        "docker",
        "run",
        *docker_run_args,
        image_name,
    ] + GitHubActionsEntrypointConfiguration.get_entrypoint_command()

    jobs = {}
    for step in sorted_steps:
        job_steps = []

        # Copy the shared dicts here to avoid creating yaml anchors (which
        # are currently not supported in GitHub workflow yaml files)
        if write_env_file_step:
            job_steps.append(copy.deepcopy(write_env_file_step))

        if docker_login_step:
            job_steps.append(copy.deepcopy(docker_login_step))

        entrypoint_args = (
            GitHubActionsEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
            )
        )

        # Replace the encoded string by a global environment variable to
        # keep the workflow file small
        index = entrypoint_args.index(f"--{PIPELINE_JSON_OPTION}")
        entrypoint_args[index + 1] = f"${ENV_ENCODED_ZENML_PIPELINE}"

        command = base_command + entrypoint_args
        docker_run_step = {
            "name": "Run the docker image",
            "run": " ".join(command),
        }

        job_steps.append(docker_run_step)
        job_dict = {
            "runs-on": "ubuntu-latest",
            "needs": self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            ),
            "steps": job_steps,
        }
        jobs[step.name] = job_dict

    workflow_dict["jobs"] = jobs

    fileio.makedirs(self.workflow_directory)
    yaml_utils.write_yaml(workflow_path, workflow_dict, sort_keys=False)
    logger.info("Wrote GitHub workflow file to %s", workflow_path)

    if self.push:
        # Add, commit and push the pipeline workflow yaml
        self.git_repo.index.add(workflow_path)
        self.git_repo.index.commit(
            "[ZenML GitHub Actions Orchestrator] Add github workflow for "
            f"pipeline {pipeline.name}."
        )
        self.git_repo.remote().push()
        logger.info("Pushed workflow file '%s'", workflow_path)
    else:
        logger.info(
            "Automatically committing and pushing is disabled for this "
            "orchestrator. To run the pipeline, you'll have to commit and "
            "push the workflow file '%s' manually.\n"
            "If you want to update this orchestrator to automatically "
            "commit and push in the future, run "
            "`zenml orchestrator update %s --push=true`",
            workflow_path,
            self.name,
        )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Builds and uploads a docker image.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline for which the image is built.

required
stack Stack

The stack on which the pipeline will be executed.

required
runtime_configuration RuntimeConfiguration

Runtime configuration for the pipeline run.

required

Exceptions:

Type Description
RuntimeError

If the orchestrator should only run in a clean git repository and the repository is dirty.

Source code in zenml/integrations/github/orchestrators/github_actions_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Builds and uploads a docker image.

    Args:
        pipeline: The pipeline for which the image is built.
        stack: The stack on which the pipeline will be executed.
        runtime_configuration: Runtime configuration for the pipeline run.

    Raises:
        RuntimeError: If the orchestrator should only run in a clean git
            repository and the repository is dirty.
    """
    if not self.skip_dirty_repository_check and self.git_repo.is_dirty(
        untracked_files=True
    ):
        raise RuntimeError(
            "Trying to run a pipeline from within a dirty (=containing "
            "untracked/uncommitted files) git repository."
            "If you want this orchestrator to skip the dirty repo check in "
            f"the future, run\n `zenml orchestrator update {self.name} "
            "--skip_dirty_repository_check=true`"
        )

    image_name = self.get_docker_image_name(pipeline.name)
    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug(
        "Github actions docker image requirements: %s", requirements
    )

    docker_utils.build_docker_image(
        build_context_path=source_utils.get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest

secrets_managers special

Initialization of the GitHub Secrets Manager.

github_secrets_manager

Implementation of the GitHub Secrets Manager.

GitHubSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the GitHub secrets manager.

Attributes:

Name Type Description
owner str

The owner (either individual or organization) of the repository.

repository str

Name of the GitHub repository.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
class GitHubSecretsManager(BaseSecretsManager):
    """Class to interact with the GitHub secrets manager.

    Attributes:
        owner: The owner (either individual or organization) of the repository.
        repository: Name of the GitHub repository.
    """

    owner: str
    repository: str

    _session: Optional[requests.Session] = None

    # Class configuration
    FLAVOR: ClassVar[str] = GITHUB_SECRET_MANAGER_FLAVOR

    @property
    def post_registration_message(self) -> Optional[str]:
        """Info message regarding GitHub API authentication env variables.

        Returns:
            The info message.
        """
        return AUTHENTICATION_CREDENTIALS_MESSAGE

    @property
    def session(self) -> requests.Session:
        """Session to send requests to the GitHub API.

        Returns:
            Session to use for GitHub API calls.

        Raises:
            RuntimeError: If authentication credentials for the GitHub API are
                not set.
        """
        if not self._session:
            session = requests.Session()
            github_username = os.getenv(ENV_GITHUB_USERNAME)
            authentication_token = os.getenv(ENV_GITHUB_AUTHENTICATION_TOKEN)

            if not github_username or not authentication_token:
                raise RuntimeError(
                    "Missing authentication credentials for GitHub secrets "
                    "manager. " + AUTHENTICATION_CREDENTIALS_MESSAGE
                )

            session.auth = HTTPBasicAuth(github_username, authentication_token)
            session.headers["Accept"] = "application/vnd.github.v3+json"
            self._session = session

        return self._session

    def _send_request(
        self, method: str, resource: Optional[str] = None, **kwargs: Any
    ) -> requests.Response:
        """Sends an HTTP request to the GitHub API.

        Args:
            method: Method of the HTTP request that should be sent.
            resource: Optional resource to which the request should be sent. If
                none is given, the default GitHub API secrets endpoint will be
                used.
            **kwargs: Will be passed to the `requests` library.

        Returns:
            HTTP response.

        # noqa: DAR402

        Raises:
            HTTPError: If the request failed due to a client or server error.
        """
        url = (
            f"https://api.github.com/repos/{self.owner}/{self.repository}"
            f"/actions/secrets"
        )
        if resource:
            url += resource

        response = self.session.request(method=method, url=url, **kwargs)
        # Raise an exception in case of a client or server error
        response.raise_for_status()

        return response

    def _encrypt_secret(self, secret_value: str) -> Tuple[str, str]:
        """Encrypts a secret value.

        This method first fetches a public key from the GitHub API and then uses
        this key to encrypt the secret value. This is needed in order to
        register GitHub secrets using the API.

        Args:
            secret_value: Secret value to encrypt.

        Returns:
            The encrypted secret value and the key id of the GitHub public key.
        """
        from nacl.encoding import Base64Encoder
        from nacl.public import PublicKey, SealedBox

        response_json = self._send_request("GET", resource="/public-key").json()
        public_key = PublicKey(
            response_json["key"].encode("utf-8"), Base64Encoder
        )
        sealed_box = SealedBox(public_key)
        encrypted_bytes = sealed_box.encrypt(secret_value.encode("utf-8"))
        encrypted_string = base64.b64encode(encrypted_bytes).decode("utf-8")
        return encrypted_string, cast(str, response_json["key_id"])

    def _has_secret(self, secret_name: str) -> bool:
        """Checks whether a secret exists for the given name.

        Args:
            secret_name: Name of the secret which should be checked.

        Returns:
            `True` if a secret with the given name exists, `False` otherwise.
        """
        secret_name = _convert_secret_name(secret_name, remove_prefix=True)
        return secret_name in self.get_all_secret_keys(include_prefix=False)

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets the value of a secret.

        This method only works when called from within a GitHub Actions
        environment.

        Args:
            secret_name: The name of the secret to get.

        Returns:
            The secret.

        Raises:
            KeyError: If a secret with this name doesn't exist.
            RuntimeError: If not inside a GitHub Actions environments.
        """
        full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
        # Raise a KeyError if the secret doesn't exist. We can do that even
        # if we're not inside a GitHub Actions environment
        if not self._has_secret(secret_name):
            raise KeyError(
                f"Unable to find secret '{secret_name}'. Please check the "
                "GitHub UI to see if a **Repository** secret called "
                f"'{full_secret_name}' exists. (ZenML uses the "
                f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
                "secrets from other GitHub secrets)"
            )

        if not inside_github_action_environment():
            stack_name = Repository().active_stack_name
            commands = [
                f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
                "zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
                "--flavor=local",
                "zenml stack update <NEW_STACK_NAME> "
                "--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
                "zenml stack set <NEW_STACK_NAME>",
                f"zenml secret register {secret_name} ...",
            ]

            raise RuntimeError(
                "Getting GitHub secrets is only possible within a GitHub "
                "Actions workflow. If you need this secret to access "
                "stack components (e.g. your metadata store to fetch pipelines "
                "during the post-execution workflow) locally, you need to "
                "register this secret in a different secrets manager. "
                "You can do this by running the following commands: \n\n"
                + "\n".join(commands)
            )

        # If we're running inside an GitHub Actions environment using the a
        # workflow generated by the GitHub Actions orchestrator, all ZenML
        # secrets stored in the GitHub secrets manager will be accessible as
        # environment variables
        secret_value = cast(str, os.getenv(full_secret_name))

        secret_dict = json.loads(string_utils.b64_decode(secret_value))
        schema_class = SecretSchemaClassRegistry.get_class(
            secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
        )
        secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]

        return schema_class(name=secret_name, **secret_content)

    def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
        """Get all secret keys.

        If we're running inside a GitHub Actions environment, this will return
        the names of all environment variables starting with a ZenML internal
        prefix. Otherwise, this will return all GitHub **Repository** secrets
        created by ZenML.

        Args:
            include_prefix: Whether or not the internal prefix that is used to
                differentiate ZenML secrets from other GitHub secrets should be
                included in the returned names.

        Returns:
            List of all secret keys.
        """
        if inside_github_action_environment():
            potential_secret_keys = list(os.environ)
        else:
            logger.info(
                "Fetching list of secrets for repository %s/%s",
                self.owner,
                self.repository,
            )
            response = self._send_request("GET", params={"per_page": 100})
            potential_secret_keys = [
                secret_dict["name"]
                for secret_dict in response.json()["secrets"]
            ]

        keys = [
            _convert_secret_name(key, remove_prefix=not include_prefix)
            for key in potential_secret_keys
            if key.startswith(GITHUB_SECRET_PREFIX)
        ]

        return keys

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: The secret to register.

        Raises:
            SecretExistsError: If a secret with this name already exists.
        """
        if self._has_secret(secret.name):
            raise SecretExistsError(
                f"A secret with name '{secret.name}' already exists for this "
                "GitHub repository. If you want to register a new value for "
                f"this secret, please run `zenml secret delete {secret.name}` "
                f"followed by `zenml secret register {secret.name} ...`."
            )

        secret_dict = {
            SECRET_SCHEMA_DICT_KEY: secret.TYPE,
            SECRET_CONTENT_DICT_KEY: secret.content,
        }
        secret_value = string_utils.b64_encode(json.dumps(secret_dict))
        encrypted_secret, public_key_id = self._encrypt_secret(
            secret_value=secret_value
        )
        body = {
            "encrypted_value": encrypted_secret,
            "key_id": public_key_id,
        }

        full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
        self._send_request("PUT", resource=f"/{full_secret_name}", json=body)

    def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
        """Update an existing secret.

        Args:
            secret: The secret to update.

        Raises:
            NotImplementedError: Always, as this functionality is not possible
                using GitHub secrets which doesn't allow us to retrieve the
                secret values outside of a GitHub Actions environment.
        """
        raise NotImplementedError(
            "Updating secrets is not possible with the GitHub secrets manager "
            "as it is not possible to retrieve GitHub secrets values outside "
            "of a GitHub Actions environment."
        )

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: The name of the secret to delete.
        """
        full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
        self._send_request("DELETE", resource=f"/{full_secret_name}")

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        for secret_name in self.get_all_secret_keys(include_prefix=False):
            self.delete_secret(secret_name=secret_name)
post_registration_message: Optional[str] property readonly

Info message regarding GitHub API authentication env variables.

Returns:

Type Description
Optional[str]

The info message.

session: Session property readonly

Session to send requests to the GitHub API.

Returns:

Type Description
Session

Session to use for GitHub API calls.

Exceptions:

Type Description
RuntimeError

If authentication credentials for the GitHub API are not set.

delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    for secret_name in self.get_all_secret_keys(include_prefix=False):
        self.delete_secret(secret_name=secret_name)
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to delete.

required
Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: The name of the secret to delete.
    """
    full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
    self._send_request("DELETE", resource=f"/{full_secret_name}")
get_all_secret_keys(self, include_prefix=False)

Get all secret keys.

If we're running inside a GitHub Actions environment, this will return the names of all environment variables starting with a ZenML internal prefix. Otherwise, this will return all GitHub Repository secrets created by ZenML.

Parameters:

Name Type Description Default
include_prefix bool

Whether or not the internal prefix that is used to differentiate ZenML secrets from other GitHub secrets should be included in the returned names.

False

Returns:

Type Description
List[str]

List of all secret keys.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_all_secret_keys(self, include_prefix: bool = False) -> List[str]:
    """Get all secret keys.

    If we're running inside a GitHub Actions environment, this will return
    the names of all environment variables starting with a ZenML internal
    prefix. Otherwise, this will return all GitHub **Repository** secrets
    created by ZenML.

    Args:
        include_prefix: Whether or not the internal prefix that is used to
            differentiate ZenML secrets from other GitHub secrets should be
            included in the returned names.

    Returns:
        List of all secret keys.
    """
    if inside_github_action_environment():
        potential_secret_keys = list(os.environ)
    else:
        logger.info(
            "Fetching list of secrets for repository %s/%s",
            self.owner,
            self.repository,
        )
        response = self._send_request("GET", params={"per_page": 100})
        potential_secret_keys = [
            secret_dict["name"]
            for secret_dict in response.json()["secrets"]
        ]

    keys = [
        _convert_secret_name(key, remove_prefix=not include_prefix)
        for key in potential_secret_keys
        if key.startswith(GITHUB_SECRET_PREFIX)
    ]

    return keys
get_secret(self, secret_name)

Gets the value of a secret.

This method only works when called from within a GitHub Actions environment.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to get.

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

If a secret with this name doesn't exist.

RuntimeError

If not inside a GitHub Actions environments.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets the value of a secret.

    This method only works when called from within a GitHub Actions
    environment.

    Args:
        secret_name: The name of the secret to get.

    Returns:
        The secret.

    Raises:
        KeyError: If a secret with this name doesn't exist.
        RuntimeError: If not inside a GitHub Actions environments.
    """
    full_secret_name = _convert_secret_name(secret_name, add_prefix=True)
    # Raise a KeyError if the secret doesn't exist. We can do that even
    # if we're not inside a GitHub Actions environment
    if not self._has_secret(secret_name):
        raise KeyError(
            f"Unable to find secret '{secret_name}'. Please check the "
            "GitHub UI to see if a **Repository** secret called "
            f"'{full_secret_name}' exists. (ZenML uses the "
            f"'{GITHUB_SECRET_PREFIX}' to differentiate ZenML "
            "secrets from other GitHub secrets)"
        )

    if not inside_github_action_environment():
        stack_name = Repository().active_stack_name
        commands = [
            f"zenml stack copy {stack_name} <NEW_STACK_NAME>",
            "zenml secrets_manager register <NEW_SECRETS_MANAGER_NAME> "
            "--flavor=local",
            "zenml stack update <NEW_STACK_NAME> "
            "--secrets_manager=<NEW_SECRETS_MANAGER_NAME>",
            "zenml stack set <NEW_STACK_NAME>",
            f"zenml secret register {secret_name} ...",
        ]

        raise RuntimeError(
            "Getting GitHub secrets is only possible within a GitHub "
            "Actions workflow. If you need this secret to access "
            "stack components (e.g. your metadata store to fetch pipelines "
            "during the post-execution workflow) locally, you need to "
            "register this secret in a different secrets manager. "
            "You can do this by running the following commands: \n\n"
            + "\n".join(commands)
        )

    # If we're running inside an GitHub Actions environment using the a
    # workflow generated by the GitHub Actions orchestrator, all ZenML
    # secrets stored in the GitHub secrets manager will be accessible as
    # environment variables
    secret_value = cast(str, os.getenv(full_secret_name))

    secret_dict = json.loads(string_utils.b64_decode(secret_value))
    schema_class = SecretSchemaClassRegistry.get_class(
        secret_schema=secret_dict[SECRET_SCHEMA_DICT_KEY]
    )
    secret_content = secret_dict[SECRET_CONTENT_DICT_KEY]

    return schema_class(name=secret_name, **secret_content)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to register.

required

Exceptions:

Type Description
SecretExistsError

If a secret with this name already exists.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: The secret to register.

    Raises:
        SecretExistsError: If a secret with this name already exists.
    """
    if self._has_secret(secret.name):
        raise SecretExistsError(
            f"A secret with name '{secret.name}' already exists for this "
            "GitHub repository. If you want to register a new value for "
            f"this secret, please run `zenml secret delete {secret.name}` "
            f"followed by `zenml secret register {secret.name} ...`."
        )

    secret_dict = {
        SECRET_SCHEMA_DICT_KEY: secret.TYPE,
        SECRET_CONTENT_DICT_KEY: secret.content,
    }
    secret_value = string_utils.b64_encode(json.dumps(secret_dict))
    encrypted_secret, public_key_id = self._encrypt_secret(
        secret_value=secret_value
    )
    body = {
        "encrypted_value": encrypted_secret,
        "key_id": public_key_id,
    }

    full_secret_name = _convert_secret_name(secret.name, add_prefix=True)
    self._send_request("PUT", resource=f"/{full_secret_name}", json=body)
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to update.

required

Exceptions:

Type Description
NotImplementedError

Always, as this functionality is not possible using GitHub secrets which doesn't allow us to retrieve the secret values outside of a GitHub Actions environment.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> NoReturn:
    """Update an existing secret.

    Args:
        secret: The secret to update.

    Raises:
        NotImplementedError: Always, as this functionality is not possible
            using GitHub secrets which doesn't allow us to retrieve the
            secret values outside of a GitHub Actions environment.
    """
    raise NotImplementedError(
        "Updating secrets is not possible with the GitHub secrets manager "
        "as it is not possible to retrieve GitHub secrets values outside "
        "of a GitHub Actions environment."
    )
inside_github_action_environment()

Returns if the current code is executing in a GitHub Actions environment.

Returns:

Type Description
bool

True if running in a GitHub Actions environment, False otherwise.

Source code in zenml/integrations/github/secrets_managers/github_secrets_manager.py
def inside_github_action_environment() -> bool:
    """Returns if the current code is executing in a GitHub Actions environment.

    Returns:
        `True` if running in a GitHub Actions environment, `False` otherwise.
    """
    return os.getenv(ENV_IN_GITHUB_ACTIONS) == "true"

graphviz special

Initialization of the Graphviz integration.

GraphvizIntegration (Integration)

Definition of Graphviz integration for ZenML.

Source code in zenml/integrations/graphviz/__init__.py
class GraphvizIntegration(Integration):
    """Definition of Graphviz integration for ZenML."""

    NAME = GRAPHVIZ
    REQUIREMENTS = ["graphviz>=0.17"]
    SYSTEM_REQUIREMENTS = {"graphviz": "dot"}

visualizers special

Initialization of Graphviz visualizers.

pipeline_run_dag_visualizer

Implementation of the Graphviz pipeline run DAG visualizer.

PipelineRunDagVisualizer (BasePipelineRunVisualizer)

Visualize the lineage of runs in a pipeline.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
class PipelineRunDagVisualizer(BasePipelineRunVisualizer):
    """Visualize the lineage of runs in a pipeline."""

    ARTIFACT_DEFAULT_COLOR = "blue"
    ARTIFACT_CACHED_COLOR = "green"
    ARTIFACT_SHAPE = "box"
    ARTIFACT_PREFIX = "artifact_"
    STEP_COLOR = "#431D93"
    STEP_SHAPE = "ellipse"
    STEP_PREFIX = "step_"
    FONT = "Roboto"

    @abstractmethod
    def visualize(
        self, object: PipelineRunView, *args: Any, **kwargs: Any
    ) -> graphviz.Digraph:
        """Creates a pipeline lineage diagram using graphviz.

        Args:
            object: The pipeline run view to visualize.
            *args: Additional arguments to pass to the visualization.
            **kwargs: Additional keyword arguments to pass to the visualization.

        Returns:
            A graphviz digraph object.
        """
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        dot = graphviz.Digraph(comment=object.name)

        # link the steps together
        for step in object.steps:
            # add each step as a node
            dot.node(
                self.STEP_PREFIX + str(step.id),
                step.entrypoint_name,
                shape=self.STEP_SHAPE,
            )
            # for each parent of a step, add an edge

            for artifact_name, artifact in step.outputs.items():
                dot.node(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    f"{artifact_name} \n" f"({artifact._data_type})",
                    shape=self.ARTIFACT_SHAPE,
                )
                dot.edge(
                    self.STEP_PREFIX + str(step.id),
                    self.ARTIFACT_PREFIX + str(artifact.id),
                )

            for artifact_name, artifact in step.inputs.items():
                dot.edge(
                    self.ARTIFACT_PREFIX + str(artifact.id),
                    self.STEP_PREFIX + str(step.id),
                )

        with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
            dot.render(filename=f.name, format="png", view=True, cleanup=True)
        return dot
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using graphviz.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run view to visualize.

required
*args Any

Additional arguments to pass to the visualization.

()
**kwargs Any

Additional keyword arguments to pass to the visualization.

{}

Returns:

Type Description
Digraph

A graphviz digraph object.

Source code in zenml/integrations/graphviz/visualizers/pipeline_run_dag_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineRunView, *args: Any, **kwargs: Any
) -> graphviz.Digraph:
    """Creates a pipeline lineage diagram using graphviz.

    Args:
        object: The pipeline run view to visualize.
        *args: Additional arguments to pass to the visualization.
        **kwargs: Additional keyword arguments to pass to the visualization.

    Returns:
        A graphviz digraph object.
    """
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    dot = graphviz.Digraph(comment=object.name)

    # link the steps together
    for step in object.steps:
        # add each step as a node
        dot.node(
            self.STEP_PREFIX + str(step.id),
            step.entrypoint_name,
            shape=self.STEP_SHAPE,
        )
        # for each parent of a step, add an edge

        for artifact_name, artifact in step.outputs.items():
            dot.node(
                self.ARTIFACT_PREFIX + str(artifact.id),
                f"{artifact_name} \n" f"({artifact._data_type})",
                shape=self.ARTIFACT_SHAPE,
            )
            dot.edge(
                self.STEP_PREFIX + str(step.id),
                self.ARTIFACT_PREFIX + str(artifact.id),
            )

        for artifact_name, artifact in step.inputs.items():
            dot.edge(
                self.ARTIFACT_PREFIX + str(artifact.id),
                self.STEP_PREFIX + str(step.id),
            )

    with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
        dot.render(filename=f.name, format="png", view=True, cleanup=True)
    return dot

great_expectations special

Great Expectation integration for ZenML.

The Great Expectations integration enables you to use Great Expectations as a way of profiling and validating your data.

GreatExpectationsIntegration (Integration)

Definition of Great Expectations integration for ZenML.

Source code in zenml/integrations/great_expectations/__init__.py
class GreatExpectationsIntegration(Integration):
    """Definition of Great Expectations integration for ZenML."""

    NAME = GREAT_EXPECTATIONS
    REQUIREMENTS = [
        "great-expectations~=0.15.11",
    ]

    @staticmethod
    def activate() -> None:
        """Activate the Great Expectations integration."""
        from zenml.integrations.great_expectations import materializers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Great Expectations integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() staticmethod

Activate the Great Expectations integration.

Source code in zenml/integrations/great_expectations/__init__.py
@staticmethod
def activate() -> None:
    """Activate the Great Expectations integration."""
    from zenml.integrations.great_expectations import materializers  # noqa
flavors() classmethod

Declare the stack component flavors for the Great Expectations integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/great_expectations/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Great Expectations integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.great_expectations.data_validators.GreatExpectationsDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the Great Expectations data validator for ZenML.

ge_data_validator

Implementation of the Great Expectations data validator.

GreatExpectationsDataValidator (BaseDataValidator) pydantic-model

Great Expectations data validator stack component.

Attributes:

Name Type Description
context_root_dir Optional[str]

location of an already initialized Great Expectations data context. If configured, the data validator will only be usable with local orchestrators.

context_config Optional[Dict[str, Any]]

in-line Great Expectations data context configuration.

configure_zenml_stores bool

if set, ZenML will automatically configure stores that use the Artifact Store as a backend. If neither context_root_dir nor context_config are set, this is the default behavior.

configure_local_docs bool

configure a local data docs site where Great Expectations docs are generated and can be visualized locally.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
class GreatExpectationsDataValidator(BaseDataValidator):
    """Great Expectations data validator stack component.

    Attributes:
        context_root_dir: location of an already initialized Great Expectations
            data context. If configured, the data validator will only be usable
            with local orchestrators.
        context_config: in-line Great Expectations data context configuration.
        configure_zenml_stores: if set, ZenML will automatically configure
            stores that use the Artifact Store as a backend. If neither
            `context_root_dir` nor `context_config` are set, this is the default
            behavior.
        configure_local_docs: configure a local data docs site where Great
            Expectations docs are generated and can be visualized locally.
    """

    context_root_dir: Optional[str] = None
    context_config: Optional[Dict[str, Any]] = None
    configure_zenml_stores: bool = False
    configure_local_docs: bool = True
    _context: BaseDataContext = None

    # Class Configuration
    FLAVOR: ClassVar[str] = GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR

    @validator("context_root_dir")
    def _ensure_valid_context_root_dir(
        cls, context_root_dir: Optional[str] = None
    ) -> Optional[str]:
        """Ensures that the root directory is an absolute path and points to an existing path.

        Args:
            context_root_dir: The context_root_dir value to validate.

        Returns:
            The context_root_dir if it is valid.

        Raises:
            ValueError: If the context_root_dir is not valid.
        """
        if context_root_dir:
            context_root_dir = os.path.abspath(context_root_dir)
            if not fileio.exists(context_root_dir):
                raise ValueError(
                    f"The Great Expectations context_root_dir value doesn't "
                    f"point to an existing data context path: {context_root_dir}"
                )
        return context_root_dir

    @root_validator(pre=True)
    def _convert_context_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Converts context_config from JSON/YAML string format to a dict.

        Args:
            values: Values passed to the object constructor

        Returns:
            Values passed to the object constructor

        Raises:
            ValueError: If the context_config value is not a valid JSON/YAML or
                if the GE configuration extracted from it fails GE validation.
        """
        context_config = values.get("context_config")
        if context_config and not isinstance(context_config, dict):
            try:
                context_config_dict = yaml.safe_load(context_config)
            except yaml.parser.ParserError as e:
                raise ValueError(
                    f"Malformed `context_config` value. Only JSON and YAML formats "
                    f"are supported: {str(e)}"
                )
            try:
                context_config = DataContextConfig(**context_config_dict)
                BaseDataContext(project_config=context_config)
            except Exception as e:
                raise ValueError(f"Invalid `context_config` value: {str(e)}")

            values["context_config"] = context_config_dict
        return values

    @classmethod
    def get_data_context(cls) -> BaseDataContext:
        """Get the Great Expectations data context managed by ZenML.

        Call this method to retrieve the data context managed by ZenML
        through the active Great Expectations data validator stack component.
        The default context as returned by the GE `get_context()` call will be
        returned if a Great Expectations data validator component is not present
        in the active stack.

        Returns:
            A Great Expectations data context managed by ZenML as configured
            through the active data validator stack component, or the default
            data context if a Great Expectations data validator stack component
            is not present in the active stack.
        """
        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        data_validator = repo.active_stack.data_validator
        if data_validator and isinstance(data_validator, cls):
            return data_validator.data_context

        logger.warning(
            f"The active stack does not include a Great Expectations data "
            f"validator component. It is highly recommended to register one to "
            f"be able to configure Great Expectations and use it together "
            f"with ZenML in a consistent manner that uses the same "
            f"configuration across different stack configurations and runtime "
            f"environments. You can create a new stack with a Great "
            f"Expectations data validator component or update your existing "
            f"stack to add this component, e.g.:\n\n"
            f"  `zenml data-validator register great_expectations "
            f"--flavor={GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR} ...`\n"
            f"  `zenml stack register stack-name -dv great_expectations ...`\n"
            f"  or:\n"
            f"  `zenml stack update stack-name -dv great_expectations`\n\n"
            f"Defaulting to the local Great Expectations data context..."
        )
        return ge.get_context()

    @property
    def local_path(self) -> Optional[str]:
        """Return a local path where this component stores information.

        If an existing local GE data context is used, it is
        interpreted as a local path that needs to be accessible in
        all runtime environments.

        Returns:
            The local path where this component stores information.
        """
        return self.context_root_dir

    def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
        """Generate a Great Expectations store configuration.

        Args:
            class_name: The store class name
            prefix: The path prefix for the ZenML store configuration

        Returns:
            A dictionary with the GE store configuration.
        """
        return {
            "class_name": class_name,
            "store_backend": {
                "module_name": ZenMLArtifactStoreBackend.__module__,
                "class_name": ZenMLArtifactStoreBackend.__name__,
                "prefix": f"{str(self.uuid)}/{prefix}",
            },
        }

    def get_data_docs_config(
        self, prefix: str, local: bool = False
    ) -> Dict[str, Any]:
        """Generate Great Expectations data docs configuration.

        Args:
            prefix: The path prefix for the ZenML data docs configuration
            local: Whether the data docs site is local or remote.

        Returns:
            A dictionary with the GE data docs site configuration.
        """
        if local:
            store_backend = {
                "class_name": "TupleFilesystemStoreBackend",
                "base_directory": f"{self.root_directory}/{prefix}",
            }
        else:
            store_backend = {
                "module_name": ZenMLArtifactStoreBackend.__module__,
                "class_name": ZenMLArtifactStoreBackend.__name__,
                "prefix": f"{str(self.uuid)}/{prefix}",
            }

        return {
            "class_name": "SiteBuilder",
            "store_backend": store_backend,
            "site_index_builder": {
                "class_name": "DefaultSiteIndexBuilder",
            },
        }

    @property
    def data_context(self) -> BaseDataContext:
        """Returns the Great Expectations data context configured for this component.

        Returns:
            The Great Expectations data context configured for this component.
        """
        if not self._context:
            expectations_store_name = "zenml_expectations_store"
            validations_store_name = "zenml_validations_store"
            checkpoint_store_name = "zenml_checkpoint_store"
            profiler_store_name = "zenml_profiler_store"
            evaluation_parameter_store_name = "evaluation_parameter_store"

            zenml_context_config = dict(
                stores={
                    expectations_store_name: self.get_store_config(
                        "ExpectationsStore", "expectations"
                    ),
                    validations_store_name: self.get_store_config(
                        "ValidationsStore", "validations"
                    ),
                    checkpoint_store_name: self.get_store_config(
                        "CheckpointStore", "checkpoints"
                    ),
                    profiler_store_name: self.get_store_config(
                        "ProfilerStore", "profilers"
                    ),
                    evaluation_parameter_store_name: {
                        "class_name": "EvaluationParameterStore"
                    },
                },
                expectations_store_name=expectations_store_name,
                validations_store_name=validations_store_name,
                checkpoint_store_name=checkpoint_store_name,
                profiler_store_name=profiler_store_name,
                evaluation_parameter_store_name=evaluation_parameter_store_name,
                data_docs_sites={
                    "zenml_artifact_store": self.get_data_docs_config(
                        "data_docs"
                    )
                },
            )

            configure_zenml_stores = self.configure_zenml_stores
            if self.context_root_dir:
                # initialize the local data context, if a local path was
                # configured
                self._context = DataContext(self.context_root_dir)
            else:
                # create an in-memory data context configuration that is not
                # backed by a local YAML file (see https://docs.greatexpectations.io/docs/guides/setup/configuring_data_contexts/how_to_instantiate_a_data_context_without_a_yml_file/).
                if self.context_config:
                    context_config = DataContextConfig(**self.context_config)
                else:
                    context_config = DataContextConfig(**zenml_context_config)
                    # skip adding the stores after initialization, as they are
                    # already baked in the initial configuration
                    configure_zenml_stores = False
                self._context = BaseDataContext(project_config=context_config)

            if configure_zenml_stores:
                self._context.config.expectations_store_name = (
                    expectations_store_name
                )
                self._context.config.validations_store_name = (
                    validations_store_name
                )
                self._context.config.checkpoint_store_name = (
                    checkpoint_store_name
                )
                self._context.config.profiler_store_name = profiler_store_name
                self._context.config.evaluation_parameter_store_name = (
                    evaluation_parameter_store_name
                )
                for store_name, store_config in zenml_context_config[  # type: ignore[attr-defined]
                    "stores"
                ].items():
                    self._context.add_store(
                        store_name=store_name,
                        store_config=store_config,
                    )
                for site_name, site_config in zenml_context_config[  # type: ignore[attr-defined]
                    "data_docs_sites"
                ].items():
                    self._context.config.data_docs_sites[
                        site_name
                    ] = site_config

            if self.configure_local_docs:

                repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
                artifact_store = repo.active_stack.artifact_store
                if artifact_store.FLAVOR != "local":
                    self._context.config.data_docs_sites[
                        "zenml_local"
                    ] = self.get_data_docs_config("data_docs", local=True)

        return self._context

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all local files concerning this data validator.

        Returns:
            Path to the root directory.
        """
        path = os.path.join(
            io_utils.get_global_config_directory(),
            self.FLAVOR,
            str(self.uuid),
        )

        if not os.path.exists(path):
            fileio.makedirs(path)

        return path
data_context: BaseDataContext property readonly

Returns the Great Expectations data context configured for this component.

Returns:

Type Description
BaseDataContext

The Great Expectations data context configured for this component.

local_path: Optional[str] property readonly

Return a local path where this component stores information.

If an existing local GE data context is used, it is interpreted as a local path that needs to be accessible in all runtime environments.

Returns:

Type Description
Optional[str]

The local path where this component stores information.

root_directory: str property readonly

Returns path to the root directory for all local files concerning this data validator.

Returns:

Type Description
str

Path to the root directory.

get_data_context() classmethod

Get the Great Expectations data context managed by ZenML.

Call this method to retrieve the data context managed by ZenML through the active Great Expectations data validator stack component. The default context as returned by the GE get_context() call will be returned if a Great Expectations data validator component is not present in the active stack.

Returns:

Type Description
BaseDataContext

A Great Expectations data context managed by ZenML as configured through the active data validator stack component, or the default data context if a Great Expectations data validator stack component is not present in the active stack.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
@classmethod
def get_data_context(cls) -> BaseDataContext:
    """Get the Great Expectations data context managed by ZenML.

    Call this method to retrieve the data context managed by ZenML
    through the active Great Expectations data validator stack component.
    The default context as returned by the GE `get_context()` call will be
    returned if a Great Expectations data validator component is not present
    in the active stack.

    Returns:
        A Great Expectations data context managed by ZenML as configured
        through the active data validator stack component, or the default
        data context if a Great Expectations data validator stack component
        is not present in the active stack.
    """
    repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
    data_validator = repo.active_stack.data_validator
    if data_validator and isinstance(data_validator, cls):
        return data_validator.data_context

    logger.warning(
        f"The active stack does not include a Great Expectations data "
        f"validator component. It is highly recommended to register one to "
        f"be able to configure Great Expectations and use it together "
        f"with ZenML in a consistent manner that uses the same "
        f"configuration across different stack configurations and runtime "
        f"environments. You can create a new stack with a Great "
        f"Expectations data validator component or update your existing "
        f"stack to add this component, e.g.:\n\n"
        f"  `zenml data-validator register great_expectations "
        f"--flavor={GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR} ...`\n"
        f"  `zenml stack register stack-name -dv great_expectations ...`\n"
        f"  or:\n"
        f"  `zenml stack update stack-name -dv great_expectations`\n\n"
        f"Defaulting to the local Great Expectations data context..."
    )
    return ge.get_context()
get_data_docs_config(self, prefix, local=False)

Generate Great Expectations data docs configuration.

Parameters:

Name Type Description Default
prefix str

The path prefix for the ZenML data docs configuration

required
local bool

Whether the data docs site is local or remote.

False

Returns:

Type Description
Dict[str, Any]

A dictionary with the GE data docs site configuration.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_data_docs_config(
    self, prefix: str, local: bool = False
) -> Dict[str, Any]:
    """Generate Great Expectations data docs configuration.

    Args:
        prefix: The path prefix for the ZenML data docs configuration
        local: Whether the data docs site is local or remote.

    Returns:
        A dictionary with the GE data docs site configuration.
    """
    if local:
        store_backend = {
            "class_name": "TupleFilesystemStoreBackend",
            "base_directory": f"{self.root_directory}/{prefix}",
        }
    else:
        store_backend = {
            "module_name": ZenMLArtifactStoreBackend.__module__,
            "class_name": ZenMLArtifactStoreBackend.__name__,
            "prefix": f"{str(self.uuid)}/{prefix}",
        }

    return {
        "class_name": "SiteBuilder",
        "store_backend": store_backend,
        "site_index_builder": {
            "class_name": "DefaultSiteIndexBuilder",
        },
    }
get_store_config(self, class_name, prefix)

Generate a Great Expectations store configuration.

Parameters:

Name Type Description Default
class_name str

The store class name

required
prefix str

The path prefix for the ZenML store configuration

required

Returns:

Type Description
Dict[str, Any]

A dictionary with the GE store configuration.

Source code in zenml/integrations/great_expectations/data_validators/ge_data_validator.py
def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]:
    """Generate a Great Expectations store configuration.

    Args:
        class_name: The store class name
        prefix: The path prefix for the ZenML store configuration

    Returns:
        A dictionary with the GE store configuration.
    """
    return {
        "class_name": class_name,
        "store_backend": {
            "module_name": ZenMLArtifactStoreBackend.__module__,
            "class_name": ZenMLArtifactStoreBackend.__name__,
            "prefix": f"{str(self.uuid)}/{prefix}",
        },
    }

ge_store_backend

Great Expectations store plugin for ZenML.

ZenMLArtifactStoreBackend (TupleStoreBackend)

Great Expectations store backend that uses the active ZenML Artifact Store as a store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
class ZenMLArtifactStoreBackend(TupleStoreBackend):  # type: ignore[misc]
    """Great Expectations store backend that uses the active ZenML Artifact Store as a store."""

    def __init__(
        self,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
        """Create a Great Expectations ZenML store backend instance.

        Args:
            prefix: Subpath prefix to use for this store backend.
            kwargs: Additional keyword arguments passed by the Great Expectations
                core. These are transparently passed to the `TupleStoreBackend`
                constructor.
        """
        super().__init__(**kwargs)

        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        artifact_store = repo.active_stack.artifact_store
        self.root_path = os.path.join(artifact_store.path, "great_expectations")

        # extract the protocol used in the artifact store root path
        protocols = [
            scheme
            for scheme in artifact_store.SUPPORTED_SCHEMES
            if self.root_path.startswith(scheme)
        ]
        if protocols:
            self.proto = protocols[0]
        else:
            self.proto = ""

        if prefix:
            if self.platform_specific_separator:
                prefix = prefix.strip(os.sep)
            prefix = prefix.strip("/")
        self.prefix = prefix

        # Initialize with store_backend_id if not part of an HTMLSiteStore
        if not self._suppress_store_backend_id:
            _ = self.store_backend_id

        self._config = {
            "prefix": prefix,
            "module_name": self.__class__.__module__,
            "class_name": self.__class__.__name__,
        }
        self._config.update(kwargs)
        filter_properties_dict(
            properties=self._config, clean_falsy=True, inplace=True
        )

    def _build_object_path(
        self, key: Tuple[str, ...], is_prefix: bool = False
    ) -> str:
        """Build a filepath corresponding to an object key.

        Args:
            key: Great Expectation object key.
            is_prefix: If True, the key will be interpreted as a prefix instead
                of a full key identifier.

        Returns:
            The file path pointing to where the object is stored.
        """
        if not isinstance(key, tuple):
            key = key.to_tuple()  # type: ignore[attr-defined]
        if not is_prefix:
            object_relative_path = self._convert_key_to_filepath(key)
        elif key:
            object_relative_path = os.path.join(*key)
        else:
            object_relative_path = ""
        if self.prefix:
            object_key = os.path.join(self.prefix, object_relative_path)
        else:
            object_key = object_relative_path
        return os.path.join(self.root_path, object_key)

    def _get(self, key: Tuple[str, ...]) -> str:
        """Get the value of an object from the store.

        Args:
            key: object key identifier.

        Raises:
            InvalidKeyError: if the key doesn't point to an existing object.

        Returns:
            str: the object's contents
        """
        filepath: str = self._build_object_path(key)
        if fileio.exists(filepath):
            contents = io_utils.read_file_contents_as_string(filepath).rstrip(
                "\n"
            )
        else:
            raise InvalidKeyError(
                f"Unable to retrieve object from {self.__class__.__name__} with "
                f"the following Key: {str(filepath)}"
            )
        return contents

    def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str:
        """Set the value of an object in the store.

        Args:
            key: object key identifier.
            value: object value to set.
            kwargs: additional keyword arguments (ignored).

        Returns:
            The file path where the object was stored.
        """
        filepath: str = self._build_object_path(key)
        if not io_utils.is_remote(filepath):
            parent_dir = str(Path(filepath).parent)
            os.makedirs(parent_dir, exist_ok=True)

        with fileio.open(filepath, "wb") as outfile:
            if isinstance(value, str):
                outfile.write(value.encode("utf-8"))
            else:
                outfile.write(value)
        return filepath

    def _move(
        self,
        source_key: Tuple[str, ...],
        dest_key: Tuple[str, ...],
        **kwargs: Any,
    ) -> None:
        """Associate an object with a different key in the store.

        Args:
            source_key: current object key identifier.
            dest_key: new object key identifier.
            kwargs: additional keyword arguments (ignored).
        """
        source_path = self._build_object_path(source_key)
        dest_path = self._build_object_path(dest_key)

        if fileio.exists(source_path):
            if not io_utils.is_remote(dest_path):
                parent_dir = str(Path(dest_path).parent)
                os.makedirs(parent_dir, exist_ok=True)
            fileio.rename(source_path, dest_path, overwrite=True)

    def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
        """List the keys of all objects identified by a partial key.

        Args:
            prefix: partial object key identifier.

        Returns:
            List of keys identifying all objects present in the store that
            match the input partial key.
        """
        key_list = []
        list_path = self._build_object_path(prefix, is_prefix=True)
        root_path = self._build_object_path(tuple(), is_prefix=True)
        for root, dirs, files in fileio.walk(list_path):
            for file_ in files:
                filepath = os.path.relpath(
                    os.path.join(str(root), str(file_)), root_path
                )

                if self.filepath_prefix and not filepath.startswith(
                    self.filepath_prefix
                ):
                    continue
                elif self.filepath_suffix and not filepath.endswith(
                    self.filepath_suffix
                ):
                    continue
                key = self._convert_filepath_to_key(filepath)
                if key and not self.is_ignored_key(key):
                    key_list.append(key)
        return key_list

    def remove_key(self, key: Tuple[str, ...]) -> bool:
        """Delete an object from the store.

        Args:
            key: object key identifier.

        Returns:
            True if the object existed in the store and was removed, otherwise
            False.
        """
        filepath: str = self._build_object_path(key)

        if fileio.exists(filepath):
            fileio.remove(filepath)
            if not io_utils.is_remote(filepath):
                parent_dir = str(Path(filepath).parent)
                self.rrmdir(self.root_path, str(parent_dir))
            return True
        return False

    def _has_key(self, key: Tuple[str, ...]) -> bool:
        """Check if an object is present in the store.

        Args:
            key: object key identifier.

        Returns:
            True if the object is present in the store, otherwise False.
        """
        filepath: str = self._build_object_path(key)
        result = fileio.exists(filepath)
        return result

    def get_url_for_key(
        self, key: Tuple[str, ...], protocol: Optional[str] = None
    ) -> str:
        """Get the URL of an object in the store.

        Args:
            key: object key identifier.
            protocol: optional protocol to use instead of the store protocol.

        Returns:
            The URL of the object in the store.
        """
        filepath = self._build_object_path(key)
        if not protocol and not io_utils.is_remote(filepath):
            protocol = "file:"
        if protocol:
            filepath = filepath.replace(self.proto, f"{protocol}//", 1)

        return filepath

    def get_public_url_for_key(
        self, key: str, protocol: Optional[str] = None
    ) -> str:
        """Get the public URL of an object in the store.

        Args:
            key: object key identifier.
            protocol: optional protocol to use instead of the store protocol.

        Returns:
            The public URL where the object can be accessed.

        Raises:
            StoreBackendError: if a `base_public_path` attribute was not
                configured for the store.
        """
        if not self.base_public_path:
            raise StoreBackendError(
                f"Error: No base_public_path was configured! A public URL was "
                f"requested but `base_public_path` was not configured for the "
                f"{self.__class__.__name__}"
            )
        filepath = self._convert_key_to_filepath(key)
        public_url = self.base_public_path + filepath.replace(self.proto, "")
        return cast(str, public_url)

    @staticmethod
    def rrmdir(start_path: str, end_path: str) -> None:
        """Recursively removes empty dirs between start_path and end_path inclusive.

        Args:
            start_path: Directory to use as a starting point.
            end_path: Directory to use as a destination point.
        """
        while not os.listdir(end_path) and start_path != end_path:
            os.rmdir(end_path)
            end_path = os.path.dirname(end_path)

    @property
    def config(self) -> Dict[str, Any]:
        """Get the store configuration.

        Returns:
            The store configuration.
        """
        return self._config
config: Dict[str, Any] property readonly

Get the store configuration.

Returns:

Type Description
Dict[str, Any]

The store configuration.

__init__(self, prefix='', **kwargs) special

Create a Great Expectations ZenML store backend instance.

Parameters:

Name Type Description Default
prefix str

Subpath prefix to use for this store backend.

''
kwargs Any

Additional keyword arguments passed by the Great Expectations core. These are transparently passed to the TupleStoreBackend constructor.

{}
Source code in zenml/integrations/great_expectations/ge_store_backend.py
def __init__(
    self,
    prefix: str = "",
    **kwargs: Any,
) -> None:
    """Create a Great Expectations ZenML store backend instance.

    Args:
        prefix: Subpath prefix to use for this store backend.
        kwargs: Additional keyword arguments passed by the Great Expectations
            core. These are transparently passed to the `TupleStoreBackend`
            constructor.
    """
    super().__init__(**kwargs)

    repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
    artifact_store = repo.active_stack.artifact_store
    self.root_path = os.path.join(artifact_store.path, "great_expectations")

    # extract the protocol used in the artifact store root path
    protocols = [
        scheme
        for scheme in artifact_store.SUPPORTED_SCHEMES
        if self.root_path.startswith(scheme)
    ]
    if protocols:
        self.proto = protocols[0]
    else:
        self.proto = ""

    if prefix:
        if self.platform_specific_separator:
            prefix = prefix.strip(os.sep)
        prefix = prefix.strip("/")
    self.prefix = prefix

    # Initialize with store_backend_id if not part of an HTMLSiteStore
    if not self._suppress_store_backend_id:
        _ = self.store_backend_id

    self._config = {
        "prefix": prefix,
        "module_name": self.__class__.__module__,
        "class_name": self.__class__.__name__,
    }
    self._config.update(kwargs)
    filter_properties_dict(
        properties=self._config, clean_falsy=True, inplace=True
    )
get_public_url_for_key(self, key, protocol=None)

Get the public URL of an object in the store.

Parameters:

Name Type Description Default
key str

object key identifier.

required
protocol Optional[str]

optional protocol to use instead of the store protocol.

None

Returns:

Type Description
str

The public URL where the object can be accessed.

Exceptions:

Type Description
StoreBackendError

if a base_public_path attribute was not configured for the store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_public_url_for_key(
    self, key: str, protocol: Optional[str] = None
) -> str:
    """Get the public URL of an object in the store.

    Args:
        key: object key identifier.
        protocol: optional protocol to use instead of the store protocol.

    Returns:
        The public URL where the object can be accessed.

    Raises:
        StoreBackendError: if a `base_public_path` attribute was not
            configured for the store.
    """
    if not self.base_public_path:
        raise StoreBackendError(
            f"Error: No base_public_path was configured! A public URL was "
            f"requested but `base_public_path` was not configured for the "
            f"{self.__class__.__name__}"
        )
    filepath = self._convert_key_to_filepath(key)
    public_url = self.base_public_path + filepath.replace(self.proto, "")
    return cast(str, public_url)
get_url_for_key(self, key, protocol=None)

Get the URL of an object in the store.

Parameters:

Name Type Description Default
key Tuple[str, ...]

object key identifier.

required
protocol Optional[str]

optional protocol to use instead of the store protocol.

None

Returns:

Type Description
str

The URL of the object in the store.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def get_url_for_key(
    self, key: Tuple[str, ...], protocol: Optional[str] = None
) -> str:
    """Get the URL of an object in the store.

    Args:
        key: object key identifier.
        protocol: optional protocol to use instead of the store protocol.

    Returns:
        The URL of the object in the store.
    """
    filepath = self._build_object_path(key)
    if not protocol and not io_utils.is_remote(filepath):
        protocol = "file:"
    if protocol:
        filepath = filepath.replace(self.proto, f"{protocol}//", 1)

    return filepath
list_keys(self, prefix=())

List the keys of all objects identified by a partial key.

Parameters:

Name Type Description Default
prefix Tuple[str, ...]

partial object key identifier.

()

Returns:

Type Description
List[Tuple[str, ...]]

List of keys identifying all objects present in the store that match the input partial key.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]:
    """List the keys of all objects identified by a partial key.

    Args:
        prefix: partial object key identifier.

    Returns:
        List of keys identifying all objects present in the store that
        match the input partial key.
    """
    key_list = []
    list_path = self._build_object_path(prefix, is_prefix=True)
    root_path = self._build_object_path(tuple(), is_prefix=True)
    for root, dirs, files in fileio.walk(list_path):
        for file_ in files:
            filepath = os.path.relpath(
                os.path.join(str(root), str(file_)), root_path
            )

            if self.filepath_prefix and not filepath.startswith(
                self.filepath_prefix
            ):
                continue
            elif self.filepath_suffix and not filepath.endswith(
                self.filepath_suffix
            ):
                continue
            key = self._convert_filepath_to_key(filepath)
            if key and not self.is_ignored_key(key):
                key_list.append(key)
    return key_list
remove_key(self, key)

Delete an object from the store.

Parameters:

Name Type Description Default
key Tuple[str, ...]

object key identifier.

required

Returns:

Type Description
bool

True if the object existed in the store and was removed, otherwise False.

Source code in zenml/integrations/great_expectations/ge_store_backend.py
def remove_key(self, key: Tuple[str, ...]) -> bool:
    """Delete an object from the store.

    Args:
        key: object key identifier.

    Returns:
        True if the object existed in the store and was removed, otherwise
        False.
    """
    filepath: str = self._build_object_path(key)

    if fileio.exists(filepath):
        fileio.remove(filepath)
        if not io_utils.is_remote(filepath):
            parent_dir = str(Path(filepath).parent)
            self.rrmdir(self.root_path, str(parent_dir))
        return True
    return False
rrmdir(start_path, end_path) staticmethod

Recursively removes empty dirs between start_path and end_path inclusive.

Parameters:

Name Type Description Default
start_path str

Directory to use as a starting point.

required
end_path str

Directory to use as a destination point.

required
Source code in zenml/integrations/great_expectations/ge_store_backend.py
@staticmethod
def rrmdir(start_path: str, end_path: str) -> None:
    """Recursively removes empty dirs between start_path and end_path inclusive.

    Args:
        start_path: Directory to use as a starting point.
        end_path: Directory to use as a destination point.
    """
    while not os.listdir(end_path) and start_path != end_path:
        os.rmdir(end_path)
        end_path = os.path.dirname(end_path)

materializers special

Materializers for Great Expectation serializable objects.

ge_materializer

Implementation of the Great Expectations materializers.

GreatExpectationsMaterializer (BaseMaterializer)

Materializer to read/write Great Expectation objects.

Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
class GreatExpectationsMaterializer(BaseMaterializer):
    """Materializer to read/write Great Expectation objects."""

    ASSOCIATED_TYPES = (
        ExpectationSuite,
        CheckpointResult,
    )
    ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)

    @staticmethod
    def preprocess_checkpoint_result_dict(
        artifact_dict: Dict[str, Any]
    ) -> None:
        """Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

        The GE CheckpointResult object is not fully de-serializable
        due to some missing code in the GE codebase. We need to compensate
        for this by manually converting some of the attributes to
        their correct data types.

        Args:
            artifact_dict: A dict containing the GE checkpoint result.
        """

        def preprocess_run_result(key: str, value: Any) -> Any:
            if key == "validation_result":
                return ExpectationSuiteValidationResult(**value)
            return value

        artifact_dict["checkpoint_config"] = CheckpointConfig(
            **artifact_dict["checkpoint_config"]
        )
        validation_dict = {}
        for result_ident, results in artifact_dict["run_results"].items():
            validation_ident = (
                ValidationResultIdentifier.from_fixed_length_tuple(
                    result_ident.split("::")[1].split("/")
                )
            )
            validation_results = {
                result_name: preprocess_run_result(result_name, result)
                for result_name, result in results.items()
            }
            validation_dict[validation_ident] = validation_results
        artifact_dict["run_results"] = validation_dict

    def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
        """Reads and returns a Great Expectations object.

        Args:
            data_type: The type of the data to read.

        Returns:
            A loaded Great Expectations object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
        artifact_dict = yaml_utils.read_json(filepath)
        data_type = import_class_by_path(artifact_dict.pop("data_type"))

        if data_type is CheckpointResult:
            self.preprocess_checkpoint_result_dict(artifact_dict)

        return data_type(**artifact_dict)

    def handle_return(self, obj: SerializableDictDot) -> None:
        """Writes a Great Expectations object.

        Args:
            obj: A Great Expectations object.
        """
        super().handle_return(obj)
        filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
        artifact_dict = obj.to_json_dict()
        artifact_type = type(obj)
        artifact_dict[
            "data_type"
        ] = f"{artifact_type.__module__}.{artifact_type.__name__}"
        yaml_utils.write_json(filepath, artifact_dict)
handle_input(self, data_type)

Reads and returns a Great Expectations object.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
SerializableDictDot

A loaded Great Expectations object.

Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_input(self, data_type: Type[Any]) -> SerializableDictDot:
    """Reads and returns a Great Expectations object.

    Args:
        data_type: The type of the data to read.

    Returns:
        A loaded Great Expectations object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
    artifact_dict = yaml_utils.read_json(filepath)
    data_type = import_class_by_path(artifact_dict.pop("data_type"))

    if data_type is CheckpointResult:
        self.preprocess_checkpoint_result_dict(artifact_dict)

    return data_type(**artifact_dict)
handle_return(self, obj)

Writes a Great Expectations object.

Parameters:

Name Type Description Default
obj SerializableDictDot

A Great Expectations object.

required
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
def handle_return(self, obj: SerializableDictDot) -> None:
    """Writes a Great Expectations object.

    Args:
        obj: A Great Expectations object.
    """
    super().handle_return(obj)
    filepath = os.path.join(self.artifact.uri, ARTIFACT_FILENAME)
    artifact_dict = obj.to_json_dict()
    artifact_type = type(obj)
    artifact_dict[
        "data_type"
    ] = f"{artifact_type.__module__}.{artifact_type.__name__}"
    yaml_utils.write_json(filepath, artifact_dict)
preprocess_checkpoint_result_dict(artifact_dict) staticmethod

Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

The GE CheckpointResult object is not fully de-serializable due to some missing code in the GE codebase. We need to compensate for this by manually converting some of the attributes to their correct data types.

Parameters:

Name Type Description Default
artifact_dict Dict[str, Any]

A dict containing the GE checkpoint result.

required
Source code in zenml/integrations/great_expectations/materializers/ge_materializer.py
@staticmethod
def preprocess_checkpoint_result_dict(
    artifact_dict: Dict[str, Any]
) -> None:
    """Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object.

    The GE CheckpointResult object is not fully de-serializable
    due to some missing code in the GE codebase. We need to compensate
    for this by manually converting some of the attributes to
    their correct data types.

    Args:
        artifact_dict: A dict containing the GE checkpoint result.
    """

    def preprocess_run_result(key: str, value: Any) -> Any:
        if key == "validation_result":
            return ExpectationSuiteValidationResult(**value)
        return value

    artifact_dict["checkpoint_config"] = CheckpointConfig(
        **artifact_dict["checkpoint_config"]
    )
    validation_dict = {}
    for result_ident, results in artifact_dict["run_results"].items():
        validation_ident = (
            ValidationResultIdentifier.from_fixed_length_tuple(
                result_ident.split("::")[1].split("/")
            )
        )
        validation_results = {
            result_name: preprocess_run_result(result_name, result)
            for result_name, result in results.items()
        }
        validation_dict[validation_ident] = validation_results
    artifact_dict["run_results"] = validation_dict

steps special

Great Expectations data profiling and validation standard steps.

ge_profiler

Great Expectations data profiling standard step.

GreatExpectationsProfilerConfig (BaseStepConfig) pydantic-model

Config class for a Great Expectations profiler step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to create or update.

data_asset_name Optional[str]

The name of the data asset to run the expectation suite on.

profiler_kwargs Optional[Dict[str, Any]]

A dictionary of keyword arguments to pass to the profiler.

overwrite_existing_suite bool

Whether to overwrite an existing expectation suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
    """Config class for a Great Expectations profiler step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to create
            or update.
        data_asset_name: The name of the data asset to run the expectation suite on.
        profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
        overwrite_existing_suite: Whether to overwrite an existing expectation suite.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
    overwrite_existing_suite: bool = True
GreatExpectationsProfilerStep (BaseStep)

Standard Great Expectations profiling step implementation.

Use this standard Great Expectations profiling step to build an Expectation Suite automatically by running a UserConfigurableProfiler on an input dataset as covered in the official GE documentation.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerStep(BaseStep):
    """Standard Great Expectations profiling step implementation.

    Use this standard Great Expectations profiling step to build an Expectation
    Suite automatically by running a UserConfigurableProfiler on an input
    dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/expectations/how_to_create_and_edit_expectations_with_a_profiler).
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: GreatExpectationsProfilerConfig,
    ) -> ExpectationSuite:
        """Standard Great Expectations data profiling step entrypoint.

        Args:
            dataset: The dataset from which the expectation suite will be inferred.
            config: The configuration for the step.

        Returns:
            The generated Great Expectations suite.
        """
        context = GreatExpectationsDataValidator.get_data_context()

        suite_exists = False
        if context.expectations_store.has_key(  # noqa
            ExpectationSuiteIdentifier(config.expectation_suite_name)
        ):
            suite_exists = True
            suite = context.get_expectation_suite(config.expectation_suite_name)
            if not config.overwrite_existing_suite:
                logger.info(
                    f"Expectation Suite `{config.expectation_suite_name}` "
                    f"already exists and `overwrite_existing_suite` is not set "
                    f"in the step configuration. Skipping re-running the "
                    f"profiler."
                )
                return suite

        batch_request = create_batch_request(
            context, dataset, config.data_asset_name
        )

        try:
            if suite_exists:
                validator = context.get_validator(
                    batch_request=batch_request,
                    expectation_suite_name=config.expectation_suite_name,
                )
            else:
                validator = context.get_validator(
                    batch_request=batch_request,
                    create_expectation_suite_with_name=config.expectation_suite_name,
                )

            profiler = UserConfigurableProfiler(
                profile_dataset=validator, **config.profiler_kwargs
            )

            suite = profiler.build_suite()
            context.save_expectation_suite(
                expectation_suite=suite,
                expectation_suite_name=config.expectation_suite_name,
            )

            context.build_data_docs()
        finally:
            context.delete_datasource(batch_request.datasource_name)

        return suite
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for a Great Expectations profiler step.

Attributes:

Name Type Description
expectation_suite_name str

The name of the expectation suite to create or update.

data_asset_name Optional[str]

The name of the data asset to run the expectation suite on.

profiler_kwargs Optional[Dict[str, Any]]

A dictionary of keyword arguments to pass to the profiler.

overwrite_existing_suite bool

Whether to overwrite an existing expectation suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
class GreatExpectationsProfilerConfig(BaseStepConfig):
    """Config class for a Great Expectations profiler step.

    Attributes:
        expectation_suite_name: The name of the expectation suite to create
            or update.
        data_asset_name: The name of the data asset to run the expectation suite on.
        profiler_kwargs: A dictionary of keyword arguments to pass to the profiler.
        overwrite_existing_suite: Whether to overwrite an existing expectation suite.
    """

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    profiler_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
    overwrite_existing_suite: bool = True
entrypoint(self, dataset, config)

Standard Great Expectations data profiling step entrypoint.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset from which the expectation suite will be inferred.

required
config GreatExpectationsProfilerConfig

The configuration for the step.

required

Returns:

Type Description
ExpectationSuite

The generated Great Expectations suite.

Source code in zenml/integrations/great_expectations/steps/ge_profiler.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: GreatExpectationsProfilerConfig,
) -> ExpectationSuite:
    """Standard Great Expectations data profiling step entrypoint.

    Args:
        dataset: The dataset from which the expectation suite will be inferred.
        config: The configuration for the step.

    Returns:
        The generated Great Expectations suite.
    """
    context = GreatExpectationsDataValidator.get_data_context()

    suite_exists = False
    if context.expectations_store.has_key(  # noqa
        ExpectationSuiteIdentifier(config.expectation_suite_name)
    ):
        suite_exists = True
        suite = context.get_expectation_suite(config.expectation_suite_name)
        if not config.overwrite_existing_suite:
            logger.info(
                f"Expectation Suite `{config.expectation_suite_name}` "
                f"already exists and `overwrite_existing_suite` is not set "
                f"in the step configuration. Skipping re-running the "
                f"profiler."
            )
            return suite

    batch_request = create_batch_request(
        context, dataset, config.data_asset_name
    )

    try:
        if suite_exists:
            validator = context.get_validator(
                batch_request=batch_request,
                expectation_suite_name=config.expectation_suite_name,
            )
        else:
            validator = context.get_validator(
                batch_request=batch_request,
                create_expectation_suite_with_name=config.expectation_suite_name,
            )

        profiler = UserConfigurableProfiler(
            profile_dataset=validator, **config.profiler_kwargs
        )

        suite = profiler.build_suite()
        context.save_expectation_suite(
            expectation_suite=suite,
            expectation_suite_name=config.expectation_suite_name,
        )

        context.build_data_docs()
    finally:
        context.delete_datasource(batch_request.datasource_name)

    return suite
ge_validator

Great Expectations data validation standard step.

GreatExpectationsValidatorConfig (BaseStepConfig) pydantic-model

Config class for a Great Expectations checkpoint step.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
    """Config class for a Great Expectations checkpoint step."""

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    action_list: Optional[List[Dict[str, Any]]] = None
    exit_on_error: bool = False
GreatExpectationsValidatorStep (BaseStep)

Standard Great Expectations data validation step implementation.

Use this standard Great Expectations data validation step to run an existing Expectation Suite on an input dataset as covered in the official GE documentation.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorStep(BaseStep):
    """Standard Great Expectations data validation step implementation.

    Use this standard Great Expectations data validation step to run an
    existing Expectation Suite on an input dataset [as covered in the official GE documentation](https://docs.greatexpectations.io/docs/guides/validation/how_to_validate_data_by_running_a_checkpoint).
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        condition: bool,
        config: GreatExpectationsValidatorConfig,
    ) -> CheckpointResult:
        """Standard Great Expectations data validation step entrypoint.

        Args:
            dataset: The dataset to run the expectation suite on.
            condition: This dummy argument can be used as a condition to enforce
                that this step is only run after another step has completed. This
                is useful for example if the Expectation Suite used to validate
                the data is computed in a `GreatExpectationsProfilerStep` that
                is part of the same pipeline.
            config: The configuration for the step.

        Returns:
            The Great Expectations validation (checkpoint) result.
        """
        # get pipeline name, step name and run id
        step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
        run_id = step_env.pipeline_run_id
        step_name = step_env.step_name

        context = GreatExpectationsDataValidator.get_data_context()

        checkpoint_name = f"{run_id}_{step_name}"

        batch_request = create_batch_request(
            context, dataset, config.data_asset_name
        )

        action_list = config.action_list or [
            {
                "name": "store_validation_result",
                "action": {"class_name": "StoreValidationResultAction"},
            },
            {
                "name": "store_evaluation_params",
                "action": {"class_name": "StoreEvaluationParametersAction"},
            },
            {
                "name": "update_data_docs",
                "action": {"class_name": "UpdateDataDocsAction"},
            },
        ]

        checkpoint_config = {
            "name": checkpoint_name,
            "run_name_template": f"{run_id}",
            "config_version": 1,
            "class_name": "Checkpoint",
            "expectation_suite_name": config.expectation_suite_name,
            "action_list": action_list,
        }
        context.add_checkpoint(**checkpoint_config)

        try:
            results = context.run_checkpoint(
                checkpoint_name=checkpoint_name,
                validations=[{"batch_request": batch_request}],
            )
        finally:
            context.delete_datasource(batch_request.datasource_name)
            context.delete_checkpoint(checkpoint_name)

        return results
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for a Great Expectations checkpoint step.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
class GreatExpectationsValidatorConfig(BaseStepConfig):
    """Config class for a Great Expectations checkpoint step."""

    expectation_suite_name: str
    data_asset_name: Optional[str] = None
    action_list: Optional[List[Dict[str, Any]]] = None
    exit_on_error: bool = False
entrypoint(self, dataset, condition, config)

Standard Great Expectations data validation step entrypoint.

Parameters:

Name Type Description Default
dataset DataFrame

The dataset to run the expectation suite on.

required
condition bool

This dummy argument can be used as a condition to enforce that this step is only run after another step has completed. This is useful for example if the Expectation Suite used to validate the data is computed in a GreatExpectationsProfilerStep that is part of the same pipeline.

required
config GreatExpectationsValidatorConfig

The configuration for the step.

required

Returns:

Type Description
CheckpointResult

The Great Expectations validation (checkpoint) result.

Source code in zenml/integrations/great_expectations/steps/ge_validator.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    condition: bool,
    config: GreatExpectationsValidatorConfig,
) -> CheckpointResult:
    """Standard Great Expectations data validation step entrypoint.

    Args:
        dataset: The dataset to run the expectation suite on.
        condition: This dummy argument can be used as a condition to enforce
            that this step is only run after another step has completed. This
            is useful for example if the Expectation Suite used to validate
            the data is computed in a `GreatExpectationsProfilerStep` that
            is part of the same pipeline.
        config: The configuration for the step.

    Returns:
        The Great Expectations validation (checkpoint) result.
    """
    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    context = GreatExpectationsDataValidator.get_data_context()

    checkpoint_name = f"{run_id}_{step_name}"

    batch_request = create_batch_request(
        context, dataset, config.data_asset_name
    )

    action_list = config.action_list or [
        {
            "name": "store_validation_result",
            "action": {"class_name": "StoreValidationResultAction"},
        },
        {
            "name": "store_evaluation_params",
            "action": {"class_name": "StoreEvaluationParametersAction"},
        },
        {
            "name": "update_data_docs",
            "action": {"class_name": "UpdateDataDocsAction"},
        },
    ]

    checkpoint_config = {
        "name": checkpoint_name,
        "run_name_template": f"{run_id}",
        "config_version": 1,
        "class_name": "Checkpoint",
        "expectation_suite_name": config.expectation_suite_name,
        "action_list": action_list,
    }
    context.add_checkpoint(**checkpoint_config)

    try:
        results = context.run_checkpoint(
            checkpoint_name=checkpoint_name,
            validations=[{"batch_request": batch_request}],
        )
    finally:
        context.delete_datasource(batch_request.datasource_name)
        context.delete_checkpoint(checkpoint_name)

    return results
utils

Great Expectations data profiling standard step.

create_batch_request(context, dataset, data_asset_name)

Create a temporary runtime GE batch request from a dataset step artifact.

Parameters:

Name Type Description Default
context BaseDataContext

Great Expectations data context.

required
dataset DataFrame

Input dataset.

required
data_asset_name Optional[str]

Optional custom name for the data asset.

required

Returns:

Type Description
RuntimeBatchRequest

A Great Expectations runtime batch request.

Source code in zenml/integrations/great_expectations/steps/utils.py
def create_batch_request(
    context: BaseDataContext,
    dataset: pd.DataFrame,
    data_asset_name: Optional[str],
) -> RuntimeBatchRequest:
    """Create a temporary runtime GE batch request from a dataset step artifact.

    Args:
        context: Great Expectations data context.
        dataset: Input dataset.
        data_asset_name: Optional custom name for the data asset.

    Returns:
        A Great Expectations runtime batch request.
    """
    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    datasource_name = f"{run_id}_{step_name}"
    data_connector_name = datasource_name
    data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}"
    batch_identifier = "default"

    datasource_config = {
        "name": datasource_name,
        "class_name": "Datasource",
        "module_name": "great_expectations.datasource",
        "execution_engine": {
            "module_name": "great_expectations.execution_engine",
            "class_name": "PandasExecutionEngine",
        },
        "data_connectors": {
            data_connector_name: {
                "class_name": "RuntimeDataConnector",
                "batch_identifiers": [batch_identifier],
            },
        },
    }

    context.add_datasource(**datasource_config)
    batch_request = RuntimeBatchRequest(
        datasource_name=datasource_name,
        data_connector_name=data_connector_name,
        data_asset_name=data_asset_name,
        runtime_parameters={"batch_data": dataset},
        batch_identifiers={batch_identifier: batch_identifier},
    )

    return batch_request

visualizers special

Great Expectations visualizers for expectation suites and validation results.

ge_visualizer

Great Expectations visualizers for expectation suites and validation results.

GreatExpectationsVisualizer (BaseStepVisualizer)

The implementation of a Great Expectations Visualizer.

Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
class GreatExpectationsVisualizer(BaseStepVisualizer):
    """The implementation of a Great Expectations Visualizer."""

    def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
        """Method to visualize a Great Expectations resource.

        Args:
            object: StepView fetched from run.get_step().
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for artifact_view in object.outputs.values():
            # filter out anything but Great Expectations data analysis artifacts
            if (
                artifact_view.type == DataAnalysisArtifact.__name__
                and artifact_view.data_type.startswith("great_expectations.")
            ):
                artifact = artifact_view.read()
                if isinstance(artifact, CheckpointResult):
                    result = cast(CheckpointResult, artifact)
                    identifier = next(iter(result.run_results.keys()))
                else:
                    suite = cast(ExpectationSuite, artifact)
                    identifier = ExpectationSuiteIdentifier(
                        suite.expectation_suite_name
                    )

                context = GreatExpectationsDataValidator.get_data_context()
                context.open_data_docs(identifier)
visualize(self, object, *args, **kwargs)

Method to visualize a Great Expectations resource.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/great_expectations/visualizers/ge_visualizer.py
def visualize(self, object: StepView, *args: Any, **kwargs: Any) -> None:
    """Method to visualize a Great Expectations resource.

    Args:
        object: StepView fetched from run.get_step().
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for artifact_view in object.outputs.values():
        # filter out anything but Great Expectations data analysis artifacts
        if (
            artifact_view.type == DataAnalysisArtifact.__name__
            and artifact_view.data_type.startswith("great_expectations.")
        ):
            artifact = artifact_view.read()
            if isinstance(artifact, CheckpointResult):
                result = cast(CheckpointResult, artifact)
                identifier = next(iter(result.run_results.keys()))
            else:
                suite = cast(ExpectationSuite, artifact)
                identifier = ExpectationSuiteIdentifier(
                    suite.expectation_suite_name
                )

            context = GreatExpectationsDataValidator.get_data_context()
            context.open_data_docs(identifier)

huggingface special

Initialization of the Huggingface integration.

HuggingfaceIntegration (Integration)

Definition of Huggingface integration for ZenML.

Source code in zenml/integrations/huggingface/__init__.py
class HuggingfaceIntegration(Integration):
    """Definition of Huggingface integration for ZenML."""

    NAME = HUGGINGFACE
    REQUIREMENTS = ["transformers", "datasets"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.huggingface import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/huggingface/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.huggingface import materializers  # noqa

materializers special

Initialization of Huggingface materializers.

huggingface_datasets_materializer

Implementation of the Huggingface datasets materializer.

HFDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from huggingface datasets.

Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
class HFDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from huggingface datasets."""

    ASSOCIATED_TYPES = (Dataset, DatasetDict)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Dataset:
        """Reads Dataset.

        Args:
            data_type: The type of the dataset to read.

        Returns:
            The dataset read from the specified dir.
        """
        super().handle_input(data_type)

        return load_from_disk(
            os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
        )

    def handle_return(self, ds: Type[Any]) -> None:
        """Writes a Dataset to the specified dir.

        Args:
            ds: The Dataset to write.
        """
        super().handle_return(ds)
        temp_dir = TemporaryDirectory()
        ds.save_to_disk(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
        )
handle_input(self, data_type)

Reads Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the dataset to read.

required

Returns:

Type Description
Dataset

The dataset read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
    """Reads Dataset.

    Args:
        data_type: The type of the dataset to read.

    Returns:
        The dataset read from the specified dir.
    """
    super().handle_input(data_type)

    return load_from_disk(
        os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
    )
handle_return(self, ds)

Writes a Dataset to the specified dir.

Parameters:

Name Type Description Default
ds Type[Any]

The Dataset to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py
def handle_return(self, ds: Type[Any]) -> None:
    """Writes a Dataset to the specified dir.

    Args:
        ds: The Dataset to write.
    """
    super().handle_return(ds)
    temp_dir = TemporaryDirectory()
    ds.save_to_disk(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name, os.path.join(self.artifact.uri, DEFAULT_DATASET_DIR)
    )
huggingface_pt_model_materializer

Implementation of the Huggingface PyTorch model materializer.

HFPTModelMaterializer (BaseMaterializer)

Materializer to read torch model to and from huggingface pretrained model.

Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
class HFPTModelMaterializer(BaseMaterializer):
    """Materializer to read torch model to and from huggingface pretrained model."""

    ASSOCIATED_TYPES = (PreTrainedModel,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
        """Reads HFModel.

        Args:
            data_type: The type of the model to read.

        Returns:
            The model read from the specified dir.
        """
        super().handle_input(data_type)

        config = AutoConfig.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
        )
        architecture = config.architectures[0]
        model_cls = getattr(
            importlib.import_module("transformers"), architecture
        )
        return model_cls.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
        )

    def handle_return(self, model: Type[Any]) -> None:
        """Writes a Model to the specified dir.

        Args:
            model: The Torch Model to write.
        """
        super().handle_return(model)
        temp_dir = TemporaryDirectory()
        model.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
        )
handle_input(self, data_type)

Reads HFModel.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to read.

required

Returns:

Type Description
PreTrainedModel

The model read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedModel:
    """Reads HFModel.

    Args:
        data_type: The type of the model to read.

    Returns:
        The model read from the specified dir.
    """
    super().handle_input(data_type)

    config = AutoConfig.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
    )
    architecture = config.architectures[0]
    model_cls = getattr(
        importlib.import_module("transformers"), architecture
    )
    return model_cls.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR)
    )
handle_return(self, model)

Writes a Model to the specified dir.

Parameters:

Name Type Description Default
model Type[Any]

The Torch Model to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
    """Writes a Model to the specified dir.

    Args:
        model: The Torch Model to write.
    """
    super().handle_return(model)
    temp_dir = TemporaryDirectory()
    model.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_PT_MODEL_DIR),
    )
huggingface_tf_model_materializer

Implementation of the Huggingface TF model materializer.

HFTFModelMaterializer (BaseMaterializer)

Materializer to read Tensorflow model to and from huggingface pretrained model.

Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
class HFTFModelMaterializer(BaseMaterializer):
    """Materializer to read Tensorflow model to and from huggingface pretrained model."""

    ASSOCIATED_TYPES = (TFPreTrainedModel,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
        """Reads HFModel.

        Args:
            data_type: The type of the model to read.

        Returns:
            The model read from the specified dir.
        """
        super().handle_input(data_type)

        config = AutoConfig.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
        )
        architecture = "TF" + config.architectures[0]
        model_cls = getattr(
            importlib.import_module("transformers"), architecture
        )
        return model_cls.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
        )

    def handle_return(self, model: Type[Any]) -> None:
        """Writes a Model to the specified dir.

        Args:
            model: The TF Model to write.
        """
        super().handle_return(model)
        temp_dir = TemporaryDirectory()
        model.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
        )
handle_input(self, data_type)

Reads HFModel.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to read.

required

Returns:

Type Description
TFPreTrainedModel

The model read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_input(self, data_type: Type[Any]) -> TFPreTrainedModel:
    """Reads HFModel.

    Args:
        data_type: The type of the model to read.

    Returns:
        The model read from the specified dir.
    """
    super().handle_input(data_type)

    config = AutoConfig.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
    )
    architecture = "TF" + config.architectures[0]
    model_cls = getattr(
        importlib.import_module("transformers"), architecture
    )
    return model_cls.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR)
    )
handle_return(self, model)

Writes a Model to the specified dir.

Parameters:

Name Type Description Default
model Type[Any]

The TF Model to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py
def handle_return(self, model: Type[Any]) -> None:
    """Writes a Model to the specified dir.

    Args:
        model: The TF Model to write.
    """
    super().handle_return(model)
    temp_dir = TemporaryDirectory()
    model.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_TF_MODEL_DIR),
    )
huggingface_tokenizer_materializer

Implementation of the Huggingface tokenizer materializer.

HFTokenizerMaterializer (BaseMaterializer)

Materializer to read tokenizer to and from huggingface tokenizer.

Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
class HFTokenizerMaterializer(BaseMaterializer):
    """Materializer to read tokenizer to and from huggingface tokenizer."""

    ASSOCIATED_TYPES = (PreTrainedTokenizerBase,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
        """Reads Tokenizer.

        Args:
            data_type: The type of the tokenizer to read.

        Returns:
            The tokenizer read from the specified dir.
        """
        super().handle_input(data_type)

        return AutoTokenizer.from_pretrained(
            os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
        )

    def handle_return(self, tokenizer: Type[Any]) -> None:
        """Writes a Tokenizer to the specified dir.

        Args:
            tokenizer: The HFTokenizer to write.
        """
        super().handle_return(tokenizer)
        temp_dir = TemporaryDirectory()
        tokenizer.save_pretrained(temp_dir.name)
        io_utils.copy_dir(
            temp_dir.name,
            os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
        )
handle_input(self, data_type)

Reads Tokenizer.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the tokenizer to read.

required

Returns:

Type Description
PreTrainedTokenizerBase

The tokenizer read from the specified dir.

Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_input(self, data_type: Type[Any]) -> PreTrainedTokenizerBase:
    """Reads Tokenizer.

    Args:
        data_type: The type of the tokenizer to read.

    Returns:
        The tokenizer read from the specified dir.
    """
    super().handle_input(data_type)

    return AutoTokenizer.from_pretrained(
        os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR)
    )
handle_return(self, tokenizer)

Writes a Tokenizer to the specified dir.

Parameters:

Name Type Description Default
tokenizer Type[Any]

The HFTokenizer to write.

required
Source code in zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py
def handle_return(self, tokenizer: Type[Any]) -> None:
    """Writes a Tokenizer to the specified dir.

    Args:
        tokenizer: The HFTokenizer to write.
    """
    super().handle_return(tokenizer)
    temp_dir = TemporaryDirectory()
    tokenizer.save_pretrained(temp_dir.name)
    io_utils.copy_dir(
        temp_dir.name,
        os.path.join(self.artifact.uri, DEFAULT_TOKENIZER_DIR),
    )

integration

Base and meta classes for ZenML integrations.

Integration

Base class for integration in ZenML.

Source code in zenml/integrations/integration.py
class Integration(metaclass=IntegrationMeta):
    """Base class for integration in ZenML."""

    NAME = "base_integration"

    REQUIREMENTS: List[str] = []

    SYSTEM_REQUIREMENTS: Dict[str, str] = {}

    @classmethod
    def check_installation(cls) -> bool:
        """Method to check whether the required packages are installed.

        Returns:
            True if all required packages are installed, False otherwise.
        """
        try:
            for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
                result = shutil.which(command)

                if result is None:
                    logger.debug(
                        "Unable to find the required packages for %s on your "
                        "system. Please install the packages on your system "
                        "and try again.",
                        requirement,
                    )
                    return False

            for r in cls.REQUIREMENTS:
                pkg_resources.get_distribution(r)
            logger.debug(
                f"Integration {cls.NAME} is installed correctly with "
                f"requirements {cls.REQUIREMENTS}."
            )
            return True
        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"VersionConflict error when loading installation {cls.NAME}: "
                f"{str(e)}"
            )
            return False

    @classmethod
    def activate(cls) -> None:
        """Abstract method to activate the integration."""

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Abstract method to declare new stack component flavors."""
activate() classmethod

Abstract method to activate the integration.

Source code in zenml/integrations/integration.py
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in zenml/integrations/integration.py
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    try:
        for requirement, command in cls.SYSTEM_REQUIREMENTS.items():
            result = shutil.which(command)

            if result is None:
                logger.debug(
                    "Unable to find the required packages for %s on your "
                    "system. Please install the packages on your system "
                    "and try again.",
                    requirement,
                )
                return False

        for r in cls.REQUIREMENTS:
            pkg_resources.get_distribution(r)
        logger.debug(
            f"Integration {cls.NAME} is installed correctly with "
            f"requirements {cls.REQUIREMENTS}."
        )
        return True
    except pkg_resources.DistributionNotFound as e:
        logger.debug(
            f"Unable to find required package '{e.req}' for "
            f"integration {cls.NAME}."
        )
        return False
    except pkg_resources.VersionConflict as e:
        logger.debug(
            f"VersionConflict error when loading installation {cls.NAME}: "
            f"{str(e)}"
        )
        return False
flavors() classmethod

Abstract method to declare new stack component flavors.

Source code in zenml/integrations/integration.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Abstract method to declare new stack component flavors."""

IntegrationMeta (type)

Metaclass responsible for registering different Integration subclasses.

Source code in zenml/integrations/integration.py
class IntegrationMeta(type):
    """Metaclass responsible for registering different Integration subclasses."""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "IntegrationMeta":
        """Hook into creation of an Integration class.

        Args:
            name: The name of the class being created.
            bases: The base classes of the class being created.
            dct: The dictionary of attributes of the class being created.

        Returns:
            The newly created class.
        """
        cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
        if name != "Integration":
            integration_registry.register_integration(cls.NAME, cls)
        return cls
__new__(mcs, name, bases, dct) special staticmethod

Hook into creation of an Integration class.

Parameters:

Name Type Description Default
name str

The name of the class being created.

required
bases Tuple[Type[Any], ...]

The base classes of the class being created.

required
dct Dict[str, Any]

The dictionary of attributes of the class being created.

required

Returns:

Type Description
IntegrationMeta

The newly created class.

Source code in zenml/integrations/integration.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "IntegrationMeta":
    """Hook into creation of an Integration class.

    Args:
        name: The name of the class being created.
        bases: The base classes of the class being created.
        dct: The dictionary of attributes of the class being created.

    Returns:
        The newly created class.
    """
    cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct))
    if name != "Integration":
        integration_registry.register_integration(cls.NAME, cls)
    return cls

kubeflow special

Initialization of the Kubeflow integration for ZenML.

The Kubeflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool.

KubeflowIntegration (Integration)

Definition of Kubeflow Integration for ZenML.

Source code in zenml/integrations/kubeflow/__init__.py
class KubeflowIntegration(Integration):
    """Definition of Kubeflow Integration for ZenML."""

    NAME = KUBEFLOW
    REQUIREMENTS = ["kfp==1.8.9"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Kubeflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=KUBEFLOW_METADATA_STORE_FLAVOR,
                source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
                type=StackComponentType.METADATA_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the Kubeflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/kubeflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Kubeflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=KUBEFLOW_METADATA_STORE_FLAVOR,
            source="zenml.integrations.kubeflow.metadata_stores.KubeflowMetadataStore",
            type=StackComponentType.METADATA_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=KUBEFLOW_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.kubeflow.orchestrators.KubeflowOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
    ]

metadata_stores special

Initialization of the Kubeflow metadata store for ZenML.

kubeflow_metadata_store

Implementation of the Kubeflow metadata store.

KubeflowMetadataStore (BaseMetadataStore) pydantic-model

Kubeflow GRPC backend for ZenML metadata store.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
class KubeflowMetadataStore(BaseMetadataStore):
    """Kubeflow GRPC backend for ZenML metadata store."""

    upgrade_migration_enabled: bool = False
    host: str = "127.0.0.1"
    port: int = DEFAULT_KFP_METADATA_GRPC_PORT

    # Class Configuration
    FLAVOR: ClassVar[str] = KUBEFLOW_METADATA_STORE_FLAVOR

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a KFP orchestrator.

        Returns:
            The stack validator.
        """

        def _ensure_kfp_orchestrator(stack: Stack) -> Tuple[bool, str]:
            return (
                stack.orchestrator.FLAVOR == KUBEFLOW,
                "The Kubeflow metadata store can only be used with a Kubeflow "
                "orchestrator.",
            )

        return StackValidator(
            custom_validation_function=_ensure_kfp_orchestrator
        )

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for the kubeflow metadata store.

        Returns:
            The tfx metadata config for the kubeflow metadata store.

        Raises:
            RuntimeError: If the metadata store is not running.
        """
        connection_config = metadata_store_pb2.MetadataStoreClientConfig()
        if inside_kfp_pod():
            connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
            connection_config.port = int(
                os.environ["METADATA_GRPC_SERVICE_PORT"]
            )
        else:
            if not self.is_running:
                raise RuntimeError(
                    "The KFP metadata daemon is not running. Please run the "
                    "following command to start it first:\n\n"
                    "    'zenml metadata-store up'\n"
                )
            connection_config.host = self.host
            connection_config.port = self.port
        return connection_config

    @property
    def kfp_orchestrator(self) -> KubeflowOrchestrator:
        """Returns the Kubeflow orchestrator in the active stack.

        Returns:
            The Kubeflow orchestrator in the active stack.
        """
        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        return cast(KubeflowOrchestrator, repo.active_stack.orchestrator)

    @property
    def kubernetes_context(self) -> str:
        """Returns the kubernetes context.

        This is returned to the cluster where the Kubeflow Pipelines services
        are running.

        Returns:
            The kubernetes context.
        """
        kubernetes_context = self.kfp_orchestrator.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None
        return kubernetes_context

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory.

        This is for all files concerning this KFP metadata store.

        Note: the root directory for the KFP metadata store is relative to the
        root directory of the KFP orchestrator, because it is a sub-component
        of it.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            self.kfp_orchestrator.root_directory,
            "metadata-store",
            str(self.uuid),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.pid")

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.log")

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run locally.

        Returns:
            True if the component provisioned resources to run locally.
        """
        return fileio.exists(self.root_directory)

    @property
    def is_running(self) -> bool:
        """If the component is running locally.

        Returns:
            True if the component is running locally, False otherwise.
        """
        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return True

    def provision(self) -> None:
        """Provisions resources to run the component locally."""
        logger.info("Provisioning local Kubeflow Pipelines deployment...")
        fileio.makedirs(self.root_directory)

    def deprovision(self) -> None:
        """Deprovisions all local resources of the component."""
        if fileio.exists(self._log_file):
            fileio.remove(self._log_file)

        logger.info("Local kubeflow pipelines deployment deprovisioned.")

    def resume(self) -> None:
        """Resumes the local k3d cluster."""
        if self.is_running:
            logger.info("Local kubeflow pipelines deployment already running.")
            return

        self.start_kfp_metadata_daemon()
        self.wait_until_metadata_store_ready()

    def suspend(self) -> None:
        """Suspends the local k3d cluster."""
        if not self.is_running:
            logger.info("Local kubeflow pipelines deployment not running.")
            return

        self.stop_kfp_metadata_daemon()

    def start_kfp_metadata_daemon(self) -> None:
        """Starts a daemon process that forwards ports.

        This is so the Kubeflow Pipelines Metadata MySQL database is accessible
        on the localhost.

        Raises:
            ProvisioningError: if the daemon fails to start.
        """
        command = [
            "kubectl",
            "--context",
            self.kubernetes_context,
            "--namespace",
            "kubeflow",
            "port-forward",
            "svc/metadata-grpc-service",
            f"{self.port}:8080",
        ]

        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Kubeflow Pipelines Metadata locally, "
                "please run '%s' in a separate command line shell.",
                self.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.port):
            raise ProvisioningError(
                f"Unable to port-forward Kubeflow Pipelines Metadata to local "
                f"port {self.port} because the port is occupied. In order to "
                f"access the Kubeflow Pipelines Metadata locally, please "
                f"change the metadata store configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Kubeflow Pipelines Metadata daemon (check the daemon"
                "logs at %s in case you're not able to access the pipeline"
                "metadata).",
                self._log_file,
            )

    def stop_kfp_metadata_daemon(self) -> None:
        """Stops the KFP Metadata daemon process if it is running."""
        if fileio.exists(self._pid_file_path):
            if sys.platform == "win32":
                # Daemon functionality is not supported on Windows, so the PID
                # file won't exist. This if clause exists just for mypy to not
                # complain about missing functions
                pass
            else:
                from zenml.utils import daemon

                daemon.stop_daemon(self._pid_file_path)
                fileio.remove(self._pid_file_path)

    def wait_until_metadata_store_ready(
        self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
    ) -> None:
        """Waits until the metadata store connection is ready.

        Potentially an irrecoverable error could occur or the timeout could
        expire, so it checks for this.

        Args:
            timeout: The maximum time to wait for the metadata store to be
                ready.

        Raises:
            RuntimeError: if the metadata store is not ready after the timeout
        """
        logger.info(
            "Waiting for the Kubeflow metadata store to be ready (this might "
            "take a few minutes)."
        )
        while True:
            try:
                # it doesn't matter what we call here as long as it exercises
                # the MLMD connection
                self.get_pipelines()
                break
            except Exception as e:
                logger.info(
                    "The Kubeflow metadata store is not ready yet. Waiting for "
                    "10 seconds..."
                )
                if timeout <= 0:
                    raise RuntimeError(
                        f"An unexpected error was encountered while waiting for the "
                        f"Kubeflow metadata store to be functional: {str(e)}"
                    ) from e
                timeout -= 10
                time.sleep(10)

        logger.info("The Kubeflow metadata store is functional.")
is_provisioned: bool property readonly

If the component provisioned resources to run locally.

Returns:

Type Description
bool

True if the component provisioned resources to run locally.

is_running: bool property readonly

If the component is running locally.

Returns:

Type Description
bool

True if the component is running locally, False otherwise.

kfp_orchestrator: KubeflowOrchestrator property readonly

Returns the Kubeflow orchestrator in the active stack.

Returns:

Type Description
KubeflowOrchestrator

The Kubeflow orchestrator in the active stack.

kubernetes_context: str property readonly

Returns the kubernetes context.

This is returned to the cluster where the Kubeflow Pipelines services are running.

Returns:

Type Description
str

The kubernetes context.

root_directory: str property readonly

Returns path to the root directory.

This is for all files concerning this KFP metadata store.

Note: the root directory for the KFP metadata store is relative to the root directory of the KFP orchestrator, because it is a sub-component of it.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a KFP orchestrator.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

The stack validator.

deprovision(self)

Deprovisions all local resources of the component.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def deprovision(self) -> None:
    """Deprovisions all local resources of the component."""
    if fileio.exists(self._log_file):
        fileio.remove(self._log_file)

    logger.info("Local kubeflow pipelines deployment deprovisioned.")
get_tfx_metadata_config(self)

Return tfx metadata config for the kubeflow metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config for the kubeflow metadata store.

Exceptions:

Type Description
RuntimeError

If the metadata store is not running.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for the kubeflow metadata store.

    Returns:
        The tfx metadata config for the kubeflow metadata store.

    Raises:
        RuntimeError: If the metadata store is not running.
    """
    connection_config = metadata_store_pb2.MetadataStoreClientConfig()
    if inside_kfp_pod():
        connection_config.host = os.environ["METADATA_GRPC_SERVICE_HOST"]
        connection_config.port = int(
            os.environ["METADATA_GRPC_SERVICE_PORT"]
        )
    else:
        if not self.is_running:
            raise RuntimeError(
                "The KFP metadata daemon is not running. Please run the "
                "following command to start it first:\n\n"
                "    'zenml metadata-store up'\n"
            )
        connection_config.host = self.host
        connection_config.port = self.port
    return connection_config
provision(self)

Provisions resources to run the component locally.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def provision(self) -> None:
    """Provisions resources to run the component locally."""
    logger.info("Provisioning local Kubeflow Pipelines deployment...")
    fileio.makedirs(self.root_directory)
resume(self)

Resumes the local k3d cluster.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def resume(self) -> None:
    """Resumes the local k3d cluster."""
    if self.is_running:
        logger.info("Local kubeflow pipelines deployment already running.")
        return

    self.start_kfp_metadata_daemon()
    self.wait_until_metadata_store_ready()
start_kfp_metadata_daemon(self)

Starts a daemon process that forwards ports.

This is so the Kubeflow Pipelines Metadata MySQL database is accessible on the localhost.

Exceptions:

Type Description
ProvisioningError

if the daemon fails to start.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def start_kfp_metadata_daemon(self) -> None:
    """Starts a daemon process that forwards ports.

    This is so the Kubeflow Pipelines Metadata MySQL database is accessible
    on the localhost.

    Raises:
        ProvisioningError: if the daemon fails to start.
    """
    command = [
        "kubectl",
        "--context",
        self.kubernetes_context,
        "--namespace",
        "kubeflow",
        "port-forward",
        "svc/metadata-grpc-service",
        f"{self.port}:8080",
    ]

    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubeflow Pipelines Metadata locally, "
            "please run '%s' in a separate command line shell.",
            self.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.port):
        raise ProvisioningError(
            f"Unable to port-forward Kubeflow Pipelines Metadata to local "
            f"port {self.port} because the port is occupied. In order to "
            f"access the Kubeflow Pipelines Metadata locally, please "
            f"change the metadata store configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Kubeflow Pipelines Metadata daemon (check the daemon"
            "logs at %s in case you're not able to access the pipeline"
            "metadata).",
            self._log_file,
        )
stop_kfp_metadata_daemon(self)

Stops the KFP Metadata daemon process if it is running.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def stop_kfp_metadata_daemon(self) -> None:
    """Stops the KFP Metadata daemon process if it is running."""
    if fileio.exists(self._pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)
suspend(self)

Suspends the local k3d cluster.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def suspend(self) -> None:
    """Suspends the local k3d cluster."""
    if not self.is_running:
        logger.info("Local kubeflow pipelines deployment not running.")
        return

    self.stop_kfp_metadata_daemon()
wait_until_metadata_store_ready(self, timeout=60)

Waits until the metadata store connection is ready.

Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.

Parameters:

Name Type Description Default
timeout int

The maximum time to wait for the metadata store to be ready.

60

Exceptions:

Type Description
RuntimeError

if the metadata store is not ready after the timeout

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def wait_until_metadata_store_ready(
    self, timeout: int = DEFAULT_KFP_METADATA_DAEMON_TIMEOUT
) -> None:
    """Waits until the metadata store connection is ready.

    Potentially an irrecoverable error could occur or the timeout could
    expire, so it checks for this.

    Args:
        timeout: The maximum time to wait for the metadata store to be
            ready.

    Raises:
        RuntimeError: if the metadata store is not ready after the timeout
    """
    logger.info(
        "Waiting for the Kubeflow metadata store to be ready (this might "
        "take a few minutes)."
    )
    while True:
        try:
            # it doesn't matter what we call here as long as it exercises
            # the MLMD connection
            self.get_pipelines()
            break
        except Exception as e:
            logger.info(
                "The Kubeflow metadata store is not ready yet. Waiting for "
                "10 seconds..."
            )
            if timeout <= 0:
                raise RuntimeError(
                    f"An unexpected error was encountered while waiting for the "
                    f"Kubeflow metadata store to be functional: {str(e)}"
                ) from e
            timeout -= 10
            time.sleep(10)

    logger.info("The Kubeflow metadata store is functional.")
inside_kfp_pod()

Returns if the current python process is running inside a KFP Pod.

Returns:

Type Description
bool

True if the current python process is running inside a KFP Pod, False otherwise.

Source code in zenml/integrations/kubeflow/metadata_stores/kubeflow_metadata_store.py
def inside_kfp_pod() -> bool:
    """Returns if the current python process is running inside a KFP Pod.

    Returns:
        True if the current python process is running inside a KFP Pod, False otherwise.
    """
    if "KFP_POD_NAME" not in os.environ:
        return False

    try:
        k8s_config.load_incluster_config()
        return True
    except k8s_config.ConfigException:
        return False

orchestrators special

Initialization of the Kubeflow ZenML orchestrator.

kubeflow_entrypoint_configuration

Implementation of the Kubeflow entrypoint configuration.

KubeflowEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on kubeflow.

This class writes a markdown file that will be displayed in the KFP UI.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
class KubeflowEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on kubeflow.

    This class writes a markdown file that will be displayed in the KFP UI.
    """

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Kubeflow specific entrypoint options.

        The metadata ui path option expects a path where the markdown file
        that will be displayed in the kubeflow UI should be written. The same
        path needs to be added as an output artifact called
        `mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.

        Returns:
            The set of custom entrypoint options.
        """
        return {METADATA_UI_PATH_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: BaseStep, *args: Any, **kwargs: Any
    ) -> List[str]:
        """Sets the metadata ui path argument to the value passed in via the keyword args.

        Args:
            step: The step that is being executed.
            *args: The positional arguments passed to the step.
            **kwargs: The keyword arguments passed to the step.

        Returns:
            A list of strings that will be used as arguments to the step.
        """
        return [
            f"--{METADATA_UI_PATH_OPTION}",
            kwargs[METADATA_UI_PATH_OPTION],
        ]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the Kubeflow pipeline run name.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The Kubeflow pipeline run name.
        """
        k8s_config.load_incluster_config()
        run_id = os.environ["KFP_RUN_ID"]
        return kfp.Client().get_run(run_id).run.name  # type: ignore[no-any-return]

    def post_run(
        self,
        pipeline_name: str,
        step_name: str,
        pipeline_node: Pb2PipelineNode,
        execution_info: Optional[data_types.ExecutionInfo] = None,
    ) -> None:
        """Writes a markdown file that will display information.

        This will be about the step execution and input/output artifacts in the
        KFP UI.

        Args:
            pipeline_name: The name of the pipeline.
            step_name: The name of the step.
            pipeline_node: The pipeline node that is being executed.
            execution_info: The execution info of the step.
        """
        if execution_info:
            utils.dump_ui_metadata(
                node=pipeline_node,
                execution_info=execution_info,
                metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
            )
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Sets the metadata ui path argument to the value passed in via the keyword args.

Parameters:

Name Type Description Default
step BaseStep

The step that is being executed.

required
*args Any

The positional arguments passed to the step.

()
**kwargs Any

The keyword arguments passed to the step.

{}

Returns:

Type Description
List[str]

A list of strings that will be used as arguments to the step.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: BaseStep, *args: Any, **kwargs: Any
) -> List[str]:
    """Sets the metadata ui path argument to the value passed in via the keyword args.

    Args:
        step: The step that is being executed.
        *args: The positional arguments passed to the step.
        **kwargs: The keyword arguments passed to the step.

    Returns:
        A list of strings that will be used as arguments to the step.
    """
    return [
        f"--{METADATA_UI_PATH_OPTION}",
        kwargs[METADATA_UI_PATH_OPTION],
    ]
get_custom_entrypoint_options() classmethod

Kubeflow specific entrypoint options.

The metadata ui path option expects a path where the markdown file that will be displayed in the kubeflow UI should be written. The same path needs to be added as an output artifact called mlpipeline-ui-metadata for the corresponding kfp.dsl.ContainerOp.

Returns:

Type Description
Set[str]

The set of custom entrypoint options.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Kubeflow specific entrypoint options.

    The metadata ui path option expects a path where the markdown file
    that will be displayed in the kubeflow UI should be written. The same
    path needs to be added as an output artifact called
    `mlpipeline-ui-metadata` for the corresponding `kfp.dsl.ContainerOp`.

    Returns:
        The set of custom entrypoint options.
    """
    return {METADATA_UI_PATH_OPTION}
get_run_name(self, pipeline_name)

Returns the Kubeflow pipeline run name.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The Kubeflow pipeline run name.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the Kubeflow pipeline run name.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The Kubeflow pipeline run name.
    """
    k8s_config.load_incluster_config()
    run_id = os.environ["KFP_RUN_ID"]
    return kfp.Client().get_run(run_id).run.name  # type: ignore[no-any-return]
post_run(self, pipeline_name, step_name, pipeline_node, execution_info=None)

Writes a markdown file that will display information.

This will be about the step execution and input/output artifacts in the KFP UI.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required
step_name str

The name of the step.

required
pipeline_node PipelineNode

The pipeline node that is being executed.

required
execution_info Optional[tfx.orchestration.portable.data_types.ExecutionInfo]

The execution info of the step.

None
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_entrypoint_configuration.py
def post_run(
    self,
    pipeline_name: str,
    step_name: str,
    pipeline_node: Pb2PipelineNode,
    execution_info: Optional[data_types.ExecutionInfo] = None,
) -> None:
    """Writes a markdown file that will display information.

    This will be about the step execution and input/output artifacts in the
    KFP UI.

    Args:
        pipeline_name: The name of the pipeline.
        step_name: The name of the step.
        pipeline_node: The pipeline node that is being executed.
        execution_info: The execution info of the step.
    """
    if execution_info:
        utils.dump_ui_metadata(
            node=pipeline_node,
            execution_info=execution_info,
            metadata_ui_path=self.entrypoint_args[METADATA_UI_PATH_OPTION],
        )
kubeflow_orchestrator

Implementation of the Kubeflow orchestrator.

KubeflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Kubeflow.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a docker image that should be used as the base for the image that will be run on KFP pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

kubeflow_pipelines_ui_port int

A local port to which the KFP UI will be forwarded.

kubeflow_hostname Optional[str]

The hostname to use to talk to the Kubeflow Pipelines API. If not set, the hostname will be derived from the Kubernetes API proxy.

kubernetes_context Optional[str]

Optional name of a kubernetes context to run pipelines in. If not set, the current active context will be used. You can find the active context by running kubectl config current-context.

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on KFP.

skip_local_validations bool

If True, the local validations will be skipped.

skip_cluster_provisioning bool

If True, the k3d cluster provisioning will be skipped.

skip_ui_daemon_provisioning bool

If True, provisioning the KFP UI daemon will be skipped.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
class KubeflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Kubeflow.

    Attributes:
        custom_docker_base_image_name: Name of a docker image that should be
            used as the base for the image that will be run on KFP pods. If no
            custom image is given, a basic image of the active ZenML version
            will be used. **Note**: This image needs to have ZenML installed,
            otherwise the pipeline execution will fail. For that reason, you
            might want to extend the ZenML docker images found here:
            https://hub.docker.com/r/zenmldocker/zenml/
        kubeflow_pipelines_ui_port: A local port to which the KFP UI will be
            forwarded.
        kubeflow_hostname: The hostname to use to talk to the Kubeflow Pipelines
            API. If not set, the hostname will be derived from the Kubernetes
            API proxy.
        kubernetes_context: Optional name of a kubernetes context to run
            pipelines in. If not set, the current active context will be used.
            You can find the active context by running `kubectl config
            current-context`.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on KFP.
        skip_local_validations: If `True`, the local validations will be
            skipped.
        skip_cluster_provisioning: If `True`, the k3d cluster provisioning will
            be skipped.
        skip_ui_daemon_provisioning: If `True`, provisioning the KFP UI daemon
            will be skipped.
    """

    custom_docker_base_image_name: Optional[str] = None
    kubeflow_pipelines_ui_port: int = DEFAULT_KFP_UI_PORT
    kubeflow_hostname: Optional[str] = None
    kubernetes_context: Optional[str] = None
    synchronous: bool = False
    skip_local_validations: bool = False
    skip_cluster_provisioning: bool = False
    skip_ui_daemon_provisioning: bool = False

    # Class Configuration
    FLAVOR: ClassVar[str] = KUBEFLOW_ORCHESTRATOR_FLAVOR

    @staticmethod
    def _get_k3d_cluster_name(uuid: UUID) -> str:
        """Returns the k3d cluster name corresponding to the orchestrator UUID.

        Args:
            uuid: The UUID of the orchestrator.

        Returns:
            The k3d cluster name.
        """
        # k3d only allows cluster names with up to 32 characters; use the
        # first 8 chars of the orchestrator UUID as identifier
        return f"zenml-kubeflow-{str(uuid)[:8]}"

    @staticmethod
    def _get_k3d_kubernetes_context(uuid: UUID) -> str:
        """Gets the k3d kubernetes context.

        Args:
            uuid: The UUID of the orchestrator.

        Returns:
            The name of the kubernetes context associated with the k3d
                cluster managed locally by ZenML corresponding to the orchestrator UUID.
        """
        return f"k3d-{KubeflowOrchestrator._get_k3d_cluster_name(uuid)}"

    @root_validator(skip_on_failure=True)
    def set_default_kubernetes_context(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic root_validator.

        This sets the default `kubernetes_context` value to the value that is
        used to create the locally managed k3d cluster, if not explicitly set.

        Args:
            values: Values passed to the object constructor

        Returns:
            Values passed to the Pydantic constructor
        """
        if not values.get("kubernetes_context"):
            # not likely, due to Pydantic validation, but mypy complains
            assert "uuid" in values
            values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
                values["uuid"]
            )

        return values

    def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
        """Get the list of configured Kubernetes contexts and the active context.

        Returns:
            A tuple containing the list of configured Kubernetes contexts and
            the active context.

        Raises:
            RuntimeError: if the Kubernetes configuration cannot be loaded
        """
        try:
            contexts, active_context = k8s_config.list_kube_config_contexts()
        except k8s_config.config_exception.ConfigException as e:
            raise RuntimeError(
                "Could not load the Kubernetes configuration"
            ) from e

        context_names = [c["name"] for c in contexts]
        active_context_name = active_context["name"]
        return context_names, active_context_name

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Also check that requirements are met for local components.

        Returns:
            A `StackValidator` instance.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:

            container_registry = stack.container_registry

            # should not happen, because the stack validation takes care of
            # this, but just in case
            assert container_registry is not None

            contexts, active_context = self.get_kubernetes_contexts()

            if self.kubernetes_context not in contexts:
                if not self.is_local:
                    return False, (
                        f"Could not find a Kubernetes context named "
                        f"'{self.kubernetes_context}' in the local Kubernetes "
                        f"configuration. Please make sure that the Kubernetes "
                        f"cluster is running and that the kubeconfig file is "
                        f"configured correctly. To list all configured "
                        f"contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                    )
            elif self.kubernetes_context != active_context:
                logger.warning(
                    f"The Kubernetes context '{self.kubernetes_context}' "
                    f"configured for the Kubeflow orchestrator is not the "
                    f"same as the active context in the local Kubernetes "
                    f"configuration. If this is not deliberate, you should "
                    f"update the orchestrator's `kubernetes_context` field by "
                    f"running:\n\n"
                    f"  `zenml orchestrator update {self.name} "
                    f"--kubernetes_context={active_context}`\n"
                    f"To list all configured contexts, run:\n\n"
                    f"  `kubectl config get-contexts`\n"
                    f"To set the active context to be the same as the one "
                    f"configured in the Kubeflow orchestrator and silence "
                    f"this warning, run:\n\n"
                    f"  `kubectl config use-context "
                    f"{self.kubernetes_context}`\n"
                )

            silence_local_validations_msg = (
                f"To silence this warning, set the "
                f"`skip_local_validations` attribute to True in the "
                f"orchestrator configuration by running:\n\n"
                f"  'zenml orchestrator update {self.name} "
                f"--skip_local_validations=True'\n"
            )

            if not self.skip_local_validations and not self.is_local:

                # if the orchestrator is not running in a local k3d cluster,
                # we cannot have any other local components in our stack,
                # because we cannot mount the local path into the container.
                # This may result in problems when running the pipeline, because
                # the local components will not be available inside the
                # Kubeflow containers.

                # go through all stack components and identify those that
                # advertise a local path where they persist information that
                # they need to be available when running pipelines.
                for stack_comp in stack.components.values():
                    local_path = stack_comp.local_path
                    if not local_path:
                        continue
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a remote Kubernetes cluster designated "
                        f"by the '{self.kubernetes_context}' configuration "
                        f"context, but the '{stack_comp.name}' "
                        f"{stack_comp.TYPE.value} is a local stack component "
                        f"and will not be available in the Kubeflow pipeline "
                        f"step.\nPlease ensure that you always use non-local "
                        f"stack components with a remote Kubeflow orchestrator, "
                        f"otherwise you may run into pipeline execution "
                        f"problems. You should use a flavor of "
                        f"{stack_comp.TYPE.value} other than "
                        f"'{stack_comp.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

                # if the orchestrator is remote, the container registry must
                # also be remote.
                if container_registry.is_local:
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a remote Kubernetes cluster designated "
                        f"by the '{self.kubernetes_context}' configuration "
                        f"context, but the '{container_registry.name}' "
                        f"container registry URI '{container_registry.uri}' "
                        f"points to a local container registry. Please ensure "
                        f"that you always use non-local stack components with "
                        f"a remote Kubeflow orchestrator, otherwise you will "
                        f"run into problems. You should use a flavor of "
                        f"container registry other than "
                        f"'{container_registry.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

            if not self.skip_local_validations and self.is_local:

                # if the orchestrator is local, the container registry must
                # also be local.
                if not container_registry.is_local:
                    return False, (
                        f"The Kubeflow orchestrator is configured to run "
                        f"pipelines in a local k3d Kubernetes cluster "
                        f"designated by the '{self.kubernetes_context}' "
                        f"configuration context, but the container registry "
                        f"URI '{container_registry.uri}' doesn't match the "
                        f"expected format 'localhost:$PORT'. "
                        f"The local Kubeflow orchestrator only works with a "
                        f"local container registry because it cannot "
                        f"currently authenticate to external container "
                        f"registries. You should use a flavor of container "
                        f"registry other than '{container_registry.FLAVOR}'.\n"
                        + silence_local_validations_msg
                    )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Returns the full docker image name including registry and tag.

        Args:
            pipeline_name: The name of the pipeline.

        Returns:
            The full docker image name including registry and tag.
        """
        base_image_name = f"zenml-kubeflow:{pipeline_name}"
        container_registry = Repository().active_stack.container_registry

        if container_registry:
            registry_uri = container_registry.uri.rstrip("/")
            return f"{registry_uri}/{base_image_name}"
        else:
            return base_image_name

    @property
    def is_local(self) -> bool:
        """Checks if the KFP orchestrator is running locally.

        Returns:
            `True` if the KFP orchestrator is running locally (i.e. in
            the local k3d cluster managed by ZenML).
        """
        return self.kubernetes_context == self._get_k3d_kubernetes_context(
            self.uuid
        )

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all files concerning this orchestrator.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            "kubeflow",
            str(self.uuid),
        )

    @property
    def pipeline_directory(self) -> str:
        """Returns path to a directory in which the kubeflow pipeline files are stored.

        Returns:
            Path to the pipeline directory.
        """
        return os.path.join(self.root_directory, "pipelines")

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Builds a docker image for the current environment.

        This function also uploads it to a container registry if configured.

        Args:
            pipeline: The pipeline to be deployed.
            stack: The stack to be deployed.
            runtime_configuration: The runtime configuration to be used.
        """
        from zenml.utils import docker_utils

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug("Kubeflow docker container requirements: %s", requirements)

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
            environment_vars=self._get_environment_vars_from_secrets(
                pipeline.secrets
            ),
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    @staticmethod
    def _configure_container_op(container_op: dsl.ContainerOp) -> None:
        """Makes changes in place to the configuration of the container op.

        Configures persistent mounted volumes for each stack component that
        writes to a local path. Adds some labels to the container_op and applies
        some functions to ir.

        Args:
            container_op: The kubeflow container operation to configure.

        Raises:
            ValueError: If the local path is not in the global config directory.
        """
        # Path to a metadata file that will be displayed in the KFP UI
        # This metadata file needs to be in a mounted emptyDir to avoid
        # sporadic failures with the (not mature) PNS executor
        # See these links for more information about limitations of PNS +
        # security context:
        # https://www.kubeflow.org/docs/components/pipelines/installation/localcluster-deployment/#deploying-kubeflow-pipelines
        # https://argoproj.github.io/argo-workflows/empty-dir/
        # KFP will switch to the Emissary executor (soon), when this emptyDir
        # mount will not be necessary anymore, but for now it's still in alpha
        # status (https://www.kubeflow.org/docs/components/pipelines/installation/choose-executor/#emissary-executor)
        volumes: Dict[str, k8s_client.V1Volume] = {
            "/outputs": k8s_client.V1Volume(
                name="outputs", empty_dir=k8s_client.V1EmptyDirVolumeSource()
            ),
        }

        stack = Repository().active_stack
        global_cfg_dir = io_utils.get_global_config_directory()

        # go through all stack components and identify those that advertise
        # a local path where they persist information that they need to be
        # available when running pipelines. For those that do, mount them
        # into the Kubeflow container.
        has_local_repos = False
        for stack_comp in stack.components.values():
            local_path = stack_comp.local_path
            if not local_path:
                continue
            # double-check this convention, just in case it wasn't respected
            # as documented in `StackComponent.local_path`
            if not local_path.startswith(global_cfg_dir):
                raise ValueError(
                    f"Local path {local_path} for component {stack_comp.name} "
                    f"is not in the global config directory ({global_cfg_dir})."
                )
            has_local_repos = True
            host_path = k8s_client.V1HostPathVolumeSource(
                path=local_path, type="Directory"
            )
            volume_name = f"{stack_comp.TYPE.value}-{stack_comp.name}"
            volumes[local_path] = k8s_client.V1Volume(
                name=re.sub(r"[^0-9a-zA-Z-]+", "-", volume_name)
                .strip("-")
                .lower(),
                host_path=host_path,
            )
            logger.debug(
                "Adding host path volume for %s %s (path: %s) "
                "in kubeflow pipelines container.",
                stack_comp.TYPE.value,
                stack_comp.name,
                local_path,
            )
        container_op.add_pvolumes(volumes)

        if has_local_repos:
            if sys.platform == "win32":
                # File permissions are not checked on Windows. This if clause
                # prevents mypy from complaining about unused 'type: ignore'
                # statements
                pass
            else:
                # Run KFP containers in the context of the local UID/GID
                # to ensure that the artifact and metadata stores can be shared
                # with the local pipeline runs.
                container_op.container.security_context = (
                    k8s_client.V1SecurityContext(
                        run_as_user=os.getuid(),
                        run_as_group=os.getgid(),
                    )
                )
                logger.debug(
                    "Setting security context UID and GID to local user/group "
                    "in kubeflow pipelines container."
                )

        # Add environment variables for Azure Blob Storage to pod in case they
        # are set locally
        # TODO [ENG-699]: remove this as soon as we implement credential
        #  handling
        for key in [
            "AZURE_STORAGE_ACCOUNT_KEY",
            "AZURE_STORAGE_ACCOUNT_NAME",
            "AZURE_STORAGE_CONNECTION_STRING",
            "AZURE_STORAGE_SAS_TOKEN",
        ]:
            value = os.getenv(key)
            if value:
                container_op.container.add_env_variable(
                    k8s_client.V1EnvVar(name=key, value=value)
                )

        # Add some pod labels to the container_op
        for k, v in KFP_POD_LABELS.items():
            container_op.add_pod_label(k, v)

        # Mounts configmap containing Metadata gRPC server configuration.
        container_op.apply(utils.mount_config_map_op("metadata-grpc-configmap"))

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates a kfp yaml file.

        This functions as an intermediary representation of the pipeline which
        is then deployed to the kubeflow pipelines instance.

        How it works:
        -------------
        Before this method is called the `prepare_pipeline_deployment()`
        method builds a docker image that contains the code for the
        pipeline, all steps the context around these files.

        Based on this docker image a callable is created which builds
        container_ops for each step (`_construct_kfp_pipeline`).
        To do this the entrypoint of the docker image is configured to
        run the correct step within the docker image. The dependencies
        between these container_ops are then also configured onto each
        container_op by pointing at the downstream steps.

        This callable is then compiled into a kfp yaml file that is used as
        the intermediary representation of the kubeflow pipeline.

        This file, together with some metadata, runtime configurations is
        then uploaded into the kubeflow pipelines cluster for execution.

        Args:
            sorted_steps: A list of steps sorted by their order in the
                pipeline.
            pipeline: The pipeline object.
            pb2_pipeline: The pipeline object in protobuf format.
            stack: The stack object.
            runtime_configuration: The runtime configuration object.

        Raises:
            RuntimeError: If you try to run the pipelines in a notebook environment.
        """
        # First check whether the code running in a notebook
        if Environment.in_notebook():
            raise RuntimeError(
                "The Kubeflow orchestrator cannot run pipelines in a notebook "
                "environment. The reason is that it is non-trivial to create "
                "a Docker image of a notebook. Please consider refactoring "
                "your notebook cells into separate scripts in a Python module "
                "and run the code outside of a notebook when using this "
                "orchestrator."
            )

        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        # Create a callable for future compilation into a dsl.Pipeline.
        def _construct_kfp_pipeline() -> None:
            """Create a container_op for each step.

            This should contain the name of the docker image and configures the
            entrypoint of the docker image to run the step.

            Additionally, this gives each container_op information about its
            direct downstream steps.

            If this callable is passed to the `_create_and_write_workflow()`
            method of a KFPCompiler all dsl.ContainerOp instances will be
            automatically added to a singular dsl.Pipeline instance.
            """
            # Dictionary of container_ops index by the associated step name
            step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

            for step in sorted_steps:
                # The command will be needed to eventually call the python step
                # within the docker container
                command = (
                    KubeflowEntrypointConfiguration.get_entrypoint_command()
                )

                # The arguments are passed to configure the entrypoint of the
                # docker container when the step is called.
                metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
                arguments = (
                    KubeflowEntrypointConfiguration.get_entrypoint_arguments(
                        step=step,
                        pb2_pipeline=pb2_pipeline,
                        **{METADATA_UI_PATH_OPTION: metadata_ui_path},
                    )
                )

                # Create a container_op - the kubeflow equivalent of a step. It
                # contains the name of the step, the name of the docker image,
                # the command to use to run the step entrypoint
                # (e.g. `python -m zenml.entrypoints.step_entrypoint`)
                # and the arguments to be passed along with the command. Find
                # out more about how these arguments are parsed and used
                # in the base entrypoint `run()` method.
                container_op = dsl.ContainerOp(
                    name=step.name,
                    image=image_name,
                    command=command,
                    arguments=arguments,
                    output_artifact_paths={
                        "mlpipeline-ui-metadata": metadata_ui_path,
                    },
                )

                # Mounts persistent volumes, configmaps and adds labels to the
                # container op
                self._configure_container_op(container_op=container_op)

                # Find the upstream container ops of the current step and
                # configure the current container op to run after them
                upstream_step_names = self.get_upstream_step_names(
                    step=step, pb2_pipeline=pb2_pipeline
                )
                for upstream_step_name in upstream_step_names:
                    upstream_container_op = step_name_to_container_op[
                        upstream_step_name
                    ]
                    container_op.after(upstream_container_op)

                # Update dictionary of container ops with the current one
                step_name_to_container_op[step.name] = container_op

        # Get a filepath to use to save the finished yaml to
        assert runtime_configuration.run_name
        fileio.makedirs(self.pipeline_directory)
        pipeline_file_path = os.path.join(
            self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
        )

        # write the argo pipeline yaml
        KFPCompiler()._create_and_write_workflow(
            pipeline_func=_construct_kfp_pipeline,
            pipeline_name=pipeline.name,
            package_path=pipeline_file_path,
        )

        # using the kfp client uploads the pipeline to kubeflow pipelines and
        # runs it there
        self._upload_and_run_pipeline(
            pipeline_name=pipeline.name,
            pipeline_file_path=pipeline_file_path,
            runtime_configuration=runtime_configuration,
            enable_cache=pipeline.enable_cache,
        )

    def _upload_and_run_pipeline(
        self,
        pipeline_name: str,
        pipeline_file_path: str,
        runtime_configuration: "RuntimeConfiguration",
        enable_cache: bool,
    ) -> None:
        """Tries to upload and run a KFP pipeline.

        Args:
            pipeline_name: Name of the pipeline.
            pipeline_file_path: Path to the pipeline definition file.
            runtime_configuration: Runtime configuration of the pipeline run.
            enable_cache: Whether caching is enabled for this pipeline run.
        """
        try:
            logger.info(
                "Running in kubernetes context '%s'.",
                self.kubernetes_context,
            )

            # upload the pipeline to Kubeflow and start it
            client = kfp.Client(
                host=self.kubeflow_hostname,
                kube_context=self.kubernetes_context,
            )
            if runtime_configuration.schedule:
                try:
                    experiment = client.get_experiment(pipeline_name)
                    logger.info(
                        "A recurring run has already been created with this "
                        "pipeline. Creating new recurring run now.."
                    )
                except (ValueError, ApiException):
                    experiment = client.create_experiment(pipeline_name)
                    logger.info(
                        "Creating a new recurring run for pipeline '%s'.. ",
                        pipeline_name,
                    )
                logger.info(
                    "You can see all recurring runs under the '%s' experiment.'",
                    pipeline_name,
                )

                schedule = runtime_configuration.schedule
                interval_seconds = (
                    schedule.interval_second.seconds
                    if schedule.interval_second
                    else None
                )
                result = client.create_recurring_run(
                    experiment_id=experiment.id,
                    job_name=runtime_configuration.run_name,
                    pipeline_package_path=pipeline_file_path,
                    enable_caching=enable_cache,
                    cron_expression=schedule.cron_expression,
                    start_time=schedule.utc_start_time,
                    end_time=schedule.utc_end_time,
                    interval_second=interval_seconds,
                    no_catchup=not schedule.catchup,
                )

                logger.info("Started recurring run with ID '%s'.", result.id)
            else:
                logger.info(
                    "No schedule detected. Creating a one-off pipeline run.."
                )
                result = client.create_run_from_pipeline_package(
                    pipeline_file_path,
                    arguments={},
                    run_name=runtime_configuration.run_name,
                    enable_caching=enable_cache,
                )
                logger.info(
                    "Started one-off pipeline run with ID '%s'.", result.run_id
                )

                if self.synchronous:
                    # TODO [ENG-698]: Allow configuration of the timeout as a
                    #  runtime option
                    client.wait_for_run_completion(
                        run_id=result.run_id, timeout=1200
                    )
        except urllib3.exceptions.HTTPError as error:
            logger.warning(
                f"Failed to upload Kubeflow pipeline: %s. "
                f"Please make sure your kubernetes config is present and the "
                f"{self.kubernetes_context} kubernetes context is configured "
                f"correctly.",
                error,
            )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path of the daemon log file.
        """
        return os.path.join(self.root_directory, "kubeflow_daemon.log")

    @property
    def _k3d_cluster_name(self) -> str:
        """Returns the K3D cluster name.

        Returns:
            The K3D cluster name.
        """
        return self._get_k3d_cluster_name(self.uuid)

    def _get_k3d_registry_name(self, port: int) -> str:
        """Returns the K3D registry name.

        Args:
            port: Port of the registry.

        Returns:
            The registry name.
        """
        return f"k3d-zenml-kubeflow-registry.localhost:{port}"

    @property
    def _k3d_registry_config_path(self) -> str:
        """Returns the path to the K3D registry config yaml.

        Returns:
            str: Path to the K3D registry config yaml.
        """
        return os.path.join(self.root_directory, "k3d_registry.yaml")

    def _get_kfp_ui_daemon_port(self) -> int:
        """Port to use for the KFP UI daemon.

        Returns:
            Port to use for the KFP UI daemon.
        """
        port = self.kubeflow_pipelines_ui_port
        if port == DEFAULT_KFP_UI_PORT and not networking_utils.port_available(
            port
        ):
            # if the user didn't specify a specific port and the default
            # port is occupied, fallback to a random open port
            port = networking_utils.find_available_port()
        return port

    def list_manual_setup_steps(
        self, container_registry_name: str, container_registry_path: str
    ) -> None:
        """Logs manual steps needed to setup the Kubeflow local orchestrator.

        Args:
            container_registry_name: Name of the container registry.
            container_registry_path: Path to the container registry.
        """
        if not self.is_local:
            # Make sure we're not telling users to deploy Kubeflow on their
            # remote clusters
            logger.warning(
                "This Kubeflow orchestrator is configured to use a non-local "
                f"Kubernetes context {self.kubernetes_context}. Manually "
                f"deploying Kubeflow Pipelines is only possible for local "
                f"Kubeflow orchestrators."
            )
            return

        global_config_dir_path = io_utils.get_global_config_directory()
        kubeflow_commands = [
            f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
            f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
            f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
            f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
            f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
        ]

        logger.info(
            "If you wish to spin up this Kubeflow local orchestrator manually, "
            "please enter the following commands:\n"
        )
        logger.info("\n".join(kubeflow_commands))

    @property
    def is_provisioned(self) -> bool:
        """Returns if a local k3d cluster for this orchestrator exists.

        Returns:
            True if a local k3d cluster exists, False otherwise.
        """
        if not local_deployment_utils.check_prerequisites(
            skip_k3d=self.skip_cluster_provisioning or not self.is_local,
            skip_kubectl=self.skip_cluster_provisioning
            and self.skip_ui_daemon_provisioning,
        ):
            # if any prerequisites are missing there is certainly no
            # local deployment running
            return False

        return self.is_cluster_provisioned

    @property
    def is_running(self) -> bool:
        """Checks if the local k3d cluster and UI daemon are both running.

        Returns:
            True if the local k3d cluster and UI daemon for this orchestrator are both running.
        """
        return (
            self.is_provisioned
            and self.is_cluster_running
            and self.is_daemon_running
        )

    @property
    def is_suspended(self) -> bool:
        """Checks if the local k3d cluster and UI daemon are both stopped.

        Returns:
            True if the cluster and daemon for this orchestrator are both stopped, False otherwise.
        """
        return (
            self.is_provisioned
            and (self.skip_cluster_provisioning or not self.is_cluster_running)
            and (self.skip_ui_daemon_provisioning or not self.is_daemon_running)
        )

    @property
    def is_cluster_provisioned(self) -> bool:
        """Returns if the local k3d cluster for this orchestrator is provisioned.

        For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
        this always returns True.

        Returns:
            True if the local k3d cluster is provisioned, False otherwise.
        """
        if self.skip_cluster_provisioning or not self.is_local:
            return True
        return local_deployment_utils.k3d_cluster_exists(
            cluster_name=self._k3d_cluster_name
        )

    @property
    def is_cluster_running(self) -> bool:
        """Returns if the local k3d cluster for this orchestrator is running.

        For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations,
        this always returns True.

        Returns:
            True if the local k3d cluster is running, False otherwise.
        """
        if self.skip_cluster_provisioning or not self.is_local:
            return True
        return local_deployment_utils.k3d_cluster_running(
            cluster_name=self._k3d_cluster_name
        )

    @property
    def is_daemon_running(self) -> bool:
        """Returns if the local Kubeflow UI daemon for this orchestrator is running.

        Returns:
            True if the daemon is running, False otherwise.
        """
        if self.skip_ui_daemon_provisioning:
            return True

        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            return check_if_daemon_is_running(self._pid_file_path)
        else:
            return True

    def provision(self) -> None:
        """Provisions a local Kubeflow Pipelines deployment.

        Raises:
            ProvisioningError: If the provisioning fails.
        """
        if self.skip_cluster_provisioning:
            return

        if self.is_running:
            logger.info(
                "Found already existing local Kubeflow Pipelines deployment. "
                "If there are any issues with the existing deployment, please "
                "run 'zenml stack down --force' to delete it."
            )
            return

        if not local_deployment_utils.check_prerequisites():
            raise ProvisioningError(
                "Unable to provision local Kubeflow Pipelines deployment: "
                "Please install 'k3d' and 'kubectl' and try again."
            )

        container_registry = Repository().active_stack.container_registry

        # should not happen, because the stack validation takes care of this,
        # but just in case
        assert container_registry is not None

        fileio.makedirs(self.root_directory)

        if not self.is_local:
            # don't provision any resources if using a remote KFP installation
            return

        logger.info("Provisioning local Kubeflow Pipelines deployment...")

        container_registry_port = int(container_registry.uri.split(":")[-1])
        container_registry_name = self._get_k3d_registry_name(
            port=container_registry_port
        )
        local_deployment_utils.write_local_registry_yaml(
            yaml_path=self._k3d_registry_config_path,
            registry_name=container_registry_name,
            registry_uri=container_registry.uri,
        )

        try:
            local_deployment_utils.create_k3d_cluster(
                cluster_name=self._k3d_cluster_name,
                registry_name=container_registry_name,
                registry_config_path=self._k3d_registry_config_path,
            )
            kubernetes_context = self.kubernetes_context

            # will never happen, but mypy doesn't know that
            assert kubernetes_context is not None

            local_deployment_utils.deploy_kubeflow_pipelines(
                kubernetes_context=kubernetes_context
            )

            artifact_store = Repository().active_stack.artifact_store
            if isinstance(artifact_store, LocalArtifactStore):
                local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
                    kubernetes_context=kubernetes_context,
                    local_path=artifact_store.path,
                )
        except Exception as e:
            logger.error(e)
            logger.error(
                "Unable to spin up local Kubeflow Pipelines deployment."
            )

            self.list_manual_setup_steps(
                container_registry_name, self._k3d_registry_config_path
            )
            self.deprovision()

    def deprovision(self) -> None:
        """Deprovisions a local Kubeflow Pipelines deployment."""
        if self.skip_cluster_provisioning:
            return

        if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
            local_deployment_utils.stop_kfp_ui_daemon(
                pid_file_path=self._pid_file_path
            )

        if self.is_local:
            # don't deprovision any resources if using a remote KFP installation
            local_deployment_utils.delete_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

            logger.info("Local kubeflow pipelines deployment deprovisioned.")

        if fileio.exists(self.log_file):
            fileio.remove(self.log_file)

    def resume(self) -> None:
        """Resumes the local k3d cluster.

        Raises:
            ProvisioningError: If the k3d cluster is not provisioned.
        """
        if self.is_running:
            logger.info("Local kubeflow pipelines deployment already running.")
            return

        if not self.is_provisioned:
            raise ProvisioningError(
                "Unable to resume local kubeflow pipelines deployment: No "
                "resources provisioned for local deployment."
            )

        kubernetes_context = self.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None

        if (
            not self.skip_cluster_provisioning
            and self.is_local
            and not self.is_cluster_running
        ):
            # don't resume any resources if using a remote KFP installation
            local_deployment_utils.start_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

            local_deployment_utils.wait_until_kubeflow_pipelines_ready(
                kubernetes_context=kubernetes_context
            )

        if not self.is_daemon_running:
            local_deployment_utils.start_kfp_ui_daemon(
                pid_file_path=self._pid_file_path,
                log_file_path=self.log_file,
                port=self._get_kfp_ui_daemon_port(),
                kubernetes_context=kubernetes_context,
            )

    def suspend(self) -> None:
        """Suspends the local k3d cluster."""
        if not self.is_provisioned:
            logger.info("Local kubeflow pipelines deployment not provisioned.")
            return

        if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
            local_deployment_utils.stop_kfp_ui_daemon(
                pid_file_path=self._pid_file_path
            )

        if (
            not self.skip_cluster_provisioning
            and self.is_local
            and self.is_cluster_running
        ):
            # don't suspend any resources if using a remote KFP installation
            local_deployment_utils.stop_k3d_cluster(
                cluster_name=self._k3d_cluster_name
            )

    def _get_environment_vars_from_secrets(
        self, secrets: List[str]
    ) -> Dict[str, str]:
        """Get key-value pairs from list of secrets provided by the user.

        Args:
            secrets: List of secrets provided by the user.

        Returns:
            A dictionary of key-value pairs.

        Raises:
            ProvisioningError: If the stack has no secrets manager.
        """
        environment_vars: Dict[str, str] = {}
        secret_manager = Repository().active_stack.secrets_manager
        if secrets and secret_manager:
            for secret in secrets:
                secret_schema = secret_manager.get_secret(secret)
                environment_vars.update(secret_schema.content)
        elif secrets and not secret_manager:
            raise ProvisioningError(
                "Unable to provision local Kubeflow Pipelines deployment: "
                f"You passed in the following secrets: { ', '.join(secrets) }, "
                "however, no secrets manager is registered for the current "
                "stack."
            )
        else:
            # No secrets provided by the user.
            pass
        return environment_vars
is_cluster_provisioned: bool property readonly

Returns if the local k3d cluster for this orchestrator is provisioned.

For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.

Returns:

Type Description
bool

True if the local k3d cluster is provisioned, False otherwise.

is_cluster_running: bool property readonly

Returns if the local k3d cluster for this orchestrator is running.

For remote (i.e. not managed by ZenML) Kubeflow Pipelines installations, this always returns True.

Returns:

Type Description
bool

True if the local k3d cluster is running, False otherwise.

is_daemon_running: bool property readonly

Returns if the local Kubeflow UI daemon for this orchestrator is running.

Returns:

Type Description
bool

True if the daemon is running, False otherwise.

is_local: bool property readonly

Checks if the KFP orchestrator is running locally.

Returns:

Type Description
bool

True if the KFP orchestrator is running locally (i.e. in the local k3d cluster managed by ZenML).

is_provisioned: bool property readonly

Returns if a local k3d cluster for this orchestrator exists.

Returns:

Type Description
bool

True if a local k3d cluster exists, False otherwise.

is_running: bool property readonly

Checks if the local k3d cluster and UI daemon are both running.

Returns:

Type Description
bool

True if the local k3d cluster and UI daemon for this orchestrator are both running.

is_suspended: bool property readonly

Checks if the local k3d cluster and UI daemon are both stopped.

Returns:

Type Description
bool

True if the cluster and daemon for this orchestrator are both stopped, False otherwise.

log_file: str property readonly

Path of the daemon log file.

Returns:

Type Description
str

Path of the daemon log file.

pipeline_directory: str property readonly

Returns path to a directory in which the kubeflow pipeline files are stored.

Returns:

Type Description
str

Path to the pipeline directory.

root_directory: str property readonly

Returns path to the root directory for all files concerning this orchestrator.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Also check that requirements are met for local components.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

deprovision(self)

Deprovisions a local Kubeflow Pipelines deployment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def deprovision(self) -> None:
    """Deprovisions a local Kubeflow Pipelines deployment."""
    if self.skip_cluster_provisioning:
        return

    if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
        local_deployment_utils.stop_kfp_ui_daemon(
            pid_file_path=self._pid_file_path
        )

    if self.is_local:
        # don't deprovision any resources if using a remote KFP installation
        local_deployment_utils.delete_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )

        logger.info("Local kubeflow pipelines deployment deprovisioned.")

    if fileio.exists(self.log_file):
        fileio.remove(self.log_file)
get_docker_image_name(self, pipeline_name)

Returns the full docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required

Returns:

Type Description
str

The full docker image name including registry and tag.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Returns the full docker image name including registry and tag.

    Args:
        pipeline_name: The name of the pipeline.

    Returns:
        The full docker image name including registry and tag.
    """
    base_image_name = f"zenml-kubeflow:{pipeline_name}"
    container_registry = Repository().active_stack.container_registry

    if container_registry:
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/{base_image_name}"
    else:
        return base_image_name
get_kubernetes_contexts(self)

Get the list of configured Kubernetes contexts and the active context.

Returns:

Type Description
Tuple[List[str], str]

A tuple containing the list of configured Kubernetes contexts and the active context.

Exceptions:

Type Description
RuntimeError

if the Kubernetes configuration cannot be loaded

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
    """Get the list of configured Kubernetes contexts and the active context.

    Returns:
        A tuple containing the list of configured Kubernetes contexts and
        the active context.

    Raises:
        RuntimeError: if the Kubernetes configuration cannot be loaded
    """
    try:
        contexts, active_context = k8s_config.list_kube_config_contexts()
    except k8s_config.config_exception.ConfigException as e:
        raise RuntimeError(
            "Could not load the Kubernetes configuration"
        ) from e

    context_names = [c["name"] for c in contexts]
    active_context_name = active_context["name"]
    return context_names, active_context_name
list_manual_setup_steps(self, container_registry_name, container_registry_path)

Logs manual steps needed to setup the Kubeflow local orchestrator.

Parameters:

Name Type Description Default
container_registry_name str

Name of the container registry.

required
container_registry_path str

Path to the container registry.

required
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def list_manual_setup_steps(
    self, container_registry_name: str, container_registry_path: str
) -> None:
    """Logs manual steps needed to setup the Kubeflow local orchestrator.

    Args:
        container_registry_name: Name of the container registry.
        container_registry_path: Path to the container registry.
    """
    if not self.is_local:
        # Make sure we're not telling users to deploy Kubeflow on their
        # remote clusters
        logger.warning(
            "This Kubeflow orchestrator is configured to use a non-local "
            f"Kubernetes context {self.kubernetes_context}. Manually "
            f"deploying Kubeflow Pipelines is only possible for local "
            f"Kubeflow orchestrators."
        )
        return

    global_config_dir_path = io_utils.get_global_config_directory()
    kubeflow_commands = [
        f"> k3d cluster create {self._k3d_cluster_name} --image {local_deployment_utils.K3S_IMAGE_NAME} --registry-create {container_registry_name} --registry-config {container_registry_path} --volume {global_config_dir_path}:{global_config_dir_path}\n",
        f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
        f"> kubectl --context {self.kubernetes_context} wait --timeout=60s --for condition=established crd/applications.app.k8s.io",
        f"> kubectl --context {self.kubernetes_context} apply -k github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
        f"> kubectl --context {self.kubernetes_context} --namespace kubeflow port-forward svc/ml-pipeline-ui {self.kubeflow_pipelines_ui_port}:80",
    ]

    logger.info(
        "If you wish to spin up this Kubeflow local orchestrator manually, "
        "please enter the following commands:\n"
    )
    logger.info("\n".join(kubeflow_commands))
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates a kfp yaml file.

This functions as an intermediary representation of the pipeline which is then deployed to the kubeflow pipelines instance.

How it works:

Before this method is called the prepare_pipeline_deployment() method builds a docker image that contains the code for the pipeline, all steps the context around these files.

Based on this docker image a callable is created which builds container_ops for each step (_construct_kfp_pipeline). To do this the entrypoint of the docker image is configured to run the correct step within the docker image. The dependencies between these container_ops are then also configured onto each container_op by pointing at the downstream steps.

This callable is then compiled into a kfp yaml file that is used as the intermediary representation of the kubeflow pipeline.

This file, together with some metadata, runtime configurations is then uploaded into the kubeflow pipelines cluster for execution.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

A list of steps sorted by their order in the pipeline.

required
pipeline BasePipeline

The pipeline object.

required
pb2_pipeline Pipeline

The pipeline object in protobuf format.

required
stack Stack

The stack object.

required
runtime_configuration RuntimeConfiguration

The runtime configuration object.

required

Exceptions:

Type Description
RuntimeError

If you try to run the pipelines in a notebook environment.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates a kfp yaml file.

    This functions as an intermediary representation of the pipeline which
    is then deployed to the kubeflow pipelines instance.

    How it works:
    -------------
    Before this method is called the `prepare_pipeline_deployment()`
    method builds a docker image that contains the code for the
    pipeline, all steps the context around these files.

    Based on this docker image a callable is created which builds
    container_ops for each step (`_construct_kfp_pipeline`).
    To do this the entrypoint of the docker image is configured to
    run the correct step within the docker image. The dependencies
    between these container_ops are then also configured onto each
    container_op by pointing at the downstream steps.

    This callable is then compiled into a kfp yaml file that is used as
    the intermediary representation of the kubeflow pipeline.

    This file, together with some metadata, runtime configurations is
    then uploaded into the kubeflow pipelines cluster for execution.

    Args:
        sorted_steps: A list of steps sorted by their order in the
            pipeline.
        pipeline: The pipeline object.
        pb2_pipeline: The pipeline object in protobuf format.
        stack: The stack object.
        runtime_configuration: The runtime configuration object.

    Raises:
        RuntimeError: If you try to run the pipelines in a notebook environment.
    """
    # First check whether the code running in a notebook
    if Environment.in_notebook():
        raise RuntimeError(
            "The Kubeflow orchestrator cannot run pipelines in a notebook "
            "environment. The reason is that it is non-trivial to create "
            "a Docker image of a notebook. Please consider refactoring "
            "your notebook cells into separate scripts in a Python module "
            "and run the code outside of a notebook when using this "
            "orchestrator."
        )

    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    # Create a callable for future compilation into a dsl.Pipeline.
    def _construct_kfp_pipeline() -> None:
        """Create a container_op for each step.

        This should contain the name of the docker image and configures the
        entrypoint of the docker image to run the step.

        Additionally, this gives each container_op information about its
        direct downstream steps.

        If this callable is passed to the `_create_and_write_workflow()`
        method of a KFPCompiler all dsl.ContainerOp instances will be
        automatically added to a singular dsl.Pipeline instance.
        """
        # Dictionary of container_ops index by the associated step name
        step_name_to_container_op: Dict[str, dsl.ContainerOp] = {}

        for step in sorted_steps:
            # The command will be needed to eventually call the python step
            # within the docker container
            command = (
                KubeflowEntrypointConfiguration.get_entrypoint_command()
            )

            # The arguments are passed to configure the entrypoint of the
            # docker container when the step is called.
            metadata_ui_path = "/outputs/mlpipeline-ui-metadata.json"
            arguments = (
                KubeflowEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{METADATA_UI_PATH_OPTION: metadata_ui_path},
                )
            )

            # Create a container_op - the kubeflow equivalent of a step. It
            # contains the name of the step, the name of the docker image,
            # the command to use to run the step entrypoint
            # (e.g. `python -m zenml.entrypoints.step_entrypoint`)
            # and the arguments to be passed along with the command. Find
            # out more about how these arguments are parsed and used
            # in the base entrypoint `run()` method.
            container_op = dsl.ContainerOp(
                name=step.name,
                image=image_name,
                command=command,
                arguments=arguments,
                output_artifact_paths={
                    "mlpipeline-ui-metadata": metadata_ui_path,
                },
            )

            # Mounts persistent volumes, configmaps and adds labels to the
            # container op
            self._configure_container_op(container_op=container_op)

            # Find the upstream container ops of the current step and
            # configure the current container op to run after them
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                upstream_container_op = step_name_to_container_op[
                    upstream_step_name
                ]
                container_op.after(upstream_container_op)

            # Update dictionary of container ops with the current one
            step_name_to_container_op[step.name] = container_op

    # Get a filepath to use to save the finished yaml to
    assert runtime_configuration.run_name
    fileio.makedirs(self.pipeline_directory)
    pipeline_file_path = os.path.join(
        self.pipeline_directory, f"{runtime_configuration.run_name}.yaml"
    )

    # write the argo pipeline yaml
    KFPCompiler()._create_and_write_workflow(
        pipeline_func=_construct_kfp_pipeline,
        pipeline_name=pipeline.name,
        package_path=pipeline_file_path,
    )

    # using the kfp client uploads the pipeline to kubeflow pipelines and
    # runs it there
    self._upload_and_run_pipeline(
        pipeline_name=pipeline.name,
        pipeline_file_path=pipeline_file_path,
        runtime_configuration=runtime_configuration,
        enable_cache=pipeline.enable_cache,
    )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Builds a docker image for the current environment.

This function also uploads it to a container registry if configured.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to be deployed.

required
stack Stack

The stack to be deployed.

required
runtime_configuration RuntimeConfiguration

The runtime configuration to be used.

required
Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Builds a docker image for the current environment.

    This function also uploads it to a container registry if configured.

    Args:
        pipeline: The pipeline to be deployed.
        stack: The stack to be deployed.
        runtime_configuration: The runtime configuration to be used.
    """
    from zenml.utils import docker_utils

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug("Kubeflow docker container requirements: %s", requirements)

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
        environment_vars=self._get_environment_vars_from_secrets(
            pipeline.secrets
        ),
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest
provision(self)

Provisions a local Kubeflow Pipelines deployment.

Exceptions:

Type Description
ProvisioningError

If the provisioning fails.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def provision(self) -> None:
    """Provisions a local Kubeflow Pipelines deployment.

    Raises:
        ProvisioningError: If the provisioning fails.
    """
    if self.skip_cluster_provisioning:
        return

    if self.is_running:
        logger.info(
            "Found already existing local Kubeflow Pipelines deployment. "
            "If there are any issues with the existing deployment, please "
            "run 'zenml stack down --force' to delete it."
        )
        return

    if not local_deployment_utils.check_prerequisites():
        raise ProvisioningError(
            "Unable to provision local Kubeflow Pipelines deployment: "
            "Please install 'k3d' and 'kubectl' and try again."
        )

    container_registry = Repository().active_stack.container_registry

    # should not happen, because the stack validation takes care of this,
    # but just in case
    assert container_registry is not None

    fileio.makedirs(self.root_directory)

    if not self.is_local:
        # don't provision any resources if using a remote KFP installation
        return

    logger.info("Provisioning local Kubeflow Pipelines deployment...")

    container_registry_port = int(container_registry.uri.split(":")[-1])
    container_registry_name = self._get_k3d_registry_name(
        port=container_registry_port
    )
    local_deployment_utils.write_local_registry_yaml(
        yaml_path=self._k3d_registry_config_path,
        registry_name=container_registry_name,
        registry_uri=container_registry.uri,
    )

    try:
        local_deployment_utils.create_k3d_cluster(
            cluster_name=self._k3d_cluster_name,
            registry_name=container_registry_name,
            registry_config_path=self._k3d_registry_config_path,
        )
        kubernetes_context = self.kubernetes_context

        # will never happen, but mypy doesn't know that
        assert kubernetes_context is not None

        local_deployment_utils.deploy_kubeflow_pipelines(
            kubernetes_context=kubernetes_context
        )

        artifact_store = Repository().active_stack.artifact_store
        if isinstance(artifact_store, LocalArtifactStore):
            local_deployment_utils.add_hostpath_to_kubeflow_pipelines(
                kubernetes_context=kubernetes_context,
                local_path=artifact_store.path,
            )
    except Exception as e:
        logger.error(e)
        logger.error(
            "Unable to spin up local Kubeflow Pipelines deployment."
        )

        self.list_manual_setup_steps(
            container_registry_name, self._k3d_registry_config_path
        )
        self.deprovision()
resume(self)

Resumes the local k3d cluster.

Exceptions:

Type Description
ProvisioningError

If the k3d cluster is not provisioned.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def resume(self) -> None:
    """Resumes the local k3d cluster.

    Raises:
        ProvisioningError: If the k3d cluster is not provisioned.
    """
    if self.is_running:
        logger.info("Local kubeflow pipelines deployment already running.")
        return

    if not self.is_provisioned:
        raise ProvisioningError(
            "Unable to resume local kubeflow pipelines deployment: No "
            "resources provisioned for local deployment."
        )

    kubernetes_context = self.kubernetes_context

    # will never happen, but mypy doesn't know that
    assert kubernetes_context is not None

    if (
        not self.skip_cluster_provisioning
        and self.is_local
        and not self.is_cluster_running
    ):
        # don't resume any resources if using a remote KFP installation
        local_deployment_utils.start_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )

        local_deployment_utils.wait_until_kubeflow_pipelines_ready(
            kubernetes_context=kubernetes_context
        )

    if not self.is_daemon_running:
        local_deployment_utils.start_kfp_ui_daemon(
            pid_file_path=self._pid_file_path,
            log_file_path=self.log_file,
            port=self._get_kfp_ui_daemon_port(),
            kubernetes_context=kubernetes_context,
        )
set_default_kubernetes_context(values) classmethod

Pydantic root_validator.

This sets the default kubernetes_context value to the value that is used to create the locally managed k3d cluster, if not explicitly set.

Parameters:

Name Type Description Default
values Dict[str, Any]

Values passed to the object constructor

required

Returns:

Type Description
Dict[str, Any]

Values passed to the Pydantic constructor

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_default_kubernetes_context(
    cls, values: Dict[str, Any]
) -> Dict[str, Any]:
    """Pydantic root_validator.

    This sets the default `kubernetes_context` value to the value that is
    used to create the locally managed k3d cluster, if not explicitly set.

    Args:
        values: Values passed to the object constructor

    Returns:
        Values passed to the Pydantic constructor
    """
    if not values.get("kubernetes_context"):
        # not likely, due to Pydantic validation, but mypy complains
        assert "uuid" in values
        values["kubernetes_context"] = cls._get_k3d_kubernetes_context(
            values["uuid"]
        )

    return values
suspend(self)

Suspends the local k3d cluster.

Source code in zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py
def suspend(self) -> None:
    """Suspends the local k3d cluster."""
    if not self.is_provisioned:
        logger.info("Local kubeflow pipelines deployment not provisioned.")
        return

    if not self.skip_ui_daemon_provisioning and self.is_daemon_running:
        local_deployment_utils.stop_kfp_ui_daemon(
            pid_file_path=self._pid_file_path
        )

    if (
        not self.skip_cluster_provisioning
        and self.is_local
        and self.is_cluster_running
    ):
        # don't suspend any resources if using a remote KFP installation
        local_deployment_utils.stop_k3d_cluster(
            cluster_name=self._k3d_cluster_name
        )
local_deployment_utils

Utils for the local Kubeflow deployment behaviors.

add_hostpath_to_kubeflow_pipelines(kubernetes_context, local_path)

Patches the Kubeflow Pipelines deployment to mount a local folder.

This folder serves as a hostpath for visualization purposes.

This function reconfigures the Kubeflow pipelines deployment to use a shared local folder to support loading the TensorBoard viewer and other pipeline visualization results from a local artifact store, as described here:

https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context on which Kubeflow Pipelines should be patched.

required
local_path str

The path to the local folder to mount as a hostpath.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def add_hostpath_to_kubeflow_pipelines(
    kubernetes_context: str, local_path: str
) -> None:
    """Patches the Kubeflow Pipelines deployment to mount a local folder.

    This folder serves as a hostpath for visualization purposes.

    This function reconfigures the Kubeflow pipelines deployment to use a
    shared local folder to support loading the TensorBoard viewer and other
    pipeline visualization results from a local artifact store, as described
    here:

    https://github.com/kubeflow/pipelines/blob/master/docs/config/volume-support.md

    Args:
        kubernetes_context: The kubernetes context on which Kubeflow Pipelines
            should be patched.
        local_path: The path to the local folder to mount as a hostpath.
    """
    logger.info("Patching Kubeflow Pipelines to mount a local folder.")

    pod_template = {
        "spec": {
            "serviceAccountName": "kubeflow-pipelines-viewer",
            "containers": [
                {
                    "volumeMounts": [
                        {
                            "mountPath": local_path,
                            "name": "local-artifact-store",
                        }
                    ]
                }
            ],
            "volumes": [
                {
                    "hostPath": {
                        "path": local_path,
                        "type": "Directory",
                    },
                    "name": "local-artifact-store",
                }
            ],
        }
    }
    pod_template_json = json.dumps(pod_template, indent=2)
    config_map_data = {"data": {"viewer-pod-template.json": pod_template_json}}
    config_map_data_json = json.dumps(config_map_data, indent=2)

    logger.debug(
        "Adding host path volume for local path `%s` to kubeflow pipeline"
        "viewer pod template configuration.",
        local_path,
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "-n",
            "kubeflow",
            "patch",
            "configmap/ml-pipeline-ui-configmap",
            "--type",
            "merge",
            "-p",
            config_map_data_json,
        ]
    )

    deployment_patch = {
        "spec": {
            "template": {
                "spec": {
                    "containers": [
                        {
                            "name": "ml-pipeline-ui",
                            "volumeMounts": [
                                {
                                    "mountPath": local_path,
                                    "name": "local-artifact-store",
                                }
                            ],
                        }
                    ],
                    "volumes": [
                        {
                            "hostPath": {
                                "path": local_path,
                                "type": "Directory",
                            },
                            "name": "local-artifact-store",
                        }
                    ],
                }
            }
        }
    }
    deployment_patch_json = json.dumps(deployment_patch, indent=2)

    logger.debug(
        "Adding host path volume for local path `%s` to the kubeflow UI",
        local_path,
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "-n",
            "kubeflow",
            "patch",
            "deployment/ml-pipeline-ui",
            "--type",
            "strategic",
            "-p",
            deployment_patch_json,
        ]
    )
    wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)

    logger.info("Finished patching Kubeflow Pipelines setup.")
check_prerequisites(skip_k3d=False, skip_kubectl=False)

Checks prerequisites for a local kubeflow pipelines deployment.

It makes sure they are installed.

Parameters:

Name Type Description Default
skip_k3d bool

Whether to skip the check for the k3d command.

False
skip_kubectl bool

Whether to skip the check for the kubectl command.

False

Returns:

Type Description
bool

Whether all prerequisites are installed.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def check_prerequisites(
    skip_k3d: bool = False, skip_kubectl: bool = False
) -> bool:
    """Checks prerequisites for a local kubeflow pipelines deployment.

    It makes sure they are installed.

    Args:
        skip_k3d: Whether to skip the check for the k3d command.
        skip_kubectl: Whether to skip the check for the kubectl command.

    Returns:
        Whether all prerequisites are installed.
    """
    k3d_installed = skip_k3d or shutil.which("k3d") is not None
    kubectl_installed = skip_kubectl or shutil.which("kubectl") is not None
    logger.debug(
        "Local kubeflow deployment prerequisites: K3D - %s, Kubectl - %s",
        k3d_installed,
        kubectl_installed,
    )
    return k3d_installed and kubectl_installed
create_k3d_cluster(cluster_name, registry_name, registry_config_path)

Creates a K3D cluster.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to create.

required
registry_name str

Name of the registry to create for this cluster.

required
registry_config_path str

Path to the registry config file.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def create_k3d_cluster(
    cluster_name: str, registry_name: str, registry_config_path: str
) -> None:
    """Creates a K3D cluster.

    Args:
        cluster_name: Name of the cluster to create.
        registry_name: Name of the registry to create for this cluster.
        registry_config_path: Path to the registry config file.
    """
    logger.info("Creating local K3D cluster '%s'.", cluster_name)
    global_config_dir_path = io_utils.get_global_config_directory()
    subprocess.check_call(
        [
            "k3d",
            "cluster",
            "create",
            cluster_name,
            "--image",
            K3S_IMAGE_NAME,
            "--registry-create",
            registry_name,
            "--registry-config",
            registry_config_path,
            "--volume",
            f"{global_config_dir_path}:{global_config_dir_path}",
        ]
    )
    logger.info("Finished K3D cluster creation.")
delete_k3d_cluster(cluster_name)

Deletes a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to delete.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def delete_k3d_cluster(cluster_name: str) -> None:
    """Deletes a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to delete.
    """
    subprocess.check_call(["k3d", "cluster", "delete", cluster_name])
    logger.info("Deleted local k3d cluster '%s'.", cluster_name)
deploy_kubeflow_pipelines(kubernetes_context)

Deploys Kubeflow Pipelines.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context on which Kubeflow Pipelines should be deployed.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def deploy_kubeflow_pipelines(kubernetes_context: str) -> None:
    """Deploys Kubeflow Pipelines.

    Args:
        kubernetes_context: The kubernetes context on which Kubeflow Pipelines
            should be deployed.
    """
    logger.info("Deploying Kubeflow Pipelines.")
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/cluster-scoped-resources?ref={KFP_VERSION}&timeout=5m",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "wait",
            "--timeout=60s",
            "--for",
            "condition=established",
            "crd/applications.app.k8s.io",
        ]
    )
    subprocess.check_call(
        [
            "kubectl",
            "--context",
            kubernetes_context,
            "apply",
            "-k",
            f"github.com/kubeflow/pipelines/manifests/kustomize/env/platform-agnostic-pns?ref={KFP_VERSION}&timeout=5m",
        ]
    )

    wait_until_kubeflow_pipelines_ready(kubernetes_context=kubernetes_context)
    logger.info("Finished Kubeflow Pipelines setup.")
k3d_cluster_exists(cluster_name)

Checks whether there exists a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to check.

required

Returns:

Type Description
bool

Whether the cluster exists.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_exists(cluster_name: str) -> bool:
    """Checks whether there exists a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to check.

    Returns:
        Whether the cluster exists.
    """
    output = subprocess.check_output(
        ["k3d", "cluster", "list", "--output", "json"]
    )
    clusters = json.loads(output)
    for cluster in clusters:
        if cluster["name"] == cluster_name:
            return True
    return False
k3d_cluster_running(cluster_name)

Checks whether the K3D cluster with the given name is running.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to check.

required

Returns:

Type Description
bool

Whether the cluster is running.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def k3d_cluster_running(cluster_name: str) -> bool:
    """Checks whether the K3D cluster with the given name is running.

    Args:
        cluster_name: Name of the cluster to check.

    Returns:
        Whether the cluster is running.
    """
    output = subprocess.check_output(
        ["k3d", "cluster", "list", "--output", "json"]
    )
    clusters = json.loads(output)
    for cluster in clusters:
        if cluster["name"] == cluster_name:
            server_count: int = cluster["serversCount"]
            servers_running: int = cluster["serversRunning"]
            return servers_running == server_count
    return False
kubeflow_pipelines_ready(kubernetes_context)

Returns whether all Kubeflow Pipelines pods are ready.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context in which the pods should be checked.

required

Returns:

Type Description
bool

Whether all Kubeflow Pipelines pods are ready.

Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def kubeflow_pipelines_ready(kubernetes_context: str) -> bool:
    """Returns whether all Kubeflow Pipelines pods are ready.

    Args:
        kubernetes_context: The kubernetes context in which the pods
            should be checked.

    Returns:
        Whether all Kubeflow Pipelines pods are ready.
    """
    try:
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "wait",
                "--for",
                "condition=ready",
                "--timeout=0s",
                "pods",
                "-l",
                "application-crd-id=kubeflow-pipelines",
            ],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        return True
    except subprocess.CalledProcessError:
        return False
start_k3d_cluster(cluster_name)

Starts a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to start.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_k3d_cluster(cluster_name: str) -> None:
    """Starts a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to start.
    """
    subprocess.check_call(["k3d", "cluster", "start", cluster_name])
    logger.info("Started local k3d cluster '%s'.", cluster_name)
start_kfp_ui_daemon(pid_file_path, log_file_path, port, kubernetes_context)

Starts a daemon process that forwards ports.

This is so the Kubeflow Pipelines UI is accessible in the browser.

Parameters:

Name Type Description Default
pid_file_path str

Path where the file with the daemons process ID should be written.

required
log_file_path str

Path to a file where the daemon logs should be written.

required
port int

Port on which the UI should be accessible.

required
kubernetes_context str

The kubernetes context for the cluster where Kubeflow Pipelines is running.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def start_kfp_ui_daemon(
    pid_file_path: str,
    log_file_path: str,
    port: int,
    kubernetes_context: str,
) -> None:
    """Starts a daemon process that forwards ports.

    This is so the Kubeflow Pipelines UI is accessible in the browser.

    Args:
        pid_file_path: Path where the file with the daemons process ID should
            be written.
        log_file_path: Path to a file where the daemon logs should be written.
        port: Port on which the UI should be accessible.
        kubernetes_context: The kubernetes context for the cluster where
            Kubeflow Pipelines is running.
    """
    command = [
        "kubectl",
        "--context",
        kubernetes_context,
        "--namespace",
        "kubeflow",
        "port-forward",
        "svc/ml-pipeline-ui",
        f"{port}:80",
    ]

    if not networking_utils.port_available(port):
        modified_command = command.copy()
        modified_command[-1] = "PORT:80"
        logger.warning(
            "Unable to port-forward Kubeflow Pipelines UI to local port %d "
            "because the port is occupied. In order to access the Kubeflow "
            "Pipelines UI at http://localhost:PORT/, please run '%s' in a "
            "separate command line shell (replace PORT with a free port of "
            "your choice).",
            port,
            " ".join(modified_command),
        )
    elif sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubeflow Pipelines UI at "
            "http://localhost:%d/, please run '%s' in a separate command "
            "line shell.",
            port,
            " ".join(command),
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Port-forwards the Kubeflow Pipelines UI pod."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function, pid_file=pid_file_path, log_file=log_file_path
        )
        logger.info(
            "Started Kubeflow Pipelines UI daemon (check the daemon logs at %s "
            "in case you're not able to view the UI). The Kubeflow Pipelines "
            "UI should now be accessible at http://localhost:%d/.",
            log_file_path,
            port,
        )
stop_k3d_cluster(cluster_name)

Stops a K3D cluster with the given name.

Parameters:

Name Type Description Default
cluster_name str

Name of the cluster to stop.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_k3d_cluster(cluster_name: str) -> None:
    """Stops a K3D cluster with the given name.

    Args:
        cluster_name: Name of the cluster to stop.
    """
    subprocess.check_call(["k3d", "cluster", "stop", cluster_name])
    logger.info("Stopped local k3d cluster '%s'.", cluster_name)
stop_kfp_ui_daemon(pid_file_path)

Stops the KFP UI daemon process if it is running.

Parameters:

Name Type Description Default
pid_file_path str

Path to the file with the daemons process ID.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def stop_kfp_ui_daemon(pid_file_path: str) -> None:
    """Stops the KFP UI daemon process if it is running.

    Args:
        pid_file_path: Path to the file with the daemons process ID.
    """
    if fileio.exists(pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(pid_file_path)
            fileio.remove(pid_file_path)
            logger.info("Stopped Kubeflow Pipelines UI daemon.")
wait_until_kubeflow_pipelines_ready(kubernetes_context)

Waits until all Kubeflow Pipelines pods are ready.

Parameters:

Name Type Description Default
kubernetes_context str

The kubernetes context in which the pods should be checked.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def wait_until_kubeflow_pipelines_ready(kubernetes_context: str) -> None:
    """Waits until all Kubeflow Pipelines pods are ready.

    Args:
        kubernetes_context: The kubernetes context in which the pods
            should be checked.
    """
    logger.info(
        "Waiting for all Kubeflow Pipelines pods to be ready (this might "
        "take a few minutes)."
    )
    while True:
        logger.info("Current pod status:")
        subprocess.check_call(
            [
                "kubectl",
                "--context",
                kubernetes_context,
                "--namespace",
                "kubeflow",
                "get",
                "pods",
            ]
        )
        if kubeflow_pipelines_ready(kubernetes_context=kubernetes_context):
            break

        logger.info("One or more pods not ready yet, waiting for 30 seconds...")
        time.sleep(30)
write_local_registry_yaml(yaml_path, registry_name, registry_uri)

Writes a K3D registry config file.

Parameters:

Name Type Description Default
yaml_path str

Path where the config file should be written to.

required
registry_name str

Name of the registry.

required
registry_uri str

URI of the registry.

required
Source code in zenml/integrations/kubeflow/orchestrators/local_deployment_utils.py
def write_local_registry_yaml(
    yaml_path: str, registry_name: str, registry_uri: str
) -> None:
    """Writes a K3D registry config file.

    Args:
        yaml_path: Path where the config file should be written to.
        registry_name: Name of the registry.
        registry_uri: URI of the registry.
    """
    yaml_content = {
        "mirrors": {registry_uri: {"endpoint": [f"http://{registry_name}"]}}
    }
    yaml_utils.write_yaml(yaml_path, yaml_content)
utils

Utils for ZenML Kubeflow orchestrators implementation.

dump_ui_metadata(node, execution_info, metadata_ui_path)

Dump KFP UI metadata json file for visualization purpose.

For general components we just render a simple Markdown file for exec_properties/inputs/outputs.

Parameters:

Name Type Description Default
node PipelineNode

associated TFX node.

required
execution_info ExecutionInfo

runtime execution info for this component, including materialized inputs/outputs/execution properties and id.

required
metadata_ui_path str

path to dump ui metadata.

required
Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def dump_ui_metadata(
    node: PipelineNode,
    execution_info: data_types.ExecutionInfo,
    metadata_ui_path: str,
) -> None:
    """Dump KFP UI metadata json file for visualization purpose.

    For general components we just render a simple Markdown file for
        exec_properties/inputs/outputs.

    Args:
        node: associated TFX node.
        execution_info: runtime execution info for this component, including
            materialized inputs/outputs/execution properties and id.
        metadata_ui_path: path to dump ui metadata.
    """
    exec_properties_list = [
        "**{}**: {}".format(
            _sanitize_underscore(name), _sanitize_underscore(exec_property)
        )
        for name, exec_property in execution_info.exec_properties.items()
    ]
    src_str_exec_properties = "# Execution properties:\n{}".format(
        "\n\n".join(exec_properties_list) or "No execution property."
    )

    def _dump_input_populated_artifacts(
        node_inputs: MutableMapping[str, InputSpec],
        name_to_artifacts: Dict[str, List[artifact.Artifact]],
    ) -> List[str]:
        """Dump artifacts markdown string for inputs.

        Args:
            node_inputs: maps from input name to input sepc proto.
            name_to_artifacts: maps from input key to list of populated    artifacts.

        Returns:
            A list of dumped markdown string, each of which represents a channel.
        """
        rendered_list = []
        for name, spec in node_inputs.items():
            # Need to look for materialized artifacts in the execution decision.
            rendered_artifacts = "".join(
                [
                    _render_artifact_as_mdstr(single_artifact)
                    for single_artifact in name_to_artifacts.get(name, [])
                ]
            )
            # There must be at least a channel in a input, and all channels in
            # a input share the same artifact type.
            artifact_type = spec.channels[0].artifact_query.type.name
            rendered_list.append(
                "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
                    name=_sanitize_underscore(name),
                    channel_type=_sanitize_underscore(artifact_type),
                    artifacts=rendered_artifacts,
                )
            )

        return rendered_list

    def _dump_output_populated_artifacts(
        node_outputs: MutableMapping[str, OutputSpec],
        name_to_artifacts: Dict[str, List[artifact.Artifact]],
    ) -> List[str]:
        """Dump artifacts markdown string for outputs.

        Args:
            node_outputs: maps from output name to output sepc proto.
            name_to_artifacts: maps from output key to list of populated
                artifacts.

        Returns:
            A list of dumped markdown string, each of which represents a channel.
        """
        rendered_list = []
        for name, spec in node_outputs.items():
            # Need to look for materialized artifacts in the execution decision.
            rendered_artifacts = "".join(
                [
                    _render_artifact_as_mdstr(single_artifact)
                    for single_artifact in name_to_artifacts.get(name, [])
                ]
            )
            # There must be at least a channel in a input, and all channels
            # in a input share the same artifact type.
            artifact_type = spec.artifact_spec.type.name
            rendered_list.append(
                "## {name}\n\n**Type**: {channel_type}\n\n{artifacts}".format(
                    name=_sanitize_underscore(name),
                    channel_type=_sanitize_underscore(artifact_type),
                    artifacts=rendered_artifacts,
                )
            )

        return rendered_list

    src_str_inputs = "# Inputs:\n{}".format(
        "".join(
            _dump_input_populated_artifacts(
                node_inputs=node.inputs.inputs,
                name_to_artifacts=execution_info.input_dict or {},
            )
        )
        or "No input."
    )

    src_str_outputs = "# Outputs:\n{}".format(
        "".join(
            _dump_output_populated_artifacts(
                node_outputs=node.outputs.outputs,
                name_to_artifacts=execution_info.output_dict or {},
            )
        )
        or "No output."
    )

    outputs = [
        {
            "storage": "inline",
            "source": "{exec_properties}\n\n{inputs}\n\n{outputs}".format(
                exec_properties=src_str_exec_properties,
                inputs=src_str_inputs,
                outputs=src_str_outputs,
            ),
            "type": "markdown",
        }
    ]
    # Add TensorBoard view for ModelRun outputs.
    for name, spec in node.outputs.outputs.items():
        if (
            spec.artifact_spec.type.name
            == standard_artifacts.ModelRun.TYPE_NAME
            or spec.artifact_spec.type.name == ModelArtifact.TYPE_NAME
        ):
            output_model = execution_info.output_dict[name][0]
            source = output_model.uri

            # For local artifact repository, use a path that is relative to
            # the point where the local artifact folder is mounted as a volume
            artifact_store = Repository().active_stack.artifact_store
            if isinstance(artifact_store, LocalArtifactStore):
                source = os.path.relpath(source, artifact_store.path)
                source = f"volume://local-artifact-store/{source}"
            # Add TensorBoard view.
            tensorboard_output = {
                "type": "tensorboard",
                "source": source,
            }
            outputs.append(tensorboard_output)

    metadata_dict = {"outputs": outputs}

    with open(metadata_ui_path, "w") as f:
        json.dump(metadata_dict, f)
mount_config_map_op(config_map_name)

Mounts all key-value pairs found in the named Kubernetes ConfigMap.

All key-value pairs in the ConfigMap are mounted as environment variables.

Parameters:

Name Type Description Default
config_map_name str

The name of the ConfigMap resource.

required

Returns:

Type Description
Callable[[kfp.dsl._container_op.ContainerOp], NoneType]

An OpFunc for mounting the ConfigMap.

Source code in zenml/integrations/kubeflow/orchestrators/utils.py
def mount_config_map_op(
    config_map_name: str,
) -> Callable[[dsl.ContainerOp], None]:
    """Mounts all key-value pairs found in the named Kubernetes ConfigMap.

    All key-value pairs in the ConfigMap are mounted as environment variables.

    Args:
        config_map_name: The name of the ConfigMap resource.

    Returns:
        An OpFunc for mounting the ConfigMap.
    """

    def mount_config_map(container_op: dsl.ContainerOp) -> None:
        """Mounts all key-value pairs found in the Kubernetes ConfigMap.

        Args:
            container_op: The container op to mount the ConfigMap.
        """
        config_map_ref = k8s_client.V1ConfigMapEnvSource(
            name=config_map_name, optional=True
        )
        container_op.container.add_env_from(
            k8s_client.V1EnvFromSource(config_map_ref=config_map_ref)
        )

    return mount_config_map

kubernetes special

Kubernetes integration for Kubernetes-native orchestration.

The Kubernetes integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Kubernetes orchestrator with the CLI tool.

KubernetesIntegration (Integration)

Definition of Kubernetes integration for ZenML.

Source code in zenml/integrations/kubernetes/__init__.py
class KubernetesIntegration(Integration):
    """Definition of Kubernetes integration for ZenML."""

    NAME = KUBERNETES
    REQUIREMENTS = ["kubernetes==18.20.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Kubernetes integration.

        Returns:
            List of new stack component flavors.
        """
        return [
            FlavorWrapper(
                name=KUBERNETES_METADATA_STORE_FLAVOR,
                source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
                type=StackComponentType.METADATA_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=KUBERNETES_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the Kubernetes integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of new stack component flavors.

Source code in zenml/integrations/kubernetes/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Kubernetes integration.

    Returns:
        List of new stack component flavors.
    """
    return [
        FlavorWrapper(
            name=KUBERNETES_METADATA_STORE_FLAVOR,
            source="zenml.integrations.kubernetes.metadata_stores.KubernetesMetadataStore",
            type=StackComponentType.METADATA_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=KUBERNETES_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.kubernetes.orchestrators.KubernetesOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        ),
    ]

metadata_stores special

Initialization of the Kubernetes metadata store for ZenML.

kubernetes_metadata_store

Implementation of Kubernetes metadata store.

KubernetesMetadataStore (BaseMetadataStore) pydantic-model

Kubernetes metadata store (MySQL database deployed in the cluster).

Attributes:

Name Type Description
deployment_name str

Name of the Kubernetes deployment and corresponding service/pod that will be created when calling provision().

kubernetes_context str

Name of the Kubernetes context in which to deploy and provision the MySQL database.

kubernetes_namespace str

Name of the Kubernetes namespace. Defaults to "default".

storage_capacity str

Storage capacity of the metadata store. Defaults to "10Gi" (=10GB).

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
class KubernetesMetadataStore(BaseMetadataStore):
    """Kubernetes metadata store (MySQL database deployed in the cluster).

    Attributes:
        deployment_name: Name of the Kubernetes deployment and corresponding
            service/pod that will be created when calling `provision()`.
        kubernetes_context: Name of the Kubernetes context in which to deploy
            and provision the MySQL database.
        kubernetes_namespace: Name of the Kubernetes namespace.
            Defaults to "default".
        storage_capacity: Storage capacity of the metadata store.
            Defaults to `"10Gi"` (=10GB).
    """

    deployment_name: str
    kubernetes_context: str
    kubernetes_namespace: str = "zenml"
    storage_capacity: str = "10Gi"
    _k8s_core_api: k8s_client.CoreV1Api = None
    _k8s_apps_api: k8s_client.AppsV1Api = None

    FLAVOR: ClassVar[str] = KUBERNETES_METADATA_STORE_FLAVOR

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initiate the Pydantic object and initialize the Kubernetes clients.

        Args:
            *args: The positional arguments to pass to the Pydantic object.
            **kwargs: The keyword arguments to pass to the Pydantic object.
        """
        super().__init__(*args, **kwargs)
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients."""
        kube_utils.load_kube_config(context=self.kubernetes_context)
        self._k8s_core_api = k8s_client.CoreV1Api()
        self._k8s_apps_api = k8s_client.AppsV1Api()

    @root_validator(skip_on_failure=False)
    def check_required_attributes(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic root_validator.

        This ensures that both `deployment_name` and `kubernetes_context` are
        set and raises an error with a custom error message otherwise.

        Args:
            values: Values passed to the Pydantic constructor.

        Raises:
            StackComponentInterfaceError: if either `deployment_name` or
                `kubernetes_context` is not defined.

        Returns:
            Values passed to the Pydantic constructor.
        """
        usage_note = (
            "Note: the `kubernetes` metadata store flavor is a special "
            "subtype of the `mysql` metadata store that deploys a fresh "
            "MySQL database within your Kubernetes cluster when running "
            "`zenml stack up`. "
            "If you already have a MySQL database running in your cluster "
            "(or elsewhere), simply use the `mysql` metadata store flavor "
            "instead."
        )

        for required_field in ("deployment_name", "kubernetes_context"):
            if required_field not in values:
                raise StackComponentInterfaceError(
                    f"Required field `{required_field}` missing for "
                    "`KubernetesMetadataStore`. " + usage_note
                )

        return values

    @property
    def deployment_exists(self) -> bool:
        """Check whether a MySQL deployment exists in the cluster.

        Returns:
            Whether a MySQL deployment exists in the cluster.
        """
        resp = self._k8s_apps_api.list_namespaced_deployment(
            namespace=self.kubernetes_namespace
        )
        for i in resp.items:
            if i.metadata.name == self.deployment_name:
                return True
        return False

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run.

        Checks whether the required MySQL deployment exists.

        Returns:
            True if the component provisioned resources to run.
        """
        return super().is_provisioned and self.deployment_exists

    @property
    def is_running(self) -> bool:
        """If the component is running.

        Returns:
            True if `is_provisioned` else False.
        """
        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return self.is_provisioned

    def provision(self) -> None:
        """Provision the metadata store.

        Creates a deployment with a MySQL database running in it.
        """
        logger.info("Provisioning Kubernetes MySQL metadata store...")
        kube_utils.create_namespace(
            core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
        )
        kube_utils.create_mysql_deployment(
            core_api=self._k8s_core_api,
            apps_api=self._k8s_apps_api,
            namespace=self.kubernetes_namespace,
            storage_capacity=self.storage_capacity,
            deployment_name=self.deployment_name,
        )

        # wait a bit, then make sure deployment pod is alive and running.
        logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
        time.sleep(10)
        kube_utils.wait_pod(
            core_api=self._k8s_core_api,
            pod_name=self.pod_name,
            namespace=self.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_not_pending,
        )
        logger.info("Kubernetes MySQL metadata store pod is up and running.")

    def deprovision(self) -> None:
        """Deprovision the metadata store by deleting the MySQL deployment."""
        logger.info("Deleting Kubernetes MySQL metadata store...")
        self.suspend()
        kube_utils.delete_deployment(
            apps_api=self._k8s_apps_api,
            deployment_name=self.deployment_name,
            namespace=self.kubernetes_namespace,
        )

    # TODO: code duplication with kubeflow metadata store below.

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory for all files concerning this orchestrator.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            self.FLAVOR,
            str(self.uuid),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(
            self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_PID_FILE
        )

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(
            self.root_directory, DEFAULT_KUBERNETES_METADATA_DAEMON_LOG_FILE
        )

    def resume(self) -> None:
        """Resumes the metadata store."""
        self.start_metadata_daemon()
        self.wait_until_metadata_store_ready(
            timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
        )

    def suspend(self) -> None:
        """Suspends the metadata store."""
        self.stop_metadata_daemon()

    @property
    def pod_name(self) -> str:
        """Name of the Kubernetes pod where the MySQL database is deployed.

        Returns:
            Name of the Kubernetes pod.
        """
        pod_list = self._k8s_core_api.list_namespaced_pod(
            namespace=self.kubernetes_namespace,
            label_selector=f"app={self.deployment_name}",
        )
        return pod_list.items[0].metadata.name  # type: ignore[no-any-return]

    @property
    def host(self) -> str:
        """Get the MySQL host required to access the metadata store.

        This overwrites the MySQL host to use local host when
        running outside of the cluster so we can access the metadata store
        locally for post execution.

        Raises:
            RuntimeError: If the metadata store is not running.

        Returns:
            MySQL host.
        """
        if kube_utils.is_inside_kubernetes():
            return DEFAULT_KUBERNETES_MYSQL_HOST
        if not self.is_running:
            raise RuntimeError(
                "The Kubernetes metadata daemon is not running. Please run the "
                "following command to start it first:\n\n"
                "    'zenml metadata-store up'\n"
            )
        return DEFAULT_KUBERNETES_MYSQL_LOCAL_HOST

    @property
    def port(self) -> int:
        """Get the MySQL port required to access the metadata store.

        Returns:
            int: MySQL port.
        """
        return DEFAULT_KUBERNETES_MYSQL_PORT

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for the Kubernetes metadata store.

        Returns:
            The tfx metadata config.
        """
        config = MySQLDatabaseConfig(
            host=self.host,
            port=self.port,
            database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
            user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
            password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
        )
        connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
        return connection_config

    def start_metadata_daemon(self) -> None:
        """Starts a daemon process that forwards ports.

        This is so the MySQL database in the Kubernetes cluster is accessible
        on the localhost.

        Raises:
            ProvisioningError: if the daemon fails to start.
        """
        command = [
            "kubectl",
            "--context",
            self.kubernetes_context,
            "--namespace",
            self.kubernetes_namespace,
            "port-forward",
            f"svc/{self.deployment_name}",
            f"{self.port}:{self.port}",
        ]
        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Kubernetes Metadata locally, "
                "please run '%s' in a separate command line shell.",
                self.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.port):
            raise ProvisioningError(
                f"Unable to port-forward Kubernetes Metadata to local "
                f"port {self.port} because the port is occupied. In order to "
                f"access the Kubernetes Metadata locally, please "
                f"change the metadata store configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubernetes metadata store pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Kubernetes Metadata daemon (check the daemon"
                "logs at %s in case you're not able to access the pipeline"
                "metadata).",
                self._log_file,
            )

    def stop_metadata_daemon(self) -> None:
        """Stops the Kubernetes metadata daemon process if it is running."""
        if sys.platform != "win32" and fileio.exists(self._pid_file_path):
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)

    def wait_until_metadata_store_ready(self, timeout: int) -> None:
        """Waits until the metadata store connection is ready.

        Potentially an irrecoverable error could occur or the timeout could
        expire, so it checks for this.

        Args:
            timeout: The maximum time to wait for the metadata store to be
                ready.

        Raises:
            RuntimeError: if the metadata store is not ready after the timeout
        """
        logger.info(
            "Waiting for the Kubernetes metadata store to be ready (this "
            "might take a few minutes)."
        )
        while True:
            try:
                # it doesn't matter what we call here as long as it exercises
                # the MLMD connection
                self.get_pipelines()
                break
            except Exception as e:
                logger.info(
                    "The Kubernetes metadata store is not ready yet. Waiting "
                    "for 10 seconds..."
                )
                if timeout <= 0:
                    raise RuntimeError(
                        f"An unexpected error was encountered while waiting "
                        f"for the Kubernetes metadata store to be functional: "
                        f"{str(e)}"
                    ) from e
                timeout -= 10
                time.sleep(10)

        logger.info("The Kubernetes metadata store is functional.")
deployment_exists: bool property readonly

Check whether a MySQL deployment exists in the cluster.

Returns:

Type Description
bool

Whether a MySQL deployment exists in the cluster.

host: str property readonly

Get the MySQL host required to access the metadata store.

This overwrites the MySQL host to use local host when running outside of the cluster so we can access the metadata store locally for post execution.

Exceptions:

Type Description
RuntimeError

If the metadata store is not running.

Returns:

Type Description
str

MySQL host.

is_provisioned: bool property readonly

If the component provisioned resources to run.

Checks whether the required MySQL deployment exists.

Returns:

Type Description
bool

True if the component provisioned resources to run.

is_running: bool property readonly

If the component is running.

Returns:

Type Description
bool

True if is_provisioned else False.

pod_name: str property readonly

Name of the Kubernetes pod where the MySQL database is deployed.

Returns:

Type Description
str

Name of the Kubernetes pod.

port: int property readonly

Get the MySQL port required to access the metadata store.

Returns:

Type Description
int

MySQL port.

root_directory: str property readonly

Returns path to the root directory for all files concerning this orchestrator.

Returns:

Type Description
str

Path to the root directory.

__init__(self, *args, **kwargs) special

Initiate the Pydantic object and initialize the Kubernetes clients.

Parameters:

Name Type Description Default
*args Any

The positional arguments to pass to the Pydantic object.

()
**kwargs Any

The keyword arguments to pass to the Pydantic object.

{}
Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initiate the Pydantic object and initialize the Kubernetes clients.

    Args:
        *args: The positional arguments to pass to the Pydantic object.
        **kwargs: The keyword arguments to pass to the Pydantic object.
    """
    super().__init__(*args, **kwargs)
    self._initialize_k8s_clients()
check_required_attributes(values) classmethod

Pydantic root_validator.

This ensures that both deployment_name and kubernetes_context are set and raises an error with a custom error message otherwise.

Parameters:

Name Type Description Default
values Dict[str, Any]

Values passed to the Pydantic constructor.

required

Exceptions:

Type Description
StackComponentInterfaceError

if either deployment_name or kubernetes_context is not defined.

Returns:

Type Description
Dict[str, Any]

Values passed to the Pydantic constructor.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
@root_validator(skip_on_failure=False)
def check_required_attributes(
    cls, values: Dict[str, Any]
) -> Dict[str, Any]:
    """Pydantic root_validator.

    This ensures that both `deployment_name` and `kubernetes_context` are
    set and raises an error with a custom error message otherwise.

    Args:
        values: Values passed to the Pydantic constructor.

    Raises:
        StackComponentInterfaceError: if either `deployment_name` or
            `kubernetes_context` is not defined.

    Returns:
        Values passed to the Pydantic constructor.
    """
    usage_note = (
        "Note: the `kubernetes` metadata store flavor is a special "
        "subtype of the `mysql` metadata store that deploys a fresh "
        "MySQL database within your Kubernetes cluster when running "
        "`zenml stack up`. "
        "If you already have a MySQL database running in your cluster "
        "(or elsewhere), simply use the `mysql` metadata store flavor "
        "instead."
    )

    for required_field in ("deployment_name", "kubernetes_context"):
        if required_field not in values:
            raise StackComponentInterfaceError(
                f"Required field `{required_field}` missing for "
                "`KubernetesMetadataStore`. " + usage_note
            )

    return values
deprovision(self)

Deprovision the metadata store by deleting the MySQL deployment.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def deprovision(self) -> None:
    """Deprovision the metadata store by deleting the MySQL deployment."""
    logger.info("Deleting Kubernetes MySQL metadata store...")
    self.suspend()
    kube_utils.delete_deployment(
        apps_api=self._k8s_apps_api,
        deployment_name=self.deployment_name,
        namespace=self.kubernetes_namespace,
    )
get_tfx_metadata_config(self)

Return tfx metadata config for the Kubernetes metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for the Kubernetes metadata store.

    Returns:
        The tfx metadata config.
    """
    config = MySQLDatabaseConfig(
        host=self.host,
        port=self.port,
        database=DEFAULT_KUBERNETES_MYSQL_DATABASE,
        user=DEFAULT_KUBERNETES_MYSQL_USERNAME,
        password=DEFAULT_KUBERNETES_MYSQL_PASSWORD,
    )
    connection_config = metadata_store_pb2.ConnectionConfig(mysql=config)
    return connection_config
provision(self)

Provision the metadata store.

Creates a deployment with a MySQL database running in it.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def provision(self) -> None:
    """Provision the metadata store.

    Creates a deployment with a MySQL database running in it.
    """
    logger.info("Provisioning Kubernetes MySQL metadata store...")
    kube_utils.create_namespace(
        core_api=self._k8s_core_api, namespace=self.kubernetes_namespace
    )
    kube_utils.create_mysql_deployment(
        core_api=self._k8s_core_api,
        apps_api=self._k8s_apps_api,
        namespace=self.kubernetes_namespace,
        storage_capacity=self.storage_capacity,
        deployment_name=self.deployment_name,
    )

    # wait a bit, then make sure deployment pod is alive and running.
    logger.info("Trying to reach Kubernetes MySQL metadata store pod...")
    time.sleep(10)
    kube_utils.wait_pod(
        core_api=self._k8s_core_api,
        pod_name=self.pod_name,
        namespace=self.kubernetes_namespace,
        exit_condition_lambda=kube_utils.pod_is_not_pending,
    )
    logger.info("Kubernetes MySQL metadata store pod is up and running.")
resume(self)

Resumes the metadata store.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def resume(self) -> None:
    """Resumes the metadata store."""
    self.start_metadata_daemon()
    self.wait_until_metadata_store_ready(
        timeout=DEFAULT_KUBERNETES_METADATA_DAEMON_TIMEOUT
    )
start_metadata_daemon(self)

Starts a daemon process that forwards ports.

This is so the MySQL database in the Kubernetes cluster is accessible on the localhost.

Exceptions:

Type Description
ProvisioningError

if the daemon fails to start.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def start_metadata_daemon(self) -> None:
    """Starts a daemon process that forwards ports.

    This is so the MySQL database in the Kubernetes cluster is accessible
    on the localhost.

    Raises:
        ProvisioningError: if the daemon fails to start.
    """
    command = [
        "kubectl",
        "--context",
        self.kubernetes_context,
        "--namespace",
        self.kubernetes_namespace,
        "port-forward",
        f"svc/{self.deployment_name}",
        f"{self.port}:{self.port}",
    ]
    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Kubernetes Metadata locally, "
            "please run '%s' in a separate command line shell.",
            self.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.port):
        raise ProvisioningError(
            f"Unable to port-forward Kubernetes Metadata to local "
            f"port {self.port} because the port is occupied. In order to "
            f"access the Kubernetes Metadata locally, please "
            f"change the metadata store configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubernetes metadata store pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Kubernetes Metadata daemon (check the daemon"
            "logs at %s in case you're not able to access the pipeline"
            "metadata).",
            self._log_file,
        )
stop_metadata_daemon(self)

Stops the Kubernetes metadata daemon process if it is running.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def stop_metadata_daemon(self) -> None:
    """Stops the Kubernetes metadata daemon process if it is running."""
    if sys.platform != "win32" and fileio.exists(self._pid_file_path):
        from zenml.utils import daemon

        daemon.stop_daemon(self._pid_file_path)
        fileio.remove(self._pid_file_path)
suspend(self)

Suspends the metadata store.

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def suspend(self) -> None:
    """Suspends the metadata store."""
    self.stop_metadata_daemon()
wait_until_metadata_store_ready(self, timeout)

Waits until the metadata store connection is ready.

Potentially an irrecoverable error could occur or the timeout could expire, so it checks for this.

Parameters:

Name Type Description Default
timeout int

The maximum time to wait for the metadata store to be ready.

required

Exceptions:

Type Description
RuntimeError

if the metadata store is not ready after the timeout

Source code in zenml/integrations/kubernetes/metadata_stores/kubernetes_metadata_store.py
def wait_until_metadata_store_ready(self, timeout: int) -> None:
    """Waits until the metadata store connection is ready.

    Potentially an irrecoverable error could occur or the timeout could
    expire, so it checks for this.

    Args:
        timeout: The maximum time to wait for the metadata store to be
            ready.

    Raises:
        RuntimeError: if the metadata store is not ready after the timeout
    """
    logger.info(
        "Waiting for the Kubernetes metadata store to be ready (this "
        "might take a few minutes)."
    )
    while True:
        try:
            # it doesn't matter what we call here as long as it exercises
            # the MLMD connection
            self.get_pipelines()
            break
        except Exception as e:
            logger.info(
                "The Kubernetes metadata store is not ready yet. Waiting "
                "for 10 seconds..."
            )
            if timeout <= 0:
                raise RuntimeError(
                    f"An unexpected error was encountered while waiting "
                    f"for the Kubernetes metadata store to be functional: "
                    f"{str(e)}"
                ) from e
            timeout -= 10
            time.sleep(10)

    logger.info("The Kubernetes metadata store is functional.")

orchestrators special

Kubernetes-native orchestration.

dag_runner

DAG (Directed Acyclic Graph) Runners.

NodeStatus (Enum)

Status of the execution of a node.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class NodeStatus(Enum):
    """Status of the execution of a node."""

    WAITING = "Waiting"
    RUNNING = "Running"
    COMPLETED = "Completed"
ThreadedDagRunner

Multi-threaded DAG Runner.

This class expects a DAG of strings in adjacency list representation, as well as a custom run_fn as input, then calls run_fn(node) for each string node in the DAG.

Steps that can be executed in parallel will be started in separate threads.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
class ThreadedDagRunner:
    """Multi-threaded DAG Runner.

    This class expects a DAG of strings in adjacency list representation, as
    well as a custom `run_fn` as input, then calls `run_fn(node)` for each
    string node in the DAG.

    Steps that can be executed in parallel will be started in separate threads.
    """

    def __init__(
        self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
    ) -> None:
        """Define attributes and initialize all nodes in waiting state.

        Args:
            dag: Adjacency list representation of a DAG.
                E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
                `dag={2: [1], 3: [1], 4: [2, 3]}`
            run_fn: A function `run_fn(node)` that runs a single node
        """
        self.dag = dag
        self.reversed_dag = reverse_dag(dag)
        self.run_fn = run_fn
        self.nodes = dag.keys()
        self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
        self._lock = threading.Lock()

    def _can_run(self, node: str) -> bool:
        """Determine whether a node is ready to be run.

        This is the case if the node has not run yet and all of its upstream
        node have already completed.

        Args:
            node: The node.

        Returns:
            True if the node can run else False.
        """
        # Check that node has not run yet.
        if not self.node_states[node] == NodeStatus.WAITING:
            return False

        # Check that all upstream nodes of this node have already completed.
        for upstream_node in self.dag[node]:
            if not self.node_states[upstream_node] == NodeStatus.COMPLETED:
                return False

        return True

    def _run_node(self, node: str) -> None:
        """Run a single node.

        Calls the user-defined run_fn, then calls `self._finish_node`.

        Args:
            node: The node.
        """
        self.run_fn(node)
        self._finish_node(node)

    def _run_node_in_thread(self, node: str) -> threading.Thread:
        """Run a single node in a separate thread.

        First updates the node status to running.
        Then calls self._run_node() in a new thread and returns the thread.

        Args:
            node: The node.

        Returns:
            The thread in which the node was run.
        """
        # Update node status to running.
        assert self.node_states[node] == NodeStatus.WAITING
        with self._lock:
            self.node_states[node] = NodeStatus.RUNNING

        # Run node in new thread.
        thread = threading.Thread(target=self._run_node, args=(node,))
        thread.start()
        return thread

    def _finish_node(self, node: str) -> None:
        """Finish a node run.

        First updates the node status to completed.
        Then starts all other nodes that can now be run and waits for them.

        Args:
            node: The node.
        """
        # Update node status to completed.
        assert self.node_states[node] == NodeStatus.RUNNING
        with self._lock:
            self.node_states[node] = NodeStatus.COMPLETED

        # Run downstream nodes.
        threads = []
        for downstram_node in self.reversed_dag[node]:
            if self._can_run(downstram_node):
                thread = self._run_node_in_thread(downstram_node)
                threads.append(thread)

        # Wait for all downstream nodes to complete.
        for thread in threads:
            thread.join()

    def run(self) -> None:
        """Call `self.run_fn` on all nodes in `self.dag`.

        The order of execution is determined using topological sort.
        Each node is run in a separate thread to enable parallelism.
        """
        # Run all nodes that can be started immediately.
        # These will, in turn, start other nodes once all of their respective
        # upstream nodes have completed.
        threads = []
        for node in self.nodes:
            if self._can_run(node):
                thread = self._run_node_in_thread(node)
                threads.append(thread)

        # Wait till all nodes have completed.
        for thread in threads:
            thread.join()

        # Make sure all nodes were run, otherwise print a warning.
        for node in self.nodes:
            if self.node_states[node] == NodeStatus.WAITING:
                upstream_nodes = self.dag[node]
                logger.warning(
                    f"Node `{node}` was never run, because it was still"
                    f" waiting for the following nodes: `{upstream_nodes}`."
                )
__init__(self, dag, run_fn) special

Define attributes and initialize all nodes in waiting state.

Parameters:

Name Type Description Default
dag Dict[str, List[str]]

Adjacency list representation of a DAG. E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as dag={2: [1], 3: [1], 4: [2, 3]}

required
run_fn Callable[[str], Any]

A function run_fn(node) that runs a single node

required
Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def __init__(
    self, dag: Dict[str, List[str]], run_fn: Callable[[str], Any]
) -> None:
    """Define attributes and initialize all nodes in waiting state.

    Args:
        dag: Adjacency list representation of a DAG.
            E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
            `dag={2: [1], 3: [1], 4: [2, 3]}`
        run_fn: A function `run_fn(node)` that runs a single node
    """
    self.dag = dag
    self.reversed_dag = reverse_dag(dag)
    self.run_fn = run_fn
    self.nodes = dag.keys()
    self.node_states = {node: NodeStatus.WAITING for node in self.nodes}
    self._lock = threading.Lock()
run(self)

Call self.run_fn on all nodes in self.dag.

The order of execution is determined using topological sort. Each node is run in a separate thread to enable parallelism.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def run(self) -> None:
    """Call `self.run_fn` on all nodes in `self.dag`.

    The order of execution is determined using topological sort.
    Each node is run in a separate thread to enable parallelism.
    """
    # Run all nodes that can be started immediately.
    # These will, in turn, start other nodes once all of their respective
    # upstream nodes have completed.
    threads = []
    for node in self.nodes:
        if self._can_run(node):
            thread = self._run_node_in_thread(node)
            threads.append(thread)

    # Wait till all nodes have completed.
    for thread in threads:
        thread.join()

    # Make sure all nodes were run, otherwise print a warning.
    for node in self.nodes:
        if self.node_states[node] == NodeStatus.WAITING:
            upstream_nodes = self.dag[node]
            logger.warning(
                f"Node `{node}` was never run, because it was still"
                f" waiting for the following nodes: `{upstream_nodes}`."
            )
reverse_dag(dag)

Reverse a DAG.

Parameters:

Name Type Description Default
dag Dict[str, List[str]]

Adjacency list representation of a DAG.

required

Returns:

Type Description
Dict[str, List[str]]

Adjacency list representation of the reversed DAG.

Source code in zenml/integrations/kubernetes/orchestrators/dag_runner.py
def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """Reverse a DAG.

    Args:
        dag: Adjacency list representation of a DAG.

    Returns:
        Adjacency list representation of the reversed DAG.
    """
    reversed_dag = defaultdict(list)

    # Reverse all edges in the graph.
    for node, upstream_nodes in dag.items():
        for upstream_node in upstream_nodes:
            reversed_dag[upstream_node].append(node)

    # Add nodes without incoming edges back in.
    for node in dag:
        if node not in reversed_dag:
            reversed_dag[node] = []

    return reversed_dag
kube_utils

Utilities for Kubernetes related functions.

Internal interface: no backwards compatibility guarantees. Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils.py.

PodPhase (Enum)

Phase of the Kubernetes pod.

Pod phases are defined in https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
class PodPhase(enum.Enum):
    """Phase of the Kubernetes pod.

    Pod phases are defined in
    https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
    """

    PENDING = "Pending"
    RUNNING = "Running"
    SUCCEEDED = "Succeeded"
    FAILED = "Failed"
    UNKNOWN = "Unknown"
create_edit_service_account(core_api, rbac_api, service_account_name, namespace, cluster_role_binding_name='zenml-edit')

Create a new Kubernetes service account with "edit" rights.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
rbac_api RbacAuthorizationV1Api

Client of Rbac Authorization V1 API of Kubernetes API.

required
service_account_name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
cluster_role_binding_name str

Name of the cluster role binding. Defaults to "zenml-edit".

'zenml-edit'
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_edit_service_account(
    core_api: k8s_client.CoreV1Api,
    rbac_api: k8s_client.RbacAuthorizationV1Api,
    service_account_name: str,
    namespace: str,
    cluster_role_binding_name: str = "zenml-edit",
) -> None:
    """Create a new Kubernetes service account with "edit" rights.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        rbac_api: Client of Rbac Authorization V1 API of Kubernetes API.
        service_account_name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".
        cluster_role_binding_name: Name of the cluster role binding.
            Defaults to "zenml-edit".
    """
    crb_manifest = build_cluster_role_binding_manifest_for_service_account(
        name=cluster_role_binding_name,
        role_name="edit",
        service_account_name=service_account_name,
        namespace=namespace,
    )
    _if_not_exists(rbac_api.create_cluster_role_binding)(body=crb_manifest)

    sa_manifest = build_service_account_manifest(
        name=service_account_name, namespace=namespace
    )
    _if_not_exists(core_api.create_namespaced_service_account)(
        namespace=namespace,
        body=sa_manifest,
    )
create_mysql_deployment(core_api, apps_api, deployment_name, namespace, storage_capacity='10Gi', volume_name='mysql-pv-volume', volume_claim_name='mysql-pv-claim')

Create a Kubernetes deployment with a MySQL database running on it.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
apps_api AppsV1Api

Client of Apps V1 API of Kubernetes API.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
storage_capacity str

Storage capacity of the database. Defaults to "10Gi".

'10Gi'
deployment_name str

Name of the deployment. Defaults to "mysql".

required
volume_name str

Name of the persistent volume. Defaults to "mysql-pv-volume".

'mysql-pv-volume'
volume_claim_name str

Name of the persistent volume claim. Defaults to "mysql-pv-claim".

'mysql-pv-claim'
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_mysql_deployment(
    core_api: k8s_client.CoreV1Api,
    apps_api: k8s_client.AppsV1Api,
    deployment_name: str,
    namespace: str,
    storage_capacity: str = "10Gi",
    volume_name: str = "mysql-pv-volume",
    volume_claim_name: str = "mysql-pv-claim",
) -> None:
    """Create a Kubernetes deployment with a MySQL database running on it.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        apps_api: Client of Apps V1 API of Kubernetes API.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_capacity: Storage capacity of the database.
            Defaults to `"10Gi"`.
        deployment_name: Name of the deployment. Defaults to "mysql".
        volume_name: Name of the persistent volume.
            Defaults to `"mysql-pv-volume"`.
        volume_claim_name: Name of the persistent volume claim.
            Defaults to `"mysql-pv-claim"`.
    """
    pvc_manifest = build_persistent_volume_claim_manifest(
        name=volume_claim_name,
        namespace=namespace,
        storage_request=storage_capacity,
    )
    _if_not_exists(core_api.create_namespaced_persistent_volume_claim)(
        namespace=namespace,
        body=pvc_manifest,
    )
    pv_manifest = build_persistent_volume_manifest(
        name=volume_name, storage_capacity=storage_capacity
    )
    _if_not_exists(core_api.create_persistent_volume)(body=pv_manifest)
    deployment_manifest = build_mysql_deployment_manifest(
        name=deployment_name,
        namespace=namespace,
        pv_claim_name=volume_claim_name,
    )
    _if_not_exists(apps_api.create_namespaced_deployment)(
        body=deployment_manifest, namespace=namespace
    )
    service_manifest = build_mysql_service_manifest(
        name=deployment_name, namespace=namespace
    )
    _if_not_exists(core_api.create_namespaced_service)(
        namespace=namespace, body=service_manifest
    )
create_namespace(core_api, namespace)

Create a Kubernetes namespace.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of Core V1 API of Kubernetes API.

required
namespace str

Kubernetes namespace. Defaults to "default".

required
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def create_namespace(core_api: k8s_client.CoreV1Api, namespace: str) -> None:
    """Create a Kubernetes namespace.

    Args:
        core_api: Client of Core V1 API of Kubernetes API.
        namespace: Kubernetes namespace. Defaults to "default".
    """
    manifest = build_namespace_manifest(namespace)
    _if_not_exists(core_api.create_namespace)(body=manifest)
delete_deployment(apps_api, deployment_name, namespace)

Delete a Kubernetes deployment.

Parameters:

Name Type Description Default
apps_api AppsV1Api

Client of Apps V1 API of Kubernetes API.

required
deployment_name str

Name of the deployment to be deleted.

required
namespace str

Kubernetes namespace containing the deployment.

required
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def delete_deployment(
    apps_api: k8s_client.AppsV1Api, deployment_name: str, namespace: str
) -> None:
    """Delete a Kubernetes deployment.

    Args:
        apps_api: Client of Apps V1 API of Kubernetes API.
        deployment_name: Name of the deployment to be deleted.
        namespace: Kubernetes namespace containing the deployment.
    """
    options = k8s_client.V1DeleteOptions()
    apps_api.delete_namespaced_deployment(
        name=deployment_name,
        namespace=namespace,
        body=options,
        propagation_policy="Foreground",
    )
get_pod(core_api, pod_name, namespace)

Get a pod from Kubernetes metadata API.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of CoreV1Api of Kubernetes API.

required
pod_name str

The name of the pod.

required
namespace str

The namespace of the pod.

required

Exceptions:

Type Description
RuntimeError

When it sees unexpected errors from Kubernetes API.

Returns:

Type Description
Optional[kubernetes.client.models.v1_pod.V1Pod]

The found pod object. None if it's not found.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def get_pod(
    core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str
) -> Optional[k8s_client.V1Pod]:
    """Get a pod from Kubernetes metadata API.

    Args:
        core_api: Client of `CoreV1Api` of Kubernetes API.
        pod_name: The name of the pod.
        namespace: The namespace of the pod.

    Raises:
        RuntimeError: When it sees unexpected errors from Kubernetes API.

    Returns:
        The found pod object. None if it's not found.
    """
    try:
        return core_api.read_namespaced_pod(name=pod_name, namespace=namespace)
    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            return None
        raise RuntimeError from e
is_inside_kubernetes()

Check whether we are inside a Kubernetes cluster or on a remote host.

Returns:

Type Description
bool

True if inside a Kubernetes cluster, else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def is_inside_kubernetes() -> bool:
    """Check whether we are inside a Kubernetes cluster or on a remote host.

    Returns:
        True if inside a Kubernetes cluster, else False.
    """
    try:
        k8s_config.load_incluster_config()
        return True
    except k8s_config.ConfigException:
        return False
load_kube_config(context=None)

Load the Kubernetes client config.

Depending on the environment (whether it is inside the running Kubernetes cluster or remote host), different location will be searched for the config file.

Parameters:

Name Type Description Default
context Optional[str]

Name of the Kubernetes context. If not provided, uses the currently active context.

None
Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def load_kube_config(context: Optional[str] = None) -> None:
    """Load the Kubernetes client config.

    Depending on the environment (whether it is inside the running Kubernetes
    cluster or remote host), different location will be searched for the config
    file.

    Args:
        context: Name of the Kubernetes context. If not provided, uses the
            currently active context.
    """
    try:
        k8s_config.load_incluster_config()
    except k8s_config.ConfigException:
        k8s_config.load_kube_config(context=context)
pod_failed(pod)

Check if pod status is 'Failed'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

True if pod status is 'Failed' else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_failed(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is 'Failed'.

    Args:
        pod: Kubernetes pod.

    Returns:
        True if pod status is 'Failed' else False.
    """
    return pod.status.phase == PodPhase.FAILED.value  # type: ignore[no-any-return]
pod_is_done(pod)

Check if pod status is 'Succeeded'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

True if pod status is 'Succeeded' else False.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_done(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is 'Succeeded'.

    Args:
        pod: Kubernetes pod.

    Returns:
        True if pod status is 'Succeeded' else False.
    """
    return pod.status.phase == PodPhase.SUCCEEDED.value  # type: ignore[no-any-return]
pod_is_not_pending(pod)

Check if pod status is not 'Pending'.

Parameters:

Name Type Description Default
pod V1Pod

Kubernetes pod.

required

Returns:

Type Description
bool

False if the pod status is 'Pending' else True.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def pod_is_not_pending(pod: k8s_client.V1Pod) -> bool:
    """Check if pod status is not 'Pending'.

    Args:
        pod: Kubernetes pod.

    Returns:
        False if the pod status is 'Pending' else True.
    """
    return pod.status.phase != PodPhase.PENDING.value  # type: ignore[no-any-return]
sanitize_pod_name(pod_name)

Sanitize pod names so they conform to Kubernetes pod naming convention.

Parameters:

Name Type Description Default
pod_name str

Arbitrary input pod name.

required

Returns:

Type Description
str

Sanitized pod name.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def sanitize_pod_name(pod_name: str) -> str:
    """Sanitize pod names so they conform to Kubernetes pod naming convention.

    Args:
        pod_name: Arbitrary input pod name.

    Returns:
        Sanitized pod name.
    """
    pod_name = re.sub(r"[^a-z0-9-]", "-", pod_name.lower())
    pod_name = re.sub(r"^[-]+", "", pod_name)
    return re.sub(r"[-]+", "-", pod_name)
wait_pod(core_api, pod_name, namespace, exit_condition_lambda, timeout_sec=0, exponential_backoff=False, stream_logs=False)

Wait for a pod to meet an exit condition.

Parameters:

Name Type Description Default
core_api CoreV1Api

Client of CoreV1Api of Kubernetes API.

required
pod_name str

The name of the pod.

required
namespace str

The namespace of the pod.

required
exit_condition_lambda Callable[[kubernetes.client.models.v1_pod.V1Pod], bool]

A lambda which will be called periodically to wait for a pod to exit. The function returns True to exit.

required
timeout_sec int

Timeout in seconds to wait for pod to reach exit condition, or 0 to wait for an unlimited duration. Defaults to unlimited.

0
exponential_backoff bool

Whether to use exponential back off for polling. Defaults to False.

False
stream_logs bool

Whether to stream the pod logs to zenml.logger.info(). Defaults to False.

False

Exceptions:

Type Description
RuntimeError

when the function times out.

Returns:

Type Description
V1Pod

The pod object which meets the exit condition.

Source code in zenml/integrations/kubernetes/orchestrators/kube_utils.py
def wait_pod(
    core_api: k8s_client.CoreV1Api,
    pod_name: str,
    namespace: str,
    exit_condition_lambda: Callable[[k8s_client.V1Pod], bool],
    timeout_sec: int = 0,
    exponential_backoff: bool = False,
    stream_logs: bool = False,
) -> k8s_client.V1Pod:
    """Wait for a pod to meet an exit condition.

    Args:
        core_api: Client of `CoreV1Api` of Kubernetes API.
        pod_name: The name of the pod.
        namespace: The namespace of the pod.
        exit_condition_lambda: A lambda
            which will be called periodically to wait for a pod to exit. The
            function returns True to exit.
        timeout_sec: Timeout in seconds to wait for pod to reach exit
            condition, or 0 to wait for an unlimited duration.
            Defaults to unlimited.
        exponential_backoff: Whether to use exponential back off for polling.
            Defaults to False.
        stream_logs: Whether to stream the pod logs to
            `zenml.logger.info()`. Defaults to False.

    Raises:
        RuntimeError: when the function times out.

    Returns:
        The pod object which meets the exit condition.
    """
    start_time = datetime.datetime.utcnow()

    # Link to exponential back-off algorithm used here:
    # https://cloud.google.com/storage/docs/exponential-backoff
    backoff_interval = 1
    maximum_backoff = 32

    logged_lines = 0

    while True:
        resp = get_pod(core_api, pod_name, namespace)

        # Stream logs to `zenml.logger.info()`.
        # TODO: can we do this without parsing all logs every time?
        if stream_logs and pod_is_not_pending(resp):
            response = core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
            )
            logs = response.splitlines()
            if len(logs) > logged_lines:
                for line in logs[logged_lines:]:
                    logger.info(line)
                logged_lines = len(logs)

        # Raise an error if the pod failed.
        if pod_failed(resp):
            raise RuntimeError(f"Pod `{namespace}:{pod_name}` failed.")

        # Check if pod is in desired state (e.g. finished / running / ...).
        if exit_condition_lambda(resp):
            return resp

        # Check if wait timed out.
        elapse_time = datetime.datetime.utcnow() - start_time
        if elapse_time.seconds >= timeout_sec and timeout_sec != 0:
            raise RuntimeError(
                f"Waiting for pod `{namespace}:{pod_name}` timed out after "
                f"{timeout_sec} seconds."
            )

        # Wait (using exponential backoff).
        time.sleep(backoff_interval)
        if exponential_backoff and backoff_interval < maximum_backoff:
            backoff_interval *= 2
kubernetes_orchestrator

Kubernetes-native orchestrator.

KubernetesOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator for running ZenML pipelines using native Kubernetes.

Attributes:

Name Type Description
custom_docker_base_image_name Optional[str]

Name of a Docker image that should be used as the base for the image that will be run on Kubernetes pods. If no custom image is given, a basic image of the active ZenML version will be used. Note: This image needs to have ZenML installed, otherwise the pipeline execution will fail. For that reason, you might want to extend the ZenML Docker images found here: https://hub.docker.com/r/zenmldocker/zenml/

kubernetes_context Optional[str]

Optional name of a Kubernetes context to run pipelines in. If not set, the current active context will be used. You can find the active context by running kubectl config current-context.

kubernetes_namespace str

Name of the Kubernetes namespace to be used. If not provided, default namespace will be used.

synchronous bool

If True, running a pipeline using this orchestrator will block until all steps finished running on Kubernetes.

skip_config_loading bool

If True, don't load the Kubernetes context and clients. This is only useful for unit testing.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
class KubernetesOrchestrator(BaseOrchestrator):
    """Orchestrator for running ZenML pipelines using native Kubernetes.

    Attributes:
        custom_docker_base_image_name: Name of a Docker image that should be
            used as the base for the image that will be run on Kubernetes pods.
            If no custom image is given, a basic image of the active ZenML
            version will be used.
            **Note**: This image needs to have ZenML installed,
            otherwise the pipeline execution will fail. For that reason, you
            might want to extend the ZenML Docker images found here:
            https://hub.docker.com/r/zenmldocker/zenml/
        kubernetes_context: Optional name of a Kubernetes context to run
            pipelines in. If not set, the current active context will be used.
            You can find the active context by running `kubectl config
            current-context`.
        kubernetes_namespace: Name of the Kubernetes namespace to be used.
            If not provided, `default` namespace will be used.
        synchronous: If `True`, running a pipeline using this orchestrator will
            block until all steps finished running on Kubernetes.
        skip_config_loading: If `True`, don't load the Kubernetes context and
            clients. This is only useful for unit testing.
    """

    custom_docker_base_image_name: Optional[str] = None
    kubernetes_context: Optional[str] = None
    kubernetes_namespace: str = "zenml"
    synchronous: bool = False
    skip_config_loading: bool = False
    _k8s_core_api: k8s_client.CoreV1Api = None
    _k8s_batch_api: k8s_client.BatchV1beta1Api = None
    _k8s_rbac_api: k8s_client.RbacAuthorizationV1Api = None

    FLAVOR: ClassVar[str] = KUBERNETES_ORCHESTRATOR_FLAVOR

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initialize the Pydantic object the Kubernetes clients.

        Args:
            *args: The positional arguments to pass to the Pydantic object.
            **kwargs: The keyword arguments to pass to the Pydantic object.
        """
        super().__init__(*args, **kwargs)
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients."""
        if self.skip_config_loading:
            return
        kube_utils.load_kube_config(context=self.kubernetes_context)
        self._k8s_core_api = k8s_client.CoreV1Api()
        self._k8s_batch_api = k8s_client.BatchV1beta1Api()
        self._k8s_rbac_api = k8s_client.RbacAuthorizationV1Api()

    def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
        """Get list of configured Kubernetes contexts and the active context.

        Raises:
            RuntimeError: if the Kubernetes configuration cannot be loaded.

        Returns:
            context_name: List of configured Kubernetes contexts
            active_context_name: Name of the active Kubernetes context.
        """
        try:
            contexts, active_context = k8s_config.list_kube_config_contexts()
        except k8s_config.config_exception.ConfigException as e:
            raise RuntimeError(
                "Could not load the Kubernetes configuration"
            ) from e

        context_names = [c["name"] for c in contexts]
        active_context_name = active_context["name"]
        return context_names, active_context_name

    @property
    def validator(self) -> Optional[StackValidator]:
        """Defines the validator that checks whether the stack is valid.

        Returns:
            Stack validator.
        """

        def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that the stack contains no local components.

            Args:
                stack: The stack.

            Returns:
                Whether the stack is valid or not.
                An explanation why the stack is invalid, if applicable.
            """
            container_registry = stack.container_registry

            # should not happen, because the stack validation takes care of
            # this, but just in case
            assert container_registry is not None

            if not self.skip_config_loading:
                contexts, active_context = self.get_kubernetes_contexts()
                if self.kubernetes_context not in contexts:
                    return False, (
                        f"Could not find a Kubernetes context named "
                        f"'{self.kubernetes_context}' in the local Kubernetes "
                        f"configuration. Please make sure that the Kubernetes "
                        f"cluster is running and that the kubeconfig file is "
                        f"configured correctly. To list all configured "
                        f"contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                    )
                if self.kubernetes_context != active_context:
                    logger.warning(
                        f"The Kubernetes context '{self.kubernetes_context}' "
                        f"configured for the Kubernetes orchestrator is not "
                        f"the same as the active context in the local "
                        f"Kubernetes configuration. If this is not deliberate,"
                        f" you should update the orchestrator's "
                        f"`kubernetes_context` field by running:\n\n"
                        f"  `zenml orchestrator update {self.name} "
                        f"--kubernetes_context={active_context}`\n"
                        f"To list all configured contexts, run:\n\n"
                        f"  `kubectl config get-contexts`\n"
                        f"To set the active context to be the same as the one "
                        f"configured in the Kubernetes orchestrator and "
                        f"silence this warning, run:\n\n"
                        f"  `kubectl config use-context "
                        f"{self.kubernetes_context}`\n"
                    )

            # Check that all stack components are non-local.
            for stack_comp in stack.components.values():
                if stack_comp.local_path:
                    return False, (
                        f"The Kubernetes orchestrator currently only supports "
                        f"remote stacks, but the '{stack_comp.name}' "
                        f"{stack_comp.TYPE.value} is a local component. "
                        f"Please make sure to only use non-local stack "
                        f"components with a Kubernetes orchestrator."
                    )

            # if the orchestrator is remote, the container registry must
            # also be remote.
            if container_registry.is_local:
                return False, (
                    f"The Kubernetes orchestrator requires a remote container "
                    f"registry, but the '{container_registry.name}' container "
                    f"registry of your active stack points to a local URI "
                    f"'{container_registry.uri}'. Please make sure stacks "
                    f"with a Kubernetes orchestrator always contain remote "
                    f"container registries."
                )

            return True, ""

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_validate_local_requirements,
        )

    def get_docker_image_name(self, pipeline_name: str) -> str:
        """Return the full Docker image name including registry and tag.

        Args:
            pipeline_name: Name of a ZenML pipeline.

        Returns:
            Docker image name.
        """
        container_registry = Repository().active_stack.container_registry
        assert container_registry
        registry_uri = container_registry.uri.rstrip("/")
        return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Build a Docker image and upload it to the container registry.

        Args:
            pipeline: A ZenML pipeline.
            stack: A ZenML stack.
            runtime_configuration: The runtime configuration of the pipeline.
        """
        from zenml.utils import docker_utils

        image_name = self.get_docker_image_name(pipeline.name)

        requirements = {*stack.requirements(), *pipeline.requirements}

        logger.debug("Kubernetes container requirements: %s", requirements)

        docker_utils.build_docker_image(
            build_context_path=get_source_root_path(),
            image_name=image_name,
            dockerignore_path=pipeline.dockerignore_file,
            requirements=requirements,
            base_image=self.custom_docker_base_image_name,
        )

        assert stack.container_registry  # should never happen due to validation
        stack.container_registry.push_image(image_name)

        # Store the Docker image digest in the runtime configuration so it gets
        # tracked in the ZenStore
        image_digest = docker_utils.get_image_digest(image_name) or image_name
        runtime_configuration["docker_image"] = image_digest

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List["BaseStep"],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Run pipeline in Kubernetes.

        Args:
            sorted_steps: List of steps in execution order.
            pipeline: ZenML pipeline.
            pb2_pipeline: ZenML pipeline in TFX pb2 format.
            stack: ZenML stack.
            runtime_configuration: The runtime configuration of the pipeline.

        Raises:
            RuntimeError: If trying to run from a Jupyter notebook.
        """
        # First check whether the code is running in a notebook.
        if Environment.in_notebook():
            raise RuntimeError(
                "The Kubernetes orchestrator cannot run pipelines in a notebook "
                "environment. The reason is that it is non-trivial to create "
                "a Docker image of a notebook. Please consider refactoring "
                "your notebook cells into separate scripts in a Python module "
                "and run the code outside of a notebook when using this "
                "orchestrator."
            )

        assert runtime_configuration.run_name, "Run name must be set"

        run_name = runtime_configuration.run_name
        pipeline_name = pipeline.name
        pod_name = kube_utils.sanitize_pod_name(run_name)

        # Get Docker image name (for all pods).
        image_name = self.get_docker_image_name(pipeline.name)
        image_name = get_image_digest(image_name) or image_name

        # Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
        pipeline_dag: Dict[str, List[str]] = {
            step.name: self.get_upstream_step_names(step, pb2_pipeline)
            for step in sorted_steps
        }

        # Build entrypoint command and args for the orchestrator pod.
        # This will internally also build the command/args for all step pods.
        command = (
            KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
        )
        args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
            run_name=run_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            kubernetes_namespace=self.kubernetes_namespace,
            pb2_pipeline=pb2_pipeline,
            sorted_steps=sorted_steps,
            pipeline_dag=pipeline_dag,
        )

        # Authorize pod to run Kubernetes commands inside the cluster.
        service_account_name = "zenml-service-account"
        kube_utils.create_edit_service_account(
            core_api=self._k8s_core_api,
            rbac_api=self._k8s_rbac_api,
            service_account_name=service_account_name,
            namespace=self.kubernetes_namespace,
        )

        # Schedule as CRON job if CRON schedule is given.
        if runtime_configuration.schedule:
            if not runtime_configuration.schedule.cron_expression:
                raise RuntimeError(
                    "The Kubernetes orchestrator only supports scheduling via "
                    "CRON jobs, but the run was configured with a manual "
                    "schedule. Use `Schedule(cron_expression=...)` instead."
                )
            cron_expression = runtime_configuration.schedule.cron_expression
            cron_job_manifest = build_cron_job_manifest(
                cron_expression=cron_expression,
                run_name=run_name,
                pod_name=pod_name,
                pipeline_name=pipeline_name,
                image_name=image_name,
                command=command,
                args=args,
                service_account_name=service_account_name,
            )
            self._k8s_batch_api.create_namespaced_cron_job(
                body=cron_job_manifest, namespace=self.kubernetes_namespace
            )
            logger.info(
                f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
                f'`"{cron_expression}"`.'
            )
            return

        # Create and run the orchestrator pod.
        pod_manifest = build_pod_manifest(
            run_name=run_name,
            pod_name=pod_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            command=command,
            args=args,
            service_account_name=service_account_name,
        )
        self._k8s_core_api.create_namespaced_pod(
            namespace=self.kubernetes_namespace,
            body=pod_manifest,
        )

        # Wait for the orchestrator pod to finish and stream logs.
        if self.synchronous:
            logger.info("Waiting for Kubernetes orchestrator pod...")
            kube_utils.wait_pod(
                core_api=self._k8s_core_api,
                pod_name=pod_name,
                namespace=self.kubernetes_namespace,
                exit_condition_lambda=kube_utils.pod_is_done,
                stream_logs=True,
            )
        else:
            logger.info(
                f"Orchestration started asynchronously in pod "
                f"`{self.kubernetes_namespace}:{pod_name}`. "
                f"Run the following command to inspect the logs: "
                f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
            )
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Defines the validator that checks whether the stack is valid.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

Stack validator.

__init__(self, *args, **kwargs) special

Initialize the Pydantic object the Kubernetes clients.

Parameters:

Name Type Description Default
*args Any

The positional arguments to pass to the Pydantic object.

()
**kwargs Any

The keyword arguments to pass to the Pydantic object.

{}
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the Pydantic object the Kubernetes clients.

    Args:
        *args: The positional arguments to pass to the Pydantic object.
        **kwargs: The keyword arguments to pass to the Pydantic object.
    """
    super().__init__(*args, **kwargs)
    self._initialize_k8s_clients()
get_docker_image_name(self, pipeline_name)

Return the full Docker image name including registry and tag.

Parameters:

Name Type Description Default
pipeline_name str

Name of a ZenML pipeline.

required

Returns:

Type Description
str

Docker image name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_docker_image_name(self, pipeline_name: str) -> str:
    """Return the full Docker image name including registry and tag.

    Args:
        pipeline_name: Name of a ZenML pipeline.

    Returns:
        Docker image name.
    """
    container_registry = Repository().active_stack.container_registry
    assert container_registry
    registry_uri = container_registry.uri.rstrip("/")
    return f"{registry_uri}/zenml-kubernetes:{pipeline_name}"
get_kubernetes_contexts(self)

Get list of configured Kubernetes contexts and the active context.

Exceptions:

Type Description
RuntimeError

if the Kubernetes configuration cannot be loaded.

Returns:

Type Description
context_name

List of configured Kubernetes contexts active_context_name: Name of the active Kubernetes context.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def get_kubernetes_contexts(self) -> Tuple[List[str], str]:
    """Get list of configured Kubernetes contexts and the active context.

    Raises:
        RuntimeError: if the Kubernetes configuration cannot be loaded.

    Returns:
        context_name: List of configured Kubernetes contexts
        active_context_name: Name of the active Kubernetes context.
    """
    try:
        contexts, active_context = k8s_config.list_kube_config_contexts()
    except k8s_config.config_exception.ConfigException as e:
        raise RuntimeError(
            "Could not load the Kubernetes configuration"
        ) from e

    context_names = [c["name"] for c in contexts]
    active_context_name = active_context["name"]
    return context_names, active_context_name
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Run pipeline in Kubernetes.

Parameters:

Name Type Description Default
sorted_steps List[BaseStep]

List of steps in execution order.

required
pipeline BasePipeline

ZenML pipeline.

required
pb2_pipeline Pipeline

ZenML pipeline in TFX pb2 format.

required
stack Stack

ZenML stack.

required
runtime_configuration RuntimeConfiguration

The runtime configuration of the pipeline.

required

Exceptions:

Type Description
RuntimeError

If trying to run from a Jupyter notebook.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List["BaseStep"],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Run pipeline in Kubernetes.

    Args:
        sorted_steps: List of steps in execution order.
        pipeline: ZenML pipeline.
        pb2_pipeline: ZenML pipeline in TFX pb2 format.
        stack: ZenML stack.
        runtime_configuration: The runtime configuration of the pipeline.

    Raises:
        RuntimeError: If trying to run from a Jupyter notebook.
    """
    # First check whether the code is running in a notebook.
    if Environment.in_notebook():
        raise RuntimeError(
            "The Kubernetes orchestrator cannot run pipelines in a notebook "
            "environment. The reason is that it is non-trivial to create "
            "a Docker image of a notebook. Please consider refactoring "
            "your notebook cells into separate scripts in a Python module "
            "and run the code outside of a notebook when using this "
            "orchestrator."
        )

    assert runtime_configuration.run_name, "Run name must be set"

    run_name = runtime_configuration.run_name
    pipeline_name = pipeline.name
    pod_name = kube_utils.sanitize_pod_name(run_name)

    # Get Docker image name (for all pods).
    image_name = self.get_docker_image_name(pipeline.name)
    image_name = get_image_digest(image_name) or image_name

    # Get pipeline DAG as dict {"step": ["upstream_step_1", ...], ...}
    pipeline_dag: Dict[str, List[str]] = {
        step.name: self.get_upstream_step_names(step, pb2_pipeline)
        for step in sorted_steps
    }

    # Build entrypoint command and args for the orchestrator pod.
    # This will internally also build the command/args for all step pods.
    command = (
        KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command()
    )
    args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments(
        run_name=run_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        kubernetes_namespace=self.kubernetes_namespace,
        pb2_pipeline=pb2_pipeline,
        sorted_steps=sorted_steps,
        pipeline_dag=pipeline_dag,
    )

    # Authorize pod to run Kubernetes commands inside the cluster.
    service_account_name = "zenml-service-account"
    kube_utils.create_edit_service_account(
        core_api=self._k8s_core_api,
        rbac_api=self._k8s_rbac_api,
        service_account_name=service_account_name,
        namespace=self.kubernetes_namespace,
    )

    # Schedule as CRON job if CRON schedule is given.
    if runtime_configuration.schedule:
        if not runtime_configuration.schedule.cron_expression:
            raise RuntimeError(
                "The Kubernetes orchestrator only supports scheduling via "
                "CRON jobs, but the run was configured with a manual "
                "schedule. Use `Schedule(cron_expression=...)` instead."
            )
        cron_expression = runtime_configuration.schedule.cron_expression
        cron_job_manifest = build_cron_job_manifest(
            cron_expression=cron_expression,
            run_name=run_name,
            pod_name=pod_name,
            pipeline_name=pipeline_name,
            image_name=image_name,
            command=command,
            args=args,
            service_account_name=service_account_name,
        )
        self._k8s_batch_api.create_namespaced_cron_job(
            body=cron_job_manifest, namespace=self.kubernetes_namespace
        )
        logger.info(
            f"Scheduling Kubernetes run `{pod_name}` with CRON expression "
            f'`"{cron_expression}"`.'
        )
        return

    # Create and run the orchestrator pod.
    pod_manifest = build_pod_manifest(
        run_name=run_name,
        pod_name=pod_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        command=command,
        args=args,
        service_account_name=service_account_name,
    )
    self._k8s_core_api.create_namespaced_pod(
        namespace=self.kubernetes_namespace,
        body=pod_manifest,
    )

    # Wait for the orchestrator pod to finish and stream logs.
    if self.synchronous:
        logger.info("Waiting for Kubernetes orchestrator pod...")
        kube_utils.wait_pod(
            core_api=self._k8s_core_api,
            pod_name=pod_name,
            namespace=self.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            stream_logs=True,
        )
    else:
        logger.info(
            f"Orchestration started asynchronously in pod "
            f"`{self.kubernetes_namespace}:{pod_name}`. "
            f"Run the following command to inspect the logs: "
            f"`kubectl logs {pod_name} -n {self.kubernetes_namespace}`."
        )
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Build a Docker image and upload it to the container registry.

Parameters:

Name Type Description Default
pipeline BasePipeline

A ZenML pipeline.

required
stack Stack

A ZenML stack.

required
runtime_configuration RuntimeConfiguration

The runtime configuration of the pipeline.

required
Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Build a Docker image and upload it to the container registry.

    Args:
        pipeline: A ZenML pipeline.
        stack: A ZenML stack.
        runtime_configuration: The runtime configuration of the pipeline.
    """
    from zenml.utils import docker_utils

    image_name = self.get_docker_image_name(pipeline.name)

    requirements = {*stack.requirements(), *pipeline.requirements}

    logger.debug("Kubernetes container requirements: %s", requirements)

    docker_utils.build_docker_image(
        build_context_path=get_source_root_path(),
        image_name=image_name,
        dockerignore_path=pipeline.dockerignore_file,
        requirements=requirements,
        base_image=self.custom_docker_base_image_name,
    )

    assert stack.container_registry  # should never happen due to validation
    stack.container_registry.push_image(image_name)

    # Store the Docker image digest in the runtime configuration so it gets
    # tracked in the ZenStore
    image_digest = docker_utils.get_image_digest(image_name) or image_name
    runtime_configuration["docker_image"] = image_digest
kubernetes_orchestrator_entrypoint

Entrypoint of the Kubernetes master/orchestrator pod.

main()

Entrypoint of the k8s master/orchestrator pod.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def main() -> None:
    """Entrypoint of the k8s master/orchestrator pod."""
    # Log to the container's stdout so it can be streamed by the client.
    logger.info("Kubernetes orchestrator pod started.")

    # Parse / extract args.
    args = parse_args()
    pipeline_config = args.pipeline_config
    step_command = pipeline_config["step_command"]
    fixed_step_args = pipeline_config["fixed_step_args"]
    step_specific_args = pipeline_config["step_specific_args"]
    pipeline_dag = pipeline_config["pipeline_dag"]

    # Get Kubernetes Core API for running kubectl commands later.
    kube_utils.load_kube_config()
    core_api = k8s_client.CoreV1Api()

    # Patch run name (only needed for CRON scheduling)
    run_name = patch_run_name_for_cron_scheduling(
        args.run_name, fixed_step_args
    )

    def run_step_on_kubernetes(step_name: str) -> None:
        """Run a pipeline step in a separate Kubernetes pod.

        Args:
            step_name: Name of the step.
        """
        # Define Kubernetes pod name.
        pod_name = f"{run_name}-{step_name}"
        pod_name = kube_utils.sanitize_pod_name(pod_name)

        # Build list of args for this step.
        step_args = [*fixed_step_args, *step_specific_args[step_name]]

        # Define Kubernetes pod manifest.
        pod_manifest = build_pod_manifest(
            pod_name=pod_name,
            run_name=run_name,
            pipeline_name=args.pipeline_name,
            image_name=args.image_name,
            command=step_command,
            args=step_args,
        )

        # Create and run pod.
        core_api.create_namespaced_pod(
            namespace=args.kubernetes_namespace,
            body=pod_manifest,
        )

        # Wait for pod to finish.
        logger.info(f"Waiting for pod of step `{step_name}` to start...")
        kube_utils.wait_pod(
            core_api=core_api,
            pod_name=pod_name,
            namespace=args.kubernetes_namespace,
            exit_condition_lambda=kube_utils.pod_is_done,
            stream_logs=True,
        )
        logger.info(f"Pod of step `{step_name}` completed.")

    ThreadedDagRunner(dag=pipeline_dag, run_fn=run_step_on_kubernetes).run()

    logger.info("Orchestration pod completed.")
parse_args()

Parse entrypoint arguments.

Returns:

Type Description
Namespace

Parsed args.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def parse_args() -> argparse.Namespace:
    """Parse entrypoint arguments.

    Returns:
        Parsed args.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--pipeline_name", type=str, required=True)
    parser.add_argument("--image_name", type=str, required=True)
    parser.add_argument("--kubernetes_namespace", type=str, required=True)
    parser.add_argument("--pipeline_config", type=json.loads, required=True)
    return parser.parse_args()
patch_run_name_for_cron_scheduling(run_name, fixed_step_args)

Adjust run name according to the Kubernetes orchestrator pod name.

This is required for scheduling via CRON jobs, since each job would otherwise have the same run name, which zenml does not support.

Parameters:

Name Type Description Default
run_name str

Initial run name.

required
fixed_step_args List[str]

Fixed entrypoint args for the step pods. We also need to patch the run name in there.

required

Returns:

Type Description
str

New unique run name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py
def patch_run_name_for_cron_scheduling(
    run_name: str, fixed_step_args: List[str]
) -> str:
    """Adjust run name according to the Kubernetes orchestrator pod name.

    This is required for scheduling via CRON jobs, since each job would
    otherwise have the same run name, which zenml does not support.

    Args:
        run_name: Initial run name.
        fixed_step_args: Fixed entrypoint args for the step pods.
            We also need to patch the run name in there.

    Returns:
        New unique run name.
    """
    # Get name of the orchestrator pod.
    host_name = socket.gethostname()

    # If we are not running as CRON job, we don't need to do anything.
    if host_name == kube_utils.sanitize_pod_name(run_name):
        return run_name

    # Otherwise, define new run_name.
    job_id = host_name.split("-")[-1]
    run_name = f"{run_name}-{job_id}"

    # Then also adjust run_name in fixed_step_args.
    for i, arg in enumerate(fixed_step_args):
        if arg == "--run_name":
            fixed_step_args[i + 1] = run_name

    return run_name
kubernetes_orchestrator_entrypoint_configuration

Entrypoint configuration for the Kubernetes master/orchestrator pod.

KubernetesOrchestratorEntrypointConfiguration

Entrypoint configuration for the k8s master/orchestrator pod.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
class KubernetesOrchestratorEntrypointConfiguration:
    """Entrypoint configuration for the k8s master/orchestrator pod."""

    @classmethod
    def get_entrypoint_options(cls) -> Set[str]:
        """Gets all the options required for running this entrypoint.

        Returns:
            Entrypoint options.
        """
        options = {
            RUN_NAME_OPTION,
            PIPELINE_NAME_OPTION,
            IMAGE_NAME_OPTION,
            NAMESPACE_OPTION,
            PIPELINE_CONFIG_OPTION,
        }
        return options

    @classmethod
    def get_entrypoint_command(cls) -> List[str]:
        """Returns a command that runs the entrypoint module.

        Returns:
            Entrypoint command.
        """
        command = [
            "python",
            "-m",
            "zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
        ]
        return command

    @classmethod
    def get_entrypoint_arguments(
        cls,
        run_name: str,
        pipeline_name: str,
        image_name: str,
        kubernetes_namespace: str,
        pb2_pipeline: Pb2Pipeline,
        sorted_steps: List[BaseStep],
        pipeline_dag: Dict[str, List[str]],
    ) -> List[str]:
        """Gets all arguments that the entrypoint command should be called with.

        Args:
            run_name: Name of the ZenML run.
            pipeline_name: Name of the ZenML pipeline.
            image_name: Name of the Docker image.
            kubernetes_namespace: Name of the Kubernetes namespace.
            pb2_pipeline: ZenML pipeline in TFX pb2 format.
            sorted_steps: List of steps in execution order.
            pipeline_dag: For each step, list of steps that need to run before.

        Returns:
            List of entrypoint arguments.
        """

        def _get_step_args(step: BaseStep) -> List[str]:
            """Get the entrypoint args for a specific step.

            Args:
                step: ZenML step for which to get entrypoint args.

            Returns:
                Entrypoint args of the step.
            """
            return (
                KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
                    step=step,
                    pb2_pipeline=pb2_pipeline,
                    **{RUN_NAME_OPTION: run_name},
                )
            )

        # Get name, command, and args for each step
        step_names = [step.name for step in sorted_steps]
        step_command = (
            KubernetesStepEntrypointConfiguration.get_entrypoint_command()
        )
        fixed_step_args = []
        if len(sorted_steps) > 0:
            first_step_args = _get_step_args(sorted_steps[0])
            fixed_step_args = split_step_args(first_step_args)[0]
        step_specific_args = {
            step.name: split_step_args(_get_step_args(step))[1]
            for step in sorted_steps
        }  # e.g.: {"trainer": train_step_args, ...}

        # Serialize all complex datatype args into a single JSON string
        pipeline_config = {
            "sorted_steps": step_names,
            "step_command": step_command,
            "fixed_step_args": fixed_step_args,
            "step_specific_args": step_specific_args,
            "pipeline_dag": pipeline_dag,
        }
        pipeline_config_json = json.dumps(pipeline_config)

        # Define entrypoint args.
        args = [
            f"--{RUN_NAME_OPTION}",
            run_name,
            f"--{PIPELINE_NAME_OPTION}",
            pipeline_name,
            f"--{IMAGE_NAME_OPTION}",
            image_name,
            f"--{NAMESPACE_OPTION}",
            kubernetes_namespace,
            f"--{PIPELINE_CONFIG_OPTION}",
            pipeline_config_json,
        ]

        return args
get_entrypoint_arguments(run_name, pipeline_name, image_name, kubernetes_namespace, pb2_pipeline, sorted_steps, pipeline_dag) classmethod

Gets all arguments that the entrypoint command should be called with.

Parameters:

Name Type Description Default
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
kubernetes_namespace str

Name of the Kubernetes namespace.

required
pb2_pipeline Pipeline

ZenML pipeline in TFX pb2 format.

required
sorted_steps List[zenml.steps.base_step.BaseStep]

List of steps in execution order.

required
pipeline_dag Dict[str, List[str]]

For each step, list of steps that need to run before.

required

Returns:

Type Description
List[str]

List of entrypoint arguments.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
    cls,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    kubernetes_namespace: str,
    pb2_pipeline: Pb2Pipeline,
    sorted_steps: List[BaseStep],
    pipeline_dag: Dict[str, List[str]],
) -> List[str]:
    """Gets all arguments that the entrypoint command should be called with.

    Args:
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        kubernetes_namespace: Name of the Kubernetes namespace.
        pb2_pipeline: ZenML pipeline in TFX pb2 format.
        sorted_steps: List of steps in execution order.
        pipeline_dag: For each step, list of steps that need to run before.

    Returns:
        List of entrypoint arguments.
    """

    def _get_step_args(step: BaseStep) -> List[str]:
        """Get the entrypoint args for a specific step.

        Args:
            step: ZenML step for which to get entrypoint args.

        Returns:
            Entrypoint args of the step.
        """
        return (
            KubernetesStepEntrypointConfiguration.get_entrypoint_arguments(
                step=step,
                pb2_pipeline=pb2_pipeline,
                **{RUN_NAME_OPTION: run_name},
            )
        )

    # Get name, command, and args for each step
    step_names = [step.name for step in sorted_steps]
    step_command = (
        KubernetesStepEntrypointConfiguration.get_entrypoint_command()
    )
    fixed_step_args = []
    if len(sorted_steps) > 0:
        first_step_args = _get_step_args(sorted_steps[0])
        fixed_step_args = split_step_args(first_step_args)[0]
    step_specific_args = {
        step.name: split_step_args(_get_step_args(step))[1]
        for step in sorted_steps
    }  # e.g.: {"trainer": train_step_args, ...}

    # Serialize all complex datatype args into a single JSON string
    pipeline_config = {
        "sorted_steps": step_names,
        "step_command": step_command,
        "fixed_step_args": fixed_step_args,
        "step_specific_args": step_specific_args,
        "pipeline_dag": pipeline_dag,
    }
    pipeline_config_json = json.dumps(pipeline_config)

    # Define entrypoint args.
    args = [
        f"--{RUN_NAME_OPTION}",
        run_name,
        f"--{PIPELINE_NAME_OPTION}",
        pipeline_name,
        f"--{IMAGE_NAME_OPTION}",
        image_name,
        f"--{NAMESPACE_OPTION}",
        kubernetes_namespace,
        f"--{PIPELINE_CONFIG_OPTION}",
        pipeline_config_json,
    ]

    return args
get_entrypoint_command() classmethod

Returns a command that runs the entrypoint module.

Returns:

Type Description
List[str]

Entrypoint command.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_command(cls) -> List[str]:
    """Returns a command that runs the entrypoint module.

    Returns:
        Entrypoint command.
    """
    command = [
        "python",
        "-m",
        "zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator_entrypoint",
    ]
    return command
get_entrypoint_options() classmethod

Gets all the options required for running this entrypoint.

Returns:

Type Description
Set[str]

Entrypoint options.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
    """Gets all the options required for running this entrypoint.

    Returns:
        Entrypoint options.
    """
    options = {
        RUN_NAME_OPTION,
        PIPELINE_NAME_OPTION,
        IMAGE_NAME_OPTION,
        NAMESPACE_OPTION,
        PIPELINE_CONFIG_OPTION,
    }
    return options
split_step_args(step_args)

Split step args into fixed and step-specific.

We want to have them separate so we can send the fixed args to the orchestrator pod only once.

Parameters:

Name Type Description Default
step_args List[str]

list of ALL step args. E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...].

required

Returns:

Type Description
Tuple[List[str], List[str]]

Tuple (fixed step args, step-specific args).

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py
def split_step_args(step_args: List[str]) -> Tuple[List[str], List[str]]:
    """Split step args into fixed and step-specific.

    We want to have them separate so we can send the fixed args to the
    orchestrator pod only once.

    Args:
        step_args: list of ALL step args.
            E.g. ["--arg1", "arg1_value", "--arg2", "arg2_value", ...].

    Returns:
        Tuple (fixed step args, step-specific args).
    """
    fixed_args = []
    step_specific_args = []
    for i, arg in enumerate(step_args):
        if not arg.startswith("--"):  # arg is a value, not an option
            continue
        option_and_value = step_args[i : i + 2]  # e.g. ["--name", "Aria"]
        is_fixed = arg[2:] not in STEP_SPECIFIC_STEP_ENTRYPOINT_OPTIONS
        if is_fixed:
            fixed_args += option_and_value
        else:
            step_specific_args += option_and_value
    return fixed_args, step_specific_args
kubernetes_step_entrypoint_configuration

Entrypoint configuration for the Kubernetes worker/step pods.

KubernetesStepEntrypointConfiguration (StepEntrypointConfiguration)

Entrypoint configuration for running steps on Kubernetes.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
class KubernetesStepEntrypointConfiguration(StepEntrypointConfiguration):
    """Entrypoint configuration for running steps on Kubernetes."""

    @classmethod
    def get_custom_entrypoint_options(cls) -> Set[str]:
        """Kubernetes specific entrypoint options.

        The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
        consistent values between steps.

        Returns:
            Set of entrypoint options.
        """
        return {RUN_NAME_OPTION}

    @classmethod
    def get_custom_entrypoint_arguments(
        cls, step: "BaseStep", *args: Any, **kwargs: Any
    ) -> List[str]:
        """Kubernetes specific entrypoint arguments.

        Sets the value for the `RUN_NAME_OPTION` argument.

        Args:
            step: ZenML step for which the entrypoint is built.
            args: additional (unused) arguments.
            kwargs: keyword args; needs to include `RUN_NAME_OPTION`.

        Returns:
            List of entrypoint arguments.
        """
        return [
            f"--{RUN_NAME_OPTION}",
            kwargs[RUN_NAME_OPTION],
        ]

    def get_run_name(self, pipeline_name: str) -> str:
        """Returns the ZenML run name.

        Args:
            pipeline_name: Name of the ZenML pipeline (unused).

        Returns:
            ZenML run name.
        """
        job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
        return job_id
get_custom_entrypoint_arguments(step, *args, **kwargs) classmethod

Kubernetes specific entrypoint arguments.

Sets the value for the RUN_NAME_OPTION argument.

Parameters:

Name Type Description Default
step BaseStep

ZenML step for which the entrypoint is built.

required
args Any

additional (unused) arguments.

()
kwargs Any

keyword args; needs to include RUN_NAME_OPTION.

{}

Returns:

Type Description
List[str]

List of entrypoint arguments.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_arguments(
    cls, step: "BaseStep", *args: Any, **kwargs: Any
) -> List[str]:
    """Kubernetes specific entrypoint arguments.

    Sets the value for the `RUN_NAME_OPTION` argument.

    Args:
        step: ZenML step for which the entrypoint is built.
        args: additional (unused) arguments.
        kwargs: keyword args; needs to include `RUN_NAME_OPTION`.

    Returns:
        List of entrypoint arguments.
    """
    return [
        f"--{RUN_NAME_OPTION}",
        kwargs[RUN_NAME_OPTION],
    ]
get_custom_entrypoint_options() classmethod

Kubernetes specific entrypoint options.

The argument RUN_NAME_OPTION is needed for get_run_name to have consistent values between steps.

Returns:

Type Description
Set[str]

Set of entrypoint options.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
@classmethod
def get_custom_entrypoint_options(cls) -> Set[str]:
    """Kubernetes specific entrypoint options.

    The argument `RUN_NAME_OPTION` is needed for `get_run_name` to have
    consistent values between steps.

    Returns:
        Set of entrypoint options.
    """
    return {RUN_NAME_OPTION}
get_run_name(self, pipeline_name)

Returns the ZenML run name.

Parameters:

Name Type Description Default
pipeline_name str

Name of the ZenML pipeline (unused).

required

Returns:

Type Description
str

ZenML run name.

Source code in zenml/integrations/kubernetes/orchestrators/kubernetes_step_entrypoint_configuration.py
def get_run_name(self, pipeline_name: str) -> str:
    """Returns the ZenML run name.

    Args:
        pipeline_name: Name of the ZenML pipeline (unused).

    Returns:
        ZenML run name.
    """
    job_id: str = self.entrypoint_args[RUN_NAME_OPTION]
    return job_id
manifest_utils

Utility functions for building manifests for k8s pods.

build_cluster_role_binding_manifest_for_service_account(name, role_name, service_account_name, namespace='default')

Build a manifest for a cluster role binding of a service account.

Parameters:

Name Type Description Default
name str

Name of the cluster role binding.

required
role_name str

Name of the role.

required
service_account_name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'

Returns:

Type Description
Dict[str, Any]

Manifest for a cluster role binding of a service account.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cluster_role_binding_manifest_for_service_account(
    name: str,
    role_name: str,
    service_account_name: str,
    namespace: str = "default",
) -> Dict[str, Any]:
    """Build a manifest for a cluster role binding of a service account.

    Args:
        name: Name of the cluster role binding.
        role_name: Name of the role.
        service_account_name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".

    Returns:
        Manifest for a cluster role binding of a service account.
    """
    return {
        "apiVersion": "rbac.authorization.k8s.io/v1",
        "kind": "ClusterRoleBinding",
        "metadata": {"name": name},
        "subjects": [
            {
                "kind": "ServiceAccount",
                "name": service_account_name,
                "namespace": namespace,
            }
        ],
        "roleRef": {
            "kind": "ClusterRole",
            "name": role_name,
            "apiGroup": "rbac.authorization.k8s.io",
        },
    }
build_cron_job_manifest(cron_expression, pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)

Create a manifest for launching a pod as scheduled CRON job.

Parameters:

Name Type Description Default
cron_expression str

CRON job schedule expression, e.g. " * * *".

required
pod_name str

Name of the pod.

required
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
command List[str]

Command to execute the entrypoint in the pod.

required
args List[str]

Arguments provided to the entrypoint command.

required
service_account_name Optional[str]

Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster.

None

Returns:

Type Description
Dict[str, Any]

CRON job manifest.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_cron_job_manifest(
    cron_expression: str,
    pod_name: str,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    command: List[str],
    args: List[str],
    service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
    """Create a manifest for launching a pod as scheduled CRON job.

    Args:
        cron_expression: CRON job schedule expression, e.g. "* * * * *".
        pod_name: Name of the pod.
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        command: Command to execute the entrypoint in the pod.
        args: Arguments provided to the entrypoint command.
        service_account_name: Optional name of a service account.
            Can be used to assign certain roles to a pod, e.g., to allow it to
            run Kubernetes commands from within the cluster.

    Returns:
        CRON job manifest.
    """
    pod_manifest = build_pod_manifest(
        pod_name=pod_name,
        run_name=run_name,
        pipeline_name=pipeline_name,
        image_name=image_name,
        command=command,
        args=args,
        service_account_name=service_account_name,
    )
    return {
        "apiVersion": "batch/v1beta1",
        "kind": "CronJob",
        "metadata": pod_manifest["metadata"],
        "spec": {
            "schedule": cron_expression,
            "jobTemplate": {
                "metadata": pod_manifest["metadata"],
                "spec": {"template": {"spec": pod_manifest["spec"]}},
            },
        },
    }
build_mysql_deployment_manifest(name='mysql', namespace='default', port=3306, pv_claim_name='mysql-pv-claim')

Build a manifest for deploying a MySQL database.

Parameters:

Name Type Description Default
name str

Name of the deployment. Defaults to "mysql".

'mysql'
namespace str

Kubernetes namespace. Defaults to "default".

'default'
port int

Port where MySQL is running. Defaults to 3306.

3306
pv_claim_name str

Name of the required persistent volume claim. Defaults to "mysql-pv-claim".

'mysql-pv-claim'

Returns:

Type Description
Dict[str, Any]

Manifest for deploying a MySQL database.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_deployment_manifest(
    name: str = "mysql",
    namespace: str = "default",
    port: int = 3306,
    pv_claim_name: str = "mysql-pv-claim",
) -> Dict[str, Any]:
    """Build a manifest for deploying a MySQL database.

    Args:
        name: Name of the deployment. Defaults to "mysql".
        namespace: Kubernetes namespace. Defaults to "default".
        port: Port where MySQL is running. Defaults to 3306.
        pv_claim_name: Name of the required persistent volume claim.
            Defaults to `"mysql-pv-claim"`.

    Returns:
        Manifest for deploying a MySQL database.
    """
    return {
        "apiVersion": "apps/v1",
        "kind": "Deployment",
        "metadata": {"name": name, "namespace": namespace},
        "spec": {
            "selector": {
                "matchLabels": {
                    "app": name,
                },
            },
            "strategy": {
                "type": "Recreate",
            },
            "template": {
                "metadata": {
                    "labels": {"app": name},
                },
                "spec": {
                    "containers": [
                        {
                            "image": "gcr.io/ml-pipeline/mysql:5.6",
                            "name": name,
                            "env": [
                                {
                                    "name": "MYSQL_ALLOW_EMPTY_PASSWORD",
                                    "value": '"true"',
                                }
                            ],
                            "ports": [{"containerPort": port, "name": name}],
                            "volumeMounts": [
                                {
                                    "name": "mysql-persistent-storage",
                                    "mountPath": "/var/lib/mysql",
                                }
                            ],
                        }
                    ],
                    "volumes": [
                        {
                            "name": "mysql-persistent-storage",
                            "persistentVolumeClaim": {
                                "claimName": pv_claim_name
                            },
                        }
                    ],
                },
            },
        },
    }
build_mysql_service_manifest(name='mysql', namespace='default', port=3306)

Build a manifest for a service relating to a deployed MySQL database.

Parameters:

Name Type Description Default
name str

Name of the service. Defaults to "mysql".

'mysql'
namespace str

Kubernetes namespace. Defaults to "default".

'default'
port int

Port where MySQL is running. Defaults to 3306.

3306

Returns:

Type Description
Dict[str, Any]

Manifest for the MySQL service.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_mysql_service_manifest(
    name: str = "mysql",
    namespace: str = "default",
    port: int = 3306,
) -> Dict[str, Any]:
    """Build a manifest for a service relating to a deployed MySQL database.

    Args:
        name: Name of the service. Defaults to "mysql".
        namespace: Kubernetes namespace. Defaults to "default".
        port: Port where MySQL is running. Defaults to 3306.

    Returns:
        Manifest for the MySQL service.
    """
    return {
        "apiVersion": "v1",
        "kind": "Service",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
        "spec": {
            "selector": {"app": "mysql"},
            "clusterIP": "None",
            "ports": [{"port": port}],
        },
    }
build_namespace_manifest(namespace)

Build the manifest for a new namespace.

Parameters:

Name Type Description Default
namespace str

Kubernetes namespace.

required

Returns:

Type Description
Dict[str, Any]

Manifest of the new namespace.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_namespace_manifest(namespace: str) -> Dict[str, Any]:
    """Build the manifest for a new namespace.

    Args:
        namespace: Kubernetes namespace.

    Returns:
        Manifest of the new namespace.
    """
    return {
        "apiVersion": "v1",
        "kind": "Namespace",
        "metadata": {
            "name": namespace,
        },
    }
build_persistent_volume_claim_manifest(name, namespace='default', storage_request='10Gi')

Build a manifest for a persistent volume claim.

Parameters:

Name Type Description Default
name str

Name of the persistent volume claim.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'
storage_request str

Size of the storage to request. Defaults to "10Gi".

'10Gi'

Returns:

Type Description
Dict[str, Any]

Manifest for a persistent volume claim.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_claim_manifest(
    name: str,
    namespace: str = "default",
    storage_request: str = "10Gi",
) -> Dict[str, Any]:
    """Build a manifest for a persistent volume claim.

    Args:
        name: Name of the persistent volume claim.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_request: Size of the storage to request. Defaults to `"10Gi"`.

    Returns:
        Manifest for a persistent volume claim.
    """
    return {
        "apiVersion": "v1",
        "kind": "PersistentVolumeClaim",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
        "spec": {
            "storageClassName": "manual",
            "accessModes": ["ReadWriteOnce"],
            "resources": {
                "requests": {
                    "storage": storage_request,
                }
            },
        },
    }
build_persistent_volume_manifest(name, namespace='default', storage_capacity='10Gi', path='/mnt/data')

Build a manifest for a persistent volume.

Parameters:

Name Type Description Default
name str

Name of the persistent volume.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'
storage_capacity str

Storage capacity of the volume. Defaults to "10Gi".

'10Gi'
path str

Path where the volume is mounted. Defaults to "/mnt/data".

'/mnt/data'

Returns:

Type Description
Dict[str, Any]

Manifest for a persistent volume.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_persistent_volume_manifest(
    name: str,
    namespace: str = "default",
    storage_capacity: str = "10Gi",
    path: str = "/mnt/data",
) -> Dict[str, Any]:
    """Build a manifest for a persistent volume.

    Args:
        name: Name of the persistent volume.
        namespace: Kubernetes namespace. Defaults to "default".
        storage_capacity: Storage capacity of the volume. Defaults to `"10Gi"`.
        path: Path where the volume is mounted. Defaults to `"/mnt/data"`.

    Returns:
        Manifest for a persistent volume.
    """
    return {
        "apiVersion": "v1",
        "kind": "PersistentVolume",
        "metadata": {
            "name": name,
            "namespace": namespace,
            "labels": {"type": "local"},
        },
        "spec": {
            "storageClassName": "manual",
            "capacity": {"storage": storage_capacity},
            "accessModes": ["ReadWriteOnce"],
            "hostPath": {"path": path},
        },
    }
build_pod_manifest(pod_name, run_name, pipeline_name, image_name, command, args, service_account_name=None)

Build a Kubernetes pod manifest for a ZenML run or step.

Parameters:

Name Type Description Default
pod_name str

Name of the pod.

required
run_name str

Name of the ZenML run.

required
pipeline_name str

Name of the ZenML pipeline.

required
image_name str

Name of the Docker image.

required
command List[str]

Command to execute the entrypoint in the pod.

required
args List[str]

Arguments provided to the entrypoint command.

required
service_account_name Optional[str]

Optional name of a service account. Can be used to assign certain roles to a pod, e.g., to allow it to run Kubernetes commands from within the cluster.

None

Returns:

Type Description
Dict[str, Any]

Pod manifest.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_pod_manifest(
    pod_name: str,
    run_name: str,
    pipeline_name: str,
    image_name: str,
    command: List[str],
    args: List[str],
    service_account_name: Optional[str] = None,
) -> Dict[str, Any]:
    """Build a Kubernetes pod manifest for a ZenML run or step.

    Args:
        pod_name: Name of the pod.
        run_name: Name of the ZenML run.
        pipeline_name: Name of the ZenML pipeline.
        image_name: Name of the Docker image.
        command: Command to execute the entrypoint in the pod.
        args: Arguments provided to the entrypoint command.
        service_account_name: Optional name of a service account.
            Can be used to assign certain roles to a pod, e.g., to allow it to
            run Kubernetes commands from within the cluster.

    Returns:
        Pod manifest.
    """
    manifest = {
        "apiVersion": "v1",
        "kind": "Pod",
        "metadata": {
            "name": pod_name,
            "labels": {
                "run": run_name,
                "pipeline": pipeline_name,
            },
        },
        "spec": {
            "restartPolicy": "Never",
            "containers": [
                {
                    "name": "main",
                    "image": image_name,
                    "command": command,
                    "args": args,
                    "env": [
                        {
                            "name": ENV_ZENML_ENABLE_REPO_INIT_WARNINGS,
                            "value": "False",
                        }
                    ],
                }
            ],
        },
    }
    if service_account_name is not None:
        spec = cast(Dict[str, Any], manifest["spec"])  # mypy stupid
        spec["serviceAccountName"] = service_account_name
    return manifest
build_service_account_manifest(name, namespace='default')

Build the manifest for a service account.

Parameters:

Name Type Description Default
name str

Name of the service account.

required
namespace str

Kubernetes namespace. Defaults to "default".

'default'

Returns:

Type Description
Dict[str, Any]

Manifest for a service account.

Source code in zenml/integrations/kubernetes/orchestrators/manifest_utils.py
def build_service_account_manifest(
    name: str, namespace: str = "default"
) -> Dict[str, Any]:
    """Build the manifest for a service account.

    Args:
        name: Name of the service account.
        namespace: Kubernetes namespace. Defaults to "default".

    Returns:
        Manifest for a service account.
    """
    return {
        "apiVersion": "v1",
        "metadata": {
            "name": name,
            "namespace": namespace,
        },
    }

lightgbm special

Initialization of the LightGBM integration.

LightGBMIntegration (Integration)

Definition of lightgbm integration for ZenML.

Source code in zenml/integrations/lightgbm/__init__.py
class LightGBMIntegration(Integration):
    """Definition of lightgbm integration for ZenML."""

    NAME = LIGHTGBM
    REQUIREMENTS = ["lightgbm>=1.0.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.lightgbm import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/lightgbm/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.lightgbm import materializers  # noqa

materializers special

Initialization of the Neural Prophet materializer.

lightgbm_booster_materializer

Implementation of the LightGBM booster materializer.

LightGBMBoosterMaterializer (BaseMaterializer)

Materializer to read data to and from lightgbm.Booster.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
class LightGBMBoosterMaterializer(BaseMaterializer):
    """Materializer to read data to and from lightgbm.Booster."""

    ASSOCIATED_TYPES = (lgb.Booster,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
        """Reads a lightgbm Booster model from a serialized JSON file.

        Args:
            data_type: A lightgbm Booster type.

        Returns:
            A lightgbm Booster object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        booster = lgb.Booster(model_file=temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return booster

    def handle_return(self, booster: lgb.Booster) -> None:
        """Creates a JSON serialization for a lightgbm Booster model.

        Args:
            booster: A lightgbm Booster model.
        """
        super().handle_return(booster)

        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".txt", delete=False
        ) as f:
            booster.save_model(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a lightgbm Booster model from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

A lightgbm Booster type.

required

Returns:

Type Description
Booster

A lightgbm Booster object.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Booster:
    """Reads a lightgbm Booster model from a serialized JSON file.

    Args:
        data_type: A lightgbm Booster type.

    Returns:
        A lightgbm Booster object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    booster = lgb.Booster(model_file=temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return booster
handle_return(self, booster)

Creates a JSON serialization for a lightgbm Booster model.

Parameters:

Name Type Description Default
booster Booster

A lightgbm Booster model.

required
Source code in zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py
def handle_return(self, booster: lgb.Booster) -> None:
    """Creates a JSON serialization for a lightgbm Booster model.

    Args:
        booster: A lightgbm Booster model.
    """
    super().handle_return(booster)

    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".txt", delete=False
    ) as f:
        booster.save_model(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)
lightgbm_dataset_materializer

Implementation of the LightGBM materializer.

LightGBMDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from lightgbm.Dataset.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
class LightGBMDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from lightgbm.Dataset."""

    ASSOCIATED_TYPES = (lgb.Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
        """Reads a lightgbm.Dataset binary file and loads it.

        Args:
            data_type: A lightgbm.Dataset type.

        Returns:
            A lightgbm.Dataset object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        matrix = lgb.Dataset(temp_file, free_raw_data=False)

        # No clean up this time because matrix is lazy loaded
        return matrix

    def handle_return(self, matrix: lgb.Dataset) -> None:
        """Creates a binary serialization for a lightgbm.Dataset object.

        Args:
            matrix: A lightgbm.Dataset object.
        """
        super().handle_return(matrix)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
        matrix.save_binary(temp_file)

        # Copy it into artifact store
        fileio.copy(temp_file, filepath)
        fileio.rmtree(temp_dir)
handle_input(self, data_type)

Reads a lightgbm.Dataset binary file and loads it.

Parameters:

Name Type Description Default
data_type Type[Any]

A lightgbm.Dataset type.

required

Returns:

Type Description
Dataset

A lightgbm.Dataset object.

Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> lgb.Dataset:
    """Reads a lightgbm.Dataset binary file and loads it.

    Args:
        data_type: A lightgbm.Dataset type.

    Returns:
        A lightgbm.Dataset object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    matrix = lgb.Dataset(temp_file, free_raw_data=False)

    # No clean up this time because matrix is lazy loaded
    return matrix
handle_return(self, matrix)

Creates a binary serialization for a lightgbm.Dataset object.

Parameters:

Name Type Description Default
matrix Dataset

A lightgbm.Dataset object.

required
Source code in zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py
def handle_return(self, matrix: lgb.Dataset) -> None:
    """Creates a binary serialization for a lightgbm.Dataset object.

    Args:
        matrix: A lightgbm.Dataset object.
    """
    super().handle_return(matrix)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
    matrix.save_binary(temp_file)

    # Copy it into artifact store
    fileio.copy(temp_file, filepath)
    fileio.rmtree(temp_dir)

mlflow special

Initialization for the ZenML MLflow integration.

The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.

MlflowIntegration (Integration)

Definition of MLflow integration for ZenML.

Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
    """Definition of MLflow integration for ZenML."""

    NAME = MLFLOW
    REQUIREMENTS = [
        "mlflow>=1.2.0,<1.26.0",
        "mlserver>=0.5.3",
        "mlserver-mlflow>=0.5.3",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the MLflow integration."""
        from zenml.integrations.mlflow import services  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the MLflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
                source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
                type=StackComponentType.MODEL_DEPLOYER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
                source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
                type=StackComponentType.EXPERIMENT_TRACKER,
                integration=cls.NAME,
            ),
        ]
activate() classmethod

Activate the MLflow integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the MLflow integration."""
    from zenml.integrations.mlflow import services  # noqa
flavors() classmethod

Declare the stack component flavors for the MLflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the MLflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=MLFLOW_MODEL_DEPLOYER_FLAVOR,
            source="zenml.integrations.mlflow.model_deployers.MLFlowModelDeployer",
            type=StackComponentType.MODEL_DEPLOYER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR,
            source="zenml.integrations.mlflow.experiment_trackers.MLFlowExperimentTracker",
            type=StackComponentType.EXPERIMENT_TRACKER,
            integration=cls.NAME,
        ),
    ]

experiment_trackers special

Initialization of the MLflow experiment tracker.

mlflow_experiment_tracker

Implementation of the MLflow experiment tracker for ZenML.

MLFlowExperimentTracker (BaseExperimentTracker) pydantic-model

Stores Mlflow configuration options.

ZenML should take care of configuring MLflow for you, but should you still need access to the configuration inside your step you can do it using a step context:

from zenml.steps import StepContext

@enable_mlflow
@step
def my_step(context: StepContext, ...)
    context.stack.experiment_tracker  # get the tracking_uri etc. from here

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
    """Stores Mlflow configuration options.

    ZenML should take care of configuring MLflow for you, but should you still
    need access to the configuration inside your step you can do it using a
    step context:
    ```python
    from zenml.steps import StepContext

    @enable_mlflow
    @step
    def my_step(context: StepContext, ...)
        context.stack.experiment_tracker  # get the tracking_uri etc. from here
    ```

    Attributes:
        tracking_uri: The uri of the mlflow tracking server. If no uri is set,
            your stack must contain a `LocalArtifactStore` and ZenML will
            point MLflow to a subdirectory of your artifact store instead.
        tracking_username: Username for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_password: Password for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_token: Token for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_insecure_tls: Skips verification of TLS connection to the
            MLflow tracking server if set to `True`.
    """

    tracking_uri: Optional[str] = None
    tracking_username: Optional[str] = None
    tracking_password: Optional[str] = None
    tracking_token: Optional[str] = None
    tracking_insecure_tls: bool = False

    # Class Configuration
    FLAVOR: ClassVar[str] = MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR

    @validator("tracking_uri")
    def _ensure_valid_tracking_uri(
        cls, tracking_uri: Optional[str] = None
    ) -> Optional[str]:
        """Ensures that the tracking uri is a valid mlflow tracking uri.

        Args:
            tracking_uri: The tracking uri to validate.

        Returns:
            The tracking uri if it is valid.

        Raises:
            ValueError: If the tracking uri is not valid.
        """
        if tracking_uri:
            valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
            if not any(
                tracking_uri.startswith(scheme) for scheme in valid_schemes
            ):
                raise ValueError(
                    f"MLflow tracking uri does not start with one of the valid "
                    f"schemes {valid_schemes}. See "
                    f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
                    f"for more information."
                )
        return tracking_uri

    @root_validator(skip_on_failure=True)
    def _ensure_authentication_if_necessary(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Ensures that credentials or a token for authentication exist.

        We make this check when running MLflow tracking with a remote backend.

        Args:
            values: The values to validate.

        Returns:
            The validated values.

        Raises:
            ValueError: If neither credentials nor a token are provided.
        """
        tracking_uri = values.get("tracking_uri")

        if tracking_uri and cls.is_remote_tracking_uri(tracking_uri):
            # we need either username + password or a token to authenticate to
            # the remote backend
            basic_auth = values.get("tracking_username") and values.get(
                "tracking_password"
            )
            token_auth = values.get("tracking_token")

            if not (basic_auth or token_auth):
                raise ValueError(
                    f"MLflow experiment tracking with a remote backend "
                    f"{tracking_uri} is only possible when specifying either "
                    f"username and password or an authentication token in your "
                    f"stack component. To update your component, run the "
                    f"following command: `zenml experiment-tracker update "
                    f"{values['name']} --tracking_username=MY_USERNAME "
                    f"--tracking_password=MY_PASSWORD "
                    f"--tracking_token=MY_TOKEN` and specify either your "
                    f"username and password or token."
                )

        return values

    @staticmethod
    def is_remote_tracking_uri(tracking_uri: str) -> bool:
        """Checks whether the given tracking uri is remote or not.

        Args:
            tracking_uri: The tracking uri to check.

        Returns:
            `True` if the tracking uri is remote, `False` otherwise.
        """
        return any(
            tracking_uri.startswith(prefix)
            for prefix in ["http://", "https://"]
        )

    @staticmethod
    def _local_mlflow_backend() -> str:
        """Gets the local MLflow backend inside the ZenML artifact repository directory.

        Returns:
            The MLflow tracking URI for the local MLflow backend.
        """
        repo = Repository(skip_repository_check=True)  # type: ignore[call-arg]
        artifact_store = repo.active_stack.artifact_store
        local_mlflow_backend_uri = os.path.join(artifact_store.path, "mlruns")
        if not os.path.exists(local_mlflow_backend_uri):
            os.makedirs(local_mlflow_backend_uri)
        return "file:" + local_mlflow_backend_uri

    def get_tracking_uri(self) -> str:
        """Returns the configured tracking URI or a local fallback.

        Returns:
            The tracking URI.
        """
        return self.tracking_uri or self._local_mlflow_backend()

    def configure_mlflow(self) -> None:
        """Configures the MLflow tracking URI and any additional credentials."""
        mlflow.set_tracking_uri(self.get_tracking_uri())

        if self.tracking_username:
            os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
        if self.tracking_password:
            os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
        if self.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
        os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
            "true" if self.tracking_insecure_tls else "false"
        )

    def prepare_step_run(self) -> None:
        """Sets the MLflow tracking uri and credentials."""
        self.configure_mlflow()

    def cleanup_step_run(self) -> None:
        """Resets the MLflow tracking uri."""
        mlflow.set_tracking_uri("")

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.

        Returns:
            An optional `StackValidator`.
        """
        if self.tracking_uri:
            # user specified a tracking uri, do nothing
            return None
        else:
            # try to fall back to a tracking uri inside the zenml artifact
            # store. this only works in case of a local artifact store, so we
            # make sure to prevent stack with other artifact stores for now
            return StackValidator(
                custom_validation_function=lambda stack: (
                    isinstance(stack.artifact_store, LocalArtifactStore),
                    "MLflow experiment tracker without a specified tracking "
                    "uri only works with a local artifact store.",
                )
            )

    @property
    def active_experiment(self) -> Optional[Experiment]:
        """Returns the currently active MLflow experiment.

        Returns:
            The active experiment or `None` if no experiment is active.
        """
        step_env = Environment().step_environment

        if not step_env:
            # we're not inside a step
            return None

        mlflow.set_experiment(experiment_name=step_env.pipeline_name)
        return mlflow.get_experiment_by_name(step_env.pipeline_name)

    @property
    def active_run(self) -> Optional[mlflow.ActiveRun]:
        """Returns the currently active MLflow run.

        Returns:
            The active MLflow run.
        """
        step_env = Environment().step_environment

        if not self.active_experiment or not step_env:
            return None

        experiment_id = self.active_experiment.experiment_id

        # TODO [ENG-458]: find a solution to avoid race-conditions while
        #  creating the same MLflow run from parallel steps
        runs = mlflow.search_runs(
            experiment_ids=[experiment_id],
            filter_string=f'tags.mlflow.runName = "{step_env.pipeline_run_id}"',
            output_format="list",
        )

        run_id = runs[0].info.run_id if runs else None

        current_active_run = mlflow.active_run()
        if current_active_run and current_active_run.info.run_id == run_id:
            return current_active_run

        return mlflow.start_run(
            run_id=run_id,
            run_name=step_env.pipeline_run_id,
            experiment_id=experiment_id,
        )
active_experiment: Optional[mlflow.entities.experiment.Experiment] property readonly

Returns the currently active MLflow experiment.

Returns:

Type Description
Optional[mlflow.entities.experiment.Experiment]

The active experiment or None if no experiment is active.

active_run: Optional[mlflow.tracking.fluent.ActiveRun] property readonly

Returns the currently active MLflow run.

Returns:

Type Description
Optional[mlflow.tracking.fluent.ActiveRun]

The active MLflow run.

validator: Optional[StackValidator] property readonly

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

cleanup_step_run(self)

Resets the MLflow tracking uri.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(self) -> None:
    """Resets the MLflow tracking uri."""
    mlflow.set_tracking_uri("")
configure_mlflow(self)

Configures the MLflow tracking URI and any additional credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    mlflow.set_tracking_uri(self.get_tracking_uri())

    if self.tracking_username:
        os.environ[MLFLOW_TRACKING_USERNAME] = self.tracking_username
    if self.tracking_password:
        os.environ[MLFLOW_TRACKING_PASSWORD] = self.tracking_password
    if self.tracking_token:
        os.environ[MLFLOW_TRACKING_TOKEN] = self.tracking_token
    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.tracking_insecure_tls else "false"
    )
get_tracking_uri(self)

Returns the configured tracking URI or a local fallback.

Returns:

Type Description
str

The tracking URI.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self) -> str:
    """Returns the configured tracking URI or a local fallback.

    Returns:
        The tracking URI.
    """
    return self.tracking_uri or self._local_mlflow_backend()
is_remote_tracking_uri(tracking_uri) staticmethod

Checks whether the given tracking uri is remote or not.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is remote, False otherwise.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
@staticmethod
def is_remote_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is remote or not.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is remote, `False` otherwise.
    """
    return any(
        tracking_uri.startswith(prefix)
        for prefix in ["http://", "https://"]
    )
prepare_step_run(self)

Sets the MLflow tracking uri and credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self) -> None:
    """Sets the MLflow tracking uri and credentials."""
    self.configure_mlflow()

mlflow_step_decorator

Implementation of the MLflow StepDecorator.

enable_mlflow(_step=None)

Decorator to enable mlflow for a step function.

Apply this decorator to a ZenML pipeline step to enable MLflow experiment tracking. The MLflow tracking configuration (tracking URI, experiment name, run name) will be automatically configured before the step code is executed, so the step can simply use the mlflow module to log metrics and artifacts, like so:

@enable_mlflow
@step
def tf_evaluator(
    x_test: np.ndarray,
    y_test: np.ndarray,
    model: tf.keras.Model,
) -> float:
    _, test_acc = model.evaluate(x_test, y_test, verbose=2)
    mlflow.log_metric("val_accuracy", test_acc)
    return test_acc

You can also use this decorator with our class-based API like so:

@enable_mlflow
class TFEvaluator(BaseStep):
    def entrypoint(
        self,
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        ...

All MLflow artifacts and metrics logged from all the steps in a pipeline run are by default grouped under a single experiment named after the pipeline. To log MLflow artifacts and metrics from a step in a separate MLflow experiment, pass a custom experiment_name argument value to the decorator.

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

The inner decorator which enhaces the input step class with mlflow tracking functionality

Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def enable_mlflow(
    _step: Optional[S] = None,
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable mlflow for a step function.

    Apply this decorator to a ZenML pipeline step to enable MLflow experiment
    tracking. The MLflow tracking configuration (tracking URI, experiment name,
    run name) will be automatically configured before the step code is executed,
    so the step can simply use the `mlflow` module to log metrics and artifacts,
    like so:

    ```python
    @enable_mlflow
    @step
    def tf_evaluator(
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        _, test_acc = model.evaluate(x_test, y_test, verbose=2)
        mlflow.log_metric("val_accuracy", test_acc)
        return test_acc
    ```

    You can also use this decorator with our class-based API like so:

    ```
    @enable_mlflow
    class TFEvaluator(BaseStep):
        def entrypoint(
            self,
            x_test: np.ndarray,
            y_test: np.ndarray,
            model: tf.keras.Model,
        ) -> float:
            ...
    ```

    All MLflow artifacts and metrics logged from all the steps in a pipeline
    run are by default grouped under a single experiment named after the
    pipeline. To log MLflow artifacts and metrics from a step in a separate
    MLflow experiment, pass a custom `experiment_name` argument value to the
    decorator.

    Args:
        _step: The decorated step class.

    Returns:
        The inner decorator which enhaces the input step class with mlflow
        tracking functionality
    """

    def inner_decorator(_step: S) -> S:

        logger.debug(
            "Applying 'enable_mlflow' decorator to step %s", _step.__name__
        )
        if not issubclass(_step, BaseStep):
            raise RuntimeError(
                "The `enable_mlflow` decorator can only be applied to a ZenML "
                "`step` decorated function or a BaseStep subclass."
            )
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = mlflow_step_entrypoint()(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
mlflow_step_entrypoint()

Decorator for a step entrypoint to enable mlflow.

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with mlflow profiling functionality

Source code in zenml/integrations/mlflow/mlflow_step_decorator.py
def mlflow_step_entrypoint() -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable mlflow.

    Returns:
        the input function enhanced with mlflow profiling functionality
    """

    def inner_decorator(func: F) -> F:

        logger.debug(
            "Applying 'mlflow_step_entrypoint' decorator to step entrypoint %s",
            func.__name__,
        )

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa
            logger.debug(
                "Setting up MLflow backend before running step entrypoint %s",
                func.__name__,
            )
            experiment_tracker = Repository(  # type: ignore[call-arg]
                skip_repository_check=True
            ).active_stack.experiment_tracker

            if not isinstance(experiment_tracker, MLFlowExperimentTracker):
                raise get_missing_mlflow_experiment_tracker_error()

            active_run = experiment_tracker.active_run
            if not active_run:
                raise RuntimeError("No active mlflow run configured.")

            with active_run:
                return func(*args, **kwargs)

        return cast(F, wrapper)

    return inner_decorator

mlflow_utils

Implementation of utils specific to the MLflow integration.

get_missing_mlflow_experiment_tracker_error()

Returns description of how to add an MLflow experiment tracker to your stack.

Returns:

Type Description
ValueError

If no MLflow experiment tracker is registered in the active stack.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
    """Returns description of how to add an MLflow experiment tracker to your stack.

    Returns:
        ValueError: If no MLflow experiment tracker is registered in the active stack.
    """
    return ValueError(
        "The active stack needs to have a MLflow experiment tracker "
        "component registered to be able to track experiments using "
        "MLflow. You can create a new stack with a MLflow experiment "
        "tracker component or update your existing stack to add this "
        "component, e.g.:\n\n"
        "  'zenml experiment-tracker register mlflow_tracker "
        "--type=mlflow'\n"
        "  'zenml stack register stack-name -e mlflow_tracker ...'\n"
    )
get_tracking_uri()

Gets the MLflow tracking URI from the active experiment tracking stack component.

noqa: DAR401

Returns:

Type Description
str

MLflow tracking URI.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
    """Gets the MLflow tracking URI from the active experiment tracking stack component.

    # noqa: DAR401

    Returns:
        MLflow tracking URI.
    """
    tracker = Repository().active_stack.experiment_tracker
    if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    return tracker.get_tracking_uri()

model_deployers special

Initialization of the MLflow model deployers.

mlflow_model_deployer

Implementation of the MLflow model deployer.

MLFlowModelDeployer (BaseModelDeployer) pydantic-model

MLflow implementation of the BaseModelDeployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
    """MLflow implementation of the BaseModelDeployer.

    Attributes:
        service_path: the path where the local MLflow deployment service
        configuration, PID and log files are stored.
    """

    service_path: str = ""

    # Class Configuration
    FLAVOR: ClassVar[str] = MLFLOW_MODEL_DEPLOYER_FLAVOR

    @root_validator(skip_on_failure=True)
    def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets the service_path attribute value according to the component UUID.

        Args:
            values: the dictionary of values to be validated.

        Returns:
            The validated dictionary of values.
        """
        if values.get("service_path"):
            return values

        # not likely to happen, due to Pydantic validation, but mypy complains
        assert "uuid" in values

        values["service_path"] = cls.get_service_path(values["uuid"])
        return values

    @staticmethod
    def get_service_path(uuid: uuid.UUID) -> str:
        """Get the path where local MLflow service information is stored.

        This includes the deployment service configuration, PID and log files are stored.

        Args:
            uuid: The UUID of the MLflow model deployer.

        Returns:
            The service path.
        """
        service_path = os.path.join(
            get_global_config_directory(),
            LOCAL_STORES_DIRECTORY_NAME,
            str(uuid),
        )
        create_dir_recursive_if_not_exists(service_path)
        return service_path

    @property
    def local_path(self) -> str:
        """Returns the path to the root directory.

        This is where all configurations for MLflow deployment daemon processes are stored.

        Returns:
            The path to the local service root directory.
        """
        return self.service_path

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "MLFlowDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            A dictionary containing the information.
        """
        return {
            "PREDICTION_URL": service_instance.endpoint.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "SERVICE_PATH": service_instance.status.runtime_path,
            "DAEMON_PID": str(service_instance.status.pid),
        }

    @staticmethod
    def get_active_model_deployer() -> "MLFlowModelDeployer":
        """Returns the MLFlowModelDeployer component of the active stack.

        Args:
            None

        Returns:
            The MLFlowModelDeployer component of the active stack.

        Raises:
            TypeError: If the active stack does not contain an MLFlowModelDeployer component.
        """
        model_deployer = Repository(  # type: ignore[call-arg]
            skip_repository_check=True
        ).active_stack.model_deployer

        if not model_deployer or not isinstance(
            model_deployer, MLFlowModelDeployer
        ):
            raise TypeError(
                f"The active stack needs to have an MLflow model deployer "
                f"component registered to be able to deploy models with MLflow. "
                f"You can create a new stack with an MLflow model "
                f"deployer component or update your existing stack to add this "
                f"component, e.g.:\n\n"
                f"  'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
                f"  'zenml stack create stack-name -d mlflow ...'\n"
            )
        return model_deployer

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new MLflow deployment service or update an existing one.

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new MLflow
            deployment server to reflect the model and other configuration
            parameters specified in the supplied MLflow service `config`.

          * if `replace` is True, this method will first attempt to find an
            existing MLflow deployment service that is *equivalent* to the
            supplied configuration parameters. Two or more MLflow deployment
            services are considered equivalent if they have the same
            `pipeline_name`, `pipeline_step_name` and `model_name` configuration
            parameters. To put it differently, two MLflow deployment services
            are equivalent if they serve versions of the same model deployed by
            the same pipeline step. If an equivalent MLflow deployment is found,
            it will be updated in place to reflect the new configuration
            parameters.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new MLflow deployment
        server for each new model version. If multiple equivalent MLflow
        deployment servers are found, one is selected at random to be updated
        and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with MLflow.
            replace: set this flag to True to find and update an equivalent
                MLflow deployment server with the new model instead of
                creating and starting a new deployment server.
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the MLflow
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML MLflow deployment service object that can be used to
            interact with the MLflow model server.
        """
        config = cast(MLFlowDeploymentConfig, config)
        service = None

        # if replace is True, remove all existing services
        if replace is True:
            existing_services = self.find_model_server(
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for existing_service in existing_services:
                if service is None:
                    # keep the most recently created service
                    service = cast(MLFlowDeploymentService, existing_service)
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    self._clean_up_existing_service(
                        existing_service=cast(
                            MLFlowDeploymentService, existing_service
                        ),
                        timeout=timeout,
                        force=True,
                    )
                except RuntimeError:
                    # ignore errors encountered while stopping old services
                    pass
        if service:
            logger.info(
                f"Updating an existing MLflow deployment service: {service}"
            )

            # set the root runtime path with the stack component's UUID
            config.root_runtime_path = self.local_path
            service.stop(timeout=timeout, force=True)
            service.update(config)
            service.start(timeout=timeout)
        else:
            # create a new MLFlowDeploymentService instance
            service = self._create_new_service(timeout, config)
            logger.info(f"Created a new MLflow deployment service: {service}")

        return cast(BaseService, service)

    def _clean_up_existing_service(
        self,
        timeout: int,
        force: bool,
        existing_service: MLFlowDeploymentService,
    ) -> None:
        # stop the older service
        existing_service.stop(timeout=timeout, force=force)

        # delete the old configuration file
        service_directory_path = existing_service.status.runtime_path or ""
        shutil.rmtree(service_directory_path)

    # the step will receive a config from the user that mentions the number
    # of workers etc.the step implementation will create a new config using
    # all values from the user and add values like pipeline name, model_uri
    def _create_new_service(
        self, timeout: int, config: MLFlowDeploymentConfig
    ) -> MLFlowDeploymentService:
        """Creates a new MLFlowDeploymentService.

        Args:
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated.
            config: the configuration of the model to be deployed with MLflow.

        Returns:
            The MLFlowDeploymentService object that can be used to interact
            with the MLflow model server.
        """
        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        # create a new service for the new model
        service = MLFlowDeploymentService(config)
        service.start(timeout=timeout)

        return service

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_type: Optional[str] = None,
    ) -> List[BaseService]:
        """Finds one or more model servers that match the given criteria.

        Args:
            running: If true, only running services will be returned.
            service_uuid: The UUID of the service that was originally used
                to deploy the model.
            pipeline_name: Name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model
                was part of.
            pipeline_step_name: The name of the pipeline model deployment step
                that deployed the model.
            model_name: Name of the deployed model.
            model_uri: URI of the deployed model.
            model_type: Type/format of the deployed model. Not used in this
                MLflow case.

        Returns:
            One or more Service objects representing model servers that match
            the input search criteria.

        Raises:
            TypeError: if any of the input arguments are of an invalid type.
        """
        services = []
        config = MLFlowDeploymentConfig(
            model_name=model_name or "",
            model_uri=model_uri or "",
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
        )

        # find all services that match the input criteria
        for root, _, files in os.walk(self.local_path):
            if service_uuid and Path(root).name != str(service_uuid):
                continue
            for file in files:
                if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                    service_config_path = os.path.join(root, file)
                    logger.debug(
                        "Loading service daemon configuration from %s",
                        service_config_path,
                    )
                    existing_service_config = None
                    with open(service_config_path, "r") as f:
                        existing_service_config = f.read()
                    existing_service = ServiceRegistry().load_service_from_json(
                        existing_service_config
                    )
                    if not isinstance(
                        existing_service, MLFlowDeploymentService
                    ):
                        raise TypeError(
                            f"Expected service type MLFlowDeploymentService but got "
                            f"{type(existing_service)} instead"
                        )
                    existing_service.update_status()
                    if self._matches_search_criteria(existing_service, config):
                        if not running or existing_service.is_running:
                            services.append(cast(BaseService, existing_service))

        return services

    def _matches_search_criteria(
        self,
        existing_service: MLFlowDeploymentService,
        config: MLFlowDeploymentConfig,
    ) -> bool:
        """Returns true if a service matches the input criteria.

        If any of the values in the input criteria are None, they are ignored.
        This allows listing services just by common pipeline names or step
        names, etc.

        Args:
            existing_service: The materialized Service instance derived from
                the config of the older (existing) service
            config: The MLFlowDeploymentConfig object passed to the
                deploy_model function holding parameters of the new service
                to be created.

        Returns:
            True if the service matches the input criteria.
        """
        existing_service_config = existing_service.config

        # check if the existing service matches the input criteria
        if (
            (
                not config.pipeline_name
                or existing_service_config.pipeline_name == config.pipeline_name
            )
            and (
                not config.model_name
                or existing_service_config.model_name == config.model_name
            )
            and (
                not config.pipeline_step_name
                or existing_service_config.pipeline_step_name
                == config.pipeline_step_name
            )
            and (
                not config.pipeline_run_id
                or existing_service_config.pipeline_run_id
                == config.pipeline_run_id
            )
        ):
            return True

        return False

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to stop a model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, stop it
        if existing_services:
            existing_services[0].stop(timeout=timeout, force=force)

    def start_model_server(
        self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
    ) -> None:
        """Method to start a model server.

        Args:
            uuid: UUID of the model server to start.
            timeout: Timeout in seconds to wait for the service to start.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, start it
        if existing_services:
            existing_services[0].start(timeout=timeout)

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to delete all configuration of a model server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, clean it up
        if existing_services:
            service = cast(MLFlowDeploymentService, existing_services[0])
            self._clean_up_existing_service(
                existing_service=service, timeout=timeout, force=force
            )
local_path: str property readonly

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

Returns:

Type Description
str

The path to the local service root directory.

delete_model_server(self, uuid, timeout=10, force=False)

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

10
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, clean it up
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])
        self._clean_up_existing_service(
            existing_service=service, timeout=timeout, force=force
        )
deploy_model(self, config, replace=False, timeout=10)

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
replace bool

set this flag to True to find and update an equivalent MLflow deployment server with the new model instead of creating and starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

10

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to interact with the MLflow model server.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with MLflow.
        replace: set this flag to True to find and update an equivalent
            MLflow deployment server with the new model instead of
            creating and starting a new deployment server.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = None

    # if replace is True, remove all existing services
    if replace is True:
        existing_services = self.find_model_server(
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for existing_service in existing_services:
            if service is None:
                # keep the most recently created service
                service = cast(MLFlowDeploymentService, existing_service)
            try:
                # delete the older services and don't wait for them to
                # be deprovisioned
                self._clean_up_existing_service(
                    existing_service=cast(
                        MLFlowDeploymentService, existing_service
                    ),
                    timeout=timeout,
                    force=True,
                )
            except RuntimeError:
                # ignore errors encountered while stopping old services
                pass
    if service:
        logger.info(
            f"Updating an existing MLflow deployment service: {service}"
        )

        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        service.stop(timeout=timeout, force=True)
        service.update(config)
        service.start(timeout=timeout)
    else:
        # create a new MLFlowDeploymentService instance
        service = self._create_new_service(timeout, config)
        logger.info(f"Created a new MLflow deployment service: {service}")

    return cast(BaseService, service)
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)

Finds one or more model servers that match the given criteria.

Parameters:

Name Type Description Default
running bool

If true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

The UUID of the service that was originally used to deploy the model.

None
pipeline_name Optional[str]

Name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

The name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

Name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
model_type Optional[str]

Type/format of the deployed model. Not used in this MLflow case.

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Service objects representing model servers that match the input search criteria.

Exceptions:

Type Description
TypeError

if any of the input arguments are of an invalid type.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_type: Optional[str] = None,
) -> List[BaseService]:
    """Finds one or more model servers that match the given criteria.

    Args:
        running: If true, only running services will be returned.
        service_uuid: The UUID of the service that was originally used
            to deploy the model.
        pipeline_name: Name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model
            was part of.
        pipeline_step_name: The name of the pipeline model deployment step
            that deployed the model.
        model_name: Name of the deployed model.
        model_uri: URI of the deployed model.
        model_type: Type/format of the deployed model. Not used in this
            MLflow case.

    Returns:
        One or more Service objects representing model servers that match
        the input search criteria.

    Raises:
        TypeError: if any of the input arguments are of an invalid type.
    """
    services = []
    config = MLFlowDeploymentConfig(
        model_name=model_name or "",
        model_uri=model_uri or "",
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
    )

    # find all services that match the input criteria
    for root, _, files in os.walk(self.local_path):
        if service_uuid and Path(root).name != str(service_uuid):
            continue
        for file in files:
            if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                service_config_path = os.path.join(root, file)
                logger.debug(
                    "Loading service daemon configuration from %s",
                    service_config_path,
                )
                existing_service_config = None
                with open(service_config_path, "r") as f:
                    existing_service_config = f.read()
                existing_service = ServiceRegistry().load_service_from_json(
                    existing_service_config
                )
                if not isinstance(
                    existing_service, MLFlowDeploymentService
                ):
                    raise TypeError(
                        f"Expected service type MLFlowDeploymentService but got "
                        f"{type(existing_service)} instead"
                    )
                existing_service.update_status()
                if self._matches_search_criteria(existing_service, config):
                    if not running or existing_service.is_running:
                        services.append(cast(BaseService, existing_service))

    return services
get_active_model_deployer() staticmethod

Returns the MLFlowModelDeployer component of the active stack.

Returns:

Type Description
MLFlowModelDeployer

The MLFlowModelDeployer component of the active stack.

Exceptions:

Type Description
TypeError

If the active stack does not contain an MLFlowModelDeployer component.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "MLFlowModelDeployer":
    """Returns the MLFlowModelDeployer component of the active stack.

    Args:
        None

    Returns:
        The MLFlowModelDeployer component of the active stack.

    Raises:
        TypeError: If the active stack does not contain an MLFlowModelDeployer component.
    """
    model_deployer = Repository(  # type: ignore[call-arg]
        skip_repository_check=True
    ).active_stack.model_deployer

    if not model_deployer or not isinstance(
        model_deployer, MLFlowModelDeployer
    ):
        raise TypeError(
            f"The active stack needs to have an MLflow model deployer "
            f"component registered to be able to deploy models with MLflow. "
            f"You can create a new stack with an MLflow model "
            f"deployer component or update your existing stack to add this "
            f"component, e.g.:\n\n"
            f"  'zenml model-deployer register mlflow --flavor={MLFLOW_MODEL_DEPLOYER_FLAVOR}'\n"
            f"  'zenml stack create stack-name -d mlflow ...'\n"
        )
    return model_deployer
get_model_server_info(service_instance) staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
    }
get_service_path(uuid) staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
uuid UUID

The UUID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(uuid: uuid.UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files are stored.

    Args:
        uuid: The UUID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        get_global_config_directory(),
        LOCAL_STORES_DIRECTORY_NAME,
        str(uuid),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
set_service_path(values) classmethod

Sets the service_path attribute value according to the component UUID.

Parameters:

Name Type Description Default
values Dict[str, Any]

the dictionary of values to be validated.

required

Returns:

Type Description
Dict[str, Any]

The validated dictionary of values.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@root_validator(skip_on_failure=True)
def set_service_path(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets the service_path attribute value according to the component UUID.

    Args:
        values: the dictionary of values to be validated.

    Returns:
        The validated dictionary of values.
    """
    if values.get("service_path"):
        return values

    # not likely to happen, due to Pydantic validation, but mypy complains
    assert "uuid" in values

    values["service_path"] = cls.get_service_path(values["uuid"])
    return values
start_model_server(self, uuid, timeout=10)

Method to start a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

Timeout in seconds to wait for the service to start.

10
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def start_model_server(
    self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
    """Method to start a model server.

    Args:
        uuid: UUID of the model server to start.
        timeout: Timeout in seconds to wait for the service to start.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, start it
    if existing_services:
        existing_services[0].start(timeout=timeout)
stop_model_server(self, uuid, timeout=10, force=False)

Method to stop a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

10
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to stop a model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, stop it
    if existing_services:
        existing_services[0].stop(timeout=timeout, force=force)

services special

Initialization of the MLflow Service.

mlflow_deployment

Implementation of the MLflow deployment functionality.

MLFlowDeploymentConfig (LocalDaemonServiceConfig) pydantic-model

MLflow model deployment configuration.

Attributes:

Name Type Description
model_uri str

URI of the MLflow model to serve

model_name str

the name of the model

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
    """MLflow model deployment configuration.

    Attributes:
        model_uri: URI of the MLflow model to serve
        model_name: the name of the model
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
    """

    model_uri: str
    model_name: str
    workers: int = 1
    mlserver: bool = False
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint) pydantic-model

A service endpoint exposed by the MLflow deployment daemon.

Attributes:

Name Type Description
config MLFlowDeploymentEndpointConfig

service endpoint configuration

monitor HTTPEndpointHealthMonitor

optional service endpoint health monitor

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
    """A service endpoint exposed by the MLflow deployment daemon.

    Attributes:
        config: service endpoint configuration
        monitor: optional service endpoint health monitor
    """

    config: MLFlowDeploymentEndpointConfig
    monitor: HTTPEndpointHealthMonitor

    @property
    def prediction_url(self) -> Optional[str]:
        """Gets the prediction URL for the endpoint.

        Returns:
            the prediction URL for the endpoint
        """
        uri = self.status.uri
        if not uri:
            return None
        return f"{uri}{self.config.prediction_url_path}"
prediction_url: Optional[str] property readonly

Gets the prediction URL for the endpoint.

Returns:

Type Description
Optional[str]

the prediction URL for the endpoint

MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig) pydantic-model

MLflow daemon service endpoint configuration.

Attributes:

Name Type Description
prediction_url_path str

URI subpath for prediction requests

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
    """MLflow daemon service endpoint configuration.

    Attributes:
        prediction_url_path: URI subpath for prediction requests
    """

    prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService) pydantic-model

MLflow deployment service used to start a local prediction server for MLflow models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the MLflow deployment service class

config MLFlowDeploymentConfig

service configuration

endpoint MLFlowDeploymentEndpoint

optional service endpoint

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService):
    """MLflow deployment service used to start a local prediction server for MLflow models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the MLflow deployment service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="mlflow-deployment",
        type="model-serving",
        flavor="mlflow",
        description="MLflow prediction service",
    )

    config: MLFlowDeploymentConfig
    endpoint: MLFlowDeploymentEndpoint

    def __init__(
        self,
        config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialize the MLflow deployment service.

        Args:
            config: service configuration
            attrs: additional attributes to set on the service
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-700]: implement a service factory or builder for MLflow
        #   deployment services
        if (
            isinstance(config, MLFlowDeploymentConfig)
            and "endpoint" not in attrs
        ):
            if config.mlserver:
                prediction_url_path = MLSERVER_PREDICTION_URL_PATH
                healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
                use_head_request = False
            else:
                prediction_url_path = MLFLOW_PREDICTION_URL_PATH
                healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
                use_head_request = True

            endpoint = MLFlowDeploymentEndpoint(
                config=MLFlowDeploymentEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                    prediction_url_path=prediction_url_path,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path=healthcheck_uri_path,
                        use_head_request=use_head_request,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Start the service."""
        logger.info(
            "Starting MLflow prediction service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            serve_kwargs: Dict[str, Any] = {}
            # MLflow version 1.26 introduces an additional mandatory
            # `timeout` argument to the `PyFuncBackend.serve` function
            if int(MLFLOW_VERSION.split(".")[1]) >= 26:
                serve_kwargs["timeout"] = None

            backend = PyFuncBackend(
                config={},
                no_conda=True,
                workers=self.config.workers,
                install_mlflow=False,
            )
            backend.serve(
                model_uri=self.config.model_uri,
                port=self.endpoint.status.port,
                host="localhost",
                enable_mlserver=self.config.mlserver,
                **serve_kwargs,
            )
        except KeyboardInterrupt:
            logger.info(
                "MLflow prediction service stopped. Resuming normal execution."
            )

    @property
    def prediction_url(self) -> Optional[str]:
        """Get the URI where the prediction service is answering requests.

        Returns:
            The URI where the prediction service can be contacted to process
            HTTP/REST inference requests, or None, if the service isn't running.
        """
        if not self.is_running:
            return None
        return self.endpoint.prediction_url

    def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a numpy array representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not running
            ValueError: if the prediction endpoint is unknown.
        """
        if not self.is_running:
            raise Exception(
                "MLflow prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.endpoint.prediction_url is not None:
            response = requests.post(
                self.endpoint.prediction_url,
                json={"instances": request.tolist()},
            )
        else:
            raise ValueError("No endpoint known for prediction.")
        response.raise_for_status()
        return np.array(response.json())
prediction_url: Optional[str] property readonly

Get the URI where the prediction service is answering requests.

Returns:

Type Description
Optional[str]

The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running.

__init__(self, config, **attrs) special

Initialize the MLflow deployment service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]]

service configuration

required
attrs Any

additional attributes to set on the service

{}
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
    self,
    config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialize the MLflow deployment service.

    Args:
        config: service configuration
        attrs: additional attributes to set on the service
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-700]: implement a service factory or builder for MLflow
    #   deployment services
    if (
        isinstance(config, MLFlowDeploymentConfig)
        and "endpoint" not in attrs
    ):
        if config.mlserver:
            prediction_url_path = MLSERVER_PREDICTION_URL_PATH
            healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
            use_head_request = False
        else:
            prediction_url_path = MLFLOW_PREDICTION_URL_PATH
            healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
            use_head_request = True

        endpoint = MLFlowDeploymentEndpoint(
            config=MLFlowDeploymentEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                prediction_url_path=prediction_url_path,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path=healthcheck_uri_path,
                    use_head_request=use_head_request,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request NDArray[Any]

a numpy array representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not running

ValueError

if the prediction endpoint is unknown.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a numpy array representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not running
        ValueError: if the prediction endpoint is unknown.
    """
    if not self.is_running:
        raise Exception(
            "MLflow prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.endpoint.prediction_url is not None:
        response = requests.post(
            self.endpoint.prediction_url,
            json={"instances": request.tolist()},
        )
    else:
        raise ValueError("No endpoint known for prediction.")
    response.raise_for_status()
    return np.array(response.json())
run(self)

Start the service.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
    """Start the service."""
    logger.info(
        "Starting MLflow prediction service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        serve_kwargs: Dict[str, Any] = {}
        # MLflow version 1.26 introduces an additional mandatory
        # `timeout` argument to the `PyFuncBackend.serve` function
        if int(MLFLOW_VERSION.split(".")[1]) >= 26:
            serve_kwargs["timeout"] = None

        backend = PyFuncBackend(
            config={},
            no_conda=True,
            workers=self.config.workers,
            install_mlflow=False,
        )
        backend.serve(
            model_uri=self.config.model_uri,
            port=self.endpoint.status.port,
            host="localhost",
            enable_mlserver=self.config.mlserver,
            **serve_kwargs,
        )
    except KeyboardInterrupt:
        logger.info(
            "MLflow prediction service stopped. Resuming normal execution."
        )

steps special

Initialization of the MLflow standard interface steps.

mlflow_deployer

Implementation of the MLflow model deployer pipeline step.

MLFlowDeployerConfig (BaseStepConfig) pydantic-model

Model deployer step configuration for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
    """Model deployer step configuration for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    model_uri: str = ""
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
mlflow_model_deployer_step (BaseStep)

Model deployer pipeline step for MLflow.

noqa: DAR401

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
model

the model artifact to deploy

required
config

configuration for the deployer step

required

Returns:

Type Description

MLflow deployment service

CONFIG_CLASS (BaseStepConfig) pydantic-model

Model deployer step configuration for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerConfig(BaseStepConfig):
    """Model deployer step configuration for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    model_uri: str = ""
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(deploy_decision, model, config) staticmethod

Model deployer pipeline step for MLflow.

noqa: DAR401

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
model ModelArtifact

the model artifact to deploy

required
config MLFlowDeployerConfig

configuration for the deployer step

required

Returns:

Type Description
MLFlowDeploymentService

MLflow deployment service

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@enable_mlflow
@step(enable_cache=False)
def mlflow_model_deployer_step(
    deploy_decision: bool,
    model: ModelArtifact,
    config: MLFlowDeployerConfig,
) -> MLFlowDeploymentService:
    """Model deployer pipeline step for MLflow.

    # noqa: DAR401

    Args:
        deploy_decision: whether to deploy the model or not
        model: the model artifact to deploy
        config: configuration for the deployer step

    Returns:
        MLflow deployment service
    """
    model_deployer = MLFlowModelDeployer.get_active_model_deployer()

    # fetch the MLflow artifacts logged during the pipeline run
    experiment_tracker = Repository(  # type: ignore[call-arg]
        skip_repository_check=True
    ).active_stack.experiment_tracker

    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    client = MlflowClient()
    model_uri = ""
    mlflow_run = experiment_tracker.active_run
    if mlflow_run and client.list_artifacts(
        mlflow_run.info.run_id, config.model_name
    ):
        model_uri = get_artifact_uri(config.model_name)

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    # fetch existing services with same pipeline name, step name and model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=config.model_name,
    )

    # create a config for the new model service
    predictor_cfg = MLFlowDeploymentConfig(
        model_name=config.model_name or "",
        model_uri=model_uri,
        workers=config.workers,
        mlserver=config.mlserver,
        pipeline_name=pipeline_name,
        pipeline_run_id=run_id,
        pipeline_step_name=step_name,
    )

    # Creating a new service with inactive state and status by default
    service = MLFlowDeploymentService(predictor_cfg)
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])

    # check for conditions to deploy the model
    if not model_uri:
        # an MLflow model was not trained in the current run, so we simply reuse
        # the currently running service created for the same model, if any
        if not existing_services:
            logger.warning(
                f"An MLflow model with name `{config.model_name}` was not "
                f"logged in the current pipeline run and no running MLflow "
                f"model server was found. Please ensure that your pipeline "
                f"includes an `@enable_mlflow` decorated step that trains a "
                f"model and logs it to MLflow. This could also happen if "
                f"the current pipeline run did not log an MLflow model  "
                f"because the training step was cached."
            )
            # return an inactive service just because we have to return
            # something
            return service
        logger.info(
            f"An MLflow model with name `{config.model_name}` was not "
            f"trained in the current pipeline run. Reusing the existing "
            f"MLflow model server."
        )
        if not service.is_running:
            service.start(config.timeout)

        # return the existing service
        return service

    # even when the deploy decision is negative, if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{config.model_name}'..."
        )
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure
        # that a model server is available at all times
        if not service.is_running:
            service.start(config.timeout)
        return service

    # create a new model deployment and replace an old one if it exists
    new_service = cast(
        MLFlowDeploymentService,
        model_deployer.deploy_model(
            replace=True,
            config=predictor_cfg,
            timeout=config.timeout,
        ),
    )

    logger.info(
        f"MLflow deployment service started and reachable at:\n"
        f"    {new_service.prediction_url}\n"
    )

    return new_service
mlflow_deployer_step(enable_cache=True, name=None)

Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.

Parameters:

Name Type Description Default
enable_cache bool

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default

True
name Optional[str]

Name of the step.

None

Returns:

Type Description
Type[zenml.steps.base_step.BaseStep]

an MLflow model deployer pipeline step

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
    enable_cache: bool = True,
    name: Optional[str] = None,
) -> Type[BaseStep]:
    """Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

    The returned step can be used in a pipeline to implement continuous
    deployment for an MLflow model.

    Args:
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default
        name: Name of the step.

    Returns:
        an MLflow model deployer pipeline step
    """
    logger.warning(
        "The `mlflow_deployer_step` function is deprecated. Please "
        "use the built-in `mlflow_model_deployer_step` step instead."
    )
    return mlflow_model_deployer_step

neural_prophet special

Initialization of the Neural Prophet integration.

NeuralProphetIntegration (Integration)

Definition of NeuralProphet integration for ZenML.

Source code in zenml/integrations/neural_prophet/__init__.py
class NeuralProphetIntegration(Integration):
    """Definition of NeuralProphet integration for ZenML."""

    NAME = NEURAL_PROPHET
    REQUIREMENTS = ["neuralprophet>=0.3.2"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.neural_prophet import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/neural_prophet/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.neural_prophet import materializers  # noqa

materializers special

Initialization of the Neural Prophet materializer.

neural_prophet_materializer

Implementation of the Neural Prophet materializer.

NeuralProphetMaterializer (BaseMaterializer)

Materializer to read/write NeuralProphet models.

Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
class NeuralProphetMaterializer(BaseMaterializer):
    """Materializer to read/write NeuralProphet models."""

    ASSOCIATED_TYPES = (NeuralProphet,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
        """Reads and returns a NeuralProphet model.

        Args:
            data_type: A NeuralProphet model object.

        Returns:
            A loaded NeuralProphet model.
        """
        super().handle_input(data_type)
        return torch.load(  # type: ignore[no-untyped-call]
            os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        )  # noqa

    def handle_return(self, model: NeuralProphet) -> None:
        """Writes a NeuralProphet model.

        Args:
            model: A NeuralProphet model object.
        """
        super().handle_return(model)
        torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))
handle_input(self, data_type)

Reads and returns a NeuralProphet model.

Parameters:

Name Type Description Default
data_type Type[Any]

A NeuralProphet model object.

required

Returns:

Type Description
NeuralProphet

A loaded NeuralProphet model.

Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_input(self, data_type: Type[Any]) -> NeuralProphet:
    """Reads and returns a NeuralProphet model.

    Args:
        data_type: A NeuralProphet model object.

    Returns:
        A loaded NeuralProphet model.
    """
    super().handle_input(data_type)
    return torch.load(  # type: ignore[no-untyped-call]
        os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    )  # noqa
handle_return(self, model)

Writes a NeuralProphet model.

Parameters:

Name Type Description Default
model NeuralProphet

A NeuralProphet model object.

required
Source code in zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py
def handle_return(self, model: NeuralProphet) -> None:
    """Writes a NeuralProphet model.

    Args:
        model: A NeuralProphet model object.
    """
    super().handle_return(model)
    torch.save(model, os.path.join(self.artifact.uri, DEFAULT_FILENAME))

plotly special

Initialization of the Plotly integration.

PlotlyIntegration (Integration)

Definition of Plotly integration for ZenML.

Source code in zenml/integrations/plotly/__init__.py
class PlotlyIntegration(Integration):
    """Definition of Plotly integration for ZenML."""

    NAME = PLOTLY
    REQUIREMENTS = ["plotly>=5.4.0"]

visualizers special

Initialization of the Plotly Visualizer.

pipeline_lineage_visualizer

Implementation of the Plotly Pipeline Lineage Visualizer.

PipelineLineageVisualizer (BasePipelineVisualizer)

Visualize the lineage of runs in a pipeline using plotly.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
class PipelineLineageVisualizer(BasePipelineVisualizer):
    """Visualize the lineage of runs in a pipeline using plotly."""

    @abstractmethod
    def visualize(
        self, object: PipelineView, *args: Any, **kwargs: Any
    ) -> Figure:
        """Creates a pipeline lineage diagram using plotly.

        Args:
            object: The pipeline view to visualize.
            *args: Additional arguments to pass to the visualization.
            **kwargs: Additional keyword arguments to pass to the visualization.

        Returns:
            A plotly figure.
        """
        logger.warning(
            "This integration is not completed yet. Results might be unexpected."
        )

        category_dict = {}
        dimensions = ["run"]
        for run in object.runs:
            category_dict[run.name] = {"run": run.name}
            for step in run.steps:
                category_dict[run.name].update(
                    {
                        step.entrypoint_name: str(step.id),
                    }
                )
                if step.entrypoint_name not in dimensions:
                    dimensions.append(f"{step.entrypoint_name}")

        category_df = pd.DataFrame.from_dict(category_dict, orient="index")

        category_df = category_df.reset_index()

        fig = px.parallel_categories(
            category_df,
            dimensions,
            color=None,
            labels="status",
        )

        fig.show()
        return fig
visualize(self, object, *args, **kwargs)

Creates a pipeline lineage diagram using plotly.

Parameters:

Name Type Description Default
object PipelineView

The pipeline view to visualize.

required
*args Any

Additional arguments to pass to the visualization.

()
**kwargs Any

Additional keyword arguments to pass to the visualization.

{}

Returns:

Type Description
Figure

A plotly figure.

Source code in zenml/integrations/plotly/visualizers/pipeline_lineage_visualizer.py
@abstractmethod
def visualize(
    self, object: PipelineView, *args: Any, **kwargs: Any
) -> Figure:
    """Creates a pipeline lineage diagram using plotly.

    Args:
        object: The pipeline view to visualize.
        *args: Additional arguments to pass to the visualization.
        **kwargs: Additional keyword arguments to pass to the visualization.

    Returns:
        A plotly figure.
    """
    logger.warning(
        "This integration is not completed yet. Results might be unexpected."
    )

    category_dict = {}
    dimensions = ["run"]
    for run in object.runs:
        category_dict[run.name] = {"run": run.name}
        for step in run.steps:
            category_dict[run.name].update(
                {
                    step.entrypoint_name: str(step.id),
                }
            )
            if step.entrypoint_name not in dimensions:
                dimensions.append(f"{step.entrypoint_name}")

    category_df = pd.DataFrame.from_dict(category_dict, orient="index")

    category_df = category_df.reset_index()

    fig = px.parallel_categories(
        category_df,
        dimensions,
        color=None,
        labels="status",
    )

    fig.show()
    return fig

pytorch special

Initialization of the PyTorch integration.

PytorchIntegration (Integration)

Definition of PyTorch integration for ZenML.

Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
    """Definition of PyTorch integration for ZenML."""

    NAME = PYTORCH
    REQUIREMENTS = ["torch"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa

materializers special

Initialization of the PyTorch Materializer.

pytorch_dataloader_materializer

Implementation of the PyTorch DataLoader materializer.

PyTorchDataLoaderMaterializer (BaseMaterializer)

Materializer to read/write PyTorch dataloaders.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch dataloaders."""

    ASSOCIATED_TYPES = (DataLoader,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
        """Reads and returns a PyTorch dataloader.

        Args:
            data_type: The type of the dataloader to load.

        Returns:
            A loaded PyTorch dataloader.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, dataloader: DataLoader[Any]) -> None:
        """Writes a PyTorch dataloader.

        Args:
            dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
        """
        super().handle_return(dataloader)

        # Save entire dataloader to artifact directory
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(dataloader, f)
handle_input(self, data_type)

Reads and returns a PyTorch dataloader.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the dataloader to load.

required

Returns:

Type Description
torch.utils.data.dataloader.DataLoader[Any]

A loaded PyTorch dataloader.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
    """Reads and returns a PyTorch dataloader.

    Args:
        data_type: The type of the dataloader to load.

    Returns:
        A loaded PyTorch dataloader.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa
handle_return(self, dataloader)

Writes a PyTorch dataloader.

Parameters:

Name Type Description Default
dataloader torch.utils.data.dataloader.DataLoader[Any]

A torch.utils.DataLoader or a dict to pass into dataloader.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
    """Writes a PyTorch dataloader.

    Args:
        dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
    """
    super().handle_return(dataloader)

    # Save entire dataloader to artifact directory
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(dataloader, f)
pytorch_module_materializer

Implementation of the PyTorch Module materializer.

PyTorchModuleMaterializer (BaseMaterializer)

Materializer to read/write Pytorch models.

Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
    """Materializer to read/write Pytorch models.

    Inspired by the guide:
    https://pytorch.org/tutorials/beginner/saving_loading_models.html
    """

    ASSOCIATED_TYPES = (Module,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Module:
        """Reads and returns a PyTorch model.

        Only loads the model, not the checkpoint.

        Args:
            data_type: The type of the model to load.

        Returns:
            A loaded pytorch model.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return torch.load(f)  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, model: Module) -> None:
        """Writes a PyTorch model, as a model and a checkpoint.

        Args:
            model: A torch.nn.Module or a dict to pass into model.save
        """
        super().handle_return(model)

        # Save entire model to artifact directory, This is the default behavior
        # for loading model in development phase (training, evaluation)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(model, f)

        # Also save model checkpoint to artifact directory,
        # This is the default behavior for loading model in production phase (inference)
        if isinstance(model, Module):
            with fileio.open(
                os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
            ) as f:
                torch.save(model.state_dict(), f)
handle_input(self, data_type)

Reads and returns a PyTorch model.

Only loads the model, not the checkpoint.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to load.

required

Returns:

Type Description
Module

A loaded pytorch model.

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
    """Reads and returns a PyTorch model.

    Only loads the model, not the checkpoint.

    Args:
        data_type: The type of the model to load.

    Returns:
        A loaded pytorch model.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return torch.load(f)  # type: ignore[no-untyped-call]  # noqa
handle_return(self, model)

Writes a PyTorch model, as a model and a checkpoint.

Parameters:

Name Type Description Default
model Module

A torch.nn.Module or a dict to pass into model.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
    """Writes a PyTorch model, as a model and a checkpoint.

    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    super().handle_return(model)

    # Save entire model to artifact directory, This is the default behavior
    # for loading model in development phase (training, evaluation)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(model, f)

    # Also save model checkpoint to artifact directory,
    # This is the default behavior for loading model in production phase (inference)
    if isinstance(model, Module):
        with fileio.open(
            os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
        ) as f:
            torch.save(model.state_dict(), f)

pytorch_lightning special

Initialization of the PyTorch Lightning integration.

PytorchLightningIntegration (Integration)

Definition of PyTorch Lightning integration for ZenML.

Source code in zenml/integrations/pytorch_lightning/__init__.py
class PytorchLightningIntegration(Integration):
    """Definition of PyTorch Lightning integration for ZenML."""

    NAME = PYTORCH_L
    REQUIREMENTS = ["pytorch_lightning"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch_lightning import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch_lightning/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch_lightning import materializers  # noqa

materializers special

Initialization of the PyTorch Lightning Materializer.

pytorch_lightning_materializer

Implementation of the PyTorch Lightning Materializer.

PyTorchLightningMaterializer (BaseMaterializer)

Materializer to read/write PyTorch models.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
class PyTorchLightningMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch models."""

    ASSOCIATED_TYPES = (Trainer,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Trainer:
        """Reads and returns a PyTorch Lightning trainer.

        Args:
            data_type: The type of the trainer to load.

        Returns:
            A PyTorch Lightning trainer object.
        """
        super().handle_input(data_type)
        return Trainer(
            resume_from_checkpoint=os.path.join(
                self.artifact.uri, CHECKPOINT_NAME
            )
        )

    def handle_return(self, trainer: Trainer) -> None:
        """Writes a PyTorch Lightning trainer.

        Args:
            trainer: A PyTorch Lightning trainer object.
        """
        super().handle_return(trainer)
        trainer.save_checkpoint(
            os.path.join(self.artifact.uri, CHECKPOINT_NAME)
        )
handle_input(self, data_type)

Reads and returns a PyTorch Lightning trainer.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the trainer to load.

required

Returns:

Type Description
Trainer

A PyTorch Lightning trainer object.

Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_input(self, data_type: Type[Any]) -> Trainer:
    """Reads and returns a PyTorch Lightning trainer.

    Args:
        data_type: The type of the trainer to load.

    Returns:
        A PyTorch Lightning trainer object.
    """
    super().handle_input(data_type)
    return Trainer(
        resume_from_checkpoint=os.path.join(
            self.artifact.uri, CHECKPOINT_NAME
        )
    )
handle_return(self, trainer)

Writes a PyTorch Lightning trainer.

Parameters:

Name Type Description Default
trainer Trainer

A PyTorch Lightning trainer object.

required
Source code in zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py
def handle_return(self, trainer: Trainer) -> None:
    """Writes a PyTorch Lightning trainer.

    Args:
        trainer: A PyTorch Lightning trainer object.
    """
    super().handle_return(trainer)
    trainer.save_checkpoint(
        os.path.join(self.artifact.uri, CHECKPOINT_NAME)
    )

registry

Implementation of a registry to track ZenML integrations.

IntegrationRegistry

Registry to keep track of ZenML Integrations.

Source code in zenml/integrations/registry.py
class IntegrationRegistry(object):
    """Registry to keep track of ZenML Integrations."""

    def __init__(self) -> None:
        """Initializing the integration registry."""
        self._integrations: Dict[str, Type["Integration"]] = {}

    @property
    def integrations(self) -> Dict[str, Type["Integration"]]:
        """Method to get integrations dictionary.

        Returns:
            A dict of integration key to type of `Integration`.
        """
        return self._integrations

    @integrations.setter
    def integrations(self, i: Any) -> None:
        """Setter method for the integrations property.

        Args:
            i: Value to set the integrations property to.

        Raises:
            IntegrationError: If you try to manually set the integrations property.
        """
        raise IntegrationError(
            "Please do not manually change the integrations within the "
            "registry. If you would like to register a new integration "
            "manually, please use "
            "`integration_registry.register_integration()`."
        )

    def register_integration(
        self, key: str, type_: Type["Integration"]
    ) -> None:
        """Method to register an integration with a given name.

        Args:
            key: Name of the integration.
            type_: Type of the integration.
        """
        self._integrations[key] = type_

    def activate_integrations(self) -> None:
        """Method to activate the integrations with are registered in the registry."""
        for name, integration in self._integrations.items():
            if integration.check_installation():
                integration.activate()
                logger.debug(f"Integration `{name}` is activated.")
            else:
                logger.debug(f"Integration `{name}` could not be activated.")

    @property
    def list_integration_names(self) -> List[str]:
        """Get a list of all possible integrations.

        Returns:
            A list of all possible integrations.
        """
        return [name for name in self._integrations]

    def select_integration_requirements(
        self, integration_name: Optional[str] = None
    ) -> List[str]:
        """Select the requirements for a given integration or all integrations.

        Args:
            integration_name: Name of the integration to check.

        Returns:
            List of requirements for the integration.

        Raises:
            KeyError: If the integration is not found.
        """
        if integration_name:
            if integration_name in self.list_integration_names:
                return self._integrations[integration_name].REQUIREMENTS
            else:
                raise KeyError(
                    f"Version {integration_name} does not exist. "
                    f"Currently the following integrations are implemented. "
                    f"{self.list_integration_names}"
                )
        else:
            return [
                requirement
                for name in self.list_integration_names
                for requirement in self._integrations[name].REQUIREMENTS
            ]

    def is_installed(self, integration_name: Optional[str] = None) -> bool:
        """Checks if all requirements for an integration are installed.

        Args:
            integration_name: Name of the integration to check.

        Returns:
            True if all requirements are installed, False otherwise.

        Raises:
            KeyError: If the integration is not found.
        """
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].check_installation()
        elif not integration_name:
            all_installed = [
                self._integrations[item].check_installation()
                for item in self.list_integration_names
            ]
            return all(all_installed)
        else:
            raise KeyError(
                f"Integration '{integration_name}' not found. "
                f"Currently the following integrations are available: "
                f"{self.list_integration_names}"
            )

    def get_installed_integrations(self) -> List[str]:
        """Returns list of installed integrations.

        Returns:
            List of installed integrations.
        """
        return [
            name
            for name, integration in integration_registry.integrations.items()
            if integration.check_installation()
        ]
integrations: Dict[str, Type[Integration]] property writable

Method to get integrations dictionary.

Returns:

Type Description
Dict[str, Type[Integration]]

A dict of integration key to type of Integration.

list_integration_names: List[str] property readonly

Get a list of all possible integrations.

Returns:

Type Description
List[str]

A list of all possible integrations.

__init__(self) special

Initializing the integration registry.

Source code in zenml/integrations/registry.py
def __init__(self) -> None:
    """Initializing the integration registry."""
    self._integrations: Dict[str, Type["Integration"]] = {}
activate_integrations(self)

Method to activate the integrations with are registered in the registry.

Source code in zenml/integrations/registry.py
def activate_integrations(self) -> None:
    """Method to activate the integrations with are registered in the registry."""
    for name, integration in self._integrations.items():
        if integration.check_installation():
            integration.activate()
            logger.debug(f"Integration `{name}` is activated.")
        else:
            logger.debug(f"Integration `{name}` could not be activated.")
get_installed_integrations(self)

Returns list of installed integrations.

Returns:

Type Description
List[str]

List of installed integrations.

Source code in zenml/integrations/registry.py
def get_installed_integrations(self) -> List[str]:
    """Returns list of installed integrations.

    Returns:
        List of installed integrations.
    """
    return [
        name
        for name, integration in integration_registry.integrations.items()
        if integration.check_installation()
    ]
is_installed(self, integration_name=None)

Checks if all requirements for an integration are installed.

Parameters:

Name Type Description Default
integration_name Optional[str]

Name of the integration to check.

None

Returns:

Type Description
bool

True if all requirements are installed, False otherwise.

Exceptions:

Type Description
KeyError

If the integration is not found.

Source code in zenml/integrations/registry.py
def is_installed(self, integration_name: Optional[str] = None) -> bool:
    """Checks if all requirements for an integration are installed.

    Args:
        integration_name: Name of the integration to check.

    Returns:
        True if all requirements are installed, False otherwise.

    Raises:
        KeyError: If the integration is not found.
    """
    if integration_name in self.list_integration_names:
        return self._integrations[integration_name].check_installation()
    elif not integration_name:
        all_installed = [
            self._integrations[item].check_installation()
            for item in self.list_integration_names
        ]
        return all(all_installed)
    else:
        raise KeyError(
            f"Integration '{integration_name}' not found. "
            f"Currently the following integrations are available: "
            f"{self.list_integration_names}"
        )
register_integration(self, key, type_)

Method to register an integration with a given name.

Parameters:

Name Type Description Default
key str

Name of the integration.

required
type_ Type[Integration]

Type of the integration.

required
Source code in zenml/integrations/registry.py
def register_integration(
    self, key: str, type_: Type["Integration"]
) -> None:
    """Method to register an integration with a given name.

    Args:
        key: Name of the integration.
        type_: Type of the integration.
    """
    self._integrations[key] = type_
select_integration_requirements(self, integration_name=None)

Select the requirements for a given integration or all integrations.

Parameters:

Name Type Description Default
integration_name Optional[str]

Name of the integration to check.

None

Returns:

Type Description
List[str]

List of requirements for the integration.

Exceptions:

Type Description
KeyError

If the integration is not found.

Source code in zenml/integrations/registry.py
def select_integration_requirements(
    self, integration_name: Optional[str] = None
) -> List[str]:
    """Select the requirements for a given integration or all integrations.

    Args:
        integration_name: Name of the integration to check.

    Returns:
        List of requirements for the integration.

    Raises:
        KeyError: If the integration is not found.
    """
    if integration_name:
        if integration_name in self.list_integration_names:
            return self._integrations[integration_name].REQUIREMENTS
        else:
            raise KeyError(
                f"Version {integration_name} does not exist. "
                f"Currently the following integrations are implemented. "
                f"{self.list_integration_names}"
            )
    else:
        return [
            requirement
            for name in self.list_integration_names
            for requirement in self._integrations[name].REQUIREMENTS
        ]

s3 special

Initialization of the S3 integration.

The S3 integration allows the use of cloud artifact stores and file operations on S3 buckets.

S3Integration (Integration)

Definition of S3 integration for ZenML.

Source code in zenml/integrations/s3/__init__.py
class S3Integration(Integration):
    """Definition of S3 integration for ZenML."""

    NAME = S3
    REQUIREMENTS = ["s3fs==2022.3.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the s3 integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=S3_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the s3 integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/s3/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the s3 integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=S3_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.s3.artifact_stores.S3ArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        )
    ]

artifact_stores special

Initialization of the S3 Artifact Store.

s3_artifact_store

Implementation of the S3 Artifact Store.

S3ArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for S3 based artifacts.

All attributes of this class except path will be passed to the s3fs.S3FileSystem initialization. See here for more information on how to use those configuration options to connect to any S3-compatible storage.

When you want to register an S3ArtifactStore from the CLI and need to pass client_kwargs, config_kwargs or s3_additional_kwargs, you should pass them as a json string:

zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket     --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
class S3ArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for S3 based artifacts.

    All attributes of this class except `path` will be passed to the
    `s3fs.S3FileSystem` initialization. See
    [here](https://s3fs.readthedocs.io/en/latest/) for more information on how
    to use those configuration options to connect to any S3-compatible storage.

    When you want to register an S3ArtifactStore from the CLI and need to pass
    `client_kwargs`, `config_kwargs` or `s3_additional_kwargs`, you should pass
    them as a json string:
    ```
    zenml artifact-store register my_s3_store --type=s3 --path=s3://my_bucket \
    --client_kwargs='{"endpoint_url": "http://my-s3-endpoint"}'
    ```
    """

    key: Optional[str] = None
    secret: Optional[str] = None
    token: Optional[str] = None
    client_kwargs: Optional[Dict[str, Any]] = None
    config_kwargs: Optional[Dict[str, Any]] = None
    s3_additional_kwargs: Optional[Dict[str, Any]] = None
    _filesystem: Optional[s3fs.S3FileSystem] = None

    # Class variables
    FLAVOR: ClassVar[str] = S3_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"}

    @validator(
        "client_kwargs", "config_kwargs", "s3_additional_kwargs", pre=True
    )
    def _convert_json_string(
        cls, value: Union[None, str, Dict[str, Any]]
    ) -> Optional[Dict[str, Any]]:
        """Converts potential JSON strings passed via the CLI to dictionaries.

        Args:
            value: The value to convert.

        Returns:
            The converted value.

        Raises:
            TypeError: If the value is not a `str`, `Dict` or `None`.
            ValueError: If the value is an invalid json string or a json string
                that does not decode into a dictionary.
        """
        if isinstance(value, str):
            try:
                dict_ = json.loads(value)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid json string '{value}'") from e

            if not isinstance(dict_, Dict):
                raise ValueError(
                    f"Json string '{value}' did not decode into a dictionary."
                )

            return dict_
        elif isinstance(value, Dict) or value is None:
            return value
        else:
            raise TypeError(f"{value} is not a json string or a dictionary.")

    def _get_credentials(
        self,
    ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
        """Gets authentication credentials.

        If an authentication secret is configured, the secret values are
        returned. Otherwise we fallback to the plain text component attributes.

        Returns:
            Tuple (key, secret, token) of credentials used to authenticate with
            the S3 filesystem.
        """
        secret = self.get_authentication_secret(
            expected_schema_type=AWSSecretSchema
        )
        if secret:
            return (
                secret.aws_access_key_id,
                secret.aws_secret_access_key,
                secret.aws_session_token,
            )
        else:
            return self.key, self.secret, self.token

    @property
    def filesystem(self) -> s3fs.S3FileSystem:
        """The s3 filesystem to access this artifact store.

        Returns:
            The s3 filesystem.
        """
        if not self._filesystem:
            key, secret, token = self._get_credentials()

            self._filesystem = s3fs.S3FileSystem(
                key=key,
                secret=secret,
                token=token,
                client_kwargs=self.client_kwargs,
                config_kwargs=self.config_kwargs,
                s3_additional_kwargs=self.s3_additional_kwargs,
            )
        return self._filesystem

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
            in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path to list.

        Returns:
            A list of paths that are files in the given directory.
        """
        # remove s3 prefix if given, so we can remove the directory later as
        # this method is expected to only return filenames
        path = convert_to_str(path)
        if path.startswith("s3://"):
            path = path[5:]

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a file info dict returned by the S3 filesystem.

            Args:
                file_dict: A file info dict returned by the S3 filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["Key"])
            base_name = file_path[len(path) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
            # s3fs.listdir also returns the root directory, so we filter
            # it out here
            if _extract_basename(dict_)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path of the file to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: The path to get stat info for.

        Returns:
            A dictionary containing the stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
                directory path, a list of directories inside the current directory
                and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        for directory, subdirectories, files in self.filesystem.walk(path=top):
            yield f"s3://{directory}", subdirectories, files
filesystem: S3FileSystem property readonly

The s3 filesystem to access this artifact store.

Returns:

Type Description
S3FileSystem

The s3 filesystem.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
        in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    return [f"s3://{path}" for path in self.filesystem.glob(path=pattern)]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that are files in the given directory.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path to list.

    Returns:
        A list of paths that are files in the given directory.
    """
    # remove s3 prefix if given, so we can remove the directory later as
    # this method is expected to only return filenames
    path = convert_to_str(path)
    if path.startswith("s3://"):
        path = path[5:]

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a file info dict returned by the S3 filesystem.

        Args:
            file_dict: A file info dict returned by the S3 filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["Key"])
        base_name = file_path[len(path) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
        # s3fs.listdir also returns the root directory, so we filter
        # it out here
        if _extract_basename(dict_)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the file to remove.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path of the file to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing the stat info.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: The path to get stat info for.

    Returns:
        A dictionary containing the stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/s3/artifact_stores/s3_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    for directory, subdirectories, files in self.filesystem.walk(path=top):
        yield f"s3://{directory}", subdirectories, files

scipy special

Initialization of the Scipy integration.

ScipyIntegration (Integration)

Definition of scipy integration for ZenML.

Source code in zenml/integrations/scipy/__init__.py
class ScipyIntegration(Integration):
    """Definition of scipy integration for ZenML."""

    NAME = SCIPY
    REQUIREMENTS = ["scipy"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.scipy import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/scipy/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.scipy import materializers  # noqa

materializers special

Initialization of the Scipy materializers.

sparse_materializer

Implementation of the Scipy Sparse Materializer.

SparseMaterializer (BaseMaterializer)

Materializer to read and write scipy sparse matrices.

Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
class SparseMaterializer(BaseMaterializer):
    """Materializer to read and write scipy sparse matrices."""

    ASSOCIATED_TYPES = (spmatrix,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> spmatrix:
        """Reads spmatrix from npz file.

        Args:
            data_type: The type of the spmatrix to load.

        Returns:
            A spmatrix object.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
        ) as f:
            mat = load_npz(f)
        return mat

    def handle_return(self, mat: spmatrix) -> None:
        """Writes a spmatrix to the artifact store as a npz file.

        Args:
            mat: The spmatrix to write.
        """
        super().handle_return(mat)
        with fileio.open(
            os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
        ) as f:
            save_npz(f, mat)
handle_input(self, data_type)

Reads spmatrix from npz file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the spmatrix to load.

required

Returns:

Type Description
spmatrix

A spmatrix object.

Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_input(self, data_type: Type[Any]) -> spmatrix:
    """Reads spmatrix from npz file.

    Args:
        data_type: The type of the spmatrix to load.

    Returns:
        A spmatrix object.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DATA_FILENAME), "rb"
    ) as f:
        mat = load_npz(f)
    return mat
handle_return(self, mat)

Writes a spmatrix to the artifact store as a npz file.

Parameters:

Name Type Description Default
mat spmatrix

The spmatrix to write.

required
Source code in zenml/integrations/scipy/materializers/sparse_materializer.py
def handle_return(self, mat: spmatrix) -> None:
    """Writes a spmatrix to the artifact store as a npz file.

    Args:
        mat: The spmatrix to write.
    """
    super().handle_return(mat)
    with fileio.open(
        os.path.join(self.artifact.uri, DATA_FILENAME), "wb"
    ) as f:
        save_npz(f, mat)

seldon special

Initialization of the Seldon integration.

The Seldon Core integration allows you to use the Seldon Core model serving platform to implement continuous model deployment.

SeldonIntegration (Integration)

Definition of Seldon Core integration for ZenML.

Source code in zenml/integrations/seldon/__init__.py
class SeldonIntegration(Integration):
    """Definition of Seldon Core integration for ZenML."""

    NAME = SELDON
    REQUIREMENTS = [
        "kubernetes==18.20.0",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the Seldon Core integration."""
        from zenml.integrations.seldon import secret_schemas  # noqa
        from zenml.integrations.seldon import services  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Seldon Core.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=SELDON_MODEL_DEPLOYER_FLAVOR,
                source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
                type=StackComponentType.MODEL_DEPLOYER,
                integration=cls.NAME,
            )
        ]
activate() classmethod

Activate the Seldon Core integration.

Source code in zenml/integrations/seldon/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the Seldon Core integration."""
    from zenml.integrations.seldon import secret_schemas  # noqa
    from zenml.integrations.seldon import services  # noqa
flavors() classmethod

Declare the stack component flavors for the Seldon Core.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/seldon/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Seldon Core.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=SELDON_MODEL_DEPLOYER_FLAVOR,
            source="zenml.integrations.seldon.model_deployers.SeldonModelDeployer",
            type=StackComponentType.MODEL_DEPLOYER,
            integration=cls.NAME,
        )
    ]

model_deployers special

Initialization of the Seldon Model Deployer.

seldon_model_deployer

Implementation of the Seldon Model Deployer.

SeldonModelDeployer (BaseModelDeployer) pydantic-model

Seldon Core model deployer stack component implementation.

Attributes:

Name Type Description
kubernetes_context Optional[str]

the Kubernetes context to use to contact the remote Seldon Core installation. If not specified, the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either a locally active context or an in-cluster Kubernetes configuration (if running inside a pod).

kubernetes_namespace Optional[str]

the Kubernetes namespace where the Seldon Core deployment servers are provisioned and managed by ZenML. If not specified, the namespace set in the current configuration is used. Depending on where the Seldon model deployer is being used, this can be either the current namespace configured in the locally active context or the namespace in the context of which the pod is running (if running inside a pod).

base_url str

the base URL of the Kubernetes ingress used to expose the Seldon Core deployment servers.

secret Optional[str]

the name of a ZenML secret containing the credentials used by Seldon Core storage initializers to authenticate to the Artifact Store (i.e. the storage backend where models are stored - see https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
class SeldonModelDeployer(BaseModelDeployer):
    """Seldon Core model deployer stack component implementation.

    Attributes:
        kubernetes_context: the Kubernetes context to use to contact the remote
            Seldon Core installation. If not specified, the current
            configuration is used. Depending on where the Seldon model deployer
            is being used, this can be either a locally active context or an
            in-cluster Kubernetes configuration (if running inside a pod).
        kubernetes_namespace: the Kubernetes namespace where the Seldon Core
            deployment servers are provisioned and managed by ZenML. If not
            specified, the namespace set in the current configuration is used.
            Depending on where the Seldon model deployer is being used, this can
            be either the current namespace configured in the locally active
            context or the namespace in the context of which the pod is running
            (if running inside a pod).
        base_url: the base URL of the Kubernetes ingress used to expose the
            Seldon Core deployment servers.
        secret: the name of a ZenML secret containing the credentials used by
            Seldon Core storage initializers to authenticate to the Artifact
            Store (i.e. the storage backend where models are stored - see
            https://docs.seldon.io/projects/seldon-core/en/latest/servers/overview.html#handling-credentials).
    """

    # Class Configuration
    FLAVOR: ClassVar[str] = SELDON_MODEL_DEPLOYER_FLAVOR

    kubernetes_context: Optional[str]
    kubernetes_namespace: Optional[str]
    base_url: str
    secret: Optional[str]

    # private attributes
    _client: Optional[SeldonClient] = None

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "SeldonDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information that might be relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            Model server information.
        """
        return {
            "PREDICTION_URL": service_instance.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
        }

    @staticmethod
    def get_active_model_deployer() -> "SeldonModelDeployer":
        """Get the Seldon Core model deployer registered in the active stack.

        Returns:
            The Seldon Core model deployer registered in the active stack.

        Raises:
            TypeError: if the Seldon Core model deployer is not available.
        """
        model_deployer = Repository(  # type: ignore [call-arg]
            skip_repository_check=True
        ).active_stack.model_deployer
        if not model_deployer or not isinstance(
            model_deployer, SeldonModelDeployer
        ):
            raise TypeError(
                f"The active stack needs to have a Seldon Core model deployer "
                f"component registered to be able to deploy models with Seldon "
                f"Core. You can create a new stack with a Seldon Core model "
                f"deployer component or update your existing stack to add this "
                f"component, e.g.:\n\n"
                f"  'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
                f"--kubernetes_context=context-name --kubernetes_namespace="
                f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
                f"  'zenml stack create stack-name -d seldon ...'\n"
            )
        return model_deployer

    @property
    def seldon_client(self) -> SeldonClient:
        """Get the Seldon Core client associated with this model deployer.

        Returns:
            The Seldon Core client.
        """
        if not self._client:
            self._client = SeldonClient(
                context=self.kubernetes_context,
                namespace=self.kubernetes_namespace,
            )
        return self._client

    @property
    def kubernetes_secret_name(self) -> Optional[str]:
        """Get the Kubernetes secret name associated with this model deployer.

        If a secret is configured for this model deployer, a corresponding
        Kubernetes secret is created in the remote cluster to be used
        by Seldon Core storage initializers to authenticate to the Artifact
        Store. This method returns the unique name that is used for this secret.

        Returns:
            The Seldon Core Kubernetes secret name, or None if no secret is
            configured.
        """
        if not self.secret:
            return None
        return (
            re.sub(r"[^0-9a-zA-Z-]+", "-", f"zenml-seldon-core-{self.secret}")
            .strip("-")
            .lower()
        )

    def _create_or_update_kubernetes_secret(self) -> Optional[str]:
        """Create or update a Kubernetes secret.

        Uses the information stored in the ZenML secret configured for the model deployer.

        Returns:
            The name of the Kubernetes secret that was created or updated, or
            None if no secret was configured.

        Raises:
            RuntimeError: if the secret cannot be created or updated.
        """
        # if a ZenML secret was configured in the model deployer,
        # create a Kubernetes secret as a means to pass this information
        # to the Seldon Core deployment
        if self.secret:

            secret_manager = Repository(  # type: ignore [call-arg]
                skip_repository_check=True
            ).active_stack.secrets_manager

            if not secret_manager or not isinstance(
                secret_manager, BaseSecretsManager
            ):
                raise RuntimeError(
                    f"The active stack doesn't have a secret manager component. "
                    f"The ZenML secret specified in the Seldon Core Model "
                    f"Deployer configuration cannot be fetched: {self.secret}."
                )

            try:
                zenml_secret = secret_manager.get_secret(self.secret)
            except KeyError:
                raise RuntimeError(
                    f"The ZenML secret '{self.secret}' specified in the "
                    f"Seldon Core Model Deployer configuration was not found "
                    f"in the active stack's secret manager."
                )

            # should never happen, just making mypy happy
            assert self.kubernetes_secret_name is not None
            self.seldon_client.create_or_update_secret(
                self.kubernetes_secret_name, zenml_secret
            )

        return self.kubernetes_secret_name

    def _delete_kubernetes_secret(self) -> None:
        """Delete the Kubernetes secret associated with this model deployer.

        Do this if no Seldon Core deployments are using it.
        """
        if self.kubernetes_secret_name:

            # fetch all the Seldon Core deployments that currently
            # configured to use this secret
            services = self.find_model_server()
            for service in services:
                config = cast(SeldonDeploymentConfig, service.config)
                if config.secret_name == self.kubernetes_secret_name:
                    return
            self.seldon_client.delete_secret(self.kubernetes_secret_name)

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new Seldon Core deployment or update an existing one.

        # noqa: DAR402

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new Seldon
            Core deployment server to reflect the model and other configuration
            parameters specified in the supplied Seldon deployment `config`.

          * if `replace` is True, this method will first attempt to find an
            existing Seldon Core deployment that is *equivalent* to the supplied
            configuration parameters. Two or more Seldon Core deployments are
            considered equivalent if they have the same `pipeline_name`,
            `pipeline_step_name` and `model_name` configuration parameters. To
            put it differently, two Seldon Core deployments are equivalent if
            they serve versions of the same model deployed by the same pipeline
            step. If an equivalent Seldon Core deployment is found, it will be
            updated in place to reflect the new configuration parameters. This
            allows an existing Seldon Core deployment to retain its prediction
            URL while performing a rolling update to serve a new model version.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new Seldon Core deployment
        server for each new model version. If multiple equivalent Seldon Core
        deployments are found, the most recently created deployment is selected
        to be updated and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with Seldon.
                Core
            replace: set this flag to True to find and update an equivalent
                Seldon Core deployment server with the new model instead of
                starting a new deployment server.
            timeout: the timeout in seconds to wait for the Seldon Core server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the Seldon Core
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML Seldon Core deployment service object that can be used to
            interact with the remote Seldon Core server.

        Raises:
            SeldonClientError: if a Seldon Core client error is encountered
                while provisioning the Seldon Core deployment server.
            RuntimeError: if `timeout` is set to a positive value that is
                exceeded while waiting for the Seldon Core deployment server
                to start, or if an operational failure is encountered before
                it reaches a ready state.
        """
        config = cast(SeldonDeploymentConfig, config)
        service = None

        # if a custom Kubernetes secret is not explicitly specified in the
        # SeldonDeploymentConfig, try to create one from the ZenML secret
        # configured for the model deployer
        config.secret_name = (
            config.secret_name or self._create_or_update_kubernetes_secret()
        )

        # if replace is True, find equivalent Seldon Core deployments
        if replace is True:
            equivalent_services = self.find_model_server(
                running=False,
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for equivalent_service in equivalent_services:
                if service is None:
                    # keep the most recently created service
                    service = equivalent_service
                else:
                    try:
                        # delete the older services and don't wait for them to
                        # be deprovisioned
                        service.stop()
                    except RuntimeError:
                        # ignore errors encountered while stopping old services
                        pass

        if service:
            # update an equivalent service in place
            service.update(config)
            logger.info(
                f"Updating an existing Seldon deployment service: {service}"
            )
        else:
            # create a new service
            service = SeldonDeploymentService(config=config)
            logger.info(f"Creating a new Seldon deployment service: {service}")

        # start the service which in turn provisions the Seldon Core
        # deployment server and waits for it to reach a ready state
        service.start(timeout=timeout)
        return service

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_type: Optional[str] = None,
    ) -> List[BaseService]:
        """Find one or more Seldon Core model services that match the given criteria.

        The Seldon Core deployment services that meet the search criteria are
        returned sorted in descending order of their creation time (i.e. more
        recent deployments first).

        Args:
            running: if true, only running services will be returned.
            service_uuid: the UUID of the Seldon Core service that was originally used
                to create the Seldon Core deployment resource.
            pipeline_name: name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model was
                part of.
            pipeline_step_name: the name of the pipeline model deployment step
                that deployed the model.
            model_name: the name of the deployed model.
            model_uri: URI of the deployed model.
            model_type: the Seldon Core server implementation used to serve
                the model

        Returns:
            One or more Seldon Core service objects representing Seldon Core
            model servers that match the input search criteria.
        """
        # Use a Seldon deployment service configuration to compute the labels
        config = SeldonDeploymentConfig(
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
            model_name=model_name or "",
            model_uri=model_uri or "",
            implementation=model_type or "",
        )
        labels = config.get_seldon_deployment_labels()
        if service_uuid:
            # the service UUID is not a label covered by the Seldon
            # deployment service configuration, so we need to add it
            # separately
            labels["zenml.service_uuid"] = str(service_uuid)

        deployments = self.seldon_client.find_deployments(labels=labels)
        # sort the deployments in descending order of their creation time
        deployments.sort(
            key=lambda deployment: datetime.strptime(
                deployment.metadata.creationTimestamp,
                "%Y-%m-%dT%H:%M:%SZ",
            )
            if deployment.metadata.creationTimestamp
            else datetime.min,
            reverse=True,
        )

        services: List[BaseService] = []
        for deployment in deployments:
            # recreate the Seldon deployment service object from the Seldon
            # deployment resource
            service = SeldonDeploymentService.create_from_deployment(
                deployment=deployment
            )
            if running and not service.is_running:
                # skip non-running services
                continue
            services.append(service)

        return services

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Stop a Seldon Core model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: timeout in seconds to wait for the service to stop.
            force: if True, force the service to stop.

        Raises:
            NotImplementedError: stopping Seldon Core model servers is not
                supported.
        """
        raise NotImplementedError(
            "Stopping Seldon Core model servers is not implemented. Try "
            "deleting the Seldon Core model server instead."
        )

    def start_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    ) -> None:
        """Start a Seldon Core model deployment server.

        Args:
            uuid: UUID of the model server to start.
            timeout: timeout in seconds to wait for the service to become
                active. . If set to 0, the method will return immediately after
                provisioning the service, without waiting for it to become
                active.

        Raises:
            NotImplementedError: since we don't support starting Seldon Core
                model servers
        """
        raise NotImplementedError(
            "Starting Seldon Core model servers is not implemented"
        )

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Delete a Seldon Core model deployment server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: timeout in seconds to wait for the service to stop. If
                set to 0, the method will return immediately after
                deprovisioning the service, without waiting for it to stop.
            force: if True, force the service to stop.
        """
        services = self.find_model_server(service_uuid=uuid)
        if len(services) == 0:
            return
        services[0].stop(timeout=timeout, force=force)

        # if this is the last Seldon Core model server, delete the Kubernetes
        # secret used to store the authentication information for the Seldon
        # Core model server storage initializer
        self._delete_kubernetes_secret()
kubernetes_secret_name: Optional[str] property readonly

Get the Kubernetes secret name associated with this model deployer.

If a secret is configured for this model deployer, a corresponding Kubernetes secret is created in the remote cluster to be used by Seldon Core storage initializers to authenticate to the Artifact Store. This method returns the unique name that is used for this secret.

Returns:

Type Description
Optional[str]

The Seldon Core Kubernetes secret name, or None if no secret is configured.

seldon_client: SeldonClient property readonly

Get the Seldon Core client associated with this model deployer.

Returns:

Type Description
SeldonClient

The Seldon Core client.

delete_model_server(self, uuid, timeout=300, force=False)

Delete a Seldon Core model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop.

300
force bool

if True, force the service to stop.

False
Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Delete a Seldon Core model deployment server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: timeout in seconds to wait for the service to stop. If
            set to 0, the method will return immediately after
            deprovisioning the service, without waiting for it to stop.
        force: if True, force the service to stop.
    """
    services = self.find_model_server(service_uuid=uuid)
    if len(services) == 0:
        return
    services[0].stop(timeout=timeout, force=force)

    # if this is the last Seldon Core model server, delete the Kubernetes
    # secret used to store the authentication information for the Seldon
    # Core model server storage initializer
    self._delete_kubernetes_secret()
deploy_model(self, config, replace=False, timeout=300)

Create a new Seldon Core deployment or update an existing one.

noqa: DAR402

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new Seldon Core deployment server to reflect the model and other configuration parameters specified in the supplied Seldon deployment config.

  • if replace is True, this method will first attempt to find an existing Seldon Core deployment that is equivalent to the supplied configuration parameters. Two or more Seldon Core deployments are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two Seldon Core deployments are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent Seldon Core deployment is found, it will be updated in place to reflect the new configuration parameters. This allows an existing Seldon Core deployment to retain its prediction URL while performing a rolling update to serve a new model version.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new Seldon Core deployment server for each new model version. If multiple equivalent Seldon Core deployments are found, the most recently created deployment is selected to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with Seldon. Core

required
replace bool

set this flag to True to find and update an equivalent Seldon Core deployment server with the new model instead of starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core server is provisioned, without waiting for it to fully start.

300

Returns:

Type Description
BaseService

The ZenML Seldon Core deployment service object that can be used to interact with the remote Seldon Core server.

Exceptions:

Type Description
SeldonClientError

if a Seldon Core client error is encountered while provisioning the Seldon Core deployment server.

RuntimeError

if timeout is set to a positive value that is exceeded while waiting for the Seldon Core deployment server to start, or if an operational failure is encountered before it reaches a ready state.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new Seldon Core deployment or update an existing one.

    # noqa: DAR402

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new Seldon
        Core deployment server to reflect the model and other configuration
        parameters specified in the supplied Seldon deployment `config`.

      * if `replace` is True, this method will first attempt to find an
        existing Seldon Core deployment that is *equivalent* to the supplied
        configuration parameters. Two or more Seldon Core deployments are
        considered equivalent if they have the same `pipeline_name`,
        `pipeline_step_name` and `model_name` configuration parameters. To
        put it differently, two Seldon Core deployments are equivalent if
        they serve versions of the same model deployed by the same pipeline
        step. If an equivalent Seldon Core deployment is found, it will be
        updated in place to reflect the new configuration parameters. This
        allows an existing Seldon Core deployment to retain its prediction
        URL while performing a rolling update to serve a new model version.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new Seldon Core deployment
    server for each new model version. If multiple equivalent Seldon Core
    deployments are found, the most recently created deployment is selected
    to be updated and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with Seldon.
            Core
        replace: set this flag to True to find and update an equivalent
            Seldon Core deployment server with the new model instead of
            starting a new deployment server.
        timeout: the timeout in seconds to wait for the Seldon Core server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the Seldon Core
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML Seldon Core deployment service object that can be used to
        interact with the remote Seldon Core server.

    Raises:
        SeldonClientError: if a Seldon Core client error is encountered
            while provisioning the Seldon Core deployment server.
        RuntimeError: if `timeout` is set to a positive value that is
            exceeded while waiting for the Seldon Core deployment server
            to start, or if an operational failure is encountered before
            it reaches a ready state.
    """
    config = cast(SeldonDeploymentConfig, config)
    service = None

    # if a custom Kubernetes secret is not explicitly specified in the
    # SeldonDeploymentConfig, try to create one from the ZenML secret
    # configured for the model deployer
    config.secret_name = (
        config.secret_name or self._create_or_update_kubernetes_secret()
    )

    # if replace is True, find equivalent Seldon Core deployments
    if replace is True:
        equivalent_services = self.find_model_server(
            running=False,
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for equivalent_service in equivalent_services:
            if service is None:
                # keep the most recently created service
                service = equivalent_service
            else:
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    service.stop()
                except RuntimeError:
                    # ignore errors encountered while stopping old services
                    pass

    if service:
        # update an equivalent service in place
        service.update(config)
        logger.info(
            f"Updating an existing Seldon deployment service: {service}"
        )
    else:
        # create a new service
        service = SeldonDeploymentService(config=config)
        logger.info(f"Creating a new Seldon deployment service: {service}")

    # start the service which in turn provisions the Seldon Core
    # deployment server and waits for it to reach a ready state
    service.start(timeout=timeout)
    return service
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None)

Find one or more Seldon Core model services that match the given criteria.

The Seldon Core deployment services that meet the search criteria are returned sorted in descending order of their creation time (i.e. more recent deployments first).

Parameters:

Name Type Description Default
running bool

if true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

the UUID of the Seldon Core service that was originally used to create the Seldon Core deployment resource.

None
pipeline_name Optional[str]

name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

the name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

the name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
model_type Optional[str]

the Seldon Core server implementation used to serve the model

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Seldon Core service objects representing Seldon Core model servers that match the input search criteria.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_type: Optional[str] = None,
) -> List[BaseService]:
    """Find one or more Seldon Core model services that match the given criteria.

    The Seldon Core deployment services that meet the search criteria are
    returned sorted in descending order of their creation time (i.e. more
    recent deployments first).

    Args:
        running: if true, only running services will be returned.
        service_uuid: the UUID of the Seldon Core service that was originally used
            to create the Seldon Core deployment resource.
        pipeline_name: name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model was
            part of.
        pipeline_step_name: the name of the pipeline model deployment step
            that deployed the model.
        model_name: the name of the deployed model.
        model_uri: URI of the deployed model.
        model_type: the Seldon Core server implementation used to serve
            the model

    Returns:
        One or more Seldon Core service objects representing Seldon Core
        model servers that match the input search criteria.
    """
    # Use a Seldon deployment service configuration to compute the labels
    config = SeldonDeploymentConfig(
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
        model_name=model_name or "",
        model_uri=model_uri or "",
        implementation=model_type or "",
    )
    labels = config.get_seldon_deployment_labels()
    if service_uuid:
        # the service UUID is not a label covered by the Seldon
        # deployment service configuration, so we need to add it
        # separately
        labels["zenml.service_uuid"] = str(service_uuid)

    deployments = self.seldon_client.find_deployments(labels=labels)
    # sort the deployments in descending order of their creation time
    deployments.sort(
        key=lambda deployment: datetime.strptime(
            deployment.metadata.creationTimestamp,
            "%Y-%m-%dT%H:%M:%SZ",
        )
        if deployment.metadata.creationTimestamp
        else datetime.min,
        reverse=True,
    )

    services: List[BaseService] = []
    for deployment in deployments:
        # recreate the Seldon deployment service object from the Seldon
        # deployment resource
        service = SeldonDeploymentService.create_from_deployment(
            deployment=deployment
        )
        if running and not service.is_running:
            # skip non-running services
            continue
        services.append(service)

    return services
get_active_model_deployer() staticmethod

Get the Seldon Core model deployer registered in the active stack.

Returns:

Type Description
SeldonModelDeployer

The Seldon Core model deployer registered in the active stack.

Exceptions:

Type Description
TypeError

if the Seldon Core model deployer is not available.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_active_model_deployer() -> "SeldonModelDeployer":
    """Get the Seldon Core model deployer registered in the active stack.

    Returns:
        The Seldon Core model deployer registered in the active stack.

    Raises:
        TypeError: if the Seldon Core model deployer is not available.
    """
    model_deployer = Repository(  # type: ignore [call-arg]
        skip_repository_check=True
    ).active_stack.model_deployer
    if not model_deployer or not isinstance(
        model_deployer, SeldonModelDeployer
    ):
        raise TypeError(
            f"The active stack needs to have a Seldon Core model deployer "
            f"component registered to be able to deploy models with Seldon "
            f"Core. You can create a new stack with a Seldon Core model "
            f"deployer component or update your existing stack to add this "
            f"component, e.g.:\n\n"
            f"  'zenml model-deployer register seldon --flavor={SELDON_MODEL_DEPLOYER_FLAVOR} "
            f"--kubernetes_context=context-name --kubernetes_namespace="
            f"namespace-name --base_url=https://ingress.cluster.kubernetes'\n"
            f"  'zenml stack create stack-name -d seldon ...'\n"
        )
    return model_deployer
get_model_server_info(service_instance) staticmethod

Return implementation specific information that might be relevant to the user.

Parameters:

Name Type Description Default
service_instance SeldonDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

Model server information.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "SeldonDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information that might be relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        Model server information.
    """
    return {
        "PREDICTION_URL": service_instance.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "SELDON_DEPLOYMENT": service_instance.seldon_deployment_name,
    }
start_model_server(self, uuid, timeout=300)

Start a Seldon Core model deployment server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become active.

300

Exceptions:

Type Description
NotImplementedError

since we don't support starting Seldon Core model servers

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def start_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
) -> None:
    """Start a Seldon Core model deployment server.

    Args:
        uuid: UUID of the model server to start.
        timeout: timeout in seconds to wait for the service to become
            active. . If set to 0, the method will return immediately after
            provisioning the service, without waiting for it to become
            active.

    Raises:
        NotImplementedError: since we don't support starting Seldon Core
            model servers
    """
    raise NotImplementedError(
        "Starting Seldon Core model servers is not implemented"
    )
stop_model_server(self, uuid, timeout=300, force=False)

Stop a Seldon Core model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

timeout in seconds to wait for the service to stop.

300
force bool

if True, force the service to stop.

False

Exceptions:

Type Description
NotImplementedError

stopping Seldon Core model servers is not supported.

Source code in zenml/integrations/seldon/model_deployers/seldon_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Stop a Seldon Core model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: timeout in seconds to wait for the service to stop.
        force: if True, force the service to stop.

    Raises:
        NotImplementedError: stopping Seldon Core model servers is not
            supported.
    """
    raise NotImplementedError(
        "Stopping Seldon Core model servers is not implemented. Try "
        "deleting the Seldon Core model server instead."
    )

secret_schemas special

Initialization for the Seldon secret schemas.

These are secret schemas that can be used to authenticate Seldon to the Artifact Store used to store served ML models.

secret_schemas

Implementation for Seldon secret schemas.

SeldonAzureSecretSchema (BaseSecretSchema) pydantic-model

Seldon Azure Blob Storage credentials.

Based on: https://rclone.org/azureblob/

Attributes:

Name Type Description
rclone_config_azureblob_type Literal['azureblob']

the rclone config type. Must be set to "azureblob" for this schema.

rclone_config_azureblob_account Optional[str]

storage Account Name. Leave blank to use SAS URL or MSI.

rclone_config_azureblob_key Optional[str]

storage Account Key. Leave blank to use SAS URL or MSI.

rclone_config_azureblob_sas_url Optional[str]

SAS URL for container level access only. Leave blank if using account/key or MSI.

rclone_config_azureblob_use_msi bool

use a managed service identity to authenticate (only works in Azure).

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonAzureSecretSchema(BaseSecretSchema):
    """Seldon Azure Blob Storage credentials.

    Based on: https://rclone.org/azureblob/

    Attributes:
        rclone_config_azureblob_type: the rclone config type. Must be set to
            "azureblob" for this schema.
        rclone_config_azureblob_account: storage Account Name. Leave blank to
            use SAS URL or MSI.
        rclone_config_azureblob_key: storage Account Key. Leave blank to
            use SAS URL or MSI.
        rclone_config_azureblob_sas_url: SAS URL for container level access
            only. Leave blank if using account/key or MSI.
        rclone_config_azureblob_use_msi: use a managed service identity to
            authenticate (only works in Azure).
    """

    TYPE: ClassVar[str] = SELDON_AZUREBLOB_SECRET_SCHEMA_TYPE

    rclone_config_azureblob_type: Literal["azureblob"] = "azureblob"
    rclone_config_azureblob_account: Optional[str]
    rclone_config_azureblob_key: Optional[str]
    rclone_config_azureblob_sas_url: Optional[str]
    rclone_config_azureblob_use_msi: bool = False
SeldonGSSecretSchema (BaseSecretSchema) pydantic-model

Seldon GCS credentials.

Based on: https://rclone.org/googlecloudstorage/

Attributes:

Name Type Description
rclone_config_gs_type Literal['google cloud storage']

the rclone config type. Must be set to "google cloud storage" for this schema.

rclone_config_gs_client_id Optional[str]

OAuth client id.

rclone_config_gs_client_secret Optional[str]

OAuth client secret.

rclone_config_gs_token Optional[str]

OAuth Access Token as a JSON blob.

rclone_config_gs_project_number Optional[str]

project number.

rclone_config_gs_service_account_credentials Optional[str]

service account credentials JSON blob.

rclone_config_gs_anonymous bool

access public buckets and objects without credentials. Set to True if you just want to download files and don't configure credentials.

rclone_config_gs_auth_url Optional[str]

auth server URL.

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonGSSecretSchema(BaseSecretSchema):
    """Seldon GCS credentials.

    Based on: https://rclone.org/googlecloudstorage/

    Attributes:
        rclone_config_gs_type: the rclone config type. Must be set to "google
            cloud storage" for this schema.
        rclone_config_gs_client_id: OAuth client id.
        rclone_config_gs_client_secret: OAuth client secret.
        rclone_config_gs_token: OAuth Access Token as a JSON blob.
        rclone_config_gs_project_number: project number.
        rclone_config_gs_service_account_credentials: service account
            credentials JSON blob.
        rclone_config_gs_anonymous: access public buckets and objects without
            credentials. Set to True if you just want to download files and
            don't configure credentials.
        rclone_config_gs_auth_url: auth server URL.
    """

    TYPE: ClassVar[str] = SELDON_GS_SECRET_SCHEMA_TYPE

    rclone_config_gs_type: Literal[
        "google cloud storage"
    ] = "google cloud storage"
    rclone_config_gs_client_id: Optional[str]
    rclone_config_gs_client_secret: Optional[str]
    rclone_config_gs_project_number: Optional[str]
    rclone_config_gs_service_account_credentials: Optional[str]
    rclone_config_gs_anonymous: bool = False
    rclone_config_gs_token: Optional[str]
    rclone_config_gs_auth_url: Optional[str]
    rclone_config_gs_token_url: Optional[str]
SeldonS3SecretSchema (BaseSecretSchema) pydantic-model

Seldon S3 credentials.

Based on: https://rclone.org/s3/#amazon-s3

Attributes:

Name Type Description
rclone_config_s3_type Literal['s3']

the rclone config type. Must be set to "s3" for this schema.

rclone_config_s3_provider str

the S3 provider (e.g. aws, ceph, minio).

rclone_config_s3_env_auth bool

get AWS credentials from EC2/ECS meta data (i.e. with IAM roles configuration). Only applies if access_key_id and secret_access_key are blank.

rclone_config_s3_access_key_id Optional[str]

AWS Access Key ID.

rclone_config_s3_secret_access_key Optional[str]

AWS Secret Access Key.

rclone_config_s3_session_token Optional[str]

AWS Session Token.

rclone_config_s3_region Optional[str]

region to connect to.

rclone_config_s3_endpoint Optional[str]

S3 API endpoint.

Source code in zenml/integrations/seldon/secret_schemas/secret_schemas.py
class SeldonS3SecretSchema(BaseSecretSchema):
    """Seldon S3 credentials.

    Based on: https://rclone.org/s3/#amazon-s3

    Attributes:
        rclone_config_s3_type: the rclone config type. Must be set to "s3" for
            this schema.
        rclone_config_s3_provider: the S3 provider (e.g. aws, ceph, minio).
        rclone_config_s3_env_auth: get AWS credentials from EC2/ECS meta data
            (i.e. with IAM roles configuration). Only applies if access_key_id
            and secret_access_key are blank.
        rclone_config_s3_access_key_id: AWS Access Key ID.
        rclone_config_s3_secret_access_key: AWS Secret Access Key.
        rclone_config_s3_session_token: AWS Session Token.
        rclone_config_s3_region: region to connect to.
        rclone_config_s3_endpoint: S3 API endpoint.

    """

    TYPE: ClassVar[str] = SELDON_S3_SECRET_SCHEMA_TYPE

    rclone_config_s3_type: Literal["s3"] = "s3"
    rclone_config_s3_provider: str = "aws"
    rclone_config_s3_env_auth: bool = False
    rclone_config_s3_access_key_id: Optional[str]
    rclone_config_s3_secret_access_key: Optional[str]
    rclone_config_s3_session_token: Optional[str]
    rclone_config_s3_region: Optional[str]
    rclone_config_s3_endpoint: Optional[str]

seldon_client

Implementation of the Seldon client for ZenML.

SeldonClient

A client for interacting with Seldon Deployments.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClient:
    """A client for interacting with Seldon Deployments."""

    def __init__(self, context: Optional[str], namespace: Optional[str]):
        """Initialize a Seldon Core client.

        Args:
            context: the Kubernetes context to use.
            namespace: the Kubernetes namespace to use.
        """
        self._context = context
        self._namespace = namespace
        self._initialize_k8s_clients()

    def _initialize_k8s_clients(self) -> None:
        """Initialize the Kubernetes clients.

        Raises:
            SeldonClientError: if Kubernetes configuration could not be loaded
        """
        try:
            k8s_config.load_incluster_config()
            if not self._namespace:
                # load the namespace in the context of which the
                # current pod is running
                self._namespace = open(
                    "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
                ).read()
        except k8s_config.config_exception.ConfigException:
            if not self._namespace:
                raise SeldonClientError(
                    "The Kubernetes namespace must be explicitly "
                    "configured when running outside of a cluster."
                )
            try:
                k8s_config.load_kube_config(
                    context=self._context, persist_config=False
                )
            except k8s_config.config_exception.ConfigException as e:
                raise SeldonClientError(
                    "Could not load the Kubernetes configuration"
                ) from e
        self._core_api = k8s_client.CoreV1Api()
        self._custom_objects_api = k8s_client.CustomObjectsApi()

    @staticmethod
    def sanitize_labels(labels: Dict[str, str]) -> None:
        """Update the label values to be valid Kubernetes labels.

        See:
        https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

        Args:
            labels: the labels to sanitize.
        """
        for key, value in labels.items():
            # Kubernetes labels must be alphanumeric, no longer than
            # 63 characters, and must begin and end with an alphanumeric
            # character ([a-z0-9A-Z])
            labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
                "-_."
            )

    @property
    def namespace(self) -> str:
        """Returns the Kubernetes namespace in use by the client.

        Returns:
            The Kubernetes namespace in use by the client.

        Raises:
            RuntimeError: if the namespace has not been configured.
        """
        if not self._namespace:
            # shouldn't happen if the client is initialized, but we need to
            # appease the mypy type checker
            raise RuntimeError("The Kubernetes namespace is not configured")
        return self._namespace

    def create_deployment(
        self,
        deployment: SeldonDeployment,
        poll_timeout: int = 0,
    ) -> SeldonDeployment:
        """Create a Seldon Core deployment resource.

        Args:
            deployment: the Seldon Core deployment resource to create
            poll_timeout: the maximum time to wait for the deployment to become
                available or to fail. If set to 0, the function will return
                immediately without checking the deployment status. If a timeout
                occurs and the deployment is still pending creation, it will
                be returned anyway and no exception will be raised.

        Returns:
            the created Seldon Core deployment resource with updated status.

        Raises:
            SeldonDeploymentExistsError: if a deployment with the same name
                already exists.
            SeldonClientError: if an unknown error occurs during the creation of
                the deployment.
        """
        try:
            logger.debug(f"Creating SeldonDeployment resource: {deployment}")

            # mark the deployment as managed by ZenML, to differentiate
            # between deployments that are created by ZenML and those that
            # are not
            deployment.mark_as_managed_by_zenml()

            response = self._custom_objects_api.create_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                body=deployment.dict(exclude_none=True),
                _request_timeout=poll_timeout or None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when creating SeldonDeployment resource: %s", str(e)
            )
            if e.status == 409:
                raise SeldonDeploymentExistsError(
                    f"A deployment with the name {deployment.name} "
                    f"already exists in namespace {self._namespace}"
                )
            raise SeldonClientError(
                "Exception when creating SeldonDeployment resource"
            ) from e

        created_deployment = self.get_deployment(name=deployment.name)

        while poll_timeout > 0 and created_deployment.is_pending():
            time.sleep(5)
            poll_timeout -= 5
            created_deployment = self.get_deployment(name=deployment.name)

        return created_deployment

    def delete_deployment(
        self,
        name: str,
        force: bool = False,
        poll_timeout: int = 0,
    ) -> None:
        """Delete a Seldon Core deployment resource managed by ZenML.

        Args:
            name: the name of the Seldon Core deployment resource to delete.
            force: if True, the deployment deletion will be forced (the graceful
                period will be set to zero).
            poll_timeout: the maximum time to wait for the deployment to be
                deleted. If set to 0, the function will return immediately
                without checking the deployment status. If a timeout
                occurs and the deployment still exists, this method will
                return and no exception will be raised.

        Raises:
            SeldonClientError: if an unknown error occurs during the deployment
                removal.
        """
        try:
            logger.debug(f"Deleting SeldonDeployment resource: {name}")

            # call `get_deployment` to check that the deployment exists
            # and is managed by ZenML. It will raise
            # a SeldonDeploymentNotFoundError otherwise
            self.get_deployment(name=name)

            response = self._custom_objects_api.delete_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=name,
                _request_timeout=poll_timeout or None,
                grace_period_seconds=0 if force else None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when deleting SeldonDeployment resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Exception when deleting SeldonDeployment resource {name}"
            ) from e

        while poll_timeout > 0:
            try:
                self.get_deployment(name=name)
            except SeldonDeploymentNotFoundError:
                return
            time.sleep(5)
            poll_timeout -= 5

    def update_deployment(
        self,
        deployment: SeldonDeployment,
        poll_timeout: int = 0,
    ) -> SeldonDeployment:
        """Update a Seldon Core deployment resource.

        Args:
            deployment: the Seldon Core deployment resource to update
            poll_timeout: the maximum time to wait for the deployment to become
                available or to fail. If set to 0, the function will return
                immediately without checking the deployment status. If a timeout
                occurs and the deployment is still pending creation, it will
                be returned anyway and no exception will be raised.

        Returns:
            the updated Seldon Core deployment resource with updated status.

        Raises:
            SeldonClientError: if an unknown error occurs while updating the
                deployment.
        """
        try:
            logger.debug(
                f"Updating SeldonDeployment resource: {deployment.name}"
            )

            # mark the deployment as managed by ZenML, to differentiate
            # between deployments that are created by ZenML and those that
            # are not
            deployment.mark_as_managed_by_zenml()

            # call `get_deployment` to check that the deployment exists
            # and is managed by ZenML. It will raise
            # a SeldonDeploymentNotFoundError otherwise
            self.get_deployment(name=deployment.name)

            response = self._custom_objects_api.patch_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=deployment.name,
                body=deployment.dict(exclude_none=True),
                _request_timeout=poll_timeout or None,
            )
            logger.debug("Seldon Core API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when updating SeldonDeployment resource: %s", str(e)
            )
            raise SeldonClientError(
                "Exception when creating SeldonDeployment resource"
            ) from e

        updated_deployment = self.get_deployment(name=deployment.name)

        while poll_timeout > 0 and updated_deployment.is_pending():
            time.sleep(5)
            poll_timeout -= 5
            updated_deployment = self.get_deployment(name=deployment.name)

        return updated_deployment

    def get_deployment(self, name: str) -> SeldonDeployment:
        """Get a ZenML managed Seldon Core deployment resource by name.

        Args:
            name: the name of the Seldon Core deployment resource to fetch.

        Returns:
            The Seldon Core deployment resource.

        Raises:
            SeldonDeploymentNotFoundError: if the deployment resource cannot
                be found or is not managed by ZenML.
            SeldonClientError: if an unknown error occurs while fetching
                the deployment.
        """
        try:
            logger.debug(f"Retrieving SeldonDeployment resource: {name}")

            response = self._custom_objects_api.get_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                name=name,
            )
            logger.debug("Seldon Core API response: %s", response)
            try:
                deployment = SeldonDeployment(**response)
            except ValidationError as e:
                logger.error(
                    "Invalid Seldon Core deployment resource: %s\n%s",
                    str(e),
                    str(response),
                )
                raise SeldonDeploymentNotFoundError(
                    f"SeldonDeployment resource {name} could not be parsed"
                )

            # Only Seldon deployments managed by ZenML are returned
            if not deployment.is_managed_by_zenml():
                raise SeldonDeploymentNotFoundError(
                    f"Seldon Deployment {name} is not managed by ZenML"
                )
            return deployment

        except k8s_client.rest.ApiException as e:
            if e.status == 404:
                raise SeldonDeploymentNotFoundError(
                    f"SeldonDeployment resource not found: {name}"
                ) from e
            logger.error(
                "Exception when fetching SeldonDeployment resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Unexpected exception when fetching SeldonDeployment "
                f"resource: {name}"
            ) from e

    def find_deployments(
        self,
        name: Optional[str] = None,
        labels: Optional[Dict[str, str]] = None,
        fields: Optional[Dict[str, str]] = None,
    ) -> List[SeldonDeployment]:
        """Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

        Args:
            name: optional name of the deployment resource to find.
            fields: optional selector to restrict the list of returned
                Seldon deployments by their fields. Defaults to everything.
            labels: optional selector to restrict the list of returned
                Seldon deployments by their labels. Defaults to everything.

        Returns:
            List of Seldon Core deployments that match the given criteria.

        Raises:
            SeldonClientError: if an unknown error occurs while fetching
                the deployments.
        """
        fields = fields or {}
        labels = labels or {}
        # always filter results to only include Seldon deployments managed
        # by ZenML
        labels["app"] = "zenml"
        if name:
            fields = {"metadata.name": name}
        field_selector = (
            ",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
        )
        label_selector = (
            ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
        )
        try:

            logger.debug(
                f"Searching SeldonDeployment resources with label selector "
                f"'{labels or ''}' and field selector '{fields or ''}'"
            )
            response = self._custom_objects_api.list_namespaced_custom_object(
                group="machinelearning.seldon.io",
                version="v1",
                namespace=self._namespace,
                plural="seldondeployments",
                field_selector=field_selector,
                label_selector=label_selector,
            )
            logger.debug(
                "Seldon Core API returned %s items", len(response["items"])
            )
            deployments = []
            for item in response.get("items") or []:
                try:
                    deployments.append(SeldonDeployment(**item))
                except ValidationError as e:
                    logger.error(
                        "Invalid Seldon Core deployment resource: %s\n%s",
                        str(e),
                        str(item),
                    )
            return deployments
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when searching SeldonDeployment resources with "
                "label selector '%s' and field selector '%s': %s",
                label_selector or "",
                field_selector or "",
            )
            raise SeldonClientError(
                f"Unexpected exception when searching SeldonDeployment "
                f"with labels '{labels or ''}' and field '{fields or ''}'"
            ) from e

    def get_deployment_logs(
        self,
        name: str,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Get the logs of a Seldon Core deployment resource.

        Args:
            name: the name of the Seldon Core deployment to get logs for.
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.

        Yields:
            The next log line.

        Raises:
            SeldonClientError: if an unknown error occurs while fetching
                the logs.
        """
        logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
        try:
            response = self._core_api.list_namespaced_pod(
                namespace=self._namespace,
                label_selector=f"seldon-deployment-id={name}",
            )
            logger.debug("Kubernetes API response: %s", response)
            pods = response.items
            if not pods:
                raise SeldonClientError(
                    f"The Seldon Core deployment {name} is not currently "
                    f"running: no Kubernetes pods associated with it were found"
                )
            pod = pods[0]
            pod_name = pod.metadata.name

            containers = [c.name for c in pod.spec.containers]
            init_containers = [c.name for c in pod.spec.init_containers]
            container_statuses = {
                c.name: c.started or c.restart_count
                for c in pod.status.container_statuses
            }

            container = "default"
            if container not in containers:
                container = containers[0]
            # some containers might not be running yet and have no logs to show,
            # so we need to filter them out
            if not container_statuses[container]:
                container = init_containers[0]

            logger.info(
                f"Retrieving logs for pod: `{pod_name}` and container "
                f"`{container}` in namespace `{self._namespace}`"
            )
            response = self._core_api.read_namespaced_pod_log(
                name=pod_name,
                namespace=self._namespace,
                container=container,
                follow=follow,
                tail_lines=tail,
                _preload_content=False,
            )
        except k8s_client.rest.ApiException as e:
            logger.error(
                "Exception when fetching logs for SeldonDeployment resource "
                "%s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Unexpected exception when fetching logs for SeldonDeployment "
                f"resource: {name}"
            ) from e

        try:
            while True:
                line = response.readline().decode("utf-8").rstrip("\n")
                if not line:
                    return
                stop = yield line
                if stop:
                    return
        finally:
            response.release_conn()

    def create_or_update_secret(
        self,
        name: str,
        secret: BaseSecretSchema,
    ) -> None:
        """Create or update a Kubernetes Secret resource.

        Uses the information contained in a ZenML secret.

        Args:
            name: the name of the Secret resource to create.
            secret: a ZenML secret with key-values that should be
                stored in the Secret resource.

        Raises:
            SeldonClientError: if an unknown error occurs during the creation of
                the secret.
            k8s_client.rest.ApiException: unexpected error.
        """
        try:
            logger.debug(f"Creating Secret resource: {name}")

            secret_data = {
                k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
                    "ascii"
                )
                for k, v in secret.content.items()
                if v is not None
            }

            secret = k8s_client.V1Secret(
                metadata=k8s_client.V1ObjectMeta(
                    name=name,
                    labels={"app": "zenml"},
                ),
                type="Opaque",
                data=secret_data,
            )

            try:
                # check if the secret is already present
                self._core_api.read_namespaced_secret(
                    name=name,
                    namespace=self._namespace,
                )
                # if we got this far, the secret is already present, update it
                # in place
                response = self._core_api.replace_namespaced_secret(
                    name=name,
                    namespace=self._namespace,
                    body=secret,
                )
            except k8s_client.rest.ApiException as e:
                if e.status != 404:
                    # if an error other than 404 is raised here, treat it
                    # as an unexpected error
                    raise
                response = self._core_api.create_namespaced_secret(
                    namespace=self._namespace,
                    body=secret,
                )
            logger.debug("Kubernetes API response: %s", response)
        except k8s_client.rest.ApiException as e:
            logger.error("Exception when creating Secret resource: %s", str(e))
            raise SeldonClientError(
                "Exception when creating Secret resource"
            ) from e

    def delete_secret(
        self,
        name: str,
    ) -> None:
        """Delete a Kubernetes Secret resource managed by ZenML.

        Args:
            name: the name of the Kubernetes Secret resource to delete.

        Raises:
            SeldonClientError: if an unknown error occurs during the removal
                of the secret.
        """
        try:
            logger.debug(f"Deleting Secret resource: {name}")

            response = self._core_api.delete_namespaced_secret(
                name=name,
                namespace=self._namespace,
            )
            logger.debug("Kubernetes API response: %s", response)
        except k8s_client.rest.ApiException as e:
            if e.status == 404:
                # the secret is no longer present, nothing to do
                return
            logger.error(
                "Exception when deleting Secret resource %s: %s",
                name,
                str(e),
            )
            raise SeldonClientError(
                f"Exception when deleting Secret resource {name}"
            ) from e
namespace: str property readonly

Returns the Kubernetes namespace in use by the client.

Returns:

Type Description
str

The Kubernetes namespace in use by the client.

Exceptions:

Type Description
RuntimeError

if the namespace has not been configured.

__init__(self, context, namespace) special

Initialize a Seldon Core client.

Parameters:

Name Type Description Default
context Optional[str]

the Kubernetes context to use.

required
namespace Optional[str]

the Kubernetes namespace to use.

required
Source code in zenml/integrations/seldon/seldon_client.py
def __init__(self, context: Optional[str], namespace: Optional[str]):
    """Initialize a Seldon Core client.

    Args:
        context: the Kubernetes context to use.
        namespace: the Kubernetes namespace to use.
    """
    self._context = context
    self._namespace = namespace
    self._initialize_k8s_clients()
create_deployment(self, deployment, poll_timeout=0)

Create a Seldon Core deployment resource.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource to create

required
poll_timeout int

the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised.

0

Returns:

Type Description
SeldonDeployment

the created Seldon Core deployment resource with updated status.

Exceptions:

Type Description
SeldonDeploymentExistsError

if a deployment with the same name already exists.

SeldonClientError

if an unknown error occurs during the creation of the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def create_deployment(
    self,
    deployment: SeldonDeployment,
    poll_timeout: int = 0,
) -> SeldonDeployment:
    """Create a Seldon Core deployment resource.

    Args:
        deployment: the Seldon Core deployment resource to create
        poll_timeout: the maximum time to wait for the deployment to become
            available or to fail. If set to 0, the function will return
            immediately without checking the deployment status. If a timeout
            occurs and the deployment is still pending creation, it will
            be returned anyway and no exception will be raised.

    Returns:
        the created Seldon Core deployment resource with updated status.

    Raises:
        SeldonDeploymentExistsError: if a deployment with the same name
            already exists.
        SeldonClientError: if an unknown error occurs during the creation of
            the deployment.
    """
    try:
        logger.debug(f"Creating SeldonDeployment resource: {deployment}")

        # mark the deployment as managed by ZenML, to differentiate
        # between deployments that are created by ZenML and those that
        # are not
        deployment.mark_as_managed_by_zenml()

        response = self._custom_objects_api.create_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            body=deployment.dict(exclude_none=True),
            _request_timeout=poll_timeout or None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when creating SeldonDeployment resource: %s", str(e)
        )
        if e.status == 409:
            raise SeldonDeploymentExistsError(
                f"A deployment with the name {deployment.name} "
                f"already exists in namespace {self._namespace}"
            )
        raise SeldonClientError(
            "Exception when creating SeldonDeployment resource"
        ) from e

    created_deployment = self.get_deployment(name=deployment.name)

    while poll_timeout > 0 and created_deployment.is_pending():
        time.sleep(5)
        poll_timeout -= 5
        created_deployment = self.get_deployment(name=deployment.name)

    return created_deployment
create_or_update_secret(self, name, secret)

Create or update a Kubernetes Secret resource.

Uses the information contained in a ZenML secret.

Parameters:

Name Type Description Default
name str

the name of the Secret resource to create.

required
secret BaseSecretSchema

a ZenML secret with key-values that should be stored in the Secret resource.

required

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the creation of the secret.

k8s_client.rest.ApiException

unexpected error.

Source code in zenml/integrations/seldon/seldon_client.py
def create_or_update_secret(
    self,
    name: str,
    secret: BaseSecretSchema,
) -> None:
    """Create or update a Kubernetes Secret resource.

    Uses the information contained in a ZenML secret.

    Args:
        name: the name of the Secret resource to create.
        secret: a ZenML secret with key-values that should be
            stored in the Secret resource.

    Raises:
        SeldonClientError: if an unknown error occurs during the creation of
            the secret.
        k8s_client.rest.ApiException: unexpected error.
    """
    try:
        logger.debug(f"Creating Secret resource: {name}")

        secret_data = {
            k.upper(): base64.b64encode(str(v).encode("utf-8")).decode(
                "ascii"
            )
            for k, v in secret.content.items()
            if v is not None
        }

        secret = k8s_client.V1Secret(
            metadata=k8s_client.V1ObjectMeta(
                name=name,
                labels={"app": "zenml"},
            ),
            type="Opaque",
            data=secret_data,
        )

        try:
            # check if the secret is already present
            self._core_api.read_namespaced_secret(
                name=name,
                namespace=self._namespace,
            )
            # if we got this far, the secret is already present, update it
            # in place
            response = self._core_api.replace_namespaced_secret(
                name=name,
                namespace=self._namespace,
                body=secret,
            )
        except k8s_client.rest.ApiException as e:
            if e.status != 404:
                # if an error other than 404 is raised here, treat it
                # as an unexpected error
                raise
            response = self._core_api.create_namespaced_secret(
                namespace=self._namespace,
                body=secret,
            )
        logger.debug("Kubernetes API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error("Exception when creating Secret resource: %s", str(e))
        raise SeldonClientError(
            "Exception when creating Secret resource"
        ) from e
delete_deployment(self, name, force=False, poll_timeout=0)

Delete a Seldon Core deployment resource managed by ZenML.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment resource to delete.

required
force bool

if True, the deployment deletion will be forced (the graceful period will be set to zero).

False
poll_timeout int

the maximum time to wait for the deployment to be deleted. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment still exists, this method will return and no exception will be raised.

0

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the deployment removal.

Source code in zenml/integrations/seldon/seldon_client.py
def delete_deployment(
    self,
    name: str,
    force: bool = False,
    poll_timeout: int = 0,
) -> None:
    """Delete a Seldon Core deployment resource managed by ZenML.

    Args:
        name: the name of the Seldon Core deployment resource to delete.
        force: if True, the deployment deletion will be forced (the graceful
            period will be set to zero).
        poll_timeout: the maximum time to wait for the deployment to be
            deleted. If set to 0, the function will return immediately
            without checking the deployment status. If a timeout
            occurs and the deployment still exists, this method will
            return and no exception will be raised.

    Raises:
        SeldonClientError: if an unknown error occurs during the deployment
            removal.
    """
    try:
        logger.debug(f"Deleting SeldonDeployment resource: {name}")

        # call `get_deployment` to check that the deployment exists
        # and is managed by ZenML. It will raise
        # a SeldonDeploymentNotFoundError otherwise
        self.get_deployment(name=name)

        response = self._custom_objects_api.delete_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=name,
            _request_timeout=poll_timeout or None,
            grace_period_seconds=0 if force else None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when deleting SeldonDeployment resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Exception when deleting SeldonDeployment resource {name}"
        ) from e

    while poll_timeout > 0:
        try:
            self.get_deployment(name=name)
        except SeldonDeploymentNotFoundError:
            return
        time.sleep(5)
        poll_timeout -= 5
delete_secret(self, name)

Delete a Kubernetes Secret resource managed by ZenML.

Parameters:

Name Type Description Default
name str

the name of the Kubernetes Secret resource to delete.

required

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs during the removal of the secret.

Source code in zenml/integrations/seldon/seldon_client.py
def delete_secret(
    self,
    name: str,
) -> None:
    """Delete a Kubernetes Secret resource managed by ZenML.

    Args:
        name: the name of the Kubernetes Secret resource to delete.

    Raises:
        SeldonClientError: if an unknown error occurs during the removal
            of the secret.
    """
    try:
        logger.debug(f"Deleting Secret resource: {name}")

        response = self._core_api.delete_namespaced_secret(
            name=name,
            namespace=self._namespace,
        )
        logger.debug("Kubernetes API response: %s", response)
    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            # the secret is no longer present, nothing to do
            return
        logger.error(
            "Exception when deleting Secret resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Exception when deleting Secret resource {name}"
        ) from e
find_deployments(self, name=None, labels=None, fields=None)

Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

Parameters:

Name Type Description Default
name Optional[str]

optional name of the deployment resource to find.

None
fields Optional[Dict[str, str]]

optional selector to restrict the list of returned Seldon deployments by their fields. Defaults to everything.

None
labels Optional[Dict[str, str]]

optional selector to restrict the list of returned Seldon deployments by their labels. Defaults to everything.

None

Returns:

Type Description
List[zenml.integrations.seldon.seldon_client.SeldonDeployment]

List of Seldon Core deployments that match the given criteria.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while fetching the deployments.

Source code in zenml/integrations/seldon/seldon_client.py
def find_deployments(
    self,
    name: Optional[str] = None,
    labels: Optional[Dict[str, str]] = None,
    fields: Optional[Dict[str, str]] = None,
) -> List[SeldonDeployment]:
    """Find all ZenML-managed Seldon Core deployment resources matching the given criteria.

    Args:
        name: optional name of the deployment resource to find.
        fields: optional selector to restrict the list of returned
            Seldon deployments by their fields. Defaults to everything.
        labels: optional selector to restrict the list of returned
            Seldon deployments by their labels. Defaults to everything.

    Returns:
        List of Seldon Core deployments that match the given criteria.

    Raises:
        SeldonClientError: if an unknown error occurs while fetching
            the deployments.
    """
    fields = fields or {}
    labels = labels or {}
    # always filter results to only include Seldon deployments managed
    # by ZenML
    labels["app"] = "zenml"
    if name:
        fields = {"metadata.name": name}
    field_selector = (
        ",".join(f"{k}={v}" for k, v in fields.items()) if fields else None
    )
    label_selector = (
        ",".join(f"{k}={v}" for k, v in labels.items()) if labels else None
    )
    try:

        logger.debug(
            f"Searching SeldonDeployment resources with label selector "
            f"'{labels or ''}' and field selector '{fields or ''}'"
        )
        response = self._custom_objects_api.list_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            field_selector=field_selector,
            label_selector=label_selector,
        )
        logger.debug(
            "Seldon Core API returned %s items", len(response["items"])
        )
        deployments = []
        for item in response.get("items") or []:
            try:
                deployments.append(SeldonDeployment(**item))
            except ValidationError as e:
                logger.error(
                    "Invalid Seldon Core deployment resource: %s\n%s",
                    str(e),
                    str(item),
                )
        return deployments
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when searching SeldonDeployment resources with "
            "label selector '%s' and field selector '%s': %s",
            label_selector or "",
            field_selector or "",
        )
        raise SeldonClientError(
            f"Unexpected exception when searching SeldonDeployment "
            f"with labels '{labels or ''}' and field '{fields or ''}'"
        ) from e
get_deployment(self, name)

Get a ZenML managed Seldon Core deployment resource by name.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment resource to fetch.

required

Returns:

Type Description
SeldonDeployment

The Seldon Core deployment resource.

Exceptions:

Type Description
SeldonDeploymentNotFoundError

if the deployment resource cannot be found or is not managed by ZenML.

SeldonClientError

if an unknown error occurs while fetching the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment(self, name: str) -> SeldonDeployment:
    """Get a ZenML managed Seldon Core deployment resource by name.

    Args:
        name: the name of the Seldon Core deployment resource to fetch.

    Returns:
        The Seldon Core deployment resource.

    Raises:
        SeldonDeploymentNotFoundError: if the deployment resource cannot
            be found or is not managed by ZenML.
        SeldonClientError: if an unknown error occurs while fetching
            the deployment.
    """
    try:
        logger.debug(f"Retrieving SeldonDeployment resource: {name}")

        response = self._custom_objects_api.get_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=name,
        )
        logger.debug("Seldon Core API response: %s", response)
        try:
            deployment = SeldonDeployment(**response)
        except ValidationError as e:
            logger.error(
                "Invalid Seldon Core deployment resource: %s\n%s",
                str(e),
                str(response),
            )
            raise SeldonDeploymentNotFoundError(
                f"SeldonDeployment resource {name} could not be parsed"
            )

        # Only Seldon deployments managed by ZenML are returned
        if not deployment.is_managed_by_zenml():
            raise SeldonDeploymentNotFoundError(
                f"Seldon Deployment {name} is not managed by ZenML"
            )
        return deployment

    except k8s_client.rest.ApiException as e:
        if e.status == 404:
            raise SeldonDeploymentNotFoundError(
                f"SeldonDeployment resource not found: {name}"
            ) from e
        logger.error(
            "Exception when fetching SeldonDeployment resource %s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Unexpected exception when fetching SeldonDeployment "
            f"resource: {name}"
        ) from e
get_deployment_logs(self, name, follow=False, tail=None)

Get the logs of a Seldon Core deployment resource.

Parameters:

Name Type Description Default
name str

the name of the Seldon Core deployment to get logs for.

required
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Yields:

Type Description
Generator[str, bool, NoneType]

The next log line.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while fetching the logs.

Source code in zenml/integrations/seldon/seldon_client.py
def get_deployment_logs(
    self,
    name: str,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Get the logs of a Seldon Core deployment resource.

    Args:
        name: the name of the Seldon Core deployment to get logs for.
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.

    Yields:
        The next log line.

    Raises:
        SeldonClientError: if an unknown error occurs while fetching
            the logs.
    """
    logger.debug(f"Retrieving logs for SeldonDeployment resource: {name}")
    try:
        response = self._core_api.list_namespaced_pod(
            namespace=self._namespace,
            label_selector=f"seldon-deployment-id={name}",
        )
        logger.debug("Kubernetes API response: %s", response)
        pods = response.items
        if not pods:
            raise SeldonClientError(
                f"The Seldon Core deployment {name} is not currently "
                f"running: no Kubernetes pods associated with it were found"
            )
        pod = pods[0]
        pod_name = pod.metadata.name

        containers = [c.name for c in pod.spec.containers]
        init_containers = [c.name for c in pod.spec.init_containers]
        container_statuses = {
            c.name: c.started or c.restart_count
            for c in pod.status.container_statuses
        }

        container = "default"
        if container not in containers:
            container = containers[0]
        # some containers might not be running yet and have no logs to show,
        # so we need to filter them out
        if not container_statuses[container]:
            container = init_containers[0]

        logger.info(
            f"Retrieving logs for pod: `{pod_name}` and container "
            f"`{container}` in namespace `{self._namespace}`"
        )
        response = self._core_api.read_namespaced_pod_log(
            name=pod_name,
            namespace=self._namespace,
            container=container,
            follow=follow,
            tail_lines=tail,
            _preload_content=False,
        )
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when fetching logs for SeldonDeployment resource "
            "%s: %s",
            name,
            str(e),
        )
        raise SeldonClientError(
            f"Unexpected exception when fetching logs for SeldonDeployment "
            f"resource: {name}"
        ) from e

    try:
        while True:
            line = response.readline().decode("utf-8").rstrip("\n")
            if not line:
                return
            stop = yield line
            if stop:
                return
    finally:
        response.release_conn()
sanitize_labels(labels) staticmethod

Update the label values to be valid Kubernetes labels.

See: https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

Parameters:

Name Type Description Default
labels Dict[str, str]

the labels to sanitize.

required
Source code in zenml/integrations/seldon/seldon_client.py
@staticmethod
def sanitize_labels(labels: Dict[str, str]) -> None:
    """Update the label values to be valid Kubernetes labels.

    See:
    https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set

    Args:
        labels: the labels to sanitize.
    """
    for key, value in labels.items():
        # Kubernetes labels must be alphanumeric, no longer than
        # 63 characters, and must begin and end with an alphanumeric
        # character ([a-z0-9A-Z])
        labels[key] = re.sub(r"[^0-9a-zA-Z-_\.]+", "_", value)[:63].strip(
            "-_."
        )
update_deployment(self, deployment, poll_timeout=0)

Update a Seldon Core deployment resource.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource to update

required
poll_timeout int

the maximum time to wait for the deployment to become available or to fail. If set to 0, the function will return immediately without checking the deployment status. If a timeout occurs and the deployment is still pending creation, it will be returned anyway and no exception will be raised.

0

Returns:

Type Description
SeldonDeployment

the updated Seldon Core deployment resource with updated status.

Exceptions:

Type Description
SeldonClientError

if an unknown error occurs while updating the deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def update_deployment(
    self,
    deployment: SeldonDeployment,
    poll_timeout: int = 0,
) -> SeldonDeployment:
    """Update a Seldon Core deployment resource.

    Args:
        deployment: the Seldon Core deployment resource to update
        poll_timeout: the maximum time to wait for the deployment to become
            available or to fail. If set to 0, the function will return
            immediately without checking the deployment status. If a timeout
            occurs and the deployment is still pending creation, it will
            be returned anyway and no exception will be raised.

    Returns:
        the updated Seldon Core deployment resource with updated status.

    Raises:
        SeldonClientError: if an unknown error occurs while updating the
            deployment.
    """
    try:
        logger.debug(
            f"Updating SeldonDeployment resource: {deployment.name}"
        )

        # mark the deployment as managed by ZenML, to differentiate
        # between deployments that are created by ZenML and those that
        # are not
        deployment.mark_as_managed_by_zenml()

        # call `get_deployment` to check that the deployment exists
        # and is managed by ZenML. It will raise
        # a SeldonDeploymentNotFoundError otherwise
        self.get_deployment(name=deployment.name)

        response = self._custom_objects_api.patch_namespaced_custom_object(
            group="machinelearning.seldon.io",
            version="v1",
            namespace=self._namespace,
            plural="seldondeployments",
            name=deployment.name,
            body=deployment.dict(exclude_none=True),
            _request_timeout=poll_timeout or None,
        )
        logger.debug("Seldon Core API response: %s", response)
    except k8s_client.rest.ApiException as e:
        logger.error(
            "Exception when updating SeldonDeployment resource: %s", str(e)
        )
        raise SeldonClientError(
            "Exception when creating SeldonDeployment resource"
        ) from e

    updated_deployment = self.get_deployment(name=deployment.name)

    while poll_timeout > 0 and updated_deployment.is_pending():
        time.sleep(5)
        poll_timeout -= 5
        updated_deployment = self.get_deployment(name=deployment.name)

    return updated_deployment
SeldonClientError (Exception)

Base exception class for all exceptions raised by the SeldonClient.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientError(Exception):
    """Base exception class for all exceptions raised by the SeldonClient."""
SeldonClientTimeout (SeldonClientError)

Raised when the Seldon client timed out while waiting for a resource to reach the expected status.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonClientTimeout(SeldonClientError):
    """Raised when the Seldon client timed out while waiting for a resource to reach the expected status."""
SeldonDeployment (BaseModel) pydantic-model

A Seldon Core deployment CRD.

This is a Pydantic representation of some of the fields in the Seldon Core CRD (documented here: https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).

Note that not all fields are represented, only those that are relevant to the ZenML integration. The fields that are not represented are silently ignored when the Seldon Deployment is created or updated from an external SeldonDeployment CRD representation.

Attributes:

Name Type Description
kind str

Kubernetes kind field.

apiVersion str

Kubernetes apiVersion field.

metadata SeldonDeploymentMetadata

Kubernetes metadata field.

spec SeldonDeploymentSpec

Seldon Deployment spec entry.

status Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatus]

Seldon Deployment status.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeployment(BaseModel):
    """A Seldon Core deployment CRD.

    This is a Pydantic representation of some of the fields in the Seldon Core
    CRD (documented here:
    https://docs.seldon.io/projects/seldon-core/en/latest/reference/seldon-deployment.html).

    Note that not all fields are represented, only those that are relevant to
    the ZenML integration. The fields that are not represented are silently
    ignored when the Seldon Deployment is created or updated from an external
    SeldonDeployment CRD representation.

    Attributes:
        kind: Kubernetes kind field.
        apiVersion: Kubernetes apiVersion field.
        metadata: Kubernetes metadata field.
        spec: Seldon Deployment spec entry.
        status: Seldon Deployment status.
    """

    kind: str = Field(SELDON_DEPLOYMENT_KIND, const=True)
    apiVersion: str = Field(SELDON_DEPLOYMENT_API_VERSION, const=True)
    metadata: SeldonDeploymentMetadata = Field(
        default_factory=SeldonDeploymentMetadata
    )
    spec: SeldonDeploymentSpec = Field(default_factory=SeldonDeploymentSpec)
    status: Optional[SeldonDeploymentStatus]

    def __str__(self) -> str:
        """Returns a string representation of the Seldon Deployment.

        Returns:
            A string representation of the Seldon Deployment.
        """
        return json.dumps(self.dict(exclude_none=True), indent=4)

    @classmethod
    def build(
        cls,
        name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_name: Optional[str] = None,
        implementation: Optional[str] = None,
        secret_name: Optional[str] = None,
        labels: Optional[Dict[str, str]] = None,
        annotations: Optional[Dict[str, str]] = None,
    ) -> "SeldonDeployment":
        """Build a basic Seldon Deployment object.

        Args:
            name: The name of the Seldon Deployment. If not explicitly passed,
                a unique name is autogenerated.
            model_uri: The URI of the model.
            model_name: The name of the model.
            implementation: The implementation of the model.
            secret_name: The name of the Kubernetes secret containing
                environment variable values (e.g. with credentials for the
                artifact store) to use with the deployment service.
            labels: A dictionary of labels to apply to the Seldon Deployment.
            annotations: A dictionary of annotations to apply to the Seldon
                Deployment.

        Returns:
            A minimal SeldonDeployment object built from the provided
            parameters.
        """
        if not name:
            name = f"zenml-{time.time()}"

        if labels is None:
            labels = {}
        if annotations is None:
            annotations = {}

        return SeldonDeployment(
            metadata=SeldonDeploymentMetadata(
                name=name, labels=labels, annotations=annotations
            ),
            spec=SeldonDeploymentSpec(
                name=name,
                predictors=[
                    SeldonDeploymentPredictor(
                        name=model_name or "",
                        graph=SeldonDeploymentPredictiveUnit(
                            name="default",
                            type=SeldonDeploymentPredictiveUnitType.MODEL,
                            modelUri=model_uri or "",
                            implementation=implementation or "",
                            envSecretRefName=secret_name,
                        ),
                    )
                ],
            ),
        )

    def is_managed_by_zenml(self) -> bool:
        """Checks if this Seldon Deployment is managed by ZenML.

        The convention used to differentiate between SeldonDeployment instances
        that are managed by ZenML and those that are not is to set the `app`
        label value to `zenml`.

        Returns:
            True if the Seldon Deployment is managed by ZenML, False
            otherwise.
        """
        return self.metadata.labels.get("app") == "zenml"

    def mark_as_managed_by_zenml(self) -> None:
        """Marks this Seldon Deployment as managed by ZenML.

        The convention used to differentiate between SeldonDeployment instances
        that are managed by ZenML and those that are not is to set the `app`
        label value to `zenml`.
        """
        self.metadata.labels["app"] = "zenml"

    @property
    def name(self) -> str:
        """Returns the name of this Seldon Deployment.

        This is just a shortcut for `self.metadata.name`.

        Returns:
            The name of this Seldon Deployment.
        """
        return self.metadata.name

    @property
    def state(self) -> SeldonDeploymentStatusState:
        """The state of the Seldon Deployment.

        Returns:
            The state of the Seldon Deployment.
        """
        if not self.status:
            return SeldonDeploymentStatusState.UNKNOWN
        return self.status.state

    def is_pending(self) -> bool:
        """Checks if the Seldon Deployment is in a pending state.

        Returns:
            True if the Seldon Deployment is pending, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.CREATING

    def is_available(self) -> bool:
        """Checks if the Seldon Deployment is in an available state.

        Returns:
            True if the Seldon Deployment is available, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.AVAILABLE

    def is_failed(self) -> bool:
        """Checks if the Seldon Deployment is in a failed state.

        Returns:
            True if the Seldon Deployment is failed, False otherwise.
        """
        return self.state == SeldonDeploymentStatusState.FAILED

    def get_error(self) -> Optional[str]:
        """Get a message describing the error, if in an error state.

        Returns:
            A message describing the error, if in an error state, otherwise
            None.
        """
        if self.status and self.is_failed():
            return self.status.description
        return None

    def get_pending_message(self) -> Optional[str]:
        """Get a message describing the pending conditions of the Seldon Deployment.

        Returns:
            A message describing the pending condition of the Seldon
            Deployment, or None, if no conditions are pending.
        """
        if not self.status or not self.status.conditions:
            return None
        ready_condition_message = [
            c.message
            for c in self.status.conditions
            if c.type == "Ready" and not c.status
        ]
        if not ready_condition_message:
            return None
        return ready_condition_message[0]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
name: str property readonly

Returns the name of this Seldon Deployment.

This is just a shortcut for self.metadata.name.

Returns:

Type Description
str

The name of this Seldon Deployment.

state: SeldonDeploymentStatusState property readonly

The state of the Seldon Deployment.

Returns:

Type Description
SeldonDeploymentStatusState

The state of the Seldon Deployment.

Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
__str__(self) special

Returns a string representation of the Seldon Deployment.

Returns:

Type Description
str

A string representation of the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
def __str__(self) -> str:
    """Returns a string representation of the Seldon Deployment.

    Returns:
        A string representation of the Seldon Deployment.
    """
    return json.dumps(self.dict(exclude_none=True), indent=4)
build(name=None, model_uri=None, model_name=None, implementation=None, secret_name=None, labels=None, annotations=None) classmethod

Build a basic Seldon Deployment object.

Parameters:

Name Type Description Default
name Optional[str]

The name of the Seldon Deployment. If not explicitly passed, a unique name is autogenerated.

None
model_uri Optional[str]

The URI of the model.

None
model_name Optional[str]

The name of the model.

None
implementation Optional[str]

The implementation of the model.

None
secret_name Optional[str]

The name of the Kubernetes secret containing environment variable values (e.g. with credentials for the artifact store) to use with the deployment service.

None
labels Optional[Dict[str, str]]

A dictionary of labels to apply to the Seldon Deployment.

None
annotations Optional[Dict[str, str]]

A dictionary of annotations to apply to the Seldon Deployment.

None

Returns:

Type Description
SeldonDeployment

A minimal SeldonDeployment object built from the provided parameters.

Source code in zenml/integrations/seldon/seldon_client.py
@classmethod
def build(
    cls,
    name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_name: Optional[str] = None,
    implementation: Optional[str] = None,
    secret_name: Optional[str] = None,
    labels: Optional[Dict[str, str]] = None,
    annotations: Optional[Dict[str, str]] = None,
) -> "SeldonDeployment":
    """Build a basic Seldon Deployment object.

    Args:
        name: The name of the Seldon Deployment. If not explicitly passed,
            a unique name is autogenerated.
        model_uri: The URI of the model.
        model_name: The name of the model.
        implementation: The implementation of the model.
        secret_name: The name of the Kubernetes secret containing
            environment variable values (e.g. with credentials for the
            artifact store) to use with the deployment service.
        labels: A dictionary of labels to apply to the Seldon Deployment.
        annotations: A dictionary of annotations to apply to the Seldon
            Deployment.

    Returns:
        A minimal SeldonDeployment object built from the provided
        parameters.
    """
    if not name:
        name = f"zenml-{time.time()}"

    if labels is None:
        labels = {}
    if annotations is None:
        annotations = {}

    return SeldonDeployment(
        metadata=SeldonDeploymentMetadata(
            name=name, labels=labels, annotations=annotations
        ),
        spec=SeldonDeploymentSpec(
            name=name,
            predictors=[
                SeldonDeploymentPredictor(
                    name=model_name or "",
                    graph=SeldonDeploymentPredictiveUnit(
                        name="default",
                        type=SeldonDeploymentPredictiveUnitType.MODEL,
                        modelUri=model_uri or "",
                        implementation=implementation or "",
                        envSecretRefName=secret_name,
                    ),
                )
            ],
        ),
    )
get_error(self)

Get a message describing the error, if in an error state.

Returns:

Type Description
Optional[str]

A message describing the error, if in an error state, otherwise None.

Source code in zenml/integrations/seldon/seldon_client.py
def get_error(self) -> Optional[str]:
    """Get a message describing the error, if in an error state.

    Returns:
        A message describing the error, if in an error state, otherwise
        None.
    """
    if self.status and self.is_failed():
        return self.status.description
    return None
get_pending_message(self)

Get a message describing the pending conditions of the Seldon Deployment.

Returns:

Type Description
Optional[str]

A message describing the pending condition of the Seldon Deployment, or None, if no conditions are pending.

Source code in zenml/integrations/seldon/seldon_client.py
def get_pending_message(self) -> Optional[str]:
    """Get a message describing the pending conditions of the Seldon Deployment.

    Returns:
        A message describing the pending condition of the Seldon
        Deployment, or None, if no conditions are pending.
    """
    if not self.status or not self.status.conditions:
        return None
    ready_condition_message = [
        c.message
        for c in self.status.conditions
        if c.type == "Ready" and not c.status
    ]
    if not ready_condition_message:
        return None
    return ready_condition_message[0]
is_available(self)

Checks if the Seldon Deployment is in an available state.

Returns:

Type Description
bool

True if the Seldon Deployment is available, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_available(self) -> bool:
    """Checks if the Seldon Deployment is in an available state.

    Returns:
        True if the Seldon Deployment is available, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.AVAILABLE
is_failed(self)

Checks if the Seldon Deployment is in a failed state.

Returns:

Type Description
bool

True if the Seldon Deployment is failed, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_failed(self) -> bool:
    """Checks if the Seldon Deployment is in a failed state.

    Returns:
        True if the Seldon Deployment is failed, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.FAILED
is_managed_by_zenml(self)

Checks if this Seldon Deployment is managed by ZenML.

The convention used to differentiate between SeldonDeployment instances that are managed by ZenML and those that are not is to set the app label value to zenml.

Returns:

Type Description
bool

True if the Seldon Deployment is managed by ZenML, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_managed_by_zenml(self) -> bool:
    """Checks if this Seldon Deployment is managed by ZenML.

    The convention used to differentiate between SeldonDeployment instances
    that are managed by ZenML and those that are not is to set the `app`
    label value to `zenml`.

    Returns:
        True if the Seldon Deployment is managed by ZenML, False
        otherwise.
    """
    return self.metadata.labels.get("app") == "zenml"
is_pending(self)

Checks if the Seldon Deployment is in a pending state.

Returns:

Type Description
bool

True if the Seldon Deployment is pending, False otherwise.

Source code in zenml/integrations/seldon/seldon_client.py
def is_pending(self) -> bool:
    """Checks if the Seldon Deployment is in a pending state.

    Returns:
        True if the Seldon Deployment is pending, False otherwise.
    """
    return self.state == SeldonDeploymentStatusState.CREATING
mark_as_managed_by_zenml(self)

Marks this Seldon Deployment as managed by ZenML.

The convention used to differentiate between SeldonDeployment instances that are managed by ZenML and those that are not is to set the app label value to zenml.

Source code in zenml/integrations/seldon/seldon_client.py
def mark_as_managed_by_zenml(self) -> None:
    """Marks this Seldon Deployment as managed by ZenML.

    The convention used to differentiate between SeldonDeployment instances
    that are managed by ZenML and those that are not is to set the `app`
    label value to `zenml`.
    """
    self.metadata.labels["app"] = "zenml"
SeldonDeploymentExistsError (SeldonClientError)

Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentExistsError(SeldonClientError):
    """Raised when a SeldonDeployment resource cannot be created because a resource with the same name already exists."""
SeldonDeploymentMetadata (BaseModel) pydantic-model

Metadata for a Seldon Deployment.

Attributes:

Name Type Description
name str

the name of the Seldon Deployment.

labels Dict[str, str]

Kubernetes labels for the Seldon Deployment.

annotations Dict[str, str]

Kubernetes annotations for the Seldon Deployment.

creationTimestamp Optional[str]

the creation timestamp of the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentMetadata(BaseModel):
    """Metadata for a Seldon Deployment.

    Attributes:
        name: the name of the Seldon Deployment.
        labels: Kubernetes labels for the Seldon Deployment.
        annotations: Kubernetes annotations for the Seldon Deployment.
        creationTimestamp: the creation timestamp of the Seldon Deployment.
    """

    name: str
    labels: Dict[str, str] = Field(default_factory=dict)
    annotations: Dict[str, str] = Field(default_factory=dict)
    creationTimestamp: Optional[str]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentNotFoundError (SeldonClientError)

Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentNotFoundError(SeldonClientError):
    """Raised when a particular SeldonDeployment resource is not found or is not managed by ZenML."""
SeldonDeploymentPredictiveUnit (BaseModel) pydantic-model

Seldon Deployment predictive unit.

Attributes:

Name Type Description
name str

the name of the predictive unit.

type Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnitType]

predictive unit type.

implementation Optional[str]

the Seldon Core implementation used to serve the model.

modelUri Optional[str]

URI of the model (or models) to serve.

serviceAccountName Optional[str]

the name of the service account to associate with the predictive unit container.

envSecretRefName Optional[str]

the name of a Kubernetes secret that contains environment variables (e.g. credentials) to be configured for the predictive unit container.

children List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictiveUnit]

a list of child predictive units that together make up the model serving graph.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnit(BaseModel):
    """Seldon Deployment predictive unit.

    Attributes:
        name: the name of the predictive unit.
        type: predictive unit type.
        implementation: the Seldon Core implementation used to serve the model.
        modelUri: URI of the model (or models) to serve.
        serviceAccountName: the name of the service account to associate with
            the predictive unit container.
        envSecretRefName: the name of a Kubernetes secret that contains
            environment variables (e.g. credentials) to be configured for the
            predictive unit container.
        children: a list of child predictive units that together make up the
            model serving graph.
    """

    name: str
    type: Optional[
        SeldonDeploymentPredictiveUnitType
    ] = SeldonDeploymentPredictiveUnitType.MODEL
    implementation: Optional[str]
    modelUri: Optional[str]
    serviceAccountName: Optional[str]
    envSecretRefName: Optional[str]
    children: List["SeldonDeploymentPredictiveUnit"] = Field(
        default_factory=list
    )

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentPredictiveUnitType (StrEnum)

Predictive unit types for a Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictiveUnitType(StrEnum):
    """Predictive unit types for a Seldon Deployment."""

    UNKNOWN_TYPE = "UNKNOWN_TYPE"
    ROUTER = "ROUTER"
    COMBINER = "COMBINER"
    MODEL = "MODEL"
    TRANSFORMER = "TRANSFORMER"
    OUTPUT_TRANSFORMER = "OUTPUT_TRANSFORMER"
SeldonDeploymentPredictor (BaseModel) pydantic-model

Seldon Deployment predictor.

Attributes:

Name Type Description
name str

the name of the predictor.

replicas int

the number of pod replicas for the predictor.

graph SeldonDeploymentPredictiveUnit

the serving graph composed of one or more predictive units.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentPredictor(BaseModel):
    """Seldon Deployment predictor.

    Attributes:
        name: the name of the predictor.
        replicas: the number of pod replicas for the predictor.
        graph: the serving graph composed of one or more predictive units.
    """

    name: str
    replicas: int = 1
    graph: SeldonDeploymentPredictiveUnit = Field(
        default_factory=SeldonDeploymentPredictiveUnit
    )

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentSpec (BaseModel) pydantic-model

Spec for a Seldon Deployment.

Attributes:

Name Type Description
name str

the name of the Seldon Deployment.

protocol Optional[str]

the API protocol used for the Seldon Deployment.

predictors List[zenml.integrations.seldon.seldon_client.SeldonDeploymentPredictor]

a list of predictors that make up the serving graph.

replicas int

the default number of pod replicas used for the predictors.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentSpec(BaseModel):
    """Spec for a Seldon Deployment.

    Attributes:
        name: the name of the Seldon Deployment.
        protocol: the API protocol used for the Seldon Deployment.
        predictors: a list of predictors that make up the serving graph.
        replicas: the default number of pod replicas used for the predictors.
    """

    name: str
    protocol: Optional[str]
    predictors: List[SeldonDeploymentPredictor]
    replicas: int = 1

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentStatus (BaseModel) pydantic-model

The status of a Seldon Deployment.

Attributes:

Name Type Description
state SeldonDeploymentStatusState

the current state of the Seldon Deployment.

description Optional[str]

a human-readable description of the current state.

replicas Optional[int]

the current number of running pod replicas

address Optional[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusAddress]

the address where the Seldon Deployment API can be accessed.

conditions List[zenml.integrations.seldon.seldon_client.SeldonDeploymentStatusCondition]

the list of Kubernetes conditions for the Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatus(BaseModel):
    """The status of a Seldon Deployment.

    Attributes:
        state: the current state of the Seldon Deployment.
        description: a human-readable description of the current state.
        replicas: the current number of running pod replicas
        address: the address where the Seldon Deployment API can be accessed.
        conditions: the list of Kubernetes conditions for the Seldon Deployment.
    """

    state: SeldonDeploymentStatusState = SeldonDeploymentStatusState.UNKNOWN
    description: Optional[str]
    replicas: Optional[int]
    address: Optional[SeldonDeploymentStatusAddress]
    conditions: List[SeldonDeploymentStatusCondition]

    class Config:
        """Pydantic configuration class."""

        # validate attribute assignments
        validate_assignment = True
        # Ignore extra attributes from the CRD that are not reflected here
        extra = "ignore"
Config

Pydantic configuration class.

Source code in zenml/integrations/seldon/seldon_client.py
class Config:
    """Pydantic configuration class."""

    # validate attribute assignments
    validate_assignment = True
    # Ignore extra attributes from the CRD that are not reflected here
    extra = "ignore"
SeldonDeploymentStatusAddress (BaseModel) pydantic-model

The status address for a Seldon Deployment.

Attributes:

Name Type Description
url str

the URL where the Seldon Deployment API can be accessed internally.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusAddress(BaseModel):
    """The status address for a Seldon Deployment.

    Attributes:
        url: the URL where the Seldon Deployment API can be accessed internally.
    """

    url: str
SeldonDeploymentStatusCondition (BaseModel) pydantic-model

The Kubernetes status condition entry for a Seldon Deployment.

Attributes:

Name Type Description
type str

Type of runtime condition.

status bool

Status of the condition.

reason Optional[str]

Brief CamelCase string containing reason for the condition's last transition.

message Optional[str]

Human-readable message indicating details about last transition.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusCondition(BaseModel):
    """The Kubernetes status condition entry for a Seldon Deployment.

    Attributes:
        type: Type of runtime condition.
        status: Status of the condition.
        reason: Brief CamelCase string containing reason for the condition's
            last transition.
        message: Human-readable message indicating details about last
            transition.
    """

    type: str
    status: bool
    reason: Optional[str]
    message: Optional[str]
SeldonDeploymentStatusState (StrEnum)

Possible state values for a Seldon Deployment.

Source code in zenml/integrations/seldon/seldon_client.py
class SeldonDeploymentStatusState(StrEnum):
    """Possible state values for a Seldon Deployment."""

    UNKNOWN = "Unknown"
    AVAILABLE = "Available"
    CREATING = "Creating"
    FAILED = "Failed"

services special

Initialization for Seldon services.

seldon_deployment

Implementation for the Seldon Deployment step.

SeldonDeploymentConfig (ServiceConfig) pydantic-model

Seldon Core deployment service configuration.

Attributes:

Name Type Description
model_uri str

URI of the model (or models) to serve.

model_name str

the name of the model. Multiple versions of the same model should use the same model name.

implementation str

the Seldon Core implementation used to serve the model.

replicas int

number of replicas to use for the prediction service.

secret_name Optional[str]

the name of a Kubernetes secret containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store).

model_metadata Dict[str, Any]

optional model metadata information (see https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).

extra_args Dict[str, Any]

additional arguments to pass to the Seldon Core deployment resource configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentConfig(ServiceConfig):
    """Seldon Core deployment service configuration.

    Attributes:
        model_uri: URI of the model (or models) to serve.
        model_name: the name of the model. Multiple versions of the same model
            should use the same model name.
        implementation: the Seldon Core implementation used to serve the model.
        replicas: number of replicas to use for the prediction service.
        secret_name: the name of a Kubernetes secret containing additional
            configuration parameters for the Seldon Core deployment (e.g.
            credentials to access the Artifact Store).
        model_metadata: optional model metadata information (see
            https://docs.seldon.io/projects/seldon-core/en/latest/reference/apis/metadata.html).
        extra_args: additional arguments to pass to the Seldon Core deployment
            resource configuration.
    """

    model_uri: str = ""
    model_name: str = "default"
    # TODO [ENG-775]: have an enum of all supported Seldon Core implementations
    implementation: str
    replicas: int = 1
    secret_name: Optional[str]
    model_metadata: Dict[str, Any] = Field(default_factory=dict)
    extra_args: Dict[str, Any] = Field(default_factory=dict)

    def get_seldon_deployment_labels(self) -> Dict[str, str]:
        """Generate labels for the Seldon Core deployment from the service configuration.

        These labels are attached to the Seldon Core deployment resource
        and may be used as label selectors in lookup operations.

        Returns:
            The labels for the Seldon Core deployment.
        """
        labels = {}
        if self.pipeline_name:
            labels["zenml.pipeline_name"] = self.pipeline_name
        if self.pipeline_run_id:
            labels["zenml.pipeline_run_id"] = self.pipeline_run_id
        if self.pipeline_step_name:
            labels["zenml.pipeline_step_name"] = self.pipeline_step_name
        if self.model_name:
            labels["zenml.model_name"] = self.model_name
        if self.model_uri:
            labels["zenml.model_uri"] = self.model_uri
        if self.implementation:
            labels["zenml.model_type"] = self.implementation
        SeldonClient.sanitize_labels(labels)
        return labels

    def get_seldon_deployment_annotations(self) -> Dict[str, str]:
        """Generate annotations for the Seldon Core deployment from the service configuration.

        The annotations are used to store additional information about the
        Seldon Core service that is associated with the deployment that is
        not available in the labels. One annotation particularly important
        is the serialized Service configuration itself, which is used to
        recreate the service configuration from a remote Seldon deployment.

        Returns:
            The annotations for the Seldon Core deployment.
        """
        annotations = {
            "zenml.service_config": self.json(),
            "zenml.version": __version__,
        }
        return annotations

    @classmethod
    def create_from_deployment(
        cls, deployment: SeldonDeployment
    ) -> "SeldonDeploymentConfig":
        """Recreate the configuration of a Seldon Core Service from a deployed instance.

        Args:
            deployment: the Seldon Core deployment resource.

        Returns:
            The Seldon Core service configuration corresponding to the given
            Seldon Core deployment resource.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected annotations or it contains an invalid or
                incompatible Seldon Core service configuration.
        """
        config_data = deployment.metadata.annotations.get(
            "zenml.service_config"
        )
        if not config_data:
            raise ValueError(
                f"The given deployment resource does not contain a "
                f"'zenml.service_config' annotation: {deployment}"
            )
        try:
            service_config = cls.parse_raw(config_data)
        except ValidationError as e:
            raise ValueError(
                f"The loaded Seldon Core deployment resource contains an "
                f"invalid or incompatible Seldon Core service configuration: "
                f"{config_data}"
            ) from e
        return service_config
create_from_deployment(deployment) classmethod

Recreate the configuration of a Seldon Core Service from a deployed instance.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource.

required

Returns:

Type Description
SeldonDeploymentConfig

The Seldon Core service configuration corresponding to the given Seldon Core deployment resource.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected annotations or it contains an invalid or incompatible Seldon Core service configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: SeldonDeployment
) -> "SeldonDeploymentConfig":
    """Recreate the configuration of a Seldon Core Service from a deployed instance.

    Args:
        deployment: the Seldon Core deployment resource.

    Returns:
        The Seldon Core service configuration corresponding to the given
        Seldon Core deployment resource.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected annotations or it contains an invalid or
            incompatible Seldon Core service configuration.
    """
    config_data = deployment.metadata.annotations.get(
        "zenml.service_config"
    )
    if not config_data:
        raise ValueError(
            f"The given deployment resource does not contain a "
            f"'zenml.service_config' annotation: {deployment}"
        )
    try:
        service_config = cls.parse_raw(config_data)
    except ValidationError as e:
        raise ValueError(
            f"The loaded Seldon Core deployment resource contains an "
            f"invalid or incompatible Seldon Core service configuration: "
            f"{config_data}"
        ) from e
    return service_config
get_seldon_deployment_annotations(self)

Generate annotations for the Seldon Core deployment from the service configuration.

The annotations are used to store additional information about the Seldon Core service that is associated with the deployment that is not available in the labels. One annotation particularly important is the serialized Service configuration itself, which is used to recreate the service configuration from a remote Seldon deployment.

Returns:

Type Description
Dict[str, str]

The annotations for the Seldon Core deployment.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_annotations(self) -> Dict[str, str]:
    """Generate annotations for the Seldon Core deployment from the service configuration.

    The annotations are used to store additional information about the
    Seldon Core service that is associated with the deployment that is
    not available in the labels. One annotation particularly important
    is the serialized Service configuration itself, which is used to
    recreate the service configuration from a remote Seldon deployment.

    Returns:
        The annotations for the Seldon Core deployment.
    """
    annotations = {
        "zenml.service_config": self.json(),
        "zenml.version": __version__,
    }
    return annotations
get_seldon_deployment_labels(self)

Generate labels for the Seldon Core deployment from the service configuration.

These labels are attached to the Seldon Core deployment resource and may be used as label selectors in lookup operations.

Returns:

Type Description
Dict[str, str]

The labels for the Seldon Core deployment.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_seldon_deployment_labels(self) -> Dict[str, str]:
    """Generate labels for the Seldon Core deployment from the service configuration.

    These labels are attached to the Seldon Core deployment resource
    and may be used as label selectors in lookup operations.

    Returns:
        The labels for the Seldon Core deployment.
    """
    labels = {}
    if self.pipeline_name:
        labels["zenml.pipeline_name"] = self.pipeline_name
    if self.pipeline_run_id:
        labels["zenml.pipeline_run_id"] = self.pipeline_run_id
    if self.pipeline_step_name:
        labels["zenml.pipeline_step_name"] = self.pipeline_step_name
    if self.model_name:
        labels["zenml.model_name"] = self.model_name
    if self.model_uri:
        labels["zenml.model_uri"] = self.model_uri
    if self.implementation:
        labels["zenml.model_type"] = self.implementation
    SeldonClient.sanitize_labels(labels)
    return labels
SeldonDeploymentService (BaseService) pydantic-model

A service that represents a Seldon Core deployment server.

Attributes:

Name Type Description
config SeldonDeploymentConfig

service configuration.

status SeldonDeploymentServiceStatus

service status.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentService(BaseService):
    """A service that represents a Seldon Core deployment server.

    Attributes:
        config: service configuration.
        status: service status.
    """

    SERVICE_TYPE = ServiceType(
        name="seldon-deployment",
        type="model-serving",
        flavor="seldon",
        description="Seldon Core prediction service",
    )

    config: SeldonDeploymentConfig = Field(
        default_factory=SeldonDeploymentConfig
    )
    status: SeldonDeploymentServiceStatus = Field(
        default_factory=SeldonDeploymentServiceStatus
    )

    def _get_client(self) -> SeldonClient:
        """Get the Seldon Core client from the active Seldon Core model deployer.

        Returns:
            The Seldon Core client.
        """
        from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
            SeldonModelDeployer,
        )

        model_deployer = SeldonModelDeployer.get_active_model_deployer()
        return model_deployer.seldon_client

    def check_status(self) -> Tuple[ServiceState, str]:
        """Check the the current operational state of the Seldon Core deployment.

        Returns:
            The operational state of the Seldon Core deployment and a message
            providing additional information about that state (e.g. a
            description of the error, if one is encountered).
        """
        client = self._get_client()
        name = self.seldon_deployment_name
        try:
            deployment = client.get_deployment(name=name)
        except SeldonDeploymentNotFoundError:
            return (ServiceState.INACTIVE, "")

        if deployment.is_available():
            return (
                ServiceState.ACTIVE,
                f"Seldon Core deployment '{name}' is available",
            )

        if deployment.is_failed():
            return (
                ServiceState.ERROR,
                f"Seldon Core deployment '{name}' failed: "
                f"{deployment.get_error()}",
            )

        pending_message = deployment.get_pending_message() or ""
        return (
            ServiceState.PENDING_STARTUP,
            "Seldon Core deployment is being created: " + pending_message,
        )

    @property
    def seldon_deployment_name(self) -> str:
        """Get the name of the Seldon Core deployment.

        It should return the one that uniquely corresponds to this service instance.

        Returns:
            The name of the Seldon Core deployment.
        """
        return f"zenml-{str(self.uuid)}"

    def _get_seldon_deployment_labels(self) -> Dict[str, str]:
        """Generate the labels for the Seldon Core deployment from the service configuration.

        Returns:
            The labels for the Seldon Core deployment.
        """
        labels = self.config.get_seldon_deployment_labels()
        labels["zenml.service_uuid"] = str(self.uuid)
        SeldonClient.sanitize_labels(labels)
        return labels

    @classmethod
    def create_from_deployment(
        cls, deployment: SeldonDeployment
    ) -> "SeldonDeploymentService":
        """Recreate a Seldon Core service from a Seldon Core deployment resource.

        It should then update their operational status.

        Args:
            deployment: the Seldon Core deployment resource.

        Returns:
            The Seldon Core service corresponding to the given
            Seldon Core deployment resource.

        Raises:
            ValueError: if the given deployment resource does not contain
                the expected service_uuid label.
        """
        config = SeldonDeploymentConfig.create_from_deployment(deployment)
        uuid = deployment.metadata.labels.get("zenml.service_uuid")
        if not uuid:
            raise ValueError(
                f"The given deployment resource does not contain a valid "
                f"'zenml.service_uuid' label: {deployment}"
            )
        service = cls(uuid=UUID(uuid), config=config)
        service.update_status()
        return service

    def provision(self) -> None:
        """Provision or update remote Seldon Core deployment instance.

        This should then match the current configuration.
        """
        client = self._get_client()

        name = self.seldon_deployment_name

        deployment = SeldonDeployment.build(
            name=name,
            model_uri=self.config.model_uri,
            model_name=self.config.model_name,
            implementation=self.config.implementation,
            secret_name=self.config.secret_name,
            labels=self._get_seldon_deployment_labels(),
            annotations=self.config.get_seldon_deployment_annotations(),
        )
        deployment.spec.replicas = self.config.replicas
        deployment.spec.predictors[0].replicas = self.config.replicas

        # check if the Seldon deployment already exists
        try:
            client.get_deployment(name=name)
            # update the existing deployment
            client.update_deployment(deployment)
        except SeldonDeploymentNotFoundError:
            # create the deployment
            client.create_deployment(deployment=deployment)

    def deprovision(self, force: bool = False) -> None:
        """Deprovision the remote Seldon Core deployment instance.

        Args:
            force: if True, the remote deployment instance will be
                forcefully deprovisioned.
        """
        client = self._get_client()
        name = self.seldon_deployment_name
        try:
            client.delete_deployment(name=name, force=force)
        except SeldonDeploymentNotFoundError:
            pass

    def get_logs(
        self,
        follow: bool = False,
        tail: Optional[int] = None,
    ) -> Generator[str, bool, None]:
        """Get the logs of a Seldon Core model deployment.

        Args:
            follow: if True, the logs will be streamed as they are written
            tail: only retrieve the last NUM lines of log output.

        Returns:
            A generator that can be accessed to get the service logs.
        """
        return self._get_client().get_deployment_logs(
            self.seldon_deployment_name,
            follow=follow,
            tail=tail,
        )

    @property
    def prediction_url(self) -> Optional[str]:
        """The prediction URI exposed by the prediction service.

        Returns:
            The prediction URI exposed by the prediction service, or None if
            the service is not yet ready.
        """
        from zenml.integrations.seldon.model_deployers.seldon_model_deployer import (
            SeldonModelDeployer,
        )

        if not self.is_running:
            return None
        namespace = self._get_client().namespace
        model_deployer = SeldonModelDeployer.get_active_model_deployer()
        return os.path.join(
            model_deployer.base_url,
            "seldon",
            namespace,
            self.seldon_deployment_name,
            "api/v0.1/predictions",
        )

    def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a numpy array representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not yet ready.
            ValueError: if the prediction_url is not set.
        """
        if not self.is_running:
            raise Exception(
                "Seldon prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.prediction_url is None:
            raise ValueError("`self.prediction_url` is not set, cannot post.")
        response = requests.post(
            self.prediction_url,
            json={"data": {"ndarray": request.tolist()}},
        )
        response.raise_for_status()
        return np.array(response.json()["data"]["ndarray"])
prediction_url: Optional[str] property readonly

The prediction URI exposed by the prediction service.

Returns:

Type Description
Optional[str]

The prediction URI exposed by the prediction service, or None if the service is not yet ready.

seldon_deployment_name: str property readonly

Get the name of the Seldon Core deployment.

It should return the one that uniquely corresponds to this service instance.

Returns:

Type Description
str

The name of the Seldon Core deployment.

check_status(self)

Check the the current operational state of the Seldon Core deployment.

Returns:

Type Description
Tuple[zenml.services.service_status.ServiceState, str]

The operational state of the Seldon Core deployment and a message providing additional information about that state (e.g. a description of the error, if one is encountered).

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def check_status(self) -> Tuple[ServiceState, str]:
    """Check the the current operational state of the Seldon Core deployment.

    Returns:
        The operational state of the Seldon Core deployment and a message
        providing additional information about that state (e.g. a
        description of the error, if one is encountered).
    """
    client = self._get_client()
    name = self.seldon_deployment_name
    try:
        deployment = client.get_deployment(name=name)
    except SeldonDeploymentNotFoundError:
        return (ServiceState.INACTIVE, "")

    if deployment.is_available():
        return (
            ServiceState.ACTIVE,
            f"Seldon Core deployment '{name}' is available",
        )

    if deployment.is_failed():
        return (
            ServiceState.ERROR,
            f"Seldon Core deployment '{name}' failed: "
            f"{deployment.get_error()}",
        )

    pending_message = deployment.get_pending_message() or ""
    return (
        ServiceState.PENDING_STARTUP,
        "Seldon Core deployment is being created: " + pending_message,
    )
create_from_deployment(deployment) classmethod

Recreate a Seldon Core service from a Seldon Core deployment resource.

It should then update their operational status.

Parameters:

Name Type Description Default
deployment SeldonDeployment

the Seldon Core deployment resource.

required

Returns:

Type Description
SeldonDeploymentService

The Seldon Core service corresponding to the given Seldon Core deployment resource.

Exceptions:

Type Description
ValueError

if the given deployment resource does not contain the expected service_uuid label.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
@classmethod
def create_from_deployment(
    cls, deployment: SeldonDeployment
) -> "SeldonDeploymentService":
    """Recreate a Seldon Core service from a Seldon Core deployment resource.

    It should then update their operational status.

    Args:
        deployment: the Seldon Core deployment resource.

    Returns:
        The Seldon Core service corresponding to the given
        Seldon Core deployment resource.

    Raises:
        ValueError: if the given deployment resource does not contain
            the expected service_uuid label.
    """
    config = SeldonDeploymentConfig.create_from_deployment(deployment)
    uuid = deployment.metadata.labels.get("zenml.service_uuid")
    if not uuid:
        raise ValueError(
            f"The given deployment resource does not contain a valid "
            f"'zenml.service_uuid' label: {deployment}"
        )
    service = cls(uuid=UUID(uuid), config=config)
    service.update_status()
    return service
deprovision(self, force=False)

Deprovision the remote Seldon Core deployment instance.

Parameters:

Name Type Description Default
force bool

if True, the remote deployment instance will be forcefully deprovisioned.

False
Source code in zenml/integrations/seldon/services/seldon_deployment.py
def deprovision(self, force: bool = False) -> None:
    """Deprovision the remote Seldon Core deployment instance.

    Args:
        force: if True, the remote deployment instance will be
            forcefully deprovisioned.
    """
    client = self._get_client()
    name = self.seldon_deployment_name
    try:
        client.delete_deployment(name=name, force=force)
    except SeldonDeploymentNotFoundError:
        pass
get_logs(self, follow=False, tail=None)

Get the logs of a Seldon Core model deployment.

Parameters:

Name Type Description Default
follow bool

if True, the logs will be streamed as they are written

False
tail Optional[int]

only retrieve the last NUM lines of log output.

None

Returns:

Type Description
Generator[str, bool, NoneType]

A generator that can be accessed to get the service logs.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def get_logs(
    self,
    follow: bool = False,
    tail: Optional[int] = None,
) -> Generator[str, bool, None]:
    """Get the logs of a Seldon Core model deployment.

    Args:
        follow: if True, the logs will be streamed as they are written
        tail: only retrieve the last NUM lines of log output.

    Returns:
        A generator that can be accessed to get the service logs.
    """
    return self._get_client().get_deployment_logs(
        self.seldon_deployment_name,
        follow=follow,
        tail=tail,
    )
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request NDArray[Any]

a numpy array representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not yet ready.

ValueError

if the prediction_url is not set.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a numpy array representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not yet ready.
        ValueError: if the prediction_url is not set.
    """
    if not self.is_running:
        raise Exception(
            "Seldon prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.prediction_url is None:
        raise ValueError("`self.prediction_url` is not set, cannot post.")
    response = requests.post(
        self.prediction_url,
        json={"data": {"ndarray": request.tolist()}},
    )
    response.raise_for_status()
    return np.array(response.json()["data"]["ndarray"])
provision(self)

Provision or update remote Seldon Core deployment instance.

This should then match the current configuration.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
def provision(self) -> None:
    """Provision or update remote Seldon Core deployment instance.

    This should then match the current configuration.
    """
    client = self._get_client()

    name = self.seldon_deployment_name

    deployment = SeldonDeployment.build(
        name=name,
        model_uri=self.config.model_uri,
        model_name=self.config.model_name,
        implementation=self.config.implementation,
        secret_name=self.config.secret_name,
        labels=self._get_seldon_deployment_labels(),
        annotations=self.config.get_seldon_deployment_annotations(),
    )
    deployment.spec.replicas = self.config.replicas
    deployment.spec.predictors[0].replicas = self.config.replicas

    # check if the Seldon deployment already exists
    try:
        client.get_deployment(name=name)
        # update the existing deployment
        client.update_deployment(deployment)
    except SeldonDeploymentNotFoundError:
        # create the deployment
        client.create_deployment(deployment=deployment)
SeldonDeploymentServiceStatus (ServiceStatus) pydantic-model

Seldon Core deployment service status.

Source code in zenml/integrations/seldon/services/seldon_deployment.py
class SeldonDeploymentServiceStatus(ServiceStatus):
    """Seldon Core deployment service status."""

steps special

Initialization for Seldon steps.

seldon_deployer

Implementation of the Seldon Deployer step.

SeldonDeployerStepConfig (BaseStepConfig) pydantic-model

Seldon model deployer step configuration.

Attributes:

Name Type Description
service_config SeldonDeploymentConfig

Seldon Core deployment service configuration.

secrets

a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables.

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
    """Seldon model deployer step configuration.

    Attributes:
        service_config: Seldon Core deployment service configuration.
        secrets: a list of ZenML secrets containing additional configuration
            parameters for the Seldon Core deployment (e.g. credentials to
            access the Artifact Store where the models are stored). If supplied,
            the information fetched from these secrets is passed to the Seldon
            Core deployment server as a list of environment variables.
    """

    service_config: SeldonDeploymentConfig
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
seldon_model_deployer_step (BaseStep)

Seldon Core model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
config

configuration for the deployer step

required
model

the model artifact to deploy

required
context

the step context

required

Returns:

Type Description

Seldon Core deployment service

CONFIG_CLASS (BaseStepConfig) pydantic-model

Seldon model deployer step configuration.

Attributes:

Name Type Description
service_config SeldonDeploymentConfig

Seldon Core deployment service configuration.

secrets

a list of ZenML secrets containing additional configuration parameters for the Seldon Core deployment (e.g. credentials to access the Artifact Store where the models are stored). If supplied, the information fetched from these secrets is passed to the Seldon Core deployment server as a list of environment variables.

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
class SeldonDeployerStepConfig(BaseStepConfig):
    """Seldon model deployer step configuration.

    Attributes:
        service_config: Seldon Core deployment service configuration.
        secrets: a list of ZenML secrets containing additional configuration
            parameters for the Seldon Core deployment (e.g. credentials to
            access the Artifact Store where the models are stored). If supplied,
            the information fetched from these secrets is passed to the Seldon
            Core deployment server as a list of environment variables.
    """

    service_config: SeldonDeploymentConfig
    timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT
entrypoint(deploy_decision, config, context, model) staticmethod

Seldon Core model deployer pipeline step.

This step can be used in a pipeline to implement continuous deployment for a ML model with Seldon Core.

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
config SeldonDeployerStepConfig

configuration for the deployer step

required
model ModelArtifact

the model artifact to deploy

required
context StepContext

the step context

required

Returns:

Type Description
SeldonDeploymentService

Seldon Core deployment service

Source code in zenml/integrations/seldon/steps/seldon_deployer.py
@step(enable_cache=False)
def seldon_model_deployer_step(
    deploy_decision: bool,
    config: SeldonDeployerStepConfig,
    context: StepContext,
    model: ModelArtifact,
) -> SeldonDeploymentService:
    """Seldon Core model deployer pipeline step.

    This step can be used in a pipeline to implement continuous
    deployment for a ML model with Seldon Core.

    Args:
        deploy_decision: whether to deploy the model or not
        config: configuration for the deployer step
        model: the model artifact to deploy
        context: the step context

    Returns:
        Seldon Core deployment service
    """
    model_deployer = SeldonModelDeployer.get_active_model_deployer()

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    pipeline_run_id = step_env.pipeline_run_id
    step_name = step_env.step_name

    # update the step configuration with the real pipeline runtime information
    config.service_config.pipeline_name = pipeline_name
    config.service_config.pipeline_run_id = pipeline_run_id
    config.service_config.pipeline_step_name = step_name

    def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig:
        """Prepare the model files for model serving.

        This creates and returns a Seldon service configuration for the model.

        This function ensures that the model files are in the correct format
        and file structure required by the Seldon Core server implementation
        used for model serving.

        Args:
            model_uri: the URI of the model artifact being served

        Returns:
            The URL to the model ready for serving.

        Raises:
            RuntimeError: if the model files were not found
        """
        served_model_uri = os.path.join(
            context.get_output_artifact_uri(), "seldon"
        )
        fileio.makedirs(served_model_uri)

        # TODO [ENG-773]: determine how to formalize how models are organized into
        #   folders and sub-folders depending on the model type/format and the
        #   Seldon Core protocol used to serve the model.

        # TODO [ENG-791]: auto-detect built-in Seldon server implementation
        #   from the model artifact type

        # TODO [ENG-792]: validate the model artifact type against the
        #   supported built-in Seldon server implementations
        if config.service_config.implementation == "TENSORFLOW_SERVER":
            # the TensorFlow server expects model artifacts to be
            # stored in numbered subdirectories, each representing a model
            # version
            io_utils.copy_dir(model_uri, os.path.join(served_model_uri, "1"))
        elif config.service_config.implementation == "SKLEARN_SERVER":
            # the sklearn server expects model artifacts to be
            # stored in a file called model.joblib
            model_uri = os.path.join(model.uri, "model")
            if not fileio.exists(model.uri):
                raise RuntimeError(
                    f"Expected sklearn model artifact was not found at "
                    f"{model_uri}"
                )
            fileio.copy(
                model_uri, os.path.join(served_model_uri, "model.joblib")
            )
        else:
            # default treatment for all other server implementations is to
            # simply reuse the model from the artifact store path where it
            # is originally stored
            served_model_uri = model_uri

        service_config = config.service_config.copy()
        service_config.model_uri = served_model_uri
        return service_config

    # fetch existing services with same pipeline name, step name and
    # model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=config.service_config.model_name,
    )

    # even when the deploy decision is negative, if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{config.service_config.model_name}'..."
        )
        service = cast(SeldonDeploymentService, existing_services[0])
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure that
        # a model server is available at all times
        if not service.is_running:
            service.start(timeout=config.timeout)
        return service

    # invoke the Seldon Core model deployer to create a new service
    # or update an existing one that was previously deployed for the same
    # model
    service_config = prepare_service_config(model.uri)
    service = cast(
        SeldonDeploymentService,
        model_deployer.deploy_model(
            service_config, replace=True, timeout=config.timeout
        ),
    )

    logger.info(
        f"Seldon deployment service started and reachable at:\n"
        f"    {service.prediction_url}\n"
    )

    return service

sklearn special

Initialization of the sklearn integration.

SklearnIntegration (Integration)

Definition of sklearn integration for ZenML.

Source code in zenml/integrations/sklearn/__init__.py
class SklearnIntegration(Integration):
    """Definition of sklearn integration for ZenML."""

    NAME = SKLEARN
    REQUIREMENTS = ["scikit-learn"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.sklearn import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/sklearn/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.sklearn import materializers  # noqa

helpers special

Initialization for helper functions for the sklearn digits dataset.

digits

Helper functions for the sklearn digits dataset.

get_digits()

Returns the digits dataset in the form of a tuple of numpy arrays.

Returns:

Type Description
Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int64], NDArray[np.int64]]

Tuple of (training_images, testing_images, training_labels, testing_labels)

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits() -> Tuple[
    "NDArray[np.float64]",
    "NDArray[np.float64]",
    "NDArray[np.int64]",
    "NDArray[np.int64]",
]:
    """Returns the digits dataset in the form of a tuple of numpy arrays.

    Returns:
        Tuple of (training_images, testing_images, training_labels, testing_labels)
    """
    digits = load_digits()
    # flatten the images
    n_samples = len(digits.images)
    data = digits.images.reshape((n_samples, -1))

    # Split data into 50% train and 50% test subsets
    X_train, X_test, y_train, y_test = train_test_split(
        data, digits.target, test_size=0.5, shuffle=False
    )
    return X_train, X_test, y_train, y_test
get_digits_model()

Creates a support vector classifier for digits dataset.

Returns:

Type Description
ClassifierMixin

A support vector classifier.

Source code in zenml/integrations/sklearn/helpers/digits.py
def get_digits_model() -> ClassifierMixin:
    """Creates a support vector classifier for digits dataset.

    Returns:
        A support vector classifier.
    """
    return SVC(gamma=0.001)

materializers special

Initialization of the sklearn materializer.

sklearn_materializer

Implementation of the sklearn materializer.

SklearnMaterializer (BaseMaterializer)

Materializer to read data to and from sklearn.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
class SklearnMaterializer(BaseMaterializer):
    """Materializer to read data to and from sklearn."""

    ASSOCIATED_TYPES = (
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    )
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(
        self, data_type: Type[Any]
    ) -> Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ]:
        """Reads a base sklearn model from a pickle file.

        Args:
            data_type: The type of the model.

        Returns:
            The model.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "rb") as fid:
            clf = pickle.load(fid)
        return clf

    def handle_return(
        self,
        clf: Union[
            BaseEstimator,
            ClassifierMixin,
            ClusterMixin,
            BiclusterMixin,
            OutlierMixin,
            RegressorMixin,
            MetaEstimatorMixin,
            MultiOutputMixin,
            DensityMixin,
            TransformerMixin,
        ],
    ) -> None:
        """Creates a pickle for a sklearn model.

        Args:
            clf: A sklearn model.
        """
        super().handle_return(clf)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        with fileio.open(filepath, "wb") as fid:
            pickle.dump(clf, fid)
handle_input(self, data_type)

Reads a base sklearn model from a pickle file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model.

required

Returns:

Type Description
Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin]

The model.

Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_input(
    self, data_type: Type[Any]
) -> Union[
    BaseEstimator,
    ClassifierMixin,
    ClusterMixin,
    BiclusterMixin,
    OutlierMixin,
    RegressorMixin,
    MetaEstimatorMixin,
    MultiOutputMixin,
    DensityMixin,
    TransformerMixin,
]:
    """Reads a base sklearn model from a pickle file.

    Args:
        data_type: The type of the model.

    Returns:
        The model.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "rb") as fid:
        clf = pickle.load(fid)
    return clf
handle_return(self, clf)

Creates a pickle for a sklearn model.

Parameters:

Name Type Description Default
clf Union[sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin, sklearn.base.ClusterMixin, sklearn.base.BiclusterMixin, sklearn.base.OutlierMixin, sklearn.base.RegressorMixin, sklearn.base.MetaEstimatorMixin, sklearn.base.MultiOutputMixin, sklearn.base.DensityMixin, sklearn.base.TransformerMixin]

A sklearn model.

required
Source code in zenml/integrations/sklearn/materializers/sklearn_materializer.py
def handle_return(
    self,
    clf: Union[
        BaseEstimator,
        ClassifierMixin,
        ClusterMixin,
        BiclusterMixin,
        OutlierMixin,
        RegressorMixin,
        MetaEstimatorMixin,
        MultiOutputMixin,
        DensityMixin,
        TransformerMixin,
    ],
) -> None:
    """Creates a pickle for a sklearn model.

    Args:
        clf: A sklearn model.
    """
    super().handle_return(clf)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    with fileio.open(filepath, "wb") as fid:
        pickle.dump(clf, fid)

steps special

Initialization of the sklearn standard steps.

sklearn_evaluator

Implementation of the sklearn evaluator step.

SklearnEvaluator (BaseEvaluatorStep)

Simple sklearn evaluator step implementation.

This uses sklearn to evaluate the performance of a given model on a given test dataset.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluator(BaseEvaluatorStep):
    """Simple sklearn evaluator step implementation.

    This uses sklearn to evaluate the performance of a given model on a given
    test dataset.
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        model: BaseEstimator,
        config: SklearnEvaluatorConfig,
    ) -> dict:  # type: ignore[type-arg]
        """Method which is responsible for the computation of the evaluation.

        Args:
            dataset: a pandas DataFrame which represents the test dataset
            model: a trained sklearn model
            config: the configuration for the step

        Returns:
            a dictionary which has the evaluation report
        """
        labels = dataset.pop(config.label_class_column)

        predictions = model.predict(dataset)
        predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

        report = classification_report(
            labels, predicted_classes, output_dict=True
        )

        return report  # type: ignore[no-any-return]
CONFIG_CLASS (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator."""

    label_class_column: str
entrypoint(self, dataset, model, config)

Method which is responsible for the computation of the evaluation.

Parameters:

Name Type Description Default
dataset DataFrame

a pandas DataFrame which represents the test dataset

required
model BaseEstimator

a trained sklearn model

required
config SklearnEvaluatorConfig

the configuration for the step

required

Returns:

Type Description
dict

a dictionary which has the evaluation report

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    model: BaseEstimator,
    config: SklearnEvaluatorConfig,
) -> dict:  # type: ignore[type-arg]
    """Method which is responsible for the computation of the evaluation.

    Args:
        dataset: a pandas DataFrame which represents the test dataset
        model: a trained sklearn model
        config: the configuration for the step

    Returns:
        a dictionary which has the evaluation report
    """
    labels = dataset.pop(config.label_class_column)

    predictions = model.predict(dataset)
    predicted_classes = [1 if v > 0.5 else 0 for v in predictions]

    report = classification_report(
        labels, predicted_classes, output_dict=True
    )

    return report  # type: ignore[no-any-return]
SklearnEvaluatorConfig (BaseEvaluatorConfig) pydantic-model

Config class for the sklearn evaluator.

Source code in zenml/integrations/sklearn/steps/sklearn_evaluator.py
class SklearnEvaluatorConfig(BaseEvaluatorConfig):
    """Config class for the sklearn evaluator."""

    label_class_column: str
sklearn_splitter

Implementation of the sklearn splitter.

SklearnSplitter (BaseSplitStep)

A simple sklearn splitter step implementation.

This uses sklearn to split a given dataset into train, test and validation splits.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitter(BaseSplitStep):
    """A simple sklearn splitter step implementation.

    This uses sklearn to split a given dataset into train, test and validation
    splits.
    """

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: SklearnSplitterConfig,
    ) -> Output(  # type:ignore[valid-type]
        train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
    ):
        """Method which is responsible for the splitting logic.

        Args:
            dataset: a pandas DataFrame which entire dataset
            config: the configuration for the step

        Returns:
            three DataFrames representing the splits

        Raises:
            KeyError: if the wrong configuration is used
            ValueError: if the ratios are not valid
        """
        if (
            any(
                [
                    split not in config.ratios
                    for split in ["train", "test", "validation"]
                ]
            )
            or len(config.ratios) != 3
        ):
            raise KeyError(
                f"Make sure that you only use 'train', 'test' and "
                f"'validation' as keys in the ratios dict. Current keys: "
                f"{config.ratios.keys()}"
            )

        if sum(config.ratios.values()) != 1:
            raise ValueError(
                f"Make sure that the ratios sum up to 1. Current "
                f"ratios: {config.ratios}"
            )

        train_dataset, test_dataset = train_test_split(
            dataset, test_size=config.ratios["test"]
        )

        train_dataset, val_dataset = train_test_split(
            train_dataset,
            test_size=(
                config.ratios["validation"]
                / (config.ratios["validation"] + config.ratios["train"])
            ),
        )

        return train_dataset, test_dataset, val_dataset
CONFIG_CLASS (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter."""

    ratios: Dict[str, float]
entrypoint(self, dataset, config)

Method which is responsible for the splitting logic.

Parameters:

Name Type Description Default
dataset DataFrame

a pandas DataFrame which entire dataset

required
config SklearnSplitterConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7fcc9f0ec400>

three DataFrames representing the splits

Exceptions:

Type Description
KeyError

if the wrong configuration is used

ValueError

if the ratios are not valid

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: SklearnSplitterConfig,
) -> Output(  # type:ignore[valid-type]
    train=pd.DataFrame, test=pd.DataFrame, validation=pd.DataFrame
):
    """Method which is responsible for the splitting logic.

    Args:
        dataset: a pandas DataFrame which entire dataset
        config: the configuration for the step

    Returns:
        three DataFrames representing the splits

    Raises:
        KeyError: if the wrong configuration is used
        ValueError: if the ratios are not valid
    """
    if (
        any(
            [
                split not in config.ratios
                for split in ["train", "test", "validation"]
            ]
        )
        or len(config.ratios) != 3
    ):
        raise KeyError(
            f"Make sure that you only use 'train', 'test' and "
            f"'validation' as keys in the ratios dict. Current keys: "
            f"{config.ratios.keys()}"
        )

    if sum(config.ratios.values()) != 1:
        raise ValueError(
            f"Make sure that the ratios sum up to 1. Current "
            f"ratios: {config.ratios}"
        )

    train_dataset, test_dataset = train_test_split(
        dataset, test_size=config.ratios["test"]
    )

    train_dataset, val_dataset = train_test_split(
        train_dataset,
        test_size=(
            config.ratios["validation"]
            / (config.ratios["validation"] + config.ratios["train"])
        ),
    )

    return train_dataset, test_dataset, val_dataset
SklearnSplitterConfig (BaseSplitStepConfig) pydantic-model

Config class for the sklearn splitter.

Source code in zenml/integrations/sklearn/steps/sklearn_splitter.py
class SklearnSplitterConfig(BaseSplitStepConfig):
    """Config class for the sklearn splitter."""

    ratios: Dict[str, float]
sklearn_standard_scaler

Implementation of the sklearn standard scaler step.

SklearnStandardScaler (BasePreprocessorStep)

Simple StandardScaler step implementation.

This uses the StandardScaler from sklearn to transform the numeric columns of a pd.DataFrame.

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScaler(BasePreprocessorStep):
    """Simple StandardScaler step implementation.

    This uses the StandardScaler from sklearn to transform the numeric columns
    of a pd.DataFrame.
    """

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        test_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        statistics: pd.DataFrame,
        schema: pd.DataFrame,
        config: SklearnStandardScalerConfig,
    ) -> Output(  # type:ignore[valid-type]
        train_transformed=pd.DataFrame,
        test_transformed=pd.DataFrame,
        validation_transformed=pd.DataFrame,
    ):
        """Main entrypoint function for the StandardScaler.

        Args:
            train_dataset: pd.DataFrame, the training dataset
            test_dataset: pd.DataFrame, the test dataset
            validation_dataset: pd.DataFrame, the validation dataset
            statistics: pd.DataFrame, the statistics over the train dataset
            schema: pd.DataFrame, the detected schema of the dataset
            config: the configuration for the step

        Returns:
            the transformed train, test and validation datasets as pd.DataFrames
        """
        schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

        # Exclude columns
        feature_set = set(train_dataset.columns) - set(config.exclude_columns)
        for feature, feature_type in schema_dict.items():
            if feature_type != "int64" and feature_type != "float64":
                feature_set.remove(feature)
                logger.warning(
                    f"{feature} column is a not numeric, thus it is excluded "
                    f"from the standard scaling."
                )

        transform_feature_set = feature_set - set(config.ignore_columns)

        # Transform the datasets
        scaler = StandardScaler()
        scaler.mean_ = statistics["mean"][transform_feature_set]
        scaler.scale_ = statistics["std"][transform_feature_set]

        train_dataset[list(transform_feature_set)] = scaler.transform(
            train_dataset[transform_feature_set]
        )
        test_dataset[list(transform_feature_set)] = scaler.transform(
            test_dataset[transform_feature_set]
        )
        validation_dataset[list(transform_feature_set)] = scaler.transform(
            validation_dataset[transform_feature_set]
        )

        return train_dataset, test_dataset, validation_dataset
CONFIG_CLASS (BasePreprocessorConfig) pydantic-model

Config class for the sklearn standard scaler.

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
    """Config class for the sklearn standard scaler.

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config)

Main entrypoint function for the StandardScaler.

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
test_dataset DataFrame

pd.DataFrame, the test dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
statistics DataFrame

pd.DataFrame, the statistics over the train dataset

required
schema DataFrame

pd.DataFrame, the detected schema of the dataset

required
config SklearnStandardScalerConfig

the configuration for the step

required

Returns:

Type Description
<zenml.steps.step_output.Output object at 0x7fcc9f055700>

the transformed train, test and validation datasets as pd.DataFrames

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    test_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    statistics: pd.DataFrame,
    schema: pd.DataFrame,
    config: SklearnStandardScalerConfig,
) -> Output(  # type:ignore[valid-type]
    train_transformed=pd.DataFrame,
    test_transformed=pd.DataFrame,
    validation_transformed=pd.DataFrame,
):
    """Main entrypoint function for the StandardScaler.

    Args:
        train_dataset: pd.DataFrame, the training dataset
        test_dataset: pd.DataFrame, the test dataset
        validation_dataset: pd.DataFrame, the validation dataset
        statistics: pd.DataFrame, the statistics over the train dataset
        schema: pd.DataFrame, the detected schema of the dataset
        config: the configuration for the step

    Returns:
        the transformed train, test and validation datasets as pd.DataFrames
    """
    schema_dict = {k: v[0] for k, v in schema.to_dict().items()}

    # Exclude columns
    feature_set = set(train_dataset.columns) - set(config.exclude_columns)
    for feature, feature_type in schema_dict.items():
        if feature_type != "int64" and feature_type != "float64":
            feature_set.remove(feature)
            logger.warning(
                f"{feature} column is a not numeric, thus it is excluded "
                f"from the standard scaling."
            )

    transform_feature_set = feature_set - set(config.ignore_columns)

    # Transform the datasets
    scaler = StandardScaler()
    scaler.mean_ = statistics["mean"][transform_feature_set]
    scaler.scale_ = statistics["std"][transform_feature_set]

    train_dataset[list(transform_feature_set)] = scaler.transform(
        train_dataset[transform_feature_set]
    )
    test_dataset[list(transform_feature_set)] = scaler.transform(
        test_dataset[transform_feature_set]
    )
    validation_dataset[list(transform_feature_set)] = scaler.transform(
        validation_dataset[transform_feature_set]
    )

    return train_dataset, test_dataset, validation_dataset
SklearnStandardScalerConfig (BasePreprocessorConfig) pydantic-model

Config class for the sklearn standard scaler.

ignore_columns: a list of column names which should not be scaled exclude_columns: a list of column names to be excluded from the dataset

Source code in zenml/integrations/sklearn/steps/sklearn_standard_scaler.py
class SklearnStandardScalerConfig(BasePreprocessorConfig):
    """Config class for the sklearn standard scaler.

    ignore_columns: a list of column names which should not be scaled
    exclude_columns: a list of column names to be excluded from the dataset
    """

    ignore_columns: List[str] = []
    exclude_columns: List[str] = []

slack special

Slack integration for alerter components.

SlackIntegration (Integration)

Definition of a Slack integration for ZenML.

Implemented using Slack SDK.

Source code in zenml/integrations/slack/__init__.py
class SlackIntegration(Integration):
    """Definition of a Slack integration for ZenML.

    Implemented using [Slack SDK](https://pypi.org/project/slack-sdk/).
    """

    NAME = SLACK
    REQUIREMENTS = ["slack-sdk>=3.16.1", "aiohttp>=3.8.1"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Slack integration.

        Returns:
            List of new flavors defined by the Slack integration.
        """
        return [
            FlavorWrapper(
                name=SLACK_ALERTER_FLAVOR,
                source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
                type=StackComponentType.ALERTER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Slack integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of new flavors defined by the Slack integration.

Source code in zenml/integrations/slack/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Slack integration.

    Returns:
        List of new flavors defined by the Slack integration.
    """
    return [
        FlavorWrapper(
            name=SLACK_ALERTER_FLAVOR,
            source="zenml.integrations.slack.alerters.slack_alerter.SlackAlerter",
            type=StackComponentType.ALERTER,
            integration=cls.NAME,
        )
    ]

alerters special

Alerter components defined by the Slack integration.

slack_alerter

Implementation for slack flavor of alerter component.

SlackAlerter (BaseAlerter) pydantic-model

Send messages to Slack channels.

Attributes:

Name Type Description
slack_token str

The Slack token tied to the Slack account to be used.

Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerter(BaseAlerter):
    """Send messages to Slack channels.

    Attributes:
        slack_token: The Slack token tied to the Slack account to be used.
    """

    slack_token: str
    default_slack_channel_id: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = SLACK_ALERTER_FLAVOR

    def _get_channel_id(self, config: Optional[BaseAlerterStepConfig]) -> str:
        """Get the Slack channel ID to be used by post/ask.

        Args:
            config: Optional runtime configuration.

        Returns:
            ID of the Slack channel to be used.

        Raises:
            RuntimeError: if config is not of type `BaseAlerterStepConfig`.
            ValueError: if a slack channel was neither defined in the config
                nor in the slack alerter component.
        """
        if not isinstance(config, BaseAlerterStepConfig):
            raise RuntimeError(
                "The config object must be of type `BaseAlerterStepConfig`."
            )
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "slack_channel_id")
            and config.slack_channel_id is not None
        ):
            return config.slack_channel_id
        if self.default_slack_channel_id is not None:
            return self.default_slack_channel_id
        raise ValueError(
            "Neither the `SlackAlerterConfig.slack_channel_id` in the runtime "
            "configuration, nor the `default_slack_channel_id` in the alerter "
            "stack component is specified. Please specify at least one."
        )

    def _get_approve_msg_options(
        self, config: Optional[BaseAlerterStepConfig]
    ) -> List[str]:
        """Define which messages will lead to approval during ask().

        Args:
            config: Optional runtime configuration.

        Returns:
            Set of messages that lead to approval in alerter.ask().
        """
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "approve_msg_options")
            and config.approve_msg_options is not None
        ):
            return config.approve_msg_options
        return DEFAULT_APPROVE_MSG_OPTIONS

    def _get_disapprove_msg_options(
        self, config: Optional[BaseAlerterStepConfig]
    ) -> List[str]:
        """Define which messages will lead to disapproval during ask().

        Args:
            config: Optional runtime configuration.

        Returns:
            Set of messages that lead to disapproval in alerter.ask().
        """
        if (
            isinstance(config, SlackAlerterConfig)
            and hasattr(config, "disapprove_msg_options")
            and config.disapprove_msg_options is not None
        ):
            return config.disapprove_msg_options
        return DEFAULT_DISAPPROVE_MSG_OPTIONS

    def post(
        self, message: str, config: Optional[BaseAlerterStepConfig]
    ) -> bool:
        """Post a message to a Slack channel.

        Args:
            message: Message to be posted.
            config: Optional runtime configuration.

        Returns:
            True if operation succeeded, else False
        """
        slack_channel_id = self._get_channel_id(config=config)
        client = WebClient(token=self.slack_token)
        try:
            response = client.chat_postMessage(
                channel=slack_channel_id,
                text=message,
            )
            return True
        except SlackApiError as error:
            response = error.response["error"]
            logger.error(f"SlackAlerter.post() failed: {response}")
            return False

    def ask(
        self, message: str, config: Optional[BaseAlerterStepConfig]
    ) -> bool:
        """Post a message to a Slack channel and wait for approval.

        Args:
            message: Initial message to be posted.
            config: Optional runtime configuration.

        Returns:
            True if a user approved the operation, else False
        """
        rtm = RTMClient(token=self.slack_token)
        slack_channel_id = self._get_channel_id(config=config)

        approved = False  # will be modified by handle()

        @RTMClient.run_on(event="hello")  # type: ignore
        def post_initial_message(**payload: Any) -> None:
            """Post an initial message in a channel and start listening.

            Args:
                payload: payload of the received Slack event.
            """
            web_client = payload["web_client"]
            web_client.chat_postMessage(channel=slack_channel_id, text=message)

        @RTMClient.run_on(event="message")  # type: ignore
        def handle(**payload: Any) -> None:
            """Listen / handle messages posted in the channel.

            Args:
                payload: payload of the received Slack event.
            """
            event = payload["data"]
            if event["channel"] == slack_channel_id:

                # approve request (return True)
                if event["text"] in self._get_approve_msg_options(config):
                    print(f"User {event['user']} approved on slack.")
                    nonlocal approved
                    approved = True
                    rtm.stop()  # type: ignore

                # disapprove request (return False)
                elif event["text"] in self._get_disapprove_msg_options(config):
                    print(f"User {event['user']} disapproved on slack.")
                    rtm.stop()  # type:ignore

        # start another thread until `rtm.stop()` is called in handle()
        rtm.start()

        return approved
ask(self, message, config)

Post a message to a Slack channel and wait for approval.

Parameters:

Name Type Description Default
message str

Initial message to be posted.

required
config Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig]

Optional runtime configuration.

required

Returns:

Type Description
bool

True if a user approved the operation, else False

Source code in zenml/integrations/slack/alerters/slack_alerter.py
def ask(
    self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
    """Post a message to a Slack channel and wait for approval.

    Args:
        message: Initial message to be posted.
        config: Optional runtime configuration.

    Returns:
        True if a user approved the operation, else False
    """
    rtm = RTMClient(token=self.slack_token)
    slack_channel_id = self._get_channel_id(config=config)

    approved = False  # will be modified by handle()

    @RTMClient.run_on(event="hello")  # type: ignore
    def post_initial_message(**payload: Any) -> None:
        """Post an initial message in a channel and start listening.

        Args:
            payload: payload of the received Slack event.
        """
        web_client = payload["web_client"]
        web_client.chat_postMessage(channel=slack_channel_id, text=message)

    @RTMClient.run_on(event="message")  # type: ignore
    def handle(**payload: Any) -> None:
        """Listen / handle messages posted in the channel.

        Args:
            payload: payload of the received Slack event.
        """
        event = payload["data"]
        if event["channel"] == slack_channel_id:

            # approve request (return True)
            if event["text"] in self._get_approve_msg_options(config):
                print(f"User {event['user']} approved on slack.")
                nonlocal approved
                approved = True
                rtm.stop()  # type: ignore

            # disapprove request (return False)
            elif event["text"] in self._get_disapprove_msg_options(config):
                print(f"User {event['user']} disapproved on slack.")
                rtm.stop()  # type:ignore

    # start another thread until `rtm.stop()` is called in handle()
    rtm.start()

    return approved
post(self, message, config)

Post a message to a Slack channel.

Parameters:

Name Type Description Default
message str

Message to be posted.

required
config Optional[zenml.steps.step_interfaces.base_alerter_step.BaseAlerterStepConfig]

Optional runtime configuration.

required

Returns:

Type Description
bool

True if operation succeeded, else False

Source code in zenml/integrations/slack/alerters/slack_alerter.py
def post(
    self, message: str, config: Optional[BaseAlerterStepConfig]
) -> bool:
    """Post a message to a Slack channel.

    Args:
        message: Message to be posted.
        config: Optional runtime configuration.

    Returns:
        True if operation succeeded, else False
    """
    slack_channel_id = self._get_channel_id(config=config)
    client = WebClient(token=self.slack_token)
    try:
        response = client.chat_postMessage(
            channel=slack_channel_id,
            text=message,
        )
        return True
    except SlackApiError as error:
        response = error.response["error"]
        logger.error(f"SlackAlerter.post() failed: {response}")
        return False
SlackAlerterConfig (BaseAlerterStepConfig) pydantic-model

Slack alerter config.

Source code in zenml/integrations/slack/alerters/slack_alerter.py
class SlackAlerterConfig(BaseAlerterStepConfig):
    """Slack alerter config."""

    # The ID of the Slack channel to use for communication.
    slack_channel_id: Optional[str] = None

    # Set of messages that lead to approval in alerter.ask()
    approve_msg_options: Optional[List[str]] = None

    # Set of messages that lead to disapproval in alerter.ask()
    disapprove_msg_options: Optional[List[str]] = None

tensorflow special

Initialization for TensorFlow integration.

TensorflowIntegration (Integration)

Definition of Tensorflow integration for ZenML.

Source code in zenml/integrations/tensorflow/__init__.py
class TensorflowIntegration(Integration):
    """Definition of Tensorflow integration for ZenML."""

    NAME = TENSORFLOW
    REQUIREMENTS = ["tensorflow==2.8.0", "tensorflow_io==0.24.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        # need to import this explicitly to load the Tensorflow file IO support
        # for S3 and other file systems
        import tensorflow_io  # type: ignore [import]

        from zenml.integrations.tensorflow import materializers  # noqa
        from zenml.integrations.tensorflow import services  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/tensorflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    # need to import this explicitly to load the Tensorflow file IO support
    # for S3 and other file systems
    import tensorflow_io  # type: ignore [import]

    from zenml.integrations.tensorflow import materializers  # noqa
    from zenml.integrations.tensorflow import services  # noqa

materializers special

Initialization for the TensorFlow materializers.

keras_materializer

Implementation of the TensorFlow Keras materializer.

KerasMaterializer (BaseMaterializer)

Materializer to read/write Keras models.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
class KerasMaterializer(BaseMaterializer):
    """Materializer to read/write Keras models."""

    ASSOCIATED_TYPES = (keras.Model,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> keras.Model:
        """Reads and returns a Keras model after copying it to temporary path.

        Args:
            data_type: The type of the data to read.

        Returns:
            A tf.keras.Model model.
        """
        super().handle_input(data_type)

        # Create a temporary directory to store the model
        temp_dir = tempfile.TemporaryDirectory()

        # Copy from artifact store to temporary directory
        io_utils.copy_dir(self.artifact.uri, temp_dir.name)

        # Load the model from the temporary directory
        model = keras.models.load_model(temp_dir.name)

        # Cleanup and return
        fileio.rmtree(temp_dir.name)

        return model

    def handle_return(self, model: keras.Model) -> None:
        """Writes a keras model to the artifact store.

        Args:
            model: A tf.keras.Model model.
        """
        super().handle_return(model)

        # Create a temporary directory to store the model
        temp_dir = tempfile.TemporaryDirectory()
        model.save(temp_dir.name)
        io_utils.copy_dir(temp_dir.name, self.artifact.uri)

        # Remove the temporary directory
        fileio.rmtree(temp_dir.name)
handle_input(self, data_type)

Reads and returns a Keras model after copying it to temporary path.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Model

A tf.keras.Model model.

Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_input(self, data_type: Type[Any]) -> keras.Model:
    """Reads and returns a Keras model after copying it to temporary path.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.keras.Model model.
    """
    super().handle_input(data_type)

    # Create a temporary directory to store the model
    temp_dir = tempfile.TemporaryDirectory()

    # Copy from artifact store to temporary directory
    io_utils.copy_dir(self.artifact.uri, temp_dir.name)

    # Load the model from the temporary directory
    model = keras.models.load_model(temp_dir.name)

    # Cleanup and return
    fileio.rmtree(temp_dir.name)

    return model
handle_return(self, model)

Writes a keras model to the artifact store.

Parameters:

Name Type Description Default
model Model

A tf.keras.Model model.

required
Source code in zenml/integrations/tensorflow/materializers/keras_materializer.py
def handle_return(self, model: keras.Model) -> None:
    """Writes a keras model to the artifact store.

    Args:
        model: A tf.keras.Model model.
    """
    super().handle_return(model)

    # Create a temporary directory to store the model
    temp_dir = tempfile.TemporaryDirectory()
    model.save(temp_dir.name)
    io_utils.copy_dir(temp_dir.name, self.artifact.uri)

    # Remove the temporary directory
    fileio.rmtree(temp_dir.name)
tf_dataset_materializer

Implementation of the TensorFlow dataset materializer.

TensorflowDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from tf.data.Dataset.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
class TensorflowDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from tf.data.Dataset."""

    ASSOCIATED_TYPES = (tf.data.Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Any:
        """Reads data into tf.data.Dataset.

        Args:
            data_type: The type of the data to read.

        Returns:
            A tf.data.Dataset object.
        """
        super().handle_input(data_type)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        return tf.data.experimental.load(path)

    def handle_return(self, dataset: tf.data.Dataset) -> None:
        """Persists a tf.data.Dataset object.

        Args:
            dataset: The dataset to persist.
        """
        super().handle_return(dataset)
        path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
        tf.data.experimental.save(
            dataset, path, compression=None, shard_func=None
        )
handle_input(self, data_type)

Reads data into tf.data.Dataset.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Any

A tf.data.Dataset object.

Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Any:
    """Reads data into tf.data.Dataset.

    Args:
        data_type: The type of the data to read.

    Returns:
        A tf.data.Dataset object.
    """
    super().handle_input(data_type)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    return tf.data.experimental.load(path)
handle_return(self, dataset)

Persists a tf.data.Dataset object.

Parameters:

Name Type Description Default
dataset DatasetV2

The dataset to persist.

required
Source code in zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py
def handle_return(self, dataset: tf.data.Dataset) -> None:
    """Persists a tf.data.Dataset object.

    Args:
        dataset: The dataset to persist.
    """
    super().handle_return(dataset)
    path = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
    tf.data.experimental.save(
        dataset, path, compression=None, shard_func=None
    )

services special

Initialization for TensorFlow services.

tensorboard_service

Implementation of the TensorBoard service.

TensorboardService (LocalDaemonService) pydantic-model

TensorBoard service.

This can be used to start a local TensorBoard server for one or more models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the TensorBoard service class

config TensorboardServiceConfig

service configuration

endpoint LocalDaemonServiceEndpoint

optional service endpoint

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardService(LocalDaemonService):
    """TensorBoard service.

    This can be used to start a local TensorBoard server for one or more models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the TensorBoard service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="tensorboard",
        type="visualization",
        flavor="tensorboard",
        description="TensorBoard visualization service",
    )

    config: TensorboardServiceConfig
    endpoint: LocalDaemonServiceEndpoint

    def __init__(
        self,
        config: Union[TensorboardServiceConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialization for TensorBoard service.

        Args:
            config: service configuration
            **attrs: additional attributes
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-697]: implement a service factory or builder for TensorBoard
        #   deployment services
        if (
            isinstance(config, TensorboardServiceConfig)
            and "endpoint" not in attrs
        ):
            endpoint = LocalDaemonServiceEndpoint(
                config=LocalDaemonServiceEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path="",
                        use_head_request=True,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Initialize and run the TensorBoard server."""
        logger.info(
            "Starting TensorBoard service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()

        try:
            tensorboard = program.TensorBoard(
                plugins=default.get_plugins(),
                subcommands=[uploader_subcommand.UploaderSubcommand()],
            )
            tensorboard.configure(
                logdir=self.config.logdir,
                port=self.endpoint.status.port,
                host="localhost",
                max_reload_threads=self.config.max_reload_threads,
                reload_interval=self.config.reload_interval,
            )
            tensorboard.main()
        except KeyboardInterrupt:
            logger.info(
                "TensorBoard service stopped. Resuming normal execution."
            )
__init__(self, config, **attrs) special

Initialization for TensorBoard service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.tensorflow.services.tensorboard_service.TensorboardServiceConfig, Dict[str, Any]]

service configuration

required
**attrs Any

additional attributes

{}
Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def __init__(
    self,
    config: Union[TensorboardServiceConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialization for TensorBoard service.

    Args:
        config: service configuration
        **attrs: additional attributes
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-697]: implement a service factory or builder for TensorBoard
    #   deployment services
    if (
        isinstance(config, TensorboardServiceConfig)
        and "endpoint" not in attrs
    ):
        endpoint = LocalDaemonServiceEndpoint(
            config=LocalDaemonServiceEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path="",
                    use_head_request=True,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
run(self)

Initialize and run the TensorBoard server.

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
def run(self) -> None:
    """Initialize and run the TensorBoard server."""
    logger.info(
        "Starting TensorBoard service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()

    try:
        tensorboard = program.TensorBoard(
            plugins=default.get_plugins(),
            subcommands=[uploader_subcommand.UploaderSubcommand()],
        )
        tensorboard.configure(
            logdir=self.config.logdir,
            port=self.endpoint.status.port,
            host="localhost",
            max_reload_threads=self.config.max_reload_threads,
            reload_interval=self.config.reload_interval,
        )
        tensorboard.main()
    except KeyboardInterrupt:
        logger.info(
            "TensorBoard service stopped. Resuming normal execution."
        )
TensorboardServiceConfig (LocalDaemonServiceConfig) pydantic-model

TensorBoard service configuration.

Attributes:

Name Type Description
logdir str

location of TensorBoard log files.

max_reload_threads int

the max number of threads that TensorBoard can use to reload runs. Each thread reloads one run at a time.

reload_interval int

how often the backend should load more data, in seconds. Set to 0 to load just once at startup.

Source code in zenml/integrations/tensorflow/services/tensorboard_service.py
class TensorboardServiceConfig(LocalDaemonServiceConfig):
    """TensorBoard service configuration.

    Attributes:
        logdir: location of TensorBoard log files.
        max_reload_threads: the max number of threads that TensorBoard can use
            to reload runs. Each thread reloads one run at a time.
        reload_interval: how often the backend should load more data, in
            seconds. Set to 0 to load just once at startup.
    """

    logdir: str
    max_reload_threads: int = 1
    reload_interval: int = 5

steps special

Initialization for TensorFlow standard steps.

tensorflow_trainer

Implementation of a TensorFlow trainer step.

TensorflowBinaryClassifier (BaseTrainerStep)

A TensorFlow binary classifier.

This simple step implementation creates a simple tensorflow feedforward neural network and trains it on a given pd.DataFrame dataset.

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifier(BaseTrainerStep):
    """A TensorFlow binary classifier.

    This simple step implementation creates a simple tensorflow feedforward
    neural network and trains it on a given pd.DataFrame dataset.
    """

    def entrypoint(  # type: ignore[override]
        self,
        train_dataset: pd.DataFrame,
        validation_dataset: pd.DataFrame,
        config: TensorflowBinaryClassifierConfig,
    ) -> tf.keras.Model:
        """Main entrypoint for the tensorflow trainer.

        Args:
            train_dataset: pd.DataFrame, the training dataset
            validation_dataset: pd.DataFrame, the validation dataset
            config: the configuration of the step

        Returns:
            the trained tf.keras.Model
        """
        model = tf.keras.Sequential()
        model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
        model.add(tf.keras.layers.Flatten())

        last_layer = config.layers.pop()
        for i, layer in enumerate(config.layers):
            model.add(tf.keras.layers.Dense(layer, activation="relu"))
        model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

        model.compile(
            optimizer=tf.keras.optimizers.Adam(config.learning_rate),
            loss=tf.keras.losses.BinaryCrossentropy(),
            metrics=config.metrics,
        )

        train_target = train_dataset.pop(config.target_column)
        validation_target = validation_dataset.pop(config.target_column)
        model.fit(
            x=train_dataset,
            y=train_target,
            validation_data=(validation_dataset, validation_target),
            batch_size=config.batch_size,
            epochs=config.epochs,
        )
        model.summary()

        return model
CONFIG_CLASS (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer.

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer.

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8
entrypoint(self, train_dataset, validation_dataset, config)

Main entrypoint for the tensorflow trainer.

Parameters:

Name Type Description Default
train_dataset DataFrame

pd.DataFrame, the training dataset

required
validation_dataset DataFrame

pd.DataFrame, the validation dataset

required
config TensorflowBinaryClassifierConfig

the configuration of the step

required

Returns:

Type Description
Model

the trained tf.keras.Model

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
def entrypoint(  # type: ignore[override]
    self,
    train_dataset: pd.DataFrame,
    validation_dataset: pd.DataFrame,
    config: TensorflowBinaryClassifierConfig,
) -> tf.keras.Model:
    """Main entrypoint for the tensorflow trainer.

    Args:
        train_dataset: pd.DataFrame, the training dataset
        validation_dataset: pd.DataFrame, the validation dataset
        config: the configuration of the step

    Returns:
        the trained tf.keras.Model
    """
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=config.input_shape))
    model.add(tf.keras.layers.Flatten())

    last_layer = config.layers.pop()
    for i, layer in enumerate(config.layers):
        model.add(tf.keras.layers.Dense(layer, activation="relu"))
    model.add(tf.keras.layers.Dense(last_layer, activation="sigmoid"))

    model.compile(
        optimizer=tf.keras.optimizers.Adam(config.learning_rate),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=config.metrics,
    )

    train_target = train_dataset.pop(config.target_column)
    validation_target = validation_dataset.pop(config.target_column)
    model.fit(
        x=train_dataset,
        y=train_target,
        validation_data=(validation_dataset, validation_target),
        batch_size=config.batch_size,
        epochs=config.epochs,
    )
    model.summary()

    return model
TensorflowBinaryClassifierConfig (BaseTrainerConfig) pydantic-model

Config class for the tensorflow trainer.

target_column: the name of the label column layers: the number of units in the fully connected layers input_shape: the shape of the input learning_rate: the learning rate metrics: the list of metrics to be computed epochs: the number of epochs batch_size: the size of the batch

Source code in zenml/integrations/tensorflow/steps/tensorflow_trainer.py
class TensorflowBinaryClassifierConfig(BaseTrainerConfig):
    """Config class for the tensorflow trainer.

    target_column: the name of the label column
    layers: the number of units in the fully connected layers
    input_shape: the shape of the input
    learning_rate: the learning rate
    metrics: the list of metrics to be computed
    epochs: the number of epochs
    batch_size: the size of the batch
    """

    target_column: str
    layers: List[int] = [256, 64, 1]
    input_shape: Tuple[int] = (8,)
    learning_rate: float = 0.001
    metrics: List[str] = ["accuracy"]
    epochs: int = 50
    batch_size: int = 8

visualizers special

Initialization for TensorFlow visualizer.

tensorboard_visualizer

Implementation of a TensorFlow visualizer step.

TensorboardVisualizer (BaseStepVisualizer)

The implementation of a Whylogs Visualizer.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
class TensorboardVisualizer(BaseStepVisualizer):
    """The implementation of a Whylogs Visualizer."""

    @classmethod
    def find_running_tensorboard_server(
        cls, logdir: str
    ) -> Optional[TensorBoardInfo]:
        """Find a local TensorBoard server instance.

        Finds when it is running for the supplied logdir location and return its
        TCP port.

        Args:
            logdir: The logdir location where the TensorBoard server is running.

        Returns:
            The TensorBoardInfo describing the running TensorBoard server or
            None if no server is running for the supplied logdir location.
        """
        for server in get_all():
            if (
                server.logdir == logdir
                and server.pid
                and psutil.pid_exists(server.pid)
            ):
                return server
        return None

    def visualize(
        self,
        object: StepView,
        height: int = 800,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Start a TensorBoard server.

        Allows for the visualization of all models logged as artifacts by the
        indicated step. The server will monitor and display all the models
        logged by past and future step runs.

        Args:
            object: StepView fetched from run.get_step().
            height: Height of the generated visualization.
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        for _, artifact_view in object.outputs.items():
            # filter out anything but model artifacts
            if artifact_view.type == ModelArtifact.TYPE_NAME:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(logdir)
                if running_server:
                    self.visualize_tensorboard(running_server.port, height)
                    return

                if sys.platform == "win32":
                    # Daemon service functionality is currently not supported on Windows
                    print(
                        "You can run:\n"
                        f"[italic green]    tensorboard --logdir {logdir}"
                        "[/italic green]\n"
                        "...to visualize the TensorBoard logs for your trained model."
                    )
                else:
                    # start a new TensorBoard server
                    service = TensorboardService(
                        TensorboardServiceConfig(
                            logdir=logdir,
                        )
                    )
                    service.start(timeout=20)
                    if service.endpoint.status.port:
                        self.visualize_tensorboard(
                            service.endpoint.status.port, height
                        )
                return

    def visualize_tensorboard(
        self,
        port: int,
        height: int,
    ) -> None:
        """Generate a visualization of a TensorBoard.

        Args:
            port: the TCP port where the TensorBoard server is listening for
                requests.
            height: Height of the generated visualization.
        """
        if Environment.in_notebook():

            notebook.display(port, height=height)
            return

        print(
            "You can visit:\n"
            f"[italic green]    http://localhost:{port}/[/italic green]\n"
            "...to visualize the TensorBoard logs for your trained model."
        )

    def stop(
        self,
        object: StepView,
    ) -> None:
        """Stop the TensorBoard server previously started for a pipeline step.

        Args:
            object: StepView fetched from run.get_step().
        """
        for _, artifact_view in object.outputs.items():
            # filter out anything but model artifacts
            if artifact_view.type == ModelArtifact.TYPE_NAME:
                logdir = os.path.dirname(artifact_view.uri)

                # first check if a TensorBoard server is already running for
                # the same logdir location and use that one
                running_server = self.find_running_tensorboard_server(logdir)
                if not running_server:
                    return

                logger.debug(
                    "Stopping tensorboard server with PID '%d' ...",
                    running_server.pid,
                )
                try:
                    p = psutil.Process(running_server.pid)
                except psutil.Error:
                    logger.error(
                        "Could not find process for PID '%d' ...",
                        running_server.pid,
                    )
                    continue
                p.kill()
                return
find_running_tensorboard_server(logdir) classmethod

Find a local TensorBoard server instance.

Finds when it is running for the supplied logdir location and return its TCP port.

Parameters:

Name Type Description Default
logdir str

The logdir location where the TensorBoard server is running.

required

Returns:

Type Description
Optional[tensorboard.manager.TensorBoardInfo]

The TensorBoardInfo describing the running TensorBoard server or None if no server is running for the supplied logdir location.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
@classmethod
def find_running_tensorboard_server(
    cls, logdir: str
) -> Optional[TensorBoardInfo]:
    """Find a local TensorBoard server instance.

    Finds when it is running for the supplied logdir location and return its
    TCP port.

    Args:
        logdir: The logdir location where the TensorBoard server is running.

    Returns:
        The TensorBoardInfo describing the running TensorBoard server or
        None if no server is running for the supplied logdir location.
    """
    for server in get_all():
        if (
            server.logdir == logdir
            and server.pid
            and psutil.pid_exists(server.pid)
        ):
            return server
    return None
stop(self, object)

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop(
    self,
    object: StepView,
) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        object: StepView fetched from run.get_step().
    """
    for _, artifact_view in object.outputs.items():
        # filter out anything but model artifacts
        if artifact_view.type == ModelArtifact.TYPE_NAME:
            logdir = os.path.dirname(artifact_view.uri)

            # first check if a TensorBoard server is already running for
            # the same logdir location and use that one
            running_server = self.find_running_tensorboard_server(logdir)
            if not running_server:
                return

            logger.debug(
                "Stopping tensorboard server with PID '%d' ...",
                running_server.pid,
            )
            try:
                p = psutil.Process(running_server.pid)
            except psutil.Error:
                logger.error(
                    "Could not find process for PID '%d' ...",
                    running_server.pid,
                )
                continue
            p.kill()
            return
visualize(self, object, height=800, *args, **kwargs)

Start a TensorBoard server.

Allows for the visualization of all models logged as artifacts by the indicated step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
height int

Height of the generated visualization.

800
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize(
    self,
    object: StepView,
    height: int = 800,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as artifacts by the
    indicated step. The server will monitor and display all the models
    logged by past and future step runs.

    Args:
        object: StepView fetched from run.get_step().
        height: Height of the generated visualization.
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    for _, artifact_view in object.outputs.items():
        # filter out anything but model artifacts
        if artifact_view.type == ModelArtifact.TYPE_NAME:
            logdir = os.path.dirname(artifact_view.uri)

            # first check if a TensorBoard server is already running for
            # the same logdir location and use that one
            running_server = self.find_running_tensorboard_server(logdir)
            if running_server:
                self.visualize_tensorboard(running_server.port, height)
                return

            if sys.platform == "win32":
                # Daemon service functionality is currently not supported on Windows
                print(
                    "You can run:\n"
                    f"[italic green]    tensorboard --logdir {logdir}"
                    "[/italic green]\n"
                    "...to visualize the TensorBoard logs for your trained model."
                )
            else:
                # start a new TensorBoard server
                service = TensorboardService(
                    TensorboardServiceConfig(
                        logdir=logdir,
                    )
                )
                service.start(timeout=20)
                if service.endpoint.status.port:
                    self.visualize_tensorboard(
                        service.endpoint.status.port, height
                    )
            return
visualize_tensorboard(self, port, height)

Generate a visualization of a TensorBoard.

Parameters:

Name Type Description Default
port int

the TCP port where the TensorBoard server is listening for requests.

required
height int

Height of the generated visualization.

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(
    self,
    port: int,
    height: int,
) -> None:
    """Generate a visualization of a TensorBoard.

    Args:
        port: the TCP port where the TensorBoard server is listening for
            requests.
        height: Height of the generated visualization.
    """
    if Environment.in_notebook():

        notebook.display(port, height=height)
        return

    print(
        "You can visit:\n"
        f"[italic green]    http://localhost:{port}/[/italic green]\n"
        "...to visualize the TensorBoard logs for your trained model."
    )
get_step(pipeline_name, step_name)

Get the StepView for the specified pipeline and step name.

Parameters:

Name Type Description Default
pipeline_name str

The name of the pipeline.

required
step_name str

The name of the step.

required

Returns:

Type Description
StepView

The StepView for the specified pipeline and step name.

Exceptions:

Type Description
RuntimeError

If the step is not found.

Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def get_step(pipeline_name: str, step_name: str) -> StepView:
    """Get the StepView for the specified pipeline and step name.

    Args:
        pipeline_name: The name of the pipeline.
        step_name: The name of the step.

    Returns:
        The StepView for the specified pipeline and step name.

    Raises:
        RuntimeError: If the step is not found.
    """
    repo = Repository()
    pipeline = repo.get_pipeline(pipeline_name)
    if pipeline is None:
        raise RuntimeError(f"No pipeline with name `{pipeline_name}` was found")

    last_run = pipeline.runs[-1]
    step = last_run.get_step(name=step_name)
    if step is None:
        raise RuntimeError(
            f"No pipeline step with name `{step_name}` was found in "
            f"pipeline `{pipeline_name}`"
        )
    return cast(StepView, step)
stop_tensorboard_server(pipeline_name, step_name)

Stop the TensorBoard server previously started for a pipeline step.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def stop_tensorboard_server(pipeline_name: str, step_name: str) -> None:
    """Stop the TensorBoard server previously started for a pipeline step.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().stop(step)
visualize_tensorboard(pipeline_name, step_name)

Start a TensorBoard server.

Allows for the visualization of all models logged as output by the named pipeline step. The server will monitor and display all the models logged by past and future step runs.

Parameters:

Name Type Description Default
pipeline_name str

the name of the pipeline

required
step_name str

pipeline step name

required
Source code in zenml/integrations/tensorflow/visualizers/tensorboard_visualizer.py
def visualize_tensorboard(pipeline_name: str, step_name: str) -> None:
    """Start a TensorBoard server.

    Allows for the visualization of all models logged as output by the named
    pipeline step. The server will monitor and display all the models logged by
    past and future step runs.

    Args:
        pipeline_name: the name of the pipeline
        step_name: pipeline step name
    """
    step = get_step(pipeline_name, step_name)
    TensorboardVisualizer().visualize(step)

utils

Utility functions for the integrations module.

get_integration_for_module(module_name)

Gets the integration class for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return None. If it is part of a ZenML integration, it will return the integration class found inside the integration init file.

Parameters:

Name Type Description Default
module_name str

The name of the module to get the integration for.

required

Returns:

Type Description
Optional[Type[zenml.integrations.integration.Integration]]

The integration class for the module.

Source code in zenml/integrations/utils.py
def get_integration_for_module(
    module_name: str,
) -> Optional[Type[Integration]]:
    """Gets the integration class for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return `None`. If it is part of a ZenML integration,
    it will return the integration class found inside the integration
    __init__ file.

    Args:
        module_name: The name of the module to get the integration for.

    Returns:
        The integration class for the module.
    """
    integration_prefix = "zenml.integrations."
    if not module_name.startswith(integration_prefix):
        return None

    integration_module_name = ".".join(module_name.split(".", 3)[:3])
    try:
        integration_module = sys.modules[integration_module_name]
    except KeyError:
        integration_module = importlib.import_module(integration_module_name)

    for name, member in inspect.getmembers(integration_module):
        if (
            member is not Integration
            and isinstance(member, IntegrationMeta)
            and issubclass(member, Integration)
        ):
            return cast(Type[Integration], member)

    return None

get_requirements_for_module(module_name)

Gets requirements for a module inside an integration.

If the module given by module_name is not part of a ZenML integration, this method will return an empty list. If it is part of a ZenML integration, it will return the list of requirements specified inside the integration class found inside the integration init file.

Parameters:

Name Type Description Default
module_name str

The name of the module to get requirements for.

required

Returns:

Type Description
List[str]

A list of requirements for the module.

Source code in zenml/integrations/utils.py
def get_requirements_for_module(module_name: str) -> List[str]:
    """Gets requirements for a module inside an integration.

    If the module given by `module_name` is not part of a ZenML integration,
    this method will return an empty list. If it is part of a ZenML integration,
    it will return the list of requirements specified inside the integration
    class found inside the integration __init__ file.

    Args:
        module_name: The name of the module to get requirements for.

    Returns:
        A list of requirements for the module.
    """
    integration = get_integration_for_module(module_name)
    return integration.REQUIREMENTS if integration else []

vault special

Initialization for the Vault Secrets Manager integration.

The Vault secrets manager integration submodule provides a way to access the HashiCorp Vault secrets manager from within your ZenML pipeline runs.

VaultSecretsManagerIntegration (Integration)

Definition of HashiCorp Vault integration with ZenML.

Source code in zenml/integrations/vault/__init__.py
class VaultSecretsManagerIntegration(Integration):
    """Definition of HashiCorp Vault integration with ZenML."""

    NAME = VAULT
    REQUIREMENTS = ["hvac>=0.11.2"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Vault integration.

        Returns:
            List of stack component flavors.
        """
        return [
            FlavorWrapper(
                name=VAULT_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Vault integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors.

Source code in zenml/integrations/vault/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Vault integration.

    Returns:
        List of stack component flavors.
    """
    return [
        FlavorWrapper(
            name=VAULT_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.vault.secrets_manager.VaultSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        )
    ]

secrets_manager special

HashiCorp Vault Secrets Manager.

vault_secrets_manager

Implementation of the HashiCorp Vault Secrets Manager integration.

VaultSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the Vault secrets manager - Key/value Engine.

Attributes:

Name Type Description
url str

The url of the Vault server.

token str

The token to use to authenticate with Vault.

cert Optional[str]

The path to the certificate to use to authenticate with Vault.

verify Optional[str]

Whether to verify the certificate or not.

mount_point str

The mount point of the secrets manager.

namespace Optional[str]

The namespace of the secrets manager.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
class VaultSecretsManager(BaseSecretsManager):
    """Class to interact with the Vault secrets manager - Key/value Engine.

    Attributes:
        url: The url of the Vault server.
        token: The token to use to authenticate with Vault.
        cert: The path to the certificate to use to authenticate with Vault.
        verify: Whether to verify the certificate or not.
        mount_point: The mount point of the secrets manager.
        namespace: The namespace of the secrets manager.
    """

    # Class configuration
    FLAVOR: ClassVar[str] = VAULT_SECRETS_MANAGER_FLAVOR
    CLIENT: ClassVar[Any] = None

    url: str
    token: str
    mount_point: str
    cert: Optional[str]
    verify: Optional[str]
    namespace: Optional[str]

    @classmethod
    def _ensure_client_connected(cls, url: str, token: str) -> None:
        """Ensure the client is connected.

        This function initializes the client if it is not initialized.

        Args:
            url: The url of the Vault server.
            token: The token to use to authenticate with Vault.
        """
        if cls.CLIENT is None:
            # Create a Vault Secrets Manager client
            cls.CLIENT = hvac.Client(
                url=url,
                token=token,
            )

    def _ensure_client_is_authenticated(self) -> None:
        """Ensure the client is authenticated.

        Raises:
            RuntimeError: If the client is not initialized or authenticated.
        """
        self._ensure_client_connected(url=self.url, token=self.token)

        if not self.CLIENT.is_authenticated():
            raise RuntimeError(
                "There was an error authenticating with Vault. Please check your configuration."
            )
        else:
            pass

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: The secret to register.

        Raises:
            SecretExistsError: If the secret already exists.
        """
        self._ensure_client_is_authenticated()

        secret_name = prepend_secret_schema_to_secret_name(secret=secret)

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name '{secret.name}' already exists."
            )

        self.CLIENT.secrets.kv.v2.create_or_update_secret(
            path=f"{ZENML_PATH}/{secret_name}",
            mount_point=self.mount_point,
            secret=secret.content,
        )

        logger.info("Created secret: %s", f"{ZENML_PATH}/{secret_name}")
        logger.info("Added value to secret.")

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets the value of a secret.

        Args:
            secret_name: The name of the secret to get.

        Returns:
            The secret.
        """
        vault_secret_name = self.vault_secret_name(secret_name)

        secret_items = (
            self.CLIENT.secrets.kv.v2.read_secret_version(
                path=f"{ZENML_PATH}/{vault_secret_name}",
                mount_point=self.mount_point,
            )
            .get("data", {})
            .get("data", {})
        )

        secret_schema = SecretSchemaClassRegistry.get_class(
            secret_schema=get_secret_schema_name(vault_secret_name)
        )
        secret_items["name"] = sanitize_secret_name(secret_name)
        return secret_schema(**secret_items)

    def vault_list_secrets(self) -> List[str]:
        """List all secrets in Vault without any reformatting.

        This function tries to get all secrets from Vault and returns
        them as a list of strings (all secrets' names)
        in the format: <secret_schema>-<secret_name>

        Returns:
            A list of all secrets in the secrets manager.
        """
        self._ensure_client_is_authenticated()

        set_of_secrets: Set[str] = set()
        try:
            secrets = self.CLIENT.secrets.kv.v2.list_secrets(
                path=f"{ZENML_PATH}/", mount_point=self.mount_point
            )
        except hvac.exceptions.InvalidPath:
            logger.error(
                f"There are no secrets created within the path `{ZENML_PATH}` "
            )
            return list(set_of_secrets)

        secrets_keys = secrets.get("data", {}).get("keys", [])
        for secret_key in secrets_keys:
            set_of_secrets.add(secret_key)
        return list(set_of_secrets)

    def vault_secret_name(self, secret_name: str) -> str:
        """Get the secret name in the form it is stored in Vault.

        This function retrieves the secret name in the Vault secrets manager, without
        any reformatting. This secret should be in the format `<secret_schema_name>-<secret_name>`.

        Args:
            secret_name: The name of the secret to get.

        Returns:
            The secret name in the Vault secrets manager.

        Raises:
            KeyError: If the secret does not exist.
        """
        self._ensure_client_is_authenticated()

        secrets_keys = self.vault_list_secrets()
        sanitized_secret_name = sanitize_secret_name(secret_name)

        for secret_key in secrets_keys:
            secret_name_without_schema = remove_secret_schema_name(secret_key)
            if sanitized_secret_name == secret_name_without_schema:
                return secret_key
        raise KeyError(f"The secret {secret_name} does not exist.")

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        This function tries to get all secrets from Vault and returns
        them as a list of strings. All secrets names are without the schema.

        Returns:
            A list of all secret keys in the secrets manager.
        """
        return [
            remove_secret_schema_name(secret_key)
            for secret_key in self.vault_list_secrets()
        ]

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: The secret to update.

        Raises:
            KeyError: If the secret does not exist.
        """
        self._ensure_client_is_authenticated()

        secret_name = prepend_secret_schema_to_secret_name(secret=secret)

        if secret_name in self.vault_list_secrets():
            self.CLIENT.secrets.kv.v2.create_or_update_secret(
                path=f"{ZENML_PATH}/{secret_name}",
                mount_point=self.mount_point,
                secret=secret.content,
            )
        else:
            raise KeyError(
                f"A Secret with the name '{secret.name}' does not exist."
            )

        logger.info("Updated secret: %s", f"{ZENML_PATH}/{secret.name}")
        logger.info("Added value to secret.")

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: The name of the secret to delete.

        Raises:
            KeyError: If the secret does not exist.
        """
        self._ensure_client_is_authenticated()

        try:
            vault_secret_name = self.vault_secret_name(secret_name)
        except KeyError:
            raise KeyError(f"The secret {secret_name} does not exist.")

        self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
            path=f"{ZENML_PATH}/{vault_secret_name}",
            mount_point=self.mount_point,
        )

        logger.info("Deleted secret: %s", f"{ZENML_PATH}/{secret_name}")

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_is_authenticated()

        for secret_name in self.get_all_secret_keys():
            self.delete_secret(secret_name)

        logger.info("Deleted all secrets.")
delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_is_authenticated()

    for secret_name in self.get_all_secret_keys():
        self.delete_secret(secret_name)

    logger.info("Deleted all secrets.")
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to delete.

required

Exceptions:

Type Description
KeyError

If the secret does not exist.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: The name of the secret to delete.

    Raises:
        KeyError: If the secret does not exist.
    """
    self._ensure_client_is_authenticated()

    try:
        vault_secret_name = self.vault_secret_name(secret_name)
    except KeyError:
        raise KeyError(f"The secret {secret_name} does not exist.")

    self.CLIENT.secrets.kv.v2.delete_metadata_and_all_versions(
        path=f"{ZENML_PATH}/{vault_secret_name}",
        mount_point=self.mount_point,
    )

    logger.info("Deleted secret: %s", f"{ZENML_PATH}/{secret_name}")
get_all_secret_keys(self)

Get all secret keys.

This function tries to get all secrets from Vault and returns them as a list of strings. All secrets names are without the schema.

Returns:

Type Description
List[str]

A list of all secret keys in the secrets manager.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    This function tries to get all secrets from Vault and returns
    them as a list of strings. All secrets names are without the schema.

    Returns:
        A list of all secret keys in the secrets manager.
    """
    return [
        remove_secret_schema_name(secret_key)
        for secret_key in self.vault_list_secrets()
    ]
get_secret(self, secret_name)

Gets the value of a secret.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to get.

required

Returns:

Type Description
BaseSecretSchema

The secret.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets the value of a secret.

    Args:
        secret_name: The name of the secret to get.

    Returns:
        The secret.
    """
    vault_secret_name = self.vault_secret_name(secret_name)

    secret_items = (
        self.CLIENT.secrets.kv.v2.read_secret_version(
            path=f"{ZENML_PATH}/{vault_secret_name}",
            mount_point=self.mount_point,
        )
        .get("data", {})
        .get("data", {})
    )

    secret_schema = SecretSchemaClassRegistry.get_class(
        secret_schema=get_secret_schema_name(vault_secret_name)
    )
    secret_items["name"] = sanitize_secret_name(secret_name)
    return secret_schema(**secret_items)
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to register.

required

Exceptions:

Type Description
SecretExistsError

If the secret already exists.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: The secret to register.

    Raises:
        SecretExistsError: If the secret already exists.
    """
    self._ensure_client_is_authenticated()

    secret_name = prepend_secret_schema_to_secret_name(secret=secret)

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name '{secret.name}' already exists."
        )

    self.CLIENT.secrets.kv.v2.create_or_update_secret(
        path=f"{ZENML_PATH}/{secret_name}",
        mount_point=self.mount_point,
        secret=secret.content,
    )

    logger.info("Created secret: %s", f"{ZENML_PATH}/{secret_name}")
    logger.info("Added value to secret.")
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to update.

required

Exceptions:

Type Description
KeyError

If the secret does not exist.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: The secret to update.

    Raises:
        KeyError: If the secret does not exist.
    """
    self._ensure_client_is_authenticated()

    secret_name = prepend_secret_schema_to_secret_name(secret=secret)

    if secret_name in self.vault_list_secrets():
        self.CLIENT.secrets.kv.v2.create_or_update_secret(
            path=f"{ZENML_PATH}/{secret_name}",
            mount_point=self.mount_point,
            secret=secret.content,
        )
    else:
        raise KeyError(
            f"A Secret with the name '{secret.name}' does not exist."
        )

    logger.info("Updated secret: %s", f"{ZENML_PATH}/{secret.name}")
    logger.info("Added value to secret.")
vault_list_secrets(self)

List all secrets in Vault without any reformatting.

This function tries to get all secrets from Vault and returns them as a list of strings (all secrets' names) in the format: -

Returns:

Type Description
List[str]

A list of all secrets in the secrets manager.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def vault_list_secrets(self) -> List[str]:
    """List all secrets in Vault without any reformatting.

    This function tries to get all secrets from Vault and returns
    them as a list of strings (all secrets' names)
    in the format: <secret_schema>-<secret_name>

    Returns:
        A list of all secrets in the secrets manager.
    """
    self._ensure_client_is_authenticated()

    set_of_secrets: Set[str] = set()
    try:
        secrets = self.CLIENT.secrets.kv.v2.list_secrets(
            path=f"{ZENML_PATH}/", mount_point=self.mount_point
        )
    except hvac.exceptions.InvalidPath:
        logger.error(
            f"There are no secrets created within the path `{ZENML_PATH}` "
        )
        return list(set_of_secrets)

    secrets_keys = secrets.get("data", {}).get("keys", [])
    for secret_key in secrets_keys:
        set_of_secrets.add(secret_key)
    return list(set_of_secrets)
vault_secret_name(self, secret_name)

Get the secret name in the form it is stored in Vault.

This function retrieves the secret name in the Vault secrets manager, without any reformatting. This secret should be in the format <secret_schema_name>-<secret_name>.

Parameters:

Name Type Description Default
secret_name str

The name of the secret to get.

required

Returns:

Type Description
str

The secret name in the Vault secrets manager.

Exceptions:

Type Description
KeyError

If the secret does not exist.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def vault_secret_name(self, secret_name: str) -> str:
    """Get the secret name in the form it is stored in Vault.

    This function retrieves the secret name in the Vault secrets manager, without
    any reformatting. This secret should be in the format `<secret_schema_name>-<secret_name>`.

    Args:
        secret_name: The name of the secret to get.

    Returns:
        The secret name in the Vault secrets manager.

    Raises:
        KeyError: If the secret does not exist.
    """
    self._ensure_client_is_authenticated()

    secrets_keys = self.vault_list_secrets()
    sanitized_secret_name = sanitize_secret_name(secret_name)

    for secret_key in secrets_keys:
        secret_name_without_schema = remove_secret_schema_name(secret_key)
        if sanitized_secret_name == secret_name_without_schema:
            return secret_key
    raise KeyError(f"The secret {secret_name} does not exist.")
get_secret_schema_name(combined_secret_name)

Get the secret schema name from the secret name.

Parameters:

Name Type Description Default
combined_secret_name str

The secret name to get the secret schema name from.

required

Returns:

Type Description
str

The secret schema name.

Exceptions:

Type Description
RuntimeError

If the secret name does not have a secret schema name on it.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def get_secret_schema_name(combined_secret_name: str) -> str:
    """Get the secret schema name from the secret name.

    Args:
        combined_secret_name: The secret name to get the secret schema name from.

    Returns:
        The secret schema name.

    Raises:
        RuntimeError: If the secret name does not have a secret schema name on it.
    """
    if "-" in combined_secret_name:
        return combined_secret_name.split("-")[0]
    else:
        raise RuntimeError(
            f"Secret name `{combined_secret_name}` does not have a "
            f"Secret_schema name on it."
        )
prepend_secret_schema_to_secret_name(secret)

Prepend the secret schema name to the secret name.

Parameters:

Name Type Description Default
secret BaseSecretSchema

The secret to prepend the secret schema name to.

required

Returns:

Type Description
str

The secret name with the secret schema name prepended.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def prepend_secret_schema_to_secret_name(secret: BaseSecretSchema) -> str:
    """Prepend the secret schema name to the secret name.

    Args:
        secret: The secret to prepend the secret schema name to.

    Returns:
        The secret name with the secret schema name prepended.
    """
    secret_name = sanitize_secret_name(secret.name)
    secret_schema_name = secret.TYPE
    return f"{secret_schema_name}-{secret_name}"
remove_secret_schema_name(combined_secret_name)

Remove the secret schema name from the secret name.

Parameters:

Name Type Description Default
combined_secret_name str

The secret name to remove the secret schema name from.

required

Returns:

Type Description
str

The secret name without the secret schema name.

Exceptions:

Type Description
RuntimeError

If the secret name does not have a secret schema name on it.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def remove_secret_schema_name(combined_secret_name: str) -> str:
    """Remove the secret schema name from the secret name.

    Args:
        combined_secret_name: The secret name to remove the secret schema name from.

    Returns:
        The secret name without the secret schema name.

    Raises:
        RuntimeError: If the secret name does not have a secret schema name on it.
    """
    if "-" in combined_secret_name:
        return combined_secret_name.split("-")[1]
    else:
        raise RuntimeError(
            f"Secret name `{combined_secret_name}` does not have a "
            f"Secret_schema name on it."
        )
sanitize_secret_name(secret_name)

Sanitize the secret name to be used in Vault.

Parameters:

Name Type Description Default
secret_name str

The secret name to sanitize.

required

Returns:

Type Description
str

The sanitized secret name.

Source code in zenml/integrations/vault/secrets_manager/vault_secrets_manager.py
def sanitize_secret_name(secret_name: str) -> str:
    """Sanitize the secret name to be used in Vault.

    Args:
        secret_name: The secret name to sanitize.

    Returns:
        The sanitized secret name.
    """
    return re.sub(r"[^0-9a-zA-Z_\.]+", "_", secret_name).strip("-_.")

wandb special

Initialization for the wandb integration.

The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui.

WandbIntegration (Integration)

Definition of Plotly integration for ZenML.

Source code in zenml/integrations/wandb/__init__.py
class WandbIntegration(Integration):
    """Definition of Plotly integration for ZenML."""

    NAME = WANDB
    REQUIREMENTS = ["wandb>=0.12.12", "Pillow>=9.1.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Weights and Biases integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
                source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
                type=StackComponentType.EXPERIMENT_TRACKER,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Weights and Biases integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/wandb/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Weights and Biases integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=WANDB_EXPERIMENT_TRACKER_FLAVOR,
            source="zenml.integrations.wandb.experiment_trackers.WandbExperimentTracker",
            type=StackComponentType.EXPERIMENT_TRACKER,
            integration=cls.NAME,
        )
    ]

experiment_trackers special

Initialization for the wandb experiment tracker.

wandb_experiment_tracker

Implementation for the wandb experiment tracker.

WandbExperimentTracker (BaseExperimentTracker) pydantic-model

Stores wandb configuration options.

ZenML should take care of configuring wandb for you, but should you still need access to the configuration inside your step you can do it using a step context:

from zenml.steps import StepContext

@enable_wandb
@step
def my_step(context: StepContext, ...)
    context.stack.experiment_tracker  # get the tracking_uri etc. from here

Attributes:

Name Type Description
entity Optional[str]

Name of an existing wandb entity.

project_name Optional[str]

Name of an existing wandb project to log to.

api_key str

API key to should be authorized to log to the configured wandb entity and project.

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
class WandbExperimentTracker(BaseExperimentTracker):
    """Stores wandb configuration options.

    ZenML should take care of configuring wandb for you, but should you still
    need access to the configuration inside your step you can do it using a
    step context:
    ```python
    from zenml.steps import StepContext

    @enable_wandb
    @step
    def my_step(context: StepContext, ...)
        context.stack.experiment_tracker  # get the tracking_uri etc. from here
    ```

    Attributes:
        entity: Name of an existing wandb entity.
        project_name: Name of an existing wandb project to log to.
        api_key: API key to should be authorized to log to the configured wandb
            entity and project.
    """

    api_key: str
    entity: Optional[str] = None
    project_name: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = WANDB_EXPERIMENT_TRACKER_FLAVOR

    def prepare_step_run(self) -> None:
        """Sets the wandb api key."""
        os.environ[WANDB_API_KEY] = self.api_key

    @contextmanager
    def activate_wandb_run(
        self,
        run_name: str,
        tags: Tuple[str, ...] = (),
        settings: Optional[wandb.Settings] = None,
    ) -> Iterator[None]:
        """Activates a wandb run for the duration of this context manager.

        Anything logged to wandb that is run while this context manager is
        active will automatically log to the same wandb run configured by the
        run name passed as an argument to this function.

        Args:
            run_name: Name of the wandb run to create.
            tags: Tags to attach to the wandb run.
            settings: Additional settings for the wandb run.

        Yields:
            None
        """
        try:
            logger.info(
                f"Initializing wandb with project name: {self.project_name}, "
                f"run_name: {run_name}, entity: {self.entity}."
            )
            wandb.init(
                project=self.project_name,
                name=run_name,
                entity=self.entity,
                settings=settings,
                tags=tags,
            )
            yield
        finally:
            wandb.finish()
activate_wandb_run(self, run_name, tags=(), settings=None)

Activates a wandb run for the duration of this context manager.

Anything logged to wandb that is run while this context manager is active will automatically log to the same wandb run configured by the run name passed as an argument to this function.

Parameters:

Name Type Description Default
run_name str

Name of the wandb run to create.

required
tags Tuple[str, ...]

Tags to attach to the wandb run.

()
settings Optional[wandb.sdk.wandb_settings.Settings]

Additional settings for the wandb run.

None

Yields:

Type Description
Iterator[NoneType]

None

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
@contextmanager
def activate_wandb_run(
    self,
    run_name: str,
    tags: Tuple[str, ...] = (),
    settings: Optional[wandb.Settings] = None,
) -> Iterator[None]:
    """Activates a wandb run for the duration of this context manager.

    Anything logged to wandb that is run while this context manager is
    active will automatically log to the same wandb run configured by the
    run name passed as an argument to this function.

    Args:
        run_name: Name of the wandb run to create.
        tags: Tags to attach to the wandb run.
        settings: Additional settings for the wandb run.

    Yields:
        None
    """
    try:
        logger.info(
            f"Initializing wandb with project name: {self.project_name}, "
            f"run_name: {run_name}, entity: {self.entity}."
        )
        wandb.init(
            project=self.project_name,
            name=run_name,
            entity=self.entity,
            settings=settings,
            tags=tags,
        )
        yield
    finally:
        wandb.finish()
prepare_step_run(self)

Sets the wandb api key.

Source code in zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py
def prepare_step_run(self) -> None:
    """Sets the wandb api key."""
    os.environ[WANDB_API_KEY] = self.api_key

wandb_step_decorator

Implementation for the wandb step decorator.

enable_wandb(_step=None, *, settings=None)

Decorator to enable wandb for a step function.

Apply this decorator to a ZenML pipeline step to enable wandb experiment tracking. The wandb tracking configuration (project name, experiment name, entity) will be automatically configured before the step code is executed, so the step can simply use the wandb module to log metrics and artifacts, like so:

@enable_wandb
@step
def tf_evaluator(
    x_test: np.ndarray,
    y_test: np.ndarray,
    model: tf.keras.Model,
) -> float:
    _, test_acc = model.evaluate(x_test, y_test, verbose=2)
    wandb.log_metric("val_accuracy", test_acc)
    return test_acc

You can also use this decorator with our class-based API like so:

@enable_wandb
class TFEvaluator(BaseStep):
    def entrypoint(
        self,
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        ...

All wandb artifacts and metrics logged from all the steps in a pipeline run are by default grouped under a single experiment named after the pipeline. To log wandb artifacts and metrics from a step in a separate wandb experiment, pass a custom experiment_name argument value to the decorator.

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None
settings Optional[wandb.sdk.wandb_settings.Settings]

wandb settings to use for the step.

None

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

The inner decorator which enhances the input step class with wandb tracking functionality

Source code in zenml/integrations/wandb/wandb_step_decorator.py
def enable_wandb(
    _step: Optional[S] = None, *, settings: Optional[wandb.Settings] = None
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable wandb for a step function.

    Apply this decorator to a ZenML pipeline step to enable wandb experiment
    tracking. The wandb tracking configuration (project name, experiment name,
    entity) will be automatically configured before the step code is executed,
    so the step can simply use the `wandb` module to log metrics and artifacts,
    like so:

    ```python
    @enable_wandb
    @step
    def tf_evaluator(
        x_test: np.ndarray,
        y_test: np.ndarray,
        model: tf.keras.Model,
    ) -> float:
        _, test_acc = model.evaluate(x_test, y_test, verbose=2)
        wandb.log_metric("val_accuracy", test_acc)
        return test_acc
    ```

    You can also use this decorator with our class-based API like so:
    ```
    @enable_wandb
    class TFEvaluator(BaseStep):
        def entrypoint(
            self,
            x_test: np.ndarray,
            y_test: np.ndarray,
            model: tf.keras.Model,
        ) -> float:
            ...
    ```

    All wandb artifacts and metrics logged from all the steps in a pipeline
    run are by default grouped under a single experiment named after the
    pipeline. To log wandb artifacts and metrics from a step in a separate
    wandb experiment, pass a custom `experiment_name` argument value to the
    decorator.

    Args:
        _step: The decorated step class.
        settings: wandb settings to use for the step.

    Returns:
        The inner decorator which enhances the input step class with wandb
        tracking functionality
    """

    def inner_decorator(_step: S) -> S:
        """Inner decorator for step enable_wandb.

        Args:
            _step: The decorated step class.

        Returns:
            The decorated step class.

        Raises:
            RuntimeError: If the decorator is not being applied to a ZenML step
                decorated function or a BaseStep subclass.
        """
        logger.debug(
            "Applying 'enable_wandb' decorator to step %s", _step.__name__
        )
        if not issubclass(_step, BaseStep):
            raise RuntimeError(
                "The `enable_wandb` decorator can only be applied to a ZenML "
                "`step` decorated function or a BaseStep subclass."
            )
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = wandb_step_entrypoint(
            settings=settings,
        )(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
wandb_step_entrypoint(settings=None)

Decorator for a step entrypoint to enable wandb.

Parameters:

Name Type Description Default
settings Optional[wandb.sdk.wandb_settings.Settings]

wandb settings to use for the step.

None

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with wandb profiling functionality

Source code in zenml/integrations/wandb/wandb_step_decorator.py
def wandb_step_entrypoint(
    settings: Optional[wandb.Settings] = None,
) -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable wandb.

    Args:
        settings: wandb settings to use for the step.

    Returns:
        the input function enhanced with wandb profiling functionality
    """

    def inner_decorator(func: F) -> F:
        """Inner decorator for step entrypoint.

        Args:
            func: The decorated function.

        Returns:
            the input function enhanced with wandb profiling functionality
        """
        logger.debug(
            "Applying 'wandb_step_entrypoint' decorator to step entrypoint %s",
            func.__name__,
        )

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa
            """Wrapper function for decorator.

            Args:
                *args: positional arguments to the decorated function.
                **kwargs: keyword arguments to the decorated function.

            Returns:
                The return value of the decorated function.

            Raises:
                ValueError: if the active stack has no active experiment tracker.
            """
            logger.debug(
                "Setting up wandb backend before running step entrypoint %s",
                func.__name__,
            )
            step_env = Environment().step_environment
            run_name = f"{step_env.pipeline_run_id}_{step_env.step_name}"
            tags = (step_env.pipeline_name, step_env.pipeline_run_id)

            experiment_tracker = Repository(  # type: ignore[call-arg]
                skip_repository_check=True
            ).active_stack.experiment_tracker

            if not isinstance(experiment_tracker, WandbExperimentTracker):
                raise ValueError(
                    "The active stack needs to have a wandb experiment tracker "
                    "component registered to be able to track experiments "
                    "using wandb. You can create a new stack with a wandb "
                    "experiment tracker component or update your existing "
                    "stack to add this component, e.g.:\n\n"
                    "  'zenml experiment-tracker register wandb_tracker "
                    "--type=wandb --entity=<WANDB_ENTITY> --project_name="
                    "<WANDB_PROJECT_NAME> --api_key=<WANDB_API_KEY>'\n"
                    "  'zenml stack register stack-name -e wandb_tracker ...'\n"
                )

            with experiment_tracker.activate_wandb_run(
                run_name=run_name,
                tags=tags,
                settings=settings,
            ):
                return func(*args, **kwargs)

        return cast(F, wrapper)

    return inner_decorator

whylogs special

Initialization of the whylogs integration.

WhylogsIntegration (Integration)

Definition of whylogs integration for ZenML.

Source code in zenml/integrations/whylogs/__init__.py
class WhylogsIntegration(Integration):
    """Definition of [whylogs](https://github.com/whylabs/whylogs) integration for ZenML."""

    NAME = WHYLOGS
    REQUIREMENTS = ["whylogs<=0.7.9", "pybars3>=0.9.7"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.whylogs import materializers  # noqa
        from zenml.integrations.whylogs import visualizers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/whylogs/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.whylogs import materializers  # noqa
    from zenml.integrations.whylogs import visualizers  # noqa

materializers special

Initialization of the whylogs materializer.

whylogs_materializer

Implementation of the whylogs materializer.

WhylogsMaterializer (BaseMaterializer)

Materializer to read/write whylogs dataset profiles.

Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
class WhylogsMaterializer(BaseMaterializer):
    """Materializer to read/write whylogs dataset profiles."""

    ASSOCIATED_TYPES = (DatasetProfile,)
    ASSOCIATED_ARTIFACT_TYPES = (StatisticsArtifact,)

    def handle_input(self, data_type: Type[Any]) -> DatasetProfile:
        """Reads and returns a whylogs DatasetProfile.

        Args:
            data_type: The type of the data to read.

        Returns:
            A loaded whylogs DatasetProfile.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
        with fileio.open(filepath, "rb") as f:
            protobuf = DatasetProfile.parse_delimited(f.read())[0]
        return protobuf

    def handle_return(self, profile: DatasetProfile) -> None:
        """Writes a whylogs DatasetProfile.

        Args:
            profile: A DatasetProfile object from whylogs.
        """
        super().handle_return(profile)
        filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
        protobuf = profile.serialize_delimited()
        with fileio.open(filepath, "wb") as f:
            f.write(protobuf)

        # TODO [ENG-439]: uploading profiles to whylabs should be enabled and
        #  configurable at step level or pipeline level instead of being
        #  globally enabled.
        if os.environ.get("WHYLABS_DEFAULT_ORG_ID"):
            upload_profile(profile)
handle_input(self, data_type)

Reads and returns a whylogs DatasetProfile.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
DatasetProfile

A loaded whylogs DatasetProfile.

Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_input(self, data_type: Type[Any]) -> DatasetProfile:
    """Reads and returns a whylogs DatasetProfile.

    Args:
        data_type: The type of the data to read.

    Returns:
        A loaded whylogs DatasetProfile.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
    with fileio.open(filepath, "rb") as f:
        protobuf = DatasetProfile.parse_delimited(f.read())[0]
    return protobuf
handle_return(self, profile)

Writes a whylogs DatasetProfile.

Parameters:

Name Type Description Default
profile DatasetProfile

A DatasetProfile object from whylogs.

required
Source code in zenml/integrations/whylogs/materializers/whylogs_materializer.py
def handle_return(self, profile: DatasetProfile) -> None:
    """Writes a whylogs DatasetProfile.

    Args:
        profile: A DatasetProfile object from whylogs.
    """
    super().handle_return(profile)
    filepath = os.path.join(self.artifact.uri, PROFILE_FILENAME)
    protobuf = profile.serialize_delimited()
    with fileio.open(filepath, "wb") as f:
        f.write(protobuf)

    # TODO [ENG-439]: uploading profiles to whylabs should be enabled and
    #  configurable at step level or pipeline level instead of being
    #  globally enabled.
    if os.environ.get("WHYLABS_DEFAULT_ORG_ID"):
        upload_profile(profile)

steps special

Initialization of the whylogs steps.

whylogs_profiler

Implementation of the whylogs profiler step.

WhylogsProfilerConfig (BaseAnalyzerConfig) pydantic-model

Config class for the WhylogsProfiler step.

Attributes:

Name Type Description
dataset_name Optional[str]

the name of the dataset (Optional). If not specified, the pipeline step name is used

dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

tags Optional[Dict[str, str]]

custom metadata tags associated with the whylogs profile

Also see WhylogsContext.log_dataframe.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
    """Config class for the WhylogsProfiler step.

    Attributes:
        dataset_name: the name of the dataset (Optional). If not specified,
            the pipeline step name is used
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
        tags: custom metadata tags associated with the whylogs profile

    Also see `WhylogsContext.log_dataframe`.
    """

    dataset_name: Optional[str] = None
    dataset_timestamp: Optional[datetime.datetime]
    tags: Optional[Dict[str, str]] = None
WhylogsProfilerStep (BaseAnalyzerStep)

Generates a whylogs data profile from a given pd.DataFrame.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerStep(BaseAnalyzerStep):
    """Generates a whylogs data profile from a given pd.DataFrame."""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: WhylogsProfilerConfig,
        context: StepContext,
    ) -> DatasetProfile:
        """Main entrypoint function for the whylogs profiler.

        Args:
            dataset: pd.DataFrame, the given dataset
            config: the configuration of the step
            context: the context of the step

        Returns:
            whylogs profile with statistics generated for the input dataset
        """
        whylogs_context = WhylogsContext(context)
        profile = whylogs_context.profile_dataframe(
            dataset, dataset_name=config.dataset_name, tags=config.tags
        )
        return profile
CONFIG_CLASS (BaseAnalyzerConfig) pydantic-model

Config class for the WhylogsProfiler step.

Attributes:

Name Type Description
dataset_name Optional[str]

the name of the dataset (Optional). If not specified, the pipeline step name is used

dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

tags Optional[Dict[str, str]]

custom metadata tags associated with the whylogs profile

Also see WhylogsContext.log_dataframe.

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
class WhylogsProfilerConfig(BaseAnalyzerConfig):
    """Config class for the WhylogsProfiler step.

    Attributes:
        dataset_name: the name of the dataset (Optional). If not specified,
            the pipeline step name is used
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
        tags: custom metadata tags associated with the whylogs profile

    Also see `WhylogsContext.log_dataframe`.
    """

    dataset_name: Optional[str] = None
    dataset_timestamp: Optional[datetime.datetime]
    tags: Optional[Dict[str, str]] = None
entrypoint(self, dataset, config, context)

Main entrypoint function for the whylogs profiler.

Parameters:

Name Type Description Default
dataset DataFrame

pd.DataFrame, the given dataset

required
config WhylogsProfilerConfig

the configuration of the step

required
context StepContext

the context of the step

required

Returns:

Type Description
DatasetProfile

whylogs profile with statistics generated for the input dataset

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: WhylogsProfilerConfig,
    context: StepContext,
) -> DatasetProfile:
    """Main entrypoint function for the whylogs profiler.

    Args:
        dataset: pd.DataFrame, the given dataset
        config: the configuration of the step
        context: the context of the step

    Returns:
        whylogs profile with statistics generated for the input dataset
    """
    whylogs_context = WhylogsContext(context)
    profile = whylogs_context.profile_dataframe(
        dataset, dataset_name=config.dataset_name, tags=config.tags
    )
    return profile
whylogs_profiler_step(step_name, enable_cache=None, dataset_name=None, dataset_timestamp=None, tags=None)

Shortcut function to create a new instance of the WhylogsProfilerStep step.

The returned WhylogsProfilerStep can be used in a pipeline to generate a whylogs DatasetProfile from a given pd.DataFrame and save it as an artifact.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
enable_cache Optional[bool]

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default

None
dataset_name Optional[str]

the dataset name to be used for the whylogs profile (Optional). If not specified, the step name is used

None
dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

None
tags Optional[Dict[str, str]]

custom metadata tags associated with the whylogs profile

None

Returns:

Type Description
WhylogsProfilerStep

a WhylogsProfilerStep step instance

Source code in zenml/integrations/whylogs/steps/whylogs_profiler.py
def whylogs_profiler_step(
    step_name: str,
    enable_cache: Optional[bool] = None,
    dataset_name: Optional[str] = None,
    dataset_timestamp: Optional[datetime.datetime] = None,
    tags: Optional[Dict[str, str]] = None,
) -> WhylogsProfilerStep:
    """Shortcut function to create a new instance of the WhylogsProfilerStep step.

    The returned WhylogsProfilerStep can be used in a pipeline to generate a
    whylogs DatasetProfile from a given pd.DataFrame and save it as an artifact.

    Args:
        step_name: The name of the step
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default
        dataset_name: the dataset name to be used for the whylogs profile
            (Optional). If not specified, the step name is used
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
        tags: custom metadata tags associated with the whylogs profile

    Returns:
        a WhylogsProfilerStep step instance
    """
    # enable cache explicitly to compensate for the fact that this step
    # takes in a context object
    if enable_cache is None:
        enable_cache = True

    step_type = type(
        step_name,
        (WhylogsProfilerStep,),
        {
            INSTANCE_CONFIGURATION: {
                PARAM_ENABLE_CACHE: enable_cache,
                PARAM_CREATED_BY_FUNCTIONAL_API: True,
            },
        },
    )

    return cast(
        WhylogsProfilerStep,
        step_type(
            WhylogsProfilerConfig(
                dataset_name=dataset_name,
                dataset_timestamp=dataset_timestamp,
                tags=tags,
            )
        ),
    )

visualizers special

Initialization of the whylogs visualizer.

whylogs_visualizer

Implementation of the whylogs visualizer step.

WhylogsPlots (StrEnum)

All supported whylogs plot types.

Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsPlots(StrEnum):
    """All supported whylogs plot types."""

    DISTRIBUTION = "plot_distribution"
    MISSING_VALUES = "plot_missing_values"
    UNIQUENESS = "plot_uniqueness"
    DATA_TYPES = "plot_data_types"
    STRING_LENGTH = "plot_string_length"
    TOKEN_LENGTH = "plot_token_length"
    CHAR_POS = "plot_char_pos"
    STRING = "plot_string"
WhylogsVisualizer (BaseStepVisualizer)

The implementation of a Whylogs Visualizer.

Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
class WhylogsVisualizer(BaseStepVisualizer):
    """The implementation of a Whylogs Visualizer."""

    def visualize(
        self,
        object: StepView,
        plots: Optional[List[WhylogsPlots]] = None,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """Visualize all whylogs dataset profiles present as outputs in the step view.

        Args:
            object: StepView fetched from run.get_step().
            plots: optional list of whylogs plots to visualize. Defaults to
                using all available plot types if not set
            *args: additional positional arguments to pass to the visualize
                method
            **kwargs: additional keyword arguments to pass to the visualize
                method
        """
        whylogs_artifact_datatype = (
            f"{DatasetProfile.__module__}.{DatasetProfile.__name__}"
        )
        for artifact_name, artifact_view in object.outputs.items():
            # filter out anything but whylogs dataset profile artifacts
            if artifact_view.data_type == whylogs_artifact_datatype:
                profile = artifact_view.read()
                # whylogs doesn't currently support visualizing multiple
                # non-related profiles side-by-side, so we open them in
                # separate viewers for now
                self.visualize_profile(artifact_name, profile, plots)

    @staticmethod
    def _get_plot_method(
        visualizer: ProfileVisualizer, plot: WhylogsPlots
    ) -> Any:
        """Get the Whylogs ProfileVisualizer plot method.

        This will be the one corresponding to a WhylogsPlots enum value.

        Args:
            visualizer: a ProfileVisualizer instance
            plot: a WhylogsPlots enum value

        Raises:
            ValueError: if the supplied WhylogsPlots enum value does not
                correspond to a valid ProfileVisualizer plot method

        Returns:
            The ProfileVisualizer plot method corresponding to the input
            WhylogsPlots enum value
        """
        plot_method = getattr(visualizer, plot, None)
        if plot_method is None:
            nl = "\n"
            raise ValueError(
                f"Invalid whylogs plot type: {plot} \n\n"
                f"Valid and supported options are: {nl}- "
                f'{f"{nl}- ".join(WhylogsPlots.names())}'
            )
        return plot_method

    def visualize_profile(
        self,
        name: str,
        profile: DatasetProfile,
        plots: Optional[List[WhylogsPlots]] = None,
    ) -> None:
        """Generate a visualization of a whylogs dataset profile.

        Args:
            name: name identifying the profile if multiple profiles are
                displayed at the same time
            profile: whylogs DatasetProfile to visualize
            plots: optional list of whylogs plots to visualize. Defaults to
                using all available plot types if not set
        """
        if Environment.in_notebook():
            from IPython.core.display import display

            if not plots:
                # default to using all plots if none are supplied
                plots = list(WhylogsPlots)

            for column in sorted(profile.columns):
                for plot in plots:
                    visualizer = ProfileVisualizer()
                    visualizer.set_profiles([profile])
                    plot_method = self._get_plot_method(visualizer, plot)
                    display(plot_method(column))
        else:
            logger.warning(
                "The magic functions are only usable in a Jupyter notebook."
            )
            with tempfile.NamedTemporaryFile(
                delete=False, suffix=f"-{name}.html"
            ) as f:
                logger.info("Opening %s in a new browser.." % f.name)
                profile_viewer([profile], output_path=f.name)
visualize(self, object, plots=None, *args, **kwargs)

Visualize all whylogs dataset profiles present as outputs in the step view.

Parameters:

Name Type Description Default
object StepView

StepView fetched from run.get_step().

required
plots Optional[List[zenml.integrations.whylogs.visualizers.whylogs_visualizer.WhylogsPlots]]

optional list of whylogs plots to visualize. Defaults to using all available plot types if not set

None
*args Any

additional positional arguments to pass to the visualize method

()
**kwargs Any

additional keyword arguments to pass to the visualize method

{}
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize(
    self,
    object: StepView,
    plots: Optional[List[WhylogsPlots]] = None,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Visualize all whylogs dataset profiles present as outputs in the step view.

    Args:
        object: StepView fetched from run.get_step().
        plots: optional list of whylogs plots to visualize. Defaults to
            using all available plot types if not set
        *args: additional positional arguments to pass to the visualize
            method
        **kwargs: additional keyword arguments to pass to the visualize
            method
    """
    whylogs_artifact_datatype = (
        f"{DatasetProfile.__module__}.{DatasetProfile.__name__}"
    )
    for artifact_name, artifact_view in object.outputs.items():
        # filter out anything but whylogs dataset profile artifacts
        if artifact_view.data_type == whylogs_artifact_datatype:
            profile = artifact_view.read()
            # whylogs doesn't currently support visualizing multiple
            # non-related profiles side-by-side, so we open them in
            # separate viewers for now
            self.visualize_profile(artifact_name, profile, plots)
visualize_profile(self, name, profile, plots=None)

Generate a visualization of a whylogs dataset profile.

Parameters:

Name Type Description Default
name str

name identifying the profile if multiple profiles are displayed at the same time

required
profile DatasetProfile

whylogs DatasetProfile to visualize

required
plots Optional[List[zenml.integrations.whylogs.visualizers.whylogs_visualizer.WhylogsPlots]]

optional list of whylogs plots to visualize. Defaults to using all available plot types if not set

None
Source code in zenml/integrations/whylogs/visualizers/whylogs_visualizer.py
def visualize_profile(
    self,
    name: str,
    profile: DatasetProfile,
    plots: Optional[List[WhylogsPlots]] = None,
) -> None:
    """Generate a visualization of a whylogs dataset profile.

    Args:
        name: name identifying the profile if multiple profiles are
            displayed at the same time
        profile: whylogs DatasetProfile to visualize
        plots: optional list of whylogs plots to visualize. Defaults to
            using all available plot types if not set
    """
    if Environment.in_notebook():
        from IPython.core.display import display

        if not plots:
            # default to using all plots if none are supplied
            plots = list(WhylogsPlots)

        for column in sorted(profile.columns):
            for plot in plots:
                visualizer = ProfileVisualizer()
                visualizer.set_profiles([profile])
                plot_method = self._get_plot_method(visualizer, plot)
                display(plot_method(column))
    else:
        logger.warning(
            "The magic functions are only usable in a Jupyter notebook."
        )
        with tempfile.NamedTemporaryFile(
            delete=False, suffix=f"-{name}.html"
        ) as f:
            logger.info("Opening %s in a new browser.." % f.name)
            profile_viewer([profile], output_path=f.name)

whylogs_context

Implementation of the whylogs step context extension.

WhylogsContext

A step context extension for whylogs.

This can be used to facilitate whylogs data logging and profiling inside a step function.

It acts as a wrapper built around the whylogs API that transparently incorporates ZenML specific information into the generated whylogs dataset profiles that can be used to associate whylogs profiles with the corresponding ZenML step run that produces them.

It also simplifies the whylogs profile generation process by abstracting away some of the whylogs specific details, such as whylogs session and logger initialization and management.

Source code in zenml/integrations/whylogs/whylogs_context.py
class WhylogsContext:
    """A step context extension for whylogs.

    This can be used to facilitate whylogs data logging and profiling inside a
    step function.

    It acts as a wrapper built around the whylogs API that transparently
    incorporates ZenML specific information into the generated whylogs dataset
    profiles that can be used to associate whylogs profiles with the
    corresponding ZenML step run that produces them.

    It also simplifies the whylogs profile generation process by abstracting
    away some of the whylogs specific details, such as whylogs session and
    logger initialization and management.
    """

    _session: Session = None

    def __init__(
        self,
        step_context: StepContext,
        project: Optional[str] = None,
        pipeline: Optional[str] = None,
        tags: Optional[Dict[str, str]] = None,
    ) -> None:
        """Create a ZenML whylogs context based on a generic step context.

        Args:
            step_context: a StepContext instance that provides information
                about the currently running step, such as the step name
            project: optional project name to use for the whylogs session
            pipeline: optional pipeline name to use for the whylogs session
            tags: optional list of tags to apply to all whylogs profiles
                generated through this context
        """
        self._step_context = step_context
        self._project = project
        self._pipeline = pipeline
        self._tags = tags

    def get_whylogs_session(
        self,
    ) -> Session:
        """Get the whylogs session associated with the current step.

        Returns:
            The whylogs Session instance associated with the current step
        """
        if self._session is not None:
            return self._session

        self._session = Session(
            project=self._project or self._step_context.step_name,
            pipeline=self._pipeline or self._step_context.step_name,
            # keeping the writers list empty, serialization is done in the
            # materializer
            writers=[],
        )

        return self._session

    def profile_dataframe(
        self,
        df: pd.DataFrame,
        dataset_name: Optional[str] = None,
        dataset_timestamp: Optional[datetime.datetime] = None,
        tags: Optional[Dict[str, str]] = None,
    ) -> DatasetProfile:
        """Generate whylogs statistics for a Pandas dataframe.

        Args:
            df: a Pandas dataframe to profile.
            dataset_name: the name of the dataset (Optional). If not specified,
                the pipeline step name is used
            dataset_timestamp: timestamp to associate with the generated
                dataset profile (Optional). The current time is used if not
                supplied.
            tags: custom metadata tags associated with the whylogs profile

        Returns:
            A whylogs DatasetProfile with the statistics generated from the
            input dataset.
        """
        session = self.get_whylogs_session()
        # TODO [ENG-437]: use a default whylogs dataset_name that is unique across
        #  multiple pipelines
        dataset_name = dataset_name or self._step_context.step_name
        final_tags = self._tags.copy() if self._tags else dict()
        # TODO [ENG-438]: add more zenml specific tags to the whylogs profile, such
        #  as the pipeline name and run ID
        final_tags["zenml.step"] = self._step_context.step_name
        # the datasetId tag is used to identify dataset profiles in whylabs.
        # dataset profiles with the same datasetID are considered to belong
        # to the same dataset/model.
        final_tags.setdefault("datasetId", dataset_name)
        if tags:
            final_tags.update(tags)
        logger = session.logger(
            dataset_name, dataset_timestamp=dataset_timestamp, tags=final_tags
        )
        logger.log_dataframe(df)
        return logger.close()
__init__(self, step_context, project=None, pipeline=None, tags=None) special

Create a ZenML whylogs context based on a generic step context.

Parameters:

Name Type Description Default
step_context StepContext

a StepContext instance that provides information about the currently running step, such as the step name

required
project Optional[str]

optional project name to use for the whylogs session

None
pipeline Optional[str]

optional pipeline name to use for the whylogs session

None
tags Optional[Dict[str, str]]

optional list of tags to apply to all whylogs profiles generated through this context

None
Source code in zenml/integrations/whylogs/whylogs_context.py
def __init__(
    self,
    step_context: StepContext,
    project: Optional[str] = None,
    pipeline: Optional[str] = None,
    tags: Optional[Dict[str, str]] = None,
) -> None:
    """Create a ZenML whylogs context based on a generic step context.

    Args:
        step_context: a StepContext instance that provides information
            about the currently running step, such as the step name
        project: optional project name to use for the whylogs session
        pipeline: optional pipeline name to use for the whylogs session
        tags: optional list of tags to apply to all whylogs profiles
            generated through this context
    """
    self._step_context = step_context
    self._project = project
    self._pipeline = pipeline
    self._tags = tags
get_whylogs_session(self)

Get the whylogs session associated with the current step.

Returns:

Type Description
Session

The whylogs Session instance associated with the current step

Source code in zenml/integrations/whylogs/whylogs_context.py
def get_whylogs_session(
    self,
) -> Session:
    """Get the whylogs session associated with the current step.

    Returns:
        The whylogs Session instance associated with the current step
    """
    if self._session is not None:
        return self._session

    self._session = Session(
        project=self._project or self._step_context.step_name,
        pipeline=self._pipeline or self._step_context.step_name,
        # keeping the writers list empty, serialization is done in the
        # materializer
        writers=[],
    )

    return self._session
profile_dataframe(self, df, dataset_name=None, dataset_timestamp=None, tags=None)

Generate whylogs statistics for a Pandas dataframe.

Parameters:

Name Type Description Default
df DataFrame

a Pandas dataframe to profile.

required
dataset_name Optional[str]

the name of the dataset (Optional). If not specified, the pipeline step name is used

None
dataset_timestamp Optional[datetime.datetime]

timestamp to associate with the generated dataset profile (Optional). The current time is used if not supplied.

None
tags Optional[Dict[str, str]]

custom metadata tags associated with the whylogs profile

None

Returns:

Type Description
DatasetProfile

A whylogs DatasetProfile with the statistics generated from the input dataset.

Source code in zenml/integrations/whylogs/whylogs_context.py
def profile_dataframe(
    self,
    df: pd.DataFrame,
    dataset_name: Optional[str] = None,
    dataset_timestamp: Optional[datetime.datetime] = None,
    tags: Optional[Dict[str, str]] = None,
) -> DatasetProfile:
    """Generate whylogs statistics for a Pandas dataframe.

    Args:
        df: a Pandas dataframe to profile.
        dataset_name: the name of the dataset (Optional). If not specified,
            the pipeline step name is used
        dataset_timestamp: timestamp to associate with the generated
            dataset profile (Optional). The current time is used if not
            supplied.
        tags: custom metadata tags associated with the whylogs profile

    Returns:
        A whylogs DatasetProfile with the statistics generated from the
        input dataset.
    """
    session = self.get_whylogs_session()
    # TODO [ENG-437]: use a default whylogs dataset_name that is unique across
    #  multiple pipelines
    dataset_name = dataset_name or self._step_context.step_name
    final_tags = self._tags.copy() if self._tags else dict()
    # TODO [ENG-438]: add more zenml specific tags to the whylogs profile, such
    #  as the pipeline name and run ID
    final_tags["zenml.step"] = self._step_context.step_name
    # the datasetId tag is used to identify dataset profiles in whylabs.
    # dataset profiles with the same datasetID are considered to belong
    # to the same dataset/model.
    final_tags.setdefault("datasetId", dataset_name)
    if tags:
        final_tags.update(tags)
    logger = session.logger(
        dataset_name, dataset_timestamp=dataset_timestamp, tags=final_tags
    )
    logger.log_dataframe(df)
    return logger.close()

whylogs_step_decorator

Implementation of the whylogs step decorator.

enable_whylogs(_step=None, *, project=None, pipeline=None, tags=None)

Decorator to enable whylogs profiling for a step function.

Apply this decorator to a ZenML pipeline step to enable whylogs profiling. The decorated function will be given access to a StepContext whylogs field that facilitates access to the whylogs dataset profiling API, like so:

@enable_whylogs
@step(enable_cache=True)
def data_loader(
    context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
    ...
    data = pd.DataFrame(...)
    profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
    ...
    return data, profile

You can also use this decorator with our class-based API like so:

@enable_whylogs
class DataLoader(BaseStep):
    def entrypoint(
        self,
        context: StepContext
    ) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
        ...

Parameters:

Name Type Description Default
_step Optional[~S]

The decorated step class.

None
project Optional[str]

optional project name to use for the whylogs session

None
pipeline Optional[str]

optional pipeline name to use for the whylogs session

None
tags Optional[Dict[str, str]]

optional list of tags to apply to all profiles generated by this step

None

Returns:

Type Description
Union[~S, Callable[[~S], ~S]]

the inner decorator which enhaces the input step class with whylogs profiling functionality

Source code in zenml/integrations/whylogs/whylogs_step_decorator.py
def enable_whylogs(
    _step: Optional[S] = None,
    *,
    project: Optional[str] = None,
    pipeline: Optional[str] = None,
    tags: Optional[Dict[str, str]] = None,
) -> Union[S, Callable[[S], S]]:
    """Decorator to enable whylogs profiling for a step function.

    Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
    The decorated function will be given access to a StepContext `whylogs`
    field that facilitates access to the whylogs dataset profiling API,
    like so:

    ```python
    @enable_whylogs
    @step(enable_cache=True)
    def data_loader(
        context: StepContext,
    ) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
        ...
        data = pd.DataFrame(...)
        profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
        ...
        return data, profile
    ```

    You can also use this decorator with our class-based API like so:
    ```
    @enable_whylogs
    class DataLoader(BaseStep):
        def entrypoint(
            self,
            context: StepContext
        ) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
            ...
    ```

    Args:
        _step: The decorated step class.
        project: optional project name to use for the whylogs session
        pipeline: optional pipeline name to use for the whylogs session
        tags: optional list of tags to apply to all profiles generated by this
            step

    Returns:
        the inner decorator which enhaces the input step class with whylogs
        profiling functionality
    """

    def inner_decorator(_step: S) -> S:
        source_fn = getattr(_step, STEP_INNER_FUNC_NAME)
        new_entrypoint = whylogs_entrypoint(project, pipeline, tags)(source_fn)
        if _step._created_by_functional_api():
            # If the step was created by the functional API, the old entrypoint
            # was a static method -> make sure the new one is as well
            new_entrypoint = staticmethod(new_entrypoint)

        setattr(_step, STEP_INNER_FUNC_NAME, new_entrypoint)
        return _step

    if _step is None:
        return inner_decorator
    else:
        return inner_decorator(_step)
whylogs_entrypoint(project=None, pipeline=None, tags=None)

Decorator for a step entrypoint to enable whylogs.

Apply this decorator to a ZenML pipeline step to enable whylogs profiling. The decorated function will be given access to a StepContext whylogs field that facilitates access to the whylogs dataset profiling API, like so:

.. highlight:: python .. code-block:: python

@step(enable_cache=True)
@whylogs_entrypoint()
def data_loader(
    context: StepContext,
) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
    ...
    data = pd.DataFrame(...)
    profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
    ...
    return data, profile

Parameters:

Name Type Description Default
project Optional[str]

optional project name to use for the whylogs session

None
pipeline Optional[str]

optional pipeline name to use for the whylogs session

None
tags Optional[Dict[str, str]]

optional list of tags to apply to all profiles generated by this step

None

Returns:

Type Description
Callable[[~F], ~F]

the input function enhanced with whylogs profiling functionality

Source code in zenml/integrations/whylogs/whylogs_step_decorator.py
def whylogs_entrypoint(
    project: Optional[str] = None,
    pipeline: Optional[str] = None,
    tags: Optional[Dict[str, str]] = None,
) -> Callable[[F], F]:
    """Decorator for a step entrypoint to enable whylogs.

    Apply this decorator to a ZenML pipeline step to enable whylogs profiling.
    The decorated function will be given access to a StepContext `whylogs`
    field that facilitates access to the whylogs dataset profiling API,
    like so:

    .. highlight:: python
    .. code-block:: python

        @step(enable_cache=True)
        @whylogs_entrypoint()
        def data_loader(
            context: StepContext,
        ) -> Output(data=pd.DataFrame, profile=DatasetProfile,):
            ...
            data = pd.DataFrame(...)
            profile = context.whylogs.profile_dataframe(data, dataset_name="input_data")
            ...
            return data, profile

    Args:
        project: optional project name to use for the whylogs session
        pipeline: optional pipeline name to use for the whylogs session
        tags: optional list of tags to apply to all profiles generated by this
            step

    Returns:
        the input function enhanced with whylogs profiling functionality
    """

    def inner_decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa

            for arg in args + tuple(kwargs.values()):
                if isinstance(arg, StepContext):
                    arg.__dict__["whylogs"] = WhylogsContext(
                        arg, project, pipeline, tags
                    )
                    break
            return func(*args, **kwargs)

        return cast(F, wrapper)

    return inner_decorator

xgboost special

Initialization of the XGBoost integration.

XgboostIntegration (Integration)

Definition of xgboost integration for ZenML.

Source code in zenml/integrations/xgboost/__init__.py
class XgboostIntegration(Integration):
    """Definition of xgboost integration for ZenML."""

    NAME = XGBOOST
    REQUIREMENTS = ["xgboost>=1.0.0"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.xgboost import materializers  # noqa
activate() classmethod

Activates the integration.

Source code in zenml/integrations/xgboost/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.xgboost import materializers  # noqa

materializers special

Initialization of the XGBoost materializers.

xgboost_booster_materializer

Implementation of an XGBoost booster materializer.

XgboostBoosterMaterializer (BaseMaterializer)

Materializer to read data to and from xgboost.Booster.

Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
class XgboostBoosterMaterializer(BaseMaterializer):
    """Materializer to read data to and from xgboost.Booster."""

    ASSOCIATED_TYPES = (xgb.Booster,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
        """Reads a xgboost Booster model from a serialized JSON file.

        Args:
            data_type: A xgboost Booster type.

        Returns:
            A xgboost Booster object.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        booster = xgb.Booster()
        booster.load_model(temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return booster

    def handle_return(self, booster: xgb.Booster) -> None:
        """Creates a JSON serialization for a xgboost Booster model.

        Args:
            booster: A xgboost Booster model.
        """
        super().handle_return(booster)

        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(
            mode="w", suffix=".json", delete=False
        ) as f:
            booster.save_model(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a xgboost Booster model from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

A xgboost Booster type.

required

Returns:

Type Description
Booster

A xgboost Booster object.

Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.Booster:
    """Reads a xgboost Booster model from a serialized JSON file.

    Args:
        data_type: A xgboost Booster type.

    Returns:
        A xgboost Booster object.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    booster = xgb.Booster()
    booster.load_model(temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return booster
handle_return(self, booster)

Creates a JSON serialization for a xgboost Booster model.

Parameters:

Name Type Description Default
booster Booster

A xgboost Booster model.

required
Source code in zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py
def handle_return(self, booster: xgb.Booster) -> None:
    """Creates a JSON serialization for a xgboost Booster model.

    Args:
        booster: A xgboost Booster model.
    """
    super().handle_return(booster)

    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".json", delete=False
    ) as f:
        booster.save_model(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)
xgboost_dmatrix_materializer

Implementation of the XGBoost dmatrix materializer.

XgboostDMatrixMaterializer (BaseMaterializer)

Materializer to read data to and from xgboost.DMatrix.

Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
class XgboostDMatrixMaterializer(BaseMaterializer):
    """Materializer to read data to and from xgboost.DMatrix."""

    ASSOCIATED_TYPES = (xgb.DMatrix,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
        """Reads a xgboost.DMatrix binary file and loads it.

        Args:
            data_type: The datatype which should be read.

        Returns:
            Materialized xgboost matrix.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Create a temporary folder
        temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
        temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

        # Copy from artifact store to temporary file
        fileio.copy(filepath, temp_file)
        matrix = xgb.DMatrix(temp_file)

        # Cleanup and return
        fileio.rmtree(temp_dir)
        return matrix

    def handle_return(self, matrix: xgb.DMatrix) -> None:
        """Creates a binary serialization for a xgboost.DMatrix object.

        Args:
            matrix: A xgboost.DMatrix object.
        """
        super().handle_return(matrix)
        filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

        # Make a temporary phantom artifact
        with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
            matrix.save_binary(f.name)
            # Copy it into artifact store
            fileio.copy(f.name, filepath)

        # Close and remove the temporary file
        f.close()
        fileio.remove(f.name)
handle_input(self, data_type)

Reads a xgboost.DMatrix binary file and loads it.

Parameters:

Name Type Description Default
data_type Type[Any]

The datatype which should be read.

required

Returns:

Type Description
DMatrix

Materialized xgboost matrix.

Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_input(self, data_type: Type[Any]) -> xgb.DMatrix:
    """Reads a xgboost.DMatrix binary file and loads it.

    Args:
        data_type: The datatype which should be read.

    Returns:
        Materialized xgboost matrix.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Create a temporary folder
    temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
    temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)

    # Copy from artifact store to temporary file
    fileio.copy(filepath, temp_file)
    matrix = xgb.DMatrix(temp_file)

    # Cleanup and return
    fileio.rmtree(temp_dir)
    return matrix
handle_return(self, matrix)

Creates a binary serialization for a xgboost.DMatrix object.

Parameters:

Name Type Description Default
matrix DMatrix

A xgboost.DMatrix object.

required
Source code in zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py
def handle_return(self, matrix: xgb.DMatrix) -> None:
    """Creates a binary serialization for a xgboost.DMatrix object.

    Args:
        matrix: A xgboost.DMatrix object.
    """
    super().handle_return(matrix)
    filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)

    # Make a temporary phantom artifact
    with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
        matrix.save_binary(f.name)
        # Copy it into artifact store
        fileio.copy(f.name, filepath)

    # Close and remove the temporary file
    f.close()
    fileio.remove(f.name)