Skip to content

Integrations

zenml.integrations special

ZenML integrations module.

The ZenML integrations module contains sub-modules for each integration that we support. This includes orchestrators like Apache Airflow, visualization tools like the facets library, as well as deep learning libraries like PyTorch.

airflow special

Airflow integration for ZenML.

The Airflow integration sub-module powers an alternative to the local orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the zenml orchestrator up command.

AirflowIntegration (Integration)

Definition of Airflow Integration for ZenML.

Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
    """Definition of Airflow Integration for ZenML."""

    NAME = AIRFLOW
    REQUIREMENTS = ["apache-airflow==2.2.0"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Airflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AIRFLOW_ORCHESTRATOR_FLAVOR,
                source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
                type=StackComponentType.ORCHESTRATOR,
                integration=cls.NAME,
            )
        ]
flavors() classmethod

Declare the stack component flavors for the Airflow integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Airflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AIRFLOW_ORCHESTRATOR_FLAVOR,
            source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
            type=StackComponentType.ORCHESTRATOR,
            integration=cls.NAME,
        )
    ]

orchestrators special

The Airflow integration enables the use of Airflow as a pipeline orchestrator.

airflow_orchestrator

Implementation of Airflow orchestrator integration.

AirflowOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines using Airflow.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines using Airflow."""

    airflow_home: str = ""

    # Class Configuration
    FLAVOR: ClassVar[str] = AIRFLOW_ORCHESTRATOR_FLAVOR

    def __init__(self, **values: Any):
        """Sets environment variables to configure airflow.

        Args:
            **values: Values to set in the orchestrator.
        """
        super().__init__(**values)
        self._set_env()

    @staticmethod
    def _translate_schedule(
        schedule: Optional[Schedule] = None,
    ) -> Dict[str, Any]:
        """Convert ZenML schedule into Airflow schedule.

        The Airflow schedule uses slightly different naming and needs some
        default entries for execution without a schedule.

        Args:
            schedule: Containing the interval, start and end date and
                a boolean flag that defines if past runs should be caught up
                on

        Returns:
            Airflow configuration dict.
        """
        if schedule:
            if schedule.cron_expression:
                return {
                    "schedule_interval": schedule.cron_expression,
                }
            else:
                return {
                    "schedule_interval": schedule.interval_second,
                    "start_date": schedule.start_time,
                    "end_date": schedule.end_time,
                    "catchup": schedule.catchup,
                }

        return {
            "schedule_interval": "@once",
            # set the a start time in the past and disable catchup so airflow runs the dag immediately
            "start_date": datetime.datetime.now() - datetime.timedelta(7),
            "catchup": False,
        }

    def prepare_or_run_pipeline(
        self,
        sorted_steps: List[BaseStep],
        pipeline: "BasePipeline",
        pb2_pipeline: Pb2Pipeline,
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Creates an Airflow DAG as the intermediate representation for the pipeline.

        This DAG will be loaded by airflow in the target environment
        and used for orchestration of the pipeline.

        How it works:
        -------------
        A new airflow_dag is instantiated with the pipeline name and among
        others things the run schedule.

        For each step of the pipeline a callable is created. This callable
        uses the run_step() method to execute the step. The parameters of
        this callable are pre-filled and an airflow step_operator is created
        within the dag. The dependencies to upstream steps are then
        configured.

        Finally, the dag is fully complete and can be returned.

        Args:
            sorted_steps: List of steps in the pipeline.
            pipeline: The pipeline to be executed.
            pb2_pipeline: The pipeline as a protobuf message.
            stack: The stack on which the pipeline will be deployed.
            runtime_configuration: The runtime configuration.

        Returns:
            The Airflow DAG.
        """
        import airflow
        from airflow.operators import python as airflow_python

        # Instantiate and configure airflow Dag with name and schedule
        airflow_dag = airflow.DAG(
            dag_id=pipeline.name,
            is_paused_upon_creation=False,
            **self._translate_schedule(runtime_configuration.schedule),
        )

        # Dictionary mapping step names to airflow_operators. This will be needed
        # to configure airflow operator dependencies
        step_name_to_airflow_operator = {}

        for step in sorted_steps:
            # Create callable that will be used by airflow to execute the step
            # within the orchestrated environment
            def _step_callable(step_instance: "BaseStep", **kwargs):
                if self.requires_resources_in_orchestration_environment(step):
                    logger.warning(
                        "Specifying step resources is not yet supported for "
                        "the Airflow orchestrator, ignoring resource "
                        "configuration for step %s.",
                        step.name,
                    )
                # Extract run name for the kwargs that will be passed to the
                # callable
                run_name = kwargs["ti"].get_dagrun().run_id
                self.run_step(
                    step=step_instance,
                    run_name=run_name,
                    pb2_pipeline=pb2_pipeline,
                )

            # Create airflow python operator that contains the step callable
            airflow_operator = airflow_python.PythonOperator(
                dag=airflow_dag,
                task_id=step.name,
                provide_context=True,
                python_callable=functools.partial(
                    _step_callable, step_instance=step
                ),
            )

            # Configure the current airflow operator to run after all upstream
            # operators finished executing
            step_name_to_airflow_operator[step.name] = airflow_operator
            upstream_step_names = self.get_upstream_step_names(
                step=step, pb2_pipeline=pb2_pipeline
            )
            for upstream_step_name in upstream_step_names:
                airflow_operator.set_upstream(
                    step_name_to_airflow_operator[upstream_step_name]
                )

        # Return the finished airflow dag
        return airflow_dag

    @root_validator(skip_on_failure=True)
    def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Sets Airflow home according to orchestrator UUID.

        Args:
            values: Dictionary containing all orchestrator attributes values.

        Returns:
            Dictionary containing all orchestrator attributes values and the airflow home.

        Raises:
            ValueError: If the orchestrator UUID is not set.
        """
        if "uuid" not in values:
            raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
        values["airflow_home"] = os.path.join(
            io_utils.get_global_config_directory(),
            AIRFLOW_ROOT_DIR,
            str(values["uuid"]),
        )
        return values

    @property
    def dags_directory(self) -> str:
        """Returns path to the airflow dags directory.

        Returns:
            Path to the airflow dags directory.
        """
        return os.path.join(self.airflow_home, "dags")

    @property
    def pid_file(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.airflow_home, "airflow_daemon.pid")

    @property
    def log_file(self) -> str:
        """Returns path to the airflow log file.

        Returns:
            str: Path to the airflow log file.
        """
        return os.path.join(self.airflow_home, "airflow_orchestrator.log")

    @property
    def password_file(self) -> str:
        """Returns path to the webserver password file.

        Returns:
            Path to the webserver password file.
        """
        return os.path.join(self.airflow_home, "standalone_admin_password.txt")

    def _set_env(self) -> None:
        """Sets environment variables to configure airflow."""
        os.environ["AIRFLOW_HOME"] = self.airflow_home
        os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
        os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
        os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
        # check the DAG folder every 10 seconds for new files
        os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"

    def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
        """Copies DAG module to the Airflow DAGs directory if not already present.

        Args:
            dag_filepath: Path to the file in which the DAG is defined.
        """
        dags_directory = io_utils.resolve_relative_path(self.dags_directory)

        if dags_directory == os.path.dirname(dag_filepath):
            logger.debug("File is already in airflow DAGs directory.")
        else:
            logger.debug(
                "Copying dag file '%s' to DAGs directory.", dag_filepath
            )
            destination_path = os.path.join(
                dags_directory, os.path.basename(dag_filepath)
            )
            if fileio.exists(destination_path):
                logger.info(
                    "File '%s' already exists, overwriting with new DAG file",
                    destination_path,
                )
            fileio.copy(dag_filepath, destination_path, overwrite=True)

    def _log_webserver_credentials(self) -> None:
        """Logs URL and credentials to log in to the airflow webserver.

        Raises:
            FileNotFoundError: If the password file does not exist.
        """
        if fileio.exists(self.password_file):
            with open(self.password_file) as file:
                password = file.read().strip()
        else:
            raise FileNotFoundError(
                f"Can't find password file '{self.password_file}'"
            )
        logger.info(
            "To inspect your DAGs, login to http://0.0.0.0:8080 "
            "with username: admin password: %s",
            password,
        )

    def runtime_options(self) -> Dict[str, Any]:
        """Runtime options for the airflow orchestrator.

        Returns:
            Runtime options dictionary.
        """
        return {DAG_FILEPATH_OPTION_KEY: None}

    def prepare_pipeline_deployment(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> None:
        """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

        Args:
            pipeline: Pipeline to be deployed.
            stack: Stack to be deployed.
            runtime_configuration: Runtime configuration for the pipeline.

        Raises:
            RuntimeError: If Airflow is not running or no DAG filepath runtime
                          option is provided.
        """
        if not self.is_running:
            raise RuntimeError(
                "Airflow orchestrator is currently not running. Run `zenml "
                "stack up` to provision resources for the active stack."
            )

        if Environment.in_notebook():
            raise RuntimeError(
                "Unable to run the Airflow orchestrator from within a "
                "notebook. Airflow requires a python file which contains a "
                "global Airflow DAG object and therefore does not work with "
                "notebooks. Please copy your ZenML pipeline code in a python "
                "file and try again."
            )

        try:
            dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
        except KeyError:
            raise RuntimeError(
                f"No DAG filepath found in runtime configuration. Make sure "
                f"to add the filepath to your airflow DAG file as a runtime "
                f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
            )

        self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)

    @property
    def is_running(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the daemon is running, False otherwise.

        Raises:
            RuntimeError: If port 8080 is occupied.
        """
        from airflow.cli.commands.standalone_command import StandaloneCommand
        from airflow.jobs.triggerer_job import TriggererJob

        daemon_running = daemon.check_if_daemon_is_running(self.pid_file)

        command = StandaloneCommand()
        webserver_port_open = command.port_open(8080)

        if not daemon_running:
            if webserver_port_open:
                raise RuntimeError(
                    "The airflow daemon does not seem to be running but "
                    "local port 8080 is occupied. Make sure the port is "
                    "available and try again."
                )

            # exit early so we don't check non-existing airflow databases
            return False

        # we can't use StandaloneCommand().is_ready() here as the
        # Airflow SequentialExecutor apparently does not send a heartbeat
        # while running a task which would result in this returning `False`
        # even if Airflow is running.
        airflow_running = webserver_port_open and command.job_running(
            TriggererJob
        )
        return airflow_running

    @property
    def is_provisioned(self) -> bool:
        """Returns whether the airflow daemon is currently running.

        Returns:
            True if the airflow daemon is running, False otherwise.
        """
        return self.is_running

    def provision(self) -> None:
        """Ensures that Airflow is running."""
        if self.is_running:
            logger.info("Airflow is already running.")
            self._log_webserver_credentials()
            return

        if not fileio.exists(self.dags_directory):
            io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

        from airflow.cli.commands.standalone_command import StandaloneCommand

        try:
            command = StandaloneCommand()
            # Run the daemon with a working directory inside the current
            # zenml repo so the same repo will be used to run the DAGs
            daemon.run_as_daemon(
                command.run,
                pid_file=self.pid_file,
                log_file=self.log_file,
                working_directory=get_source_root_path(),
            )
            while not self.is_running:
                # Wait until the daemon started all the relevant airflow
                # processes
                time.sleep(0.1)
            self._log_webserver_credentials()
        except Exception as e:
            logger.error(e)
            logger.error(
                "An error occurred while starting the Airflow daemon. If you "
                "want to start it manually, use the commands described in the "
                "official Airflow quickstart guide for running Airflow locally."
            )
            self.deprovision()

    def deprovision(self) -> None:
        """Stops the airflow daemon if necessary and tears down resources."""
        if self.is_running:
            daemon.stop_daemon(self.pid_file)

        fileio.rmtree(self.airflow_home)
        logger.info("Airflow spun down.")
dags_directory: str property readonly

Returns path to the airflow dags directory.

Returns:

Type Description
str

Path to the airflow dags directory.

is_provisioned: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the airflow daemon is running, False otherwise.

is_running: bool property readonly

Returns whether the airflow daemon is currently running.

Returns:

Type Description
bool

True if the daemon is running, False otherwise.

Exceptions:

Type Description
RuntimeError

If port 8080 is occupied.

log_file: str property readonly

Returns path to the airflow log file.

Returns:

Type Description
str

Path to the airflow log file.

password_file: str property readonly

Returns path to the webserver password file.

Returns:

Type Description
str

Path to the webserver password file.

pid_file: str property readonly

Returns path to the daemon PID file.

Returns:

Type Description
str

Path to the daemon PID file.

__init__(self, **values) special

Sets environment variables to configure airflow.

Parameters:

Name Type Description Default
**values Any

Values to set in the orchestrator.

{}
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
    """Sets environment variables to configure airflow.

    Args:
        **values: Values to set in the orchestrator.
    """
    super().__init__(**values)
    self._set_env()
deprovision(self)

Stops the airflow daemon if necessary and tears down resources.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
    """Stops the airflow daemon if necessary and tears down resources."""
    if self.is_running:
        daemon.stop_daemon(self.pid_file)

    fileio.rmtree(self.airflow_home)
    logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)

Creates an Airflow DAG as the intermediate representation for the pipeline.

This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.

How it works:

A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.

For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.

Finally, the dag is fully complete and can be returned.

Parameters:

Name Type Description Default
sorted_steps List[zenml.steps.base_step.BaseStep]

List of steps in the pipeline.

required
pipeline BasePipeline

The pipeline to be executed.

required
pb2_pipeline Pipeline

The pipeline as a protobuf message.

required
stack Stack

The stack on which the pipeline will be deployed.

required
runtime_configuration RuntimeConfiguration

The runtime configuration.

required

Returns:

Type Description
Any

The Airflow DAG.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
    self,
    sorted_steps: List[BaseStep],
    pipeline: "BasePipeline",
    pb2_pipeline: Pb2Pipeline,
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Creates an Airflow DAG as the intermediate representation for the pipeline.

    This DAG will be loaded by airflow in the target environment
    and used for orchestration of the pipeline.

    How it works:
    -------------
    A new airflow_dag is instantiated with the pipeline name and among
    others things the run schedule.

    For each step of the pipeline a callable is created. This callable
    uses the run_step() method to execute the step. The parameters of
    this callable are pre-filled and an airflow step_operator is created
    within the dag. The dependencies to upstream steps are then
    configured.

    Finally, the dag is fully complete and can be returned.

    Args:
        sorted_steps: List of steps in the pipeline.
        pipeline: The pipeline to be executed.
        pb2_pipeline: The pipeline as a protobuf message.
        stack: The stack on which the pipeline will be deployed.
        runtime_configuration: The runtime configuration.

    Returns:
        The Airflow DAG.
    """
    import airflow
    from airflow.operators import python as airflow_python

    # Instantiate and configure airflow Dag with name and schedule
    airflow_dag = airflow.DAG(
        dag_id=pipeline.name,
        is_paused_upon_creation=False,
        **self._translate_schedule(runtime_configuration.schedule),
    )

    # Dictionary mapping step names to airflow_operators. This will be needed
    # to configure airflow operator dependencies
    step_name_to_airflow_operator = {}

    for step in sorted_steps:
        # Create callable that will be used by airflow to execute the step
        # within the orchestrated environment
        def _step_callable(step_instance: "BaseStep", **kwargs):
            if self.requires_resources_in_orchestration_environment(step):
                logger.warning(
                    "Specifying step resources is not yet supported for "
                    "the Airflow orchestrator, ignoring resource "
                    "configuration for step %s.",
                    step.name,
                )
            # Extract run name for the kwargs that will be passed to the
            # callable
            run_name = kwargs["ti"].get_dagrun().run_id
            self.run_step(
                step=step_instance,
                run_name=run_name,
                pb2_pipeline=pb2_pipeline,
            )

        # Create airflow python operator that contains the step callable
        airflow_operator = airflow_python.PythonOperator(
            dag=airflow_dag,
            task_id=step.name,
            provide_context=True,
            python_callable=functools.partial(
                _step_callable, step_instance=step
            ),
        )

        # Configure the current airflow operator to run after all upstream
        # operators finished executing
        step_name_to_airflow_operator[step.name] = airflow_operator
        upstream_step_names = self.get_upstream_step_names(
            step=step, pb2_pipeline=pb2_pipeline
        )
        for upstream_step_name in upstream_step_names:
            airflow_operator.set_upstream(
                step_name_to_airflow_operator[upstream_step_name]
            )

    # Return the finished airflow dag
    return airflow_dag
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)

Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

Parameters:

Name Type Description Default
pipeline BasePipeline

Pipeline to be deployed.

required
stack Stack

Stack to be deployed.

required
runtime_configuration RuntimeConfiguration

Runtime configuration for the pipeline.

required

Exceptions:

Type Description
RuntimeError

If Airflow is not running or no DAG filepath runtime option is provided.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> None:
    """Checks Airflow is running and copies DAG file to the Airflow DAGs directory.

    Args:
        pipeline: Pipeline to be deployed.
        stack: Stack to be deployed.
        runtime_configuration: Runtime configuration for the pipeline.

    Raises:
        RuntimeError: If Airflow is not running or no DAG filepath runtime
                      option is provided.
    """
    if not self.is_running:
        raise RuntimeError(
            "Airflow orchestrator is currently not running. Run `zenml "
            "stack up` to provision resources for the active stack."
        )

    if Environment.in_notebook():
        raise RuntimeError(
            "Unable to run the Airflow orchestrator from within a "
            "notebook. Airflow requires a python file which contains a "
            "global Airflow DAG object and therefore does not work with "
            "notebooks. Please copy your ZenML pipeline code in a python "
            "file and try again."
        )

    try:
        dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
    except KeyError:
        raise RuntimeError(
            f"No DAG filepath found in runtime configuration. Make sure "
            f"to add the filepath to your airflow DAG file as a runtime "
            f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
        )

    self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)

Ensures that Airflow is running.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
    """Ensures that Airflow is running."""
    if self.is_running:
        logger.info("Airflow is already running.")
        self._log_webserver_credentials()
        return

    if not fileio.exists(self.dags_directory):
        io_utils.create_dir_recursive_if_not_exists(self.dags_directory)

    from airflow.cli.commands.standalone_command import StandaloneCommand

    try:
        command = StandaloneCommand()
        # Run the daemon with a working directory inside the current
        # zenml repo so the same repo will be used to run the DAGs
        daemon.run_as_daemon(
            command.run,
            pid_file=self.pid_file,
            log_file=self.log_file,
            working_directory=get_source_root_path(),
        )
        while not self.is_running:
            # Wait until the daemon started all the relevant airflow
            # processes
            time.sleep(0.1)
        self._log_webserver_credentials()
    except Exception as e:
        logger.error(e)
        logger.error(
            "An error occurred while starting the Airflow daemon. If you "
            "want to start it manually, use the commands described in the "
            "official Airflow quickstart guide for running Airflow locally."
        )
        self.deprovision()
runtime_options(self)

Runtime options for the airflow orchestrator.

Returns:

Type Description
Dict[str, Any]

Runtime options dictionary.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
    """Runtime options for the airflow orchestrator.

    Returns:
        Runtime options dictionary.
    """
    return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values) classmethod

Sets Airflow home according to orchestrator UUID.

Parameters:

Name Type Description Default
values Dict[str, Any]

Dictionary containing all orchestrator attributes values.

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing all orchestrator attributes values and the airflow home.

Exceptions:

Type Description
ValueError

If the orchestrator UUID is not set.

Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
    """Sets Airflow home according to orchestrator UUID.

    Args:
        values: Dictionary containing all orchestrator attributes values.

    Returns:
        Dictionary containing all orchestrator attributes values and the airflow home.

    Raises:
        ValueError: If the orchestrator UUID is not set.
    """
    if "uuid" not in values:
        raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
    values["airflow_home"] = os.path.join(
        io_utils.get_global_config_directory(),
        AIRFLOW_ROOT_DIR,
        str(values["uuid"]),
    )
    return values

aws special

Integrates multiple AWS Tools as Stack Components.

The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.

AWSIntegration (Integration)

Definition of AWS integration for ZenML.

Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
    """Definition of AWS integration for ZenML."""

    NAME = AWS
    REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the AWS integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AWS_SECRET_MANAGER_FLAVOR,
                source="zenml.integrations.aws.secrets_managers"
                ".AWSSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_CONTAINER_REGISTRY_FLAVOR,
                source="zenml.integrations.aws.container_registries"
                ".AWSContainerRegistry",
                type=StackComponentType.CONTAINER_REGISTRY,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.aws.step_operators"
                ".SagemakerStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declare the stack component flavors for the AWS integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the AWS integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AWS_SECRET_MANAGER_FLAVOR,
            source="zenml.integrations.aws.secrets_managers"
            ".AWSSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_CONTAINER_REGISTRY_FLAVOR,
            source="zenml.integrations.aws.container_registries"
            ".AWSContainerRegistry",
            type=StackComponentType.CONTAINER_REGISTRY,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.aws.step_operators"
            ".SagemakerStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

container_registries special

Initialization of AWS Container Registry integration.

aws_container_registry

Implementation of the AWS container registry integration.

AWSContainerRegistry (BaseContainerRegistry) pydantic-model

Class for AWS Container Registry.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
    """Class for AWS Container Registry."""

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_CONTAINER_REGISTRY_FLAVOR

    @validator("uri")
    def validate_aws_uri(cls, uri: str) -> str:
        """Validates that the URI is in the correct format.

        Args:
            uri: URI to validate.

        Returns:
            URI in the correct format.

        Raises:
            ValueError: If the URI contains a slash character.
        """
        if "/" in uri:
            raise ValueError(
                "Property `uri` can not contain a `/`. An example of a valid "
                "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
            )

        return uri

    def _get_region(self) -> str:
        """Parses the AWS region from the registry URI.

        Raises:
            RuntimeError: If the region parsing fails due to an invalid URI.

        Returns:
            The region string.
        """
        match = re.fullmatch(r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.uri)
        if not match:
            raise RuntimeError(
                f"Unable to parse region from ECR URI {self.uri}."
            )

        return match.group(1)

    def prepare_image_push(self, image_name: str) -> None:
        """Logs warning message if trying to push an image for which no repository exists.

        Args:
            image_name: Name of the docker image that will be pushed.

        Raises:
            ValueError: If the docker image name is invalid.
        """
        response = boto3.client(
            "ecr", region_name=self._get_region()
        ).describe_repositories()
        try:
            repo_uris: List[str] = [
                repository["repositoryUri"]
                for repository in response["repositories"]
            ]
        except (KeyError, ClientError) as e:
            # invalid boto response, let's hope for the best and just push
            logger.debug("Error while trying to fetch ECR repositories: %s", e)
            return

        repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
        if not repo_exists:
            match = re.search(f"{self.uri}/(.*):.*", image_name)
            if not match:
                raise ValueError(f"Invalid docker image name '{image_name}'.")

            repo_name = match.group(1)
            logger.warning(
                "Amazon ECR requires you to create a repository before you can "
                f"push an image to it. ZenML is trying to push the image "
                f"{image_name} but could only detect the following "
                f"repositories: {repo_uris}. We will try to push anyway, but "
                f"in case it fails you need to create a repository named "
                f"`{repo_name}`."
            )

    @property
    def post_registration_message(self) -> Optional[str]:
        """Optional message printed after the stack component is registered.

        Returns:
            Info message regarding docker repositories in AWS.
        """
        return (
            "Amazon ECR requires you to create a repository before you can "
            "push an image to it. If you want to for example run a pipeline "
            "using our Kubeflow orchestrator, ZenML will automatically build a "
            f"docker image called `{self.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
            f"and try to push it. This will fail unless you create the "
            f"repository `zenml-kubeflow` inside your amazon registry."
        )
post_registration_message: Optional[str] property readonly

Optional message printed after the stack component is registered.

Returns:

Type Description
Optional[str]

Info message regarding docker repositories in AWS.

prepare_image_push(self, image_name)

Logs warning message if trying to push an image for which no repository exists.

Parameters:

Name Type Description Default
image_name str

Name of the docker image that will be pushed.

required

Exceptions:

Type Description
ValueError

If the docker image name is invalid.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
    """Logs warning message if trying to push an image for which no repository exists.

    Args:
        image_name: Name of the docker image that will be pushed.

    Raises:
        ValueError: If the docker image name is invalid.
    """
    response = boto3.client(
        "ecr", region_name=self._get_region()
    ).describe_repositories()
    try:
        repo_uris: List[str] = [
            repository["repositoryUri"]
            for repository in response["repositories"]
        ]
    except (KeyError, ClientError) as e:
        # invalid boto response, let's hope for the best and just push
        logger.debug("Error while trying to fetch ECR repositories: %s", e)
        return

    repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
    if not repo_exists:
        match = re.search(f"{self.uri}/(.*):.*", image_name)
        if not match:
            raise ValueError(f"Invalid docker image name '{image_name}'.")

        repo_name = match.group(1)
        logger.warning(
            "Amazon ECR requires you to create a repository before you can "
            f"push an image to it. ZenML is trying to push the image "
            f"{image_name} but could only detect the following "
            f"repositories: {repo_uris}. We will try to push anyway, but "
            f"in case it fails you need to create a repository named "
            f"`{repo_name}`."
        )
validate_aws_uri(uri) classmethod

Validates that the URI is in the correct format.

Parameters:

Name Type Description Default
uri str

URI to validate.

required

Returns:

Type Description
str

URI in the correct format.

Exceptions:

Type Description
ValueError

If the URI contains a slash character.

Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
    """Validates that the URI is in the correct format.

    Args:
        uri: URI to validate.

    Returns:
        URI in the correct format.

    Raises:
        ValueError: If the URI contains a slash character.
    """
    if "/" in uri:
        raise ValueError(
            "Property `uri` can not contain a `/`. An example of a valid "
            "URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
        )

    return uri

secrets_managers special

AWS Secrets Manager.

aws_secrets_manager

Implementation of the AWS Secrets Manager integration.

AWSSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the AWS secrets manager.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
    """Class to interact with the AWS secrets manager."""

    region_name: str

    # Class configuration
    FLAVOR: ClassVar[str] = AWS_SECRET_MANAGER_FLAVOR
    SUPPORTS_SCOPING: ClassVar[bool] = True
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _validate_scope(
        cls,
        scope: SecretsManagerScope,
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            cls.validate_secret_name_or_namespace(namespace)

    @classmethod
    def _ensure_client_connected(cls, region_name: str) -> None:
        """Ensure that the client is connected to the AWS secrets manager.

        Args:
            region_name: the AWS region name
        """
        if cls.CLIENT is None:
            # Create a Secrets Manager client
            session = boto3.session.Session()
            cls.CLIENT = session.client(
                service_name="secretsmanager", region_name=region_name
            )

    @classmethod
    def validate_secret_name_or_namespace(cls, name: str) -> None:
        """Validate a secret name or namespace.

        AWS secret names must contain only alphanumeric characters and the
        characters /_+=.@-. The `/` character is only used internally to delimit
        scopes.

        Args:
            name: the secret name or namespace

        Raises:
            ValueError: if the secret name or namespace is invalid
        """
        if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. Must contain "
                f"only alphanumeric characters and the characters _+=.@-."
            )

    def _get_secret_tags(
        self, secret: BaseSecretSchema
    ) -> List[Dict[str, str]]:
        """Return a list of AWS secret tag values for a given secret.

        Args:
            secret: the secret object

        Returns:
            A list of AWS secret tag values
        """
        metadata = self._get_secret_metadata(secret)
        return [{"Key": k, "Value": v} for k, v in metadata.items()]

    def _get_secret_scope_filters(
        self,
        secret_name: Optional[str] = None,
    ) -> List[Dict[str, Any]]:
        """Return a list of AWS filters for the entire scope or just a scoped secret.

        These filters can be used when querying the AWS Secrets Manager
        for all secrets or for a single secret available in the configured
        scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html

        Example AWS filters for all secrets in the current (namespace) scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Example AWS filters for a particular secret in the current (namespace)
        scope:

        ```python
        [
            {
                "Key: "tag-key",
                "Values": ["zenml_secret_name"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_secret"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_scope"],
            },
            {
                "Key: "tag-value",
                "Values": ["namespace"],
            },
            {
                "Key: "tag-key",
                "Values": ["zenml_namespace"],
            },
            {
                "Key: "tag-value",
                "Values": ["my_namespace"],
            },
        ]
        ```

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of AWS filters uniquely identifying all secrets
            or a named secret within the configured scope.
        """
        metadata = self._get_secret_scope_metadata(secret_name)
        filters: List[Dict[str, Any]] = []
        for k, v in metadata.items():
            filters.append(
                {
                    "Key": "tag-key",
                    "Values": [
                        k,
                    ],
                }
            )
            filters.append(
                {
                    "Key": "tag-value",
                    "Values": [
                        str(v),
                    ],
                }
            )

        return filters

    def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
        """List all secrets matching a name.

        This method lists all the secrets in the current scope without loading
        their contents. An optional secret name can be supplied to filter out
        all but a single secret identified by name.

        Args:
            secret_name: Optional secret name to filter for.

        Returns:
            A list of secret names in the current scope and the optional
            secret name.
        """
        self._ensure_client_connected(self.region_name)

        filters: List[Dict[str, Any]] = []
        prefix: Optional[str] = None
        if self.scope == SecretsManagerScope.NONE:
            # unscoped (legacy) secrets don't have tags. We want to filter out
            # non-legacy secrets
            filters = [
                {
                    "Key": "tag-key",
                    "Values": [
                        "!zenml_scope",
                    ],
                },
            ]
            if secret_name:
                prefix = secret_name
        else:
            filters = self._get_secret_scope_filters()
            if secret_name:
                prefix = self._get_scoped_secret_name(secret_name)
            else:
                # add the name prefix to the filters to account for the fact
                # that AWS does not do exact matching but prefix-matching on the
                # filters
                prefix = self._get_scoped_secret_name_prefix()

        if prefix:
            filters.append(
                {
                    "Key": "name",
                    "Values": [
                        f"{prefix}",
                    ],
                }
            )

        # TODO [ENG-720]: Deal with pagination in the aws secret manager when
        #  listing all secrets
        # TODO [ENG-721]: take out this magic maxresults number
        response = self.CLIENT.list_secrets(MaxResults=100, Filters=filters)
        results = []
        for secret in response["SecretList"]:
            name = self._get_unscoped_secret_name(secret["Name"])
            # keep only the names that are in scope and filter by secret name,
            # if one was given
            if name and (not secret_name or secret_name == name):
                results.append(name)

        return results

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.region_name)

        if self._list_secrets(secret.name):
            raise SecretExistsError(
                f"A Secret with the name {secret.name} already exists"
            )

        secret_value = json.dumps(secret_to_dict(secret, encode=False))
        kwargs: Dict[str, Any] = {
            "Name": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
            "Tags": self._get_secret_tags(secret),
        }

        self.CLIENT.create_secret(**kwargs)

        logger.debug("Created AWS secret: %s", kwargs["Name"])

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Gets a secret.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret_name)
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        get_secret_value_response = self.CLIENT.get_secret_value(
            SecretId=self._get_scoped_secret_name(secret_name)
        )
        if "SecretString" not in get_secret_value_response:
            get_secret_value_response = None

        return secret_from_dict(
            json.loads(get_secret_value_response["SecretString"]),
            secret_name=secret_name,
            decode=False,
        )

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        return self._list_secrets()

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name_or_namespace(secret.name)
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret.name):
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        secret_value = json.dumps(secret_to_dict(secret))

        kwargs = {
            "SecretId": self._get_scoped_secret_name(secret.name),
            "SecretString": secret_value,
        }

        self.CLIENT.put_secret_value(**kwargs)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret does not exist
        """
        self._ensure_client_connected(self.region_name)

        if not self._list_secrets(secret_name):
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets.

        This method will force delete all your secrets. You will not be able to
        recover them once this method is called.
        """
        self._ensure_client_connected(self.region_name)
        for secret_name in self._list_secrets():
            self.CLIENT.delete_secret(
                SecretId=self._get_scoped_secret_name(secret_name),
                ForceDeleteWithoutRecovery=True,
            )
delete_all_secrets(self)

Delete all existing secrets.

This method will force delete all your secrets. You will not be able to recover them once this method is called.

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets.

    This method will force delete all your secrets. You will not be able to
    recover them once this method is called.
    """
    self._ensure_client_connected(self.region_name)
    for secret_name in self._list_secrets():
        self.CLIENT.delete_secret(
            SecretId=self._get_scoped_secret_name(secret_name),
            ForceDeleteWithoutRecovery=True,
        )
delete_secret(self, secret_name)

Delete an existing secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret does not exist
    """
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    self.CLIENT.delete_secret(
        SecretId=self._get_scoped_secret_name(secret_name),
        ForceDeleteWithoutRecovery=True,
    )
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    return self._list_secrets()
get_secret(self, secret_name)

Gets a secret.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Gets a secret.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret_name)
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret_name):
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    get_secret_value_response = self.CLIENT.get_secret_value(
        SecretId=self._get_scoped_secret_name(secret_name)
    )
    if "SecretString" not in get_secret_value_response:
        get_secret_value_response = None

    return secret_from_dict(
        json.loads(get_secret_value_response["SecretString"]),
        secret_name=secret_name,
        decode=False,
    )
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.region_name)

    if self._list_secrets(secret.name):
        raise SecretExistsError(
            f"A Secret with the name {secret.name} already exists"
        )

    secret_value = json.dumps(secret_to_dict(secret, encode=False))
    kwargs: Dict[str, Any] = {
        "Name": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
        "Tags": self._get_secret_tags(secret),
    }

    self.CLIENT.create_secret(**kwargs)

    logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)

Update an existing secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name_or_namespace(secret.name)
    self._ensure_client_connected(self.region_name)

    if not self._list_secrets(secret.name):
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    secret_value = json.dumps(secret_to_dict(secret))

    kwargs = {
        "SecretId": self._get_scoped_secret_name(secret.name),
        "SecretString": secret_value,
    }

    self.CLIENT.put_secret_value(**kwargs)
validate_secret_name_or_namespace(name) classmethod

Validate a secret name or namespace.

AWS secret names must contain only alphanumeric characters and the characters /_+=.@-. The / character is only used internally to delimit scopes.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
    """Validate a secret name or namespace.

    AWS secret names must contain only alphanumeric characters and the
    characters /_+=.@-. The `/` character is only used internally to delimit
    scopes.

    Args:
        name: the secret name or namespace

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the characters _+=.@-."
        )

step_operators special

Initialization of the Sagemaker Step Operator.

sagemaker_step_operator

Implementation of the Sagemaker Step Operator.

SagemakerStepOperator (BaseStepOperator, PipelineDockerImageBuilder) pydantic-model

Step operator to run a step on Sagemaker.

This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.

Attributes:

Name Type Description
role str

The role that has to be assigned to the jobs which are running in Sagemaker.

instance_type str

The type of the compute instance where jobs will run.

base_image Optional[str]

The base image to use for building the docker image that will be executed.

bucket Optional[str]

Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}".

experiment_name Optional[str]

The name for the experiment to which the job will be associated. If not provided, the job runs would be independent.

Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator, PipelineDockerImageBuilder):
    """Step operator to run a step on Sagemaker.

    This class defines code that builds an image with the ZenML entrypoint
    to run using Sagemaker's Estimator.

    Attributes:
        role: The role that has to be assigned to the jobs which are
            running in Sagemaker.
        instance_type: The type of the compute instance where jobs will run.
        base_image: The base image to use for building the docker
            image that will be executed.
        bucket: Name of the S3 bucket to use for storing artifacts
            from the job run. If not provided, a default bucket will be created
            based on the following format: "sagemaker-{region}-{aws-account-id}".
        experiment_name: The name for the experiment to which the job
            will be associated. If not provided, the job runs would be
            independent.
    """

    role: str
    instance_type: str

    base_image: Optional[str] = None
    bucket: Optional[str] = None
    experiment_name: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR

    _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
        ("base_image", "docker_parent_image")
    )

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains a container registry.

        Returns:
            A validator that checks that the stack contains a container registry.
        """

        def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
            return (
                stack.orchestrator.FLAVOR == "local",
                "Local orchestrator is required",
            )

        return StackValidator(
            required_components={StackComponentType.CONTAINER_REGISTRY},
            custom_validation_function=_ensure_local_orchestrator,
        )

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        docker_configuration: "DockerConfiguration",
        entrypoint_command: List[str],
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Launches a step on Sagemaker.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            docker_configuration: The Docker configuration for this step.
            entrypoint_command: Command that executes the step.
            resource_configuration: The resource configuration for this step.
        """
        image_name = self.build_and_push_docker_image(
            pipeline_name=pipeline_name,
            docker_configuration=docker_configuration,
            stack=Repository().active_stack,
            runtime_configuration=RuntimeConfiguration(),
            entrypoint=" ".join(entrypoint_command),
        )

        if not resource_configuration.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the SageMaker step operator. If you want to run this step "
                "operator on specific resources, you can do so by configuring "
                "a different instance type like this: "
                "`zenml step-operator update %s "
                "--instance_type=<INSTANCE_TYPE>`",
                self.name,
            )

        session = sagemaker.Session(default_bucket=self.bucket)
        estimator = sagemaker.estimator.Estimator(
            image_name,
            self.role,
            instance_count=1,
            instance_type=self.instance_type,
            sagemaker_session=session,
        )

        # Sagemaker doesn't allow any underscores in job/experiment/trial names
        sanitized_run_name = run_name.replace("_", "-")

        experiment_config = {}
        if self.experiment_name:
            experiment_config = {
                "ExperimentName": self.experiment_name,
                "TrialName": sanitized_run_name,
            }

        estimator.fit(
            wait=True,
            experiment_config=experiment_config,
            job_name=sanitized_run_name,
        )
validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains a container registry.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A validator that checks that the stack contains a container registry.

launch(self, pipeline_name, run_name, docker_configuration, entrypoint_command, resource_configuration)

Launches a step on Sagemaker.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
docker_configuration DockerConfiguration

The Docker configuration for this step.

required
entrypoint_command List[str]

Command that executes the step.

required
resource_configuration ResourceConfiguration

The resource configuration for this step.

required
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    docker_configuration: "DockerConfiguration",
    entrypoint_command: List[str],
    resource_configuration: "ResourceConfiguration",
) -> None:
    """Launches a step on Sagemaker.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        docker_configuration: The Docker configuration for this step.
        entrypoint_command: Command that executes the step.
        resource_configuration: The resource configuration for this step.
    """
    image_name = self.build_and_push_docker_image(
        pipeline_name=pipeline_name,
        docker_configuration=docker_configuration,
        stack=Repository().active_stack,
        runtime_configuration=RuntimeConfiguration(),
        entrypoint=" ".join(entrypoint_command),
    )

    if not resource_configuration.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the SageMaker step operator. If you want to run this step "
            "operator on specific resources, you can do so by configuring "
            "a different instance type like this: "
            "`zenml step-operator update %s "
            "--instance_type=<INSTANCE_TYPE>`",
            self.name,
        )

    session = sagemaker.Session(default_bucket=self.bucket)
    estimator = sagemaker.estimator.Estimator(
        image_name,
        self.role,
        instance_count=1,
        instance_type=self.instance_type,
        sagemaker_session=session,
    )

    # Sagemaker doesn't allow any underscores in job/experiment/trial names
    sanitized_run_name = run_name.replace("_", "-")

    experiment_config = {}
    if self.experiment_name:
        experiment_config = {
            "ExperimentName": self.experiment_name,
            "TrialName": sanitized_run_name,
        }

    estimator.fit(
        wait=True,
        experiment_config=experiment_config,
        job_name=sanitized_run_name,
    )

azure special

Initialization of the ZenML Azure integration.

The Azure integration submodule provides a way to run ZenML pipelines in a cloud environment. Specifically, it allows the use of cloud artifact stores, and an io module to handle file operations on Azure Blob Storage. The Azure Step Operator integration submodule provides a way to run ZenML steps in AzureML.

AzureIntegration (Integration)

Definition of Azure integration for ZenML.

Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
    """Definition of Azure integration for ZenML."""

    NAME = AZURE
    REQUIREMENTS = [
        "adlfs==2021.10.0",
        "azure-keyvault-keys",
        "azure-keyvault-secrets",
        "azure-identity",
        "azureml-core==1.42.0.post1",
    ]

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declares the flavors for the integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=AZURE_ARTIFACT_STORE_FLAVOR,
                source="zenml.integrations.azure.artifact_stores"
                ".AzureArtifactStore",
                type=StackComponentType.ARTIFACT_STORE,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZURE_SECRETS_MANAGER_FLAVOR,
                source="zenml.integrations.azure.secrets_managers"
                ".AzureSecretsManager",
                type=StackComponentType.SECRETS_MANAGER,
                integration=cls.NAME,
            ),
            FlavorWrapper(
                name=AZUREML_STEP_OPERATOR_FLAVOR,
                source="zenml.integrations.azure.step_operators"
                ".AzureMLStepOperator",
                type=StackComponentType.STEP_OPERATOR,
                integration=cls.NAME,
            ),
        ]
flavors() classmethod

Declares the flavors for the integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declares the flavors for the integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=AZURE_ARTIFACT_STORE_FLAVOR,
            source="zenml.integrations.azure.artifact_stores"
            ".AzureArtifactStore",
            type=StackComponentType.ARTIFACT_STORE,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZURE_SECRETS_MANAGER_FLAVOR,
            source="zenml.integrations.azure.secrets_managers"
            ".AzureSecretsManager",
            type=StackComponentType.SECRETS_MANAGER,
            integration=cls.NAME,
        ),
        FlavorWrapper(
            name=AZUREML_STEP_OPERATOR_FLAVOR,
            source="zenml.integrations.azure.step_operators"
            ".AzureMLStepOperator",
            type=StackComponentType.STEP_OPERATOR,
            integration=cls.NAME,
        ),
    ]

artifact_stores special

Initialization of the Azure Artifact Store integration.

azure_artifact_store

Implementation of the Azure Artifact Store integration.

AzureArtifactStore (BaseArtifactStore, AuthenticationMixin) pydantic-model

Artifact Store for Microsoft Azure based artifacts.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
    """Artifact Store for Microsoft Azure based artifacts."""

    _filesystem: Optional[adlfs.AzureBlobFileSystem] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = AZURE_ARTIFACT_STORE_FLAVOR
    SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}

    @property
    def filesystem(self) -> adlfs.AzureBlobFileSystem:
        """The adlfs filesystem to access this artifact store.

        Returns:
            The adlfs filesystem to access this artifact store.
        """
        if not self._filesystem:
            secret = self.get_authentication_secret(
                expected_schema_type=AzureSecretSchema
            )
            credentials = secret.content if secret else {}

            self._filesystem = adlfs.AzureBlobFileSystem(
                **credentials,
                anon=False,
                use_listings_cache=False,
            )
        return self._filesystem

    @classmethod
    def _split_path(cls, path: PathType) -> Tuple[str, str]:
        """Splits a path into the filesystem prefix and remainder.

        Example:
        ```python
        prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
        print(prefix, remainder)  # "az://" "my_container/test.txt"
        ```

        Args:
            path: The path to split.

        Returns:
            A tuple of the filesystem prefix and the remainder.
        """
        path = convert_to_str(path)
        prefix = ""
        for potential_prefix in cls.SUPPORTED_SCHEMES:
            if path.startswith(potential_prefix):
                prefix = potential_prefix
                path = path[len(potential_prefix) :]
                break

        return prefix, path

    def open(self, path: PathType, mode: str = "r") -> Any:
        """Open a file at the given path.

        Args:
            path: Path of the file to open.
            mode: Mode in which to open the file. Currently, only
                'rb' and 'wb' to read and write binary files are supported.

        Returns:
            A file-like object.
        """
        return self.filesystem.open(path=path, mode=mode)

    def copyfile(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Copy a file.

        Args:
            src: The path to copy from.
            dst: The path to copy to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to copy to destination '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to copy anyway."
            )

        # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
        #  manually remove it first
        self.filesystem.copy(path1=src, path2=dst)

    def exists(self, path: PathType) -> bool:
        """Check whether a path exists.

        Args:
            path: The path to check.

        Returns:
            True if the path exists, False otherwise.
        """
        return self.filesystem.exists(path=path)  # type: ignore[no-any-return]

    def glob(self, pattern: PathType) -> List[PathType]:
        """Return all paths that match the given glob pattern.

        The glob pattern may include:
        - '*' to match any number of characters
        - '?' to match a single character
        - '[...]' to match one of the characters inside the brackets
        - '**' as the full name of a path component to match to search
            in subdirectories of any depth (e.g. '/some_dir/**/some_file)

        Args:
            pattern: The glob pattern to match, see details above.

        Returns:
            A list of paths that match the given glob pattern.
        """
        prefix, _ = self._split_path(pattern)
        return [
            f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
        ]

    def isdir(self, path: PathType) -> bool:
        """Check whether a path is a directory.

        Args:
            path: The path to check.

        Returns:
            True if the path is a directory, False otherwise.
        """
        return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]

    def listdir(self, path: PathType) -> List[PathType]:
        """Return a list of files in a directory.

        Args:
            path: The path to list.

        Returns:
            A list of files in the given directory.
        """
        _, path = self._split_path(path)

        def _extract_basename(file_dict: Dict[str, Any]) -> str:
            """Extracts the basename from a dictionary returned by the Azure filesystem.

            Args:
                file_dict: A dictionary returned by the Azure filesystem.

            Returns:
                The basename of the file.
            """
            file_path = cast(str, file_dict["name"])
            base_name = file_path[len(path) :]
            return base_name.lstrip("/")

        return [
            _extract_basename(dict_)
            for dict_ in self.filesystem.listdir(path=path)
        ]

    def makedirs(self, path: PathType) -> None:
        """Create a directory at the given path.

        If needed also create missing parent directories.

        Args:
            path: The path to create.
        """
        self.filesystem.makedirs(path=path, exist_ok=True)

    def mkdir(self, path: PathType) -> None:
        """Create a directory at the given path.

        Args:
            path: The path to create.
        """
        self.filesystem.makedir(path=path)

    def remove(self, path: PathType) -> None:
        """Remove the file at the given path.

        Args:
            path: The path to remove.
        """
        self.filesystem.rm_file(path=path)

    def rename(
        self, src: PathType, dst: PathType, overwrite: bool = False
    ) -> None:
        """Rename source file to destination file.

        Args:
            src: The path of the file to rename.
            dst: The path to rename the source file to.
            overwrite: If a file already exists at the destination, this
                method will overwrite it if overwrite=`True` and
                raise a FileExistsError otherwise.

        Raises:
            FileExistsError: If a file already exists at the destination
                and overwrite is not set to `True`.
        """
        if not overwrite and self.filesystem.exists(dst):
            raise FileExistsError(
                f"Unable to rename file to '{convert_to_str(dst)}', "
                f"file already exists. Set `overwrite=True` to rename anyway."
            )

        # TODO [ENG-152]: Check if it works with overwrite=True or if we need
        #  to manually remove it first
        self.filesystem.rename(path1=src, path2=dst)

    def rmtree(self, path: PathType) -> None:
        """Remove the given directory.

        Args:
            path: The path of the directory to remove.
        """
        self.filesystem.delete(path=path, recursive=True)

    def stat(self, path: PathType) -> Dict[str, Any]:
        """Return stat info for the given path.

        Args:
            path: The path to get stat info for.

        Returns:
            Stat info.
        """
        return self.filesystem.stat(path=path)  # type: ignore[no-any-return]

    def walk(
        self,
        top: PathType,
        topdown: bool = True,
        onerror: Optional[Callable[..., None]] = None,
    ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
        """Return an iterator that walks the contents of the given directory.

        Args:
            top: Path of directory to walk.
            topdown: Unused argument to conform to interface.
            onerror: Unused argument to conform to interface.

        Yields:
            An Iterable of Tuples, each of which contain the path of the current
            directory path, a list of directories inside the current directory
            and a list of files inside the current directory.
        """
        # TODO [ENG-153]: Additional params
        prefix, _ = self._split_path(top)
        for (
            directory,
            subdirectories,
            files,
        ) in self.filesystem.walk(path=top):
            yield f"{prefix}{directory}", subdirectories, files
filesystem: AzureBlobFileSystem property readonly

The adlfs filesystem to access this artifact store.

Returns:

Type Description
AzureBlobFileSystem

The adlfs filesystem to access this artifact store.

copyfile(self, src, dst, overwrite=False)

Copy a file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path to copy from.

required
dst Union[bytes, str]

The path to copy to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Copy a file.

    Args:
        src: The path to copy from.
        dst: The path to copy to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to copy to destination '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to copy anyway."
        )

    # TODO [ENG-151]: Check if it works with overwrite=True or if we need to
    #  manually remove it first
    self.filesystem.copy(path1=src, path2=dst)
exists(self, path)

Check whether a path exists.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path exists, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
    """Check whether a path exists.

    Args:
        path: The path to check.

    Returns:
        True if the path exists, False otherwise.
    """
    return self.filesystem.exists(path=path)  # type: ignore[no-any-return]
glob(self, pattern)

Return all paths that match the given glob pattern.

The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)

Parameters:

Name Type Description Default
pattern Union[bytes, str]

The glob pattern to match, see details above.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of paths that match the given glob pattern.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
    """Return all paths that match the given glob pattern.

    The glob pattern may include:
    - '*' to match any number of characters
    - '?' to match a single character
    - '[...]' to match one of the characters inside the brackets
    - '**' as the full name of a path component to match to search
        in subdirectories of any depth (e.g. '/some_dir/**/some_file)

    Args:
        pattern: The glob pattern to match, see details above.

    Returns:
        A list of paths that match the given glob pattern.
    """
    prefix, _ = self._split_path(pattern)
    return [
        f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
    ]
isdir(self, path)

Check whether a path is a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to check.

required

Returns:

Type Description
bool

True if the path is a directory, False otherwise.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
    """Check whether a path is a directory.

    Args:
        path: The path to check.

    Returns:
        True if the path is a directory, False otherwise.
    """
    return self.filesystem.isdir(path=path)  # type: ignore[no-any-return]
listdir(self, path)

Return a list of files in a directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to list.

required

Returns:

Type Description
List[Union[bytes, str]]

A list of files in the given directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
    """Return a list of files in a directory.

    Args:
        path: The path to list.

    Returns:
        A list of files in the given directory.
    """
    _, path = self._split_path(path)

    def _extract_basename(file_dict: Dict[str, Any]) -> str:
        """Extracts the basename from a dictionary returned by the Azure filesystem.

        Args:
            file_dict: A dictionary returned by the Azure filesystem.

        Returns:
            The basename of the file.
        """
        file_path = cast(str, file_dict["name"])
        base_name = file_path[len(path) :]
        return base_name.lstrip("/")

    return [
        _extract_basename(dict_)
        for dict_ in self.filesystem.listdir(path=path)
    ]
makedirs(self, path)

Create a directory at the given path.

If needed also create missing parent directories.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
    """Create a directory at the given path.

    If needed also create missing parent directories.

    Args:
        path: The path to create.
    """
    self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)

Create a directory at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to create.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
    """Create a directory at the given path.

    Args:
        path: The path to create.
    """
    self.filesystem.makedir(path=path)
open(self, path, mode='r')

Open a file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

Path of the file to open.

required
mode str

Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported.

'r'

Returns:

Type Description
Any

A file-like object.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
    """Open a file at the given path.

    Args:
        path: Path of the file to open.
        mode: Mode in which to open the file. Currently, only
            'rb' and 'wb' to read and write binary files are supported.

    Returns:
        A file-like object.
    """
    return self.filesystem.open(path=path, mode=mode)
remove(self, path)

Remove the file at the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
    """Remove the file at the given path.

    Args:
        path: The path to remove.
    """
    self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)

Rename source file to destination file.

Parameters:

Name Type Description Default
src Union[bytes, str]

The path of the file to rename.

required
dst Union[bytes, str]

The path to rename the source file to.

required
overwrite bool

If a file already exists at the destination, this method will overwrite it if overwrite=True and raise a FileExistsError otherwise.

False

Exceptions:

Type Description
FileExistsError

If a file already exists at the destination and overwrite is not set to True.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
    self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
    """Rename source file to destination file.

    Args:
        src: The path of the file to rename.
        dst: The path to rename the source file to.
        overwrite: If a file already exists at the destination, this
            method will overwrite it if overwrite=`True` and
            raise a FileExistsError otherwise.

    Raises:
        FileExistsError: If a file already exists at the destination
            and overwrite is not set to `True`.
    """
    if not overwrite and self.filesystem.exists(dst):
        raise FileExistsError(
            f"Unable to rename file to '{convert_to_str(dst)}', "
            f"file already exists. Set `overwrite=True` to rename anyway."
        )

    # TODO [ENG-152]: Check if it works with overwrite=True or if we need
    #  to manually remove it first
    self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)

Remove the given directory.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path of the directory to remove.

required
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
    """Remove the given directory.

    Args:
        path: The path of the directory to remove.
    """
    self.filesystem.delete(path=path, recursive=True)
stat(self, path)

Return stat info for the given path.

Parameters:

Name Type Description Default
path Union[bytes, str]

The path to get stat info for.

required

Returns:

Type Description
Dict[str, Any]

Stat info.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
    """Return stat info for the given path.

    Args:
        path: The path to get stat info for.

    Returns:
        Stat info.
    """
    return self.filesystem.stat(path=path)  # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)

Return an iterator that walks the contents of the given directory.

Parameters:

Name Type Description Default
top Union[bytes, str]

Path of directory to walk.

required
topdown bool

Unused argument to conform to interface.

True
onerror Optional[Callable[..., NoneType]]

Unused argument to conform to interface.

None

Yields:

Type Description
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]]

An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory.

Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
    self,
    top: PathType,
    topdown: bool = True,
    onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
    """Return an iterator that walks the contents of the given directory.

    Args:
        top: Path of directory to walk.
        topdown: Unused argument to conform to interface.
        onerror: Unused argument to conform to interface.

    Yields:
        An Iterable of Tuples, each of which contain the path of the current
        directory path, a list of directories inside the current directory
        and a list of files inside the current directory.
    """
    # TODO [ENG-153]: Additional params
    prefix, _ = self._split_path(top)
    for (
        directory,
        subdirectories,
        files,
    ) in self.filesystem.walk(path=top):
        yield f"{prefix}{directory}", subdirectories, files

secrets_managers special

Initialization of the Azure Secrets Manager integration.

azure_secrets_manager

Implementation of the Azure Secrets Manager integration.

AzureSecretsManager (BaseSecretsManager) pydantic-model

Class to interact with the Azure secrets manager.

Attributes:

Name Type Description
key_vault_name str

Name of an Azure Key Vault that this secrets manager will use to store secrets.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
    """Class to interact with the Azure secrets manager.

    Attributes:
        key_vault_name: Name of an Azure Key Vault that this secrets manager
            will use to store secrets.
    """

    key_vault_name: str

    # Class configuration
    FLAVOR: ClassVar[str] = AZURE_SECRETS_MANAGER_FLAVOR
    SUPPORTS_SCOPING: ClassVar[bool] = True
    CLIENT: ClassVar[Any] = None

    @classmethod
    def _ensure_client_connected(cls, vault_name: str) -> None:
        if cls.CLIENT is None:
            KVUri = f"https://{vault_name}.vault.azure.net"

            credential = DefaultAzureCredential()
            cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)

    @classmethod
    def _validate_scope(
        cls,
        scope: SecretsManagerScope,
        namespace: Optional[str],
    ) -> None:
        """Validate the scope and namespace value.

        Args:
            scope: Scope value.
            namespace: Optional namespace value.
        """
        if namespace:
            cls.validate_secret_name_or_namespace(namespace, scope)

    @classmethod
    def validate_secret_name_or_namespace(
        cls,
        name: str,
        scope: SecretsManagerScope,
    ) -> None:
        """Validate a secret name or namespace.

        Azure secret names must contain only alphanumeric characters and the
        character `-`.

        Given that we also save secret names and namespaces as labels, we are
        also limited by the 256 maximum size limitation that Azure imposes on
        label values. An arbitrary length of 100 characters is used here for
        the maximum size for the secret name and namespace.

        Args:
            name: the secret name or namespace
            scope: the current scope

        Raises:
            ValueError: if the secret name or namespace is invalid
        """
        if scope == SecretsManagerScope.NONE:
            # to preserve backwards compatibility, we don't validate the
            # secret name for unscoped secrets.
            return

        if not re.fullmatch(r"[0-9a-zA-Z-]+", name):
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. Must contain "
                f"only alphanumeric characters and the character -."
            )

        if len(name) > 100:
            raise ValueError(
                f"Invalid secret name or namespace '{name}'. The length is "
                f"limited to maximum 100 characters."
            )

    def validate_secret_name(self, name: str) -> None:
        """Validate a secret name.

        Args:
            name: the secret name
        """
        self.validate_secret_name_or_namespace(name, self.scope)

    def _create_or_update_secret(self, secret: BaseSecretSchema) -> None:
        """Creates a new secret or updated an existing one.

        Args:
            secret: the secret to register or update
        """
        if self.scope == SecretsManagerScope.NONE:
            # legacy, non-scoped secrets

            for key, value in secret.content.items():
                encoded_key = base64.b64encode(
                    f"{secret.name}-{key}".encode()
                ).hex()
                azure_secret_name = f"zenml-{encoded_key}"

                self.CLIENT.set_secret(azure_secret_name, value)
                self.CLIENT.update_secret_properties(
                    azure_secret_name,
                    tags={
                        ZENML_GROUP_KEY: secret.name,
                        ZENML_KEY_NAME: key,
                        ZENML_SCHEMA_NAME: secret.TYPE,
                    },
                )

                logger.debug(
                    "Secret `%s` written to the Azure Key Vault.",
                    azure_secret_name,
                )
        else:
            azure_secret_name = self._get_scoped_secret_name(
                secret.name,
                separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
            )
            self.CLIENT.set_secret(
                azure_secret_name,
                json.dumps(secret_to_dict(secret)),
            )
            self.CLIENT.update_secret_properties(
                azure_secret_name,
                tags=self._get_secret_metadata(secret),
            )

    def register_secret(self, secret: BaseSecretSchema) -> None:
        """Registers a new secret.

        Args:
            secret: the secret to register

        Raises:
            SecretExistsError: if the secret already exists
        """
        self.validate_secret_name(secret.name)
        self._ensure_client_connected(self.key_vault_name)

        if secret.name in self.get_all_secret_keys():
            raise SecretExistsError(
                f"A Secret with the name '{secret.name}' already exists."
            )

        self._create_or_update_secret(secret)

    def get_secret(self, secret_name: str) -> BaseSecretSchema:
        """Get a secret by its name.

        Args:
            secret_name: the name of the secret to get

        Returns:
            The secret.

        Raises:
            KeyError: if the secret does not exist
            ValueError: if the secret is named 'name'
        """
        self.validate_secret_name(secret_name)
        self._ensure_client_connected(self.key_vault_name)
        zenml_secret: Optional[BaseSecretSchema] = None

        if self.scope == SecretsManagerScope.NONE:
            # Legacy secrets are mapped to multiple Azure secrets, one for
            # each secret key

            secret_contents = {}
            zenml_schema_name = ""

            for secret_property in self.CLIENT.list_properties_of_secrets():
                tags = secret_property.tags

                if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                    secret_key = tags.get(ZENML_KEY_NAME)
                    if not secret_key:
                        raise ValueError("Missing secret key tag.")

                    if secret_key == "name":
                        raise ValueError("The secret's key cannot be 'name'.")

                    response = self.CLIENT.get_secret(secret_property.name)
                    secret_contents[secret_key] = response.value

                    zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

            if secret_contents:
                secret_contents["name"] = secret_name

                secret_schema = SecretSchemaClassRegistry.get_class(
                    secret_schema=zenml_schema_name
                )
                zenml_secret = secret_schema(**secret_contents)
        else:
            # Scoped secrets are mapped 1-to-1 with Azure secrets

            try:
                response = self.CLIENT.get_secret(
                    self._get_scoped_secret_name(
                        secret_name,
                        separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
                    ),
                )

                scope_tags = self._get_secret_scope_metadata(secret_name)

                # all scope tags need to be included in the Azure secret tags,
                # otherwise the secret does not belong to the current scope,
                # even if it has the same name
                if scope_tags.items() <= response.properties.tags.items():
                    zenml_secret = secret_from_dict(
                        json.loads(response.value), secret_name=secret_name
                    )
            except ResourceNotFoundError:
                pass

        if not zenml_secret:
            raise KeyError(f"Can't find the specified secret '{secret_name}'")

        return zenml_secret

    def get_all_secret_keys(self) -> List[str]:
        """Get all secret keys.

        Returns:
            A list of all secret keys
        """
        self._ensure_client_connected(self.key_vault_name)

        set_of_secrets = set()

        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags
            if not tags:
                continue

            if self.scope == SecretsManagerScope.NONE:
                # legacy, non-scoped secrets
                if ZENML_GROUP_KEY in tags:
                    set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
                continue

            scope_tags = self._get_secret_scope_metadata()
            # all scope tags need to be included in the Azure secret tags,
            # otherwise the secret does not belong to the current scope
            if scope_tags.items() <= tags.items():
                set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))

        return list(set_of_secrets)

    def update_secret(self, secret: BaseSecretSchema) -> None:
        """Update an existing secret by creating new versions of the existing secrets.

        Args:
            secret: the secret to update

        Raises:
            KeyError: if the secret does not exist
        """
        self.validate_secret_name(secret.name)
        self._ensure_client_connected(self.key_vault_name)

        if secret.name not in self.get_all_secret_keys():
            raise KeyError(f"Can't find the specified secret '{secret.name}'")

        self._create_or_update_secret(secret)

    def delete_secret(self, secret_name: str) -> None:
        """Delete an existing secret. by name.

        Args:
            secret_name: the name of the secret to delete

        Raises:
            KeyError: if the secret no longer exists
        """
        self.validate_secret_name(secret_name)
        self._ensure_client_connected(self.key_vault_name)

        if self.scope == SecretsManagerScope.NONE:
            # legacy, non-scoped secrets

            # Go through all Azure secrets and delete the ones with the
            # secret_name as label.
            for secret_property in self.CLIENT.list_properties_of_secrets():
                tags = secret_property.tags
                if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                    self.CLIENT.begin_delete_secret(
                        secret_property.name
                    ).result()

        else:
            if secret_name not in self.get_all_secret_keys():
                raise KeyError(
                    f"Can't find the specified secret '{secret_name}'"
                )
            self.CLIENT.begin_delete_secret(
                self._get_scoped_secret_name(
                    secret_name,
                    separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
                ),
            ).result()

    def delete_all_secrets(self) -> None:
        """Delete all existing secrets."""
        self._ensure_client_connected(self.key_vault_name)

        # List all secrets.
        for secret_property in self.CLIENT.list_properties_of_secrets():

            tags = secret_property.tags
            if not tags:
                continue

            if self.scope == SecretsManagerScope.NONE:
                # legacy, non-scoped secrets
                if ZENML_GROUP_KEY in tags:
                    logger.info(
                        "Deleted key-value pair {`%s`, `***`} from secret "
                        "`%s`",
                        secret_property.name,
                        tags.get(ZENML_GROUP_KEY),
                    )
                    self.CLIENT.begin_delete_secret(
                        secret_property.name
                    ).result()
                continue

            scope_tags = self._get_secret_scope_metadata()
            # all scope tags need to be included in the Azure secret tags,
            # otherwise the secret does not belong to the current scope
            if scope_tags.items() <= tags.items():
                self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_all_secrets(self)

Delete all existing secrets.

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
    """Delete all existing secrets."""
    self._ensure_client_connected(self.key_vault_name)

    # List all secrets.
    for secret_property in self.CLIENT.list_properties_of_secrets():

        tags = secret_property.tags
        if not tags:
            continue

        if self.scope == SecretsManagerScope.NONE:
            # legacy, non-scoped secrets
            if ZENML_GROUP_KEY in tags:
                logger.info(
                    "Deleted key-value pair {`%s`, `***`} from secret "
                    "`%s`",
                    secret_property.name,
                    tags.get(ZENML_GROUP_KEY),
                )
                self.CLIENT.begin_delete_secret(
                    secret_property.name
                ).result()
            continue

        scope_tags = self._get_secret_scope_metadata()
        # all scope tags need to be included in the Azure secret tags,
        # otherwise the secret does not belong to the current scope
        if scope_tags.items() <= tags.items():
            self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)

Delete an existing secret. by name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to delete

required

Exceptions:

Type Description
KeyError

if the secret no longer exists

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
    """Delete an existing secret. by name.

    Args:
        secret_name: the name of the secret to delete

    Raises:
        KeyError: if the secret no longer exists
    """
    self.validate_secret_name(secret_name)
    self._ensure_client_connected(self.key_vault_name)

    if self.scope == SecretsManagerScope.NONE:
        # legacy, non-scoped secrets

        # Go through all Azure secrets and delete the ones with the
        # secret_name as label.
        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags
            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                self.CLIENT.begin_delete_secret(
                    secret_property.name
                ).result()

    else:
        if secret_name not in self.get_all_secret_keys():
            raise KeyError(
                f"Can't find the specified secret '{secret_name}'"
            )
        self.CLIENT.begin_delete_secret(
            self._get_scoped_secret_name(
                secret_name,
                separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
            ),
        ).result()
get_all_secret_keys(self)

Get all secret keys.

Returns:

Type Description
List[str]

A list of all secret keys

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
    """Get all secret keys.

    Returns:
        A list of all secret keys
    """
    self._ensure_client_connected(self.key_vault_name)

    set_of_secrets = set()

    for secret_property in self.CLIENT.list_properties_of_secrets():
        tags = secret_property.tags
        if not tags:
            continue

        if self.scope == SecretsManagerScope.NONE:
            # legacy, non-scoped secrets
            if ZENML_GROUP_KEY in tags:
                set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
            continue

        scope_tags = self._get_secret_scope_metadata()
        # all scope tags need to be included in the Azure secret tags,
        # otherwise the secret does not belong to the current scope
        if scope_tags.items() <= tags.items():
            set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))

    return list(set_of_secrets)
get_secret(self, secret_name)

Get a secret by its name.

Parameters:

Name Type Description Default
secret_name str

the name of the secret to get

required

Returns:

Type Description
BaseSecretSchema

The secret.

Exceptions:

Type Description
KeyError

if the secret does not exist

ValueError

if the secret is named 'name'

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
    """Get a secret by its name.

    Args:
        secret_name: the name of the secret to get

    Returns:
        The secret.

    Raises:
        KeyError: if the secret does not exist
        ValueError: if the secret is named 'name'
    """
    self.validate_secret_name(secret_name)
    self._ensure_client_connected(self.key_vault_name)
    zenml_secret: Optional[BaseSecretSchema] = None

    if self.scope == SecretsManagerScope.NONE:
        # Legacy secrets are mapped to multiple Azure secrets, one for
        # each secret key

        secret_contents = {}
        zenml_schema_name = ""

        for secret_property in self.CLIENT.list_properties_of_secrets():
            tags = secret_property.tags

            if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
                secret_key = tags.get(ZENML_KEY_NAME)
                if not secret_key:
                    raise ValueError("Missing secret key tag.")

                if secret_key == "name":
                    raise ValueError("The secret's key cannot be 'name'.")

                response = self.CLIENT.get_secret(secret_property.name)
                secret_contents[secret_key] = response.value

                zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)

        if secret_contents:
            secret_contents["name"] = secret_name

            secret_schema = SecretSchemaClassRegistry.get_class(
                secret_schema=zenml_schema_name
            )
            zenml_secret = secret_schema(**secret_contents)
    else:
        # Scoped secrets are mapped 1-to-1 with Azure secrets

        try:
            response = self.CLIENT.get_secret(
                self._get_scoped_secret_name(
                    secret_name,
                    separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
                ),
            )

            scope_tags = self._get_secret_scope_metadata(secret_name)

            # all scope tags need to be included in the Azure secret tags,
            # otherwise the secret does not belong to the current scope,
            # even if it has the same name
            if scope_tags.items() <= response.properties.tags.items():
                zenml_secret = secret_from_dict(
                    json.loads(response.value), secret_name=secret_name
                )
        except ResourceNotFoundError:
            pass

    if not zenml_secret:
        raise KeyError(f"Can't find the specified secret '{secret_name}'")

    return zenml_secret
register_secret(self, secret)

Registers a new secret.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to register

required

Exceptions:

Type Description
SecretExistsError

if the secret already exists

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
    """Registers a new secret.

    Args:
        secret: the secret to register

    Raises:
        SecretExistsError: if the secret already exists
    """
    self.validate_secret_name(secret.name)
    self._ensure_client_connected(self.key_vault_name)

    if secret.name in self.get_all_secret_keys():
        raise SecretExistsError(
            f"A Secret with the name '{secret.name}' already exists."
        )

    self._create_or_update_secret(secret)
update_secret(self, secret)

Update an existing secret by creating new versions of the existing secrets.

Parameters:

Name Type Description Default
secret BaseSecretSchema

the secret to update

required

Exceptions:

Type Description
KeyError

if the secret does not exist

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
    """Update an existing secret by creating new versions of the existing secrets.

    Args:
        secret: the secret to update

    Raises:
        KeyError: if the secret does not exist
    """
    self.validate_secret_name(secret.name)
    self._ensure_client_connected(self.key_vault_name)

    if secret.name not in self.get_all_secret_keys():
        raise KeyError(f"Can't find the specified secret '{secret.name}'")

    self._create_or_update_secret(secret)
validate_secret_name(self, name)

Validate a secret name.

Parameters:

Name Type Description Default
name str

the secret name

required
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def validate_secret_name(self, name: str) -> None:
    """Validate a secret name.

    Args:
        name: the secret name
    """
    self.validate_secret_name_or_namespace(name, self.scope)
validate_secret_name_or_namespace(name, scope) classmethod

Validate a secret name or namespace.

Azure secret names must contain only alphanumeric characters and the character -.

Given that we also save secret names and namespaces as labels, we are also limited by the 256 maximum size limitation that Azure imposes on label values. An arbitrary length of 100 characters is used here for the maximum size for the secret name and namespace.

Parameters:

Name Type Description Default
name str

the secret name or namespace

required
scope SecretsManagerScope

the current scope

required

Exceptions:

Type Description
ValueError

if the secret name or namespace is invalid

Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(
    cls,
    name: str,
    scope: SecretsManagerScope,
) -> None:
    """Validate a secret name or namespace.

    Azure secret names must contain only alphanumeric characters and the
    character `-`.

    Given that we also save secret names and namespaces as labels, we are
    also limited by the 256 maximum size limitation that Azure imposes on
    label values. An arbitrary length of 100 characters is used here for
    the maximum size for the secret name and namespace.

    Args:
        name: the secret name or namespace
        scope: the current scope

    Raises:
        ValueError: if the secret name or namespace is invalid
    """
    if scope == SecretsManagerScope.NONE:
        # to preserve backwards compatibility, we don't validate the
        # secret name for unscoped secrets.
        return

    if not re.fullmatch(r"[0-9a-zA-Z-]+", name):
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. Must contain "
            f"only alphanumeric characters and the character -."
        )

    if len(name) > 100:
        raise ValueError(
            f"Invalid secret name or namespace '{name}'. The length is "
            f"limited to maximum 100 characters."
        )

step_operators special

Initialization of AzureML Step Operator integration.

azureml_step_operator

Implementation of the ZenML AzureML Step Operator.

AzureMLStepOperator (BaseStepOperator, PipelineDockerImageBuilder) pydantic-model

Step operator to run a step on AzureML.

This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.

Attributes:

Name Type Description
subscription_id str

The Azure account's subscription ID

resource_group str

The resource group to which the AzureML workspace is deployed.

workspace_name str

The name of the AzureML Workspace.

compute_target_name str

The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already.

environment_name Optional[str]

The name of the environment if there already exists one.

docker_base_image Optional[str]

The custom docker base image that the environment should use.

tenant_id Optional[str]

The Azure Tenant ID.

service_principal_id Optional[str]

The ID for the service principal that is created to allow apps to access secure resources.

service_principal_password Optional[str]

Password for the service principal.

Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator, PipelineDockerImageBuilder):
    """Step operator to run a step on AzureML.

    This class defines code that can set up an AzureML environment and run the
    ZenML entrypoint command in it.

    Attributes:
        subscription_id: The Azure account's subscription ID
        resource_group: The resource group to which the AzureML workspace
            is deployed.
        workspace_name: The name of the AzureML Workspace.
        compute_target_name: The name of the configured ComputeTarget.
            An instance of it has to be created on the portal if it doesn't
            exist already.
        environment_name: The name of the environment if there
            already exists one.
        docker_base_image: The custom docker base image that the
            environment should use.
        tenant_id: The Azure Tenant ID.
        service_principal_id: The ID for the service principal that is created
            to allow apps to access secure resources.
        service_principal_password: Password for the service principal.
    """

    subscription_id: str
    resource_group: str
    workspace_name: str
    compute_target_name: str

    # Environment
    environment_name: Optional[str] = None
    docker_base_image: Optional[str] = None

    # Service principal authentication
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
    tenant_id: Optional[str] = SecretField()
    service_principal_id: Optional[str] = SecretField()
    service_principal_password: Optional[str] = SecretField()

    # Class Configuration
    FLAVOR: ClassVar[str] = AZUREML_STEP_OPERATOR_FLAVOR

    _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
        ("docker_base_image", "docker_parent_image")
    )

    def _get_authentication(self) -> Optional[AbstractAuthentication]:
        """Returns the authentication object for the AzureML environment.

        Returns:
            The authentication object for the AzureML environment.
        """
        if (
            self.tenant_id
            and self.service_principal_id
            and self.service_principal_password
        ):
            return ServicePrincipalAuthentication(
                tenant_id=self.tenant_id,
                service_principal_id=self.service_principal_id,
                service_principal_password=self.service_principal_password,
            )
        return None

    def _prepare_environment(
        self,
        workspace: Workspace,
        docker_configuration: "DockerConfiguration",
        run_name: str,
    ) -> Environment:
        """Prepares the environment in which Azure will run all jobs.

        Args:
            workspace: The AzureML Workspace that has configuration
                for a storage account, container registry among other
                things.
            docker_configuration: The Docker configuration for this step.
            run_name: The name of the pipeline run that can be used
                for naming environments and runs.

        Returns:
            The AzureML Environment object.
        """
        requirements_files = self._gather_requirements_files(
            docker_configuration=docker_configuration,
            stack=Repository().active_stack,
        )
        requirements = list(
            itertools.chain.from_iterable(
                r[1].split("\n") for r in requirements_files
            )
        )
        requirements.append(f"zenml=={zenml.__version__}")
        logger.info(
            "Using requirements for AzureML step operator environment: %s",
            requirements,
        )
        if self.environment_name:
            environment = Environment.get(
                workspace=workspace, name=self.environment_name
            )
            if not environment.python.conda_dependencies:
                environment.python.conda_dependencies = (
                    CondaDependencies.create(
                        python_version=ZenMLEnvironment.python_version()
                    )
                )

            for requirement in requirements:
                environment.python.conda_dependencies.add_pip_package(
                    requirement
                )
        else:
            environment = Environment(name=f"zenml-{run_name}")
            environment.python.conda_dependencies = CondaDependencies.create(
                pip_packages=requirements,
                python_version=ZenMLEnvironment.python_version(),
            )

            parent_image = (
                docker_configuration.parent_image or self.docker_parent_image
            )

            if parent_image:
                # replace the default azure base image
                environment.docker.base_image = parent_image

        environment_variables = {
            "ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
        }
        # set credentials to access azure storage
        for key in [
            "AZURE_STORAGE_ACCOUNT_KEY",
            "AZURE_STORAGE_ACCOUNT_NAME",
            "AZURE_STORAGE_CONNECTION_STRING",
            "AZURE_STORAGE_SAS_TOKEN",
        ]:
            value = os.getenv(key)
            if value:
                environment_variables[key] = value

        environment_variables[
            ENV_ZENML_CONFIG_PATH
        ] = f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
        environment_variables.update(docker_configuration.environment)

        environment.environment_variables = environment_variables
        return environment

    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        docker_configuration: "DockerConfiguration",
        entrypoint_command: List[str],
        resource_configuration: "ResourceConfiguration",
    ) -> None:
        """Launches a step on AzureML.

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            docker_configuration: The Docker configuration for this step.
            entrypoint_command: Command that executes the step.
            resource_configuration: The resource configuration for this step.
        """
        if not resource_configuration.empty:
            logger.warning(
                "Specifying custom step resources is not supported for "
                "the AzureML step operator. If you want to run this step "
                "operator on specific resources, you can do so by creating an "
                "Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
                "with a specific machine type and then updating this step "
                "operator: `zenml step-operator update %s "
                "--compute_target_name=<COMPUTE_TARGET_NAME>`",
                self.name,
            )

        unused_docker_fields = [
            "dockerfile",
            "build_context_root",
            "build_options",
            "docker_target_repository",
            "dockerignore",
            "copy_files",
            "copy_profile",
        ]
        ignored_docker_fields = (
            docker_configuration.__fields_set__.intersection(
                unused_docker_fields
            )
        )

        if ignored_docker_fields:
            logger.warning(
                "The AzureML step operator currently does not support all "
                "options defined in your Docker configuration. Ignoring all "
                "values set for the attributes: %s",
                ignored_docker_fields,
            )

        workspace = Workspace.get(
            subscription_id=self.subscription_id,
            resource_group=self.resource_group,
            name=self.workspace_name,
            auth=self._get_authentication(),
        )

        source_directory = get_source_root_path()
        with _include_active_profile(
            build_context_root=source_directory,
            load_config_path=PurePosixPath(
                f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
            ),
        ):
            environment = self._prepare_environment(
                workspace=workspace,
                docker_configuration=docker_configuration,
                run_name=run_name,
            )
            compute_target = ComputeTarget(
                workspace=workspace, name=self.compute_target_name
            )

            run_config = ScriptRunConfig(
                source_directory=source_directory,
                environment=environment,
                compute_target=compute_target,
                command=entrypoint_command,
            )

            experiment = Experiment(workspace=workspace, name=pipeline_name)
            run = experiment.submit(config=run_config)

        run.display_name = run_name
        run.wait_for_completion(show_output=True)
launch(self, pipeline_name, run_name, docker_configuration, entrypoint_command, resource_configuration)

Launches a step on AzureML.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
docker_configuration DockerConfiguration

The Docker configuration for this step.

required
entrypoint_command List[str]

Command that executes the step.

required
resource_configuration ResourceConfiguration

The resource configuration for this step.

required
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    docker_configuration: "DockerConfiguration",
    entrypoint_command: List[str],
    resource_configuration: "ResourceConfiguration",
) -> None:
    """Launches a step on AzureML.

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        docker_configuration: The Docker configuration for this step.
        entrypoint_command: Command that executes the step.
        resource_configuration: The resource configuration for this step.
    """
    if not resource_configuration.empty:
        logger.warning(
            "Specifying custom step resources is not supported for "
            "the AzureML step operator. If you want to run this step "
            "operator on specific resources, you can do so by creating an "
            "Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
            "with a specific machine type and then updating this step "
            "operator: `zenml step-operator update %s "
            "--compute_target_name=<COMPUTE_TARGET_NAME>`",
            self.name,
        )

    unused_docker_fields = [
        "dockerfile",
        "build_context_root",
        "build_options",
        "docker_target_repository",
        "dockerignore",
        "copy_files",
        "copy_profile",
    ]
    ignored_docker_fields = (
        docker_configuration.__fields_set__.intersection(
            unused_docker_fields
        )
    )

    if ignored_docker_fields:
        logger.warning(
            "The AzureML step operator currently does not support all "
            "options defined in your Docker configuration. Ignoring all "
            "values set for the attributes: %s",
            ignored_docker_fields,
        )

    workspace = Workspace.get(
        subscription_id=self.subscription_id,
        resource_group=self.resource_group,
        name=self.workspace_name,
        auth=self._get_authentication(),
    )

    source_directory = get_source_root_path()
    with _include_active_profile(
        build_context_root=source_directory,
        load_config_path=PurePosixPath(
            f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
        ),
    ):
        environment = self._prepare_environment(
            workspace=workspace,
            docker_configuration=docker_configuration,
            run_name=run_name,
        )
        compute_target = ComputeTarget(
            workspace=workspace, name=self.compute_target_name
        )

        run_config = ScriptRunConfig(
            source_directory=source_directory,
            environment=environment,
            compute_target=compute_target,
            command=entrypoint_command,
        )

        experiment = Experiment(workspace=workspace, name=pipeline_name)
        run = experiment.submit(config=run_config)

    run.display_name = run_name
    run.wait_for_completion(show_output=True)

constants

Constants for ZenML integrations.

dash special

Initialization of the Dash integration.

DashIntegration (Integration)

Definition of Dash integration for ZenML.

Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
    """Definition of Dash integration for ZenML."""

    NAME = DASH
    REQUIREMENTS = [
        "dash>=2.0.0",
        "dash-cytoscape>=0.3.0",
        "dash-bootstrap-components>=1.0.1",
        "jupyter-dash>=0.4.2",
    ]

visualizers special

Initialization of the Pipeline Run Visualizer.

pipeline_run_lineage_visualizer

Implementation of the pipeline run lineage visualizer.

PipelineRunLineageVisualizer (BaseVisualizer)

Implementation of a lineage diagram via the dash and dash-cytoscape libraries.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BaseVisualizer):
    """Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""

    ARTIFACT_PREFIX = "artifact_"
    STEP_PREFIX = "step_"
    STATUS_CLASS_MAPPING = {
        ExecutionStatus.CACHED: "green",
        ExecutionStatus.FAILED: "red",
        ExecutionStatus.RUNNING: "yellow",
        ExecutionStatus.COMPLETED: "blue",
    }

    def visualize(
        self,
        object: PipelineRunView,
        magic: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> dash.Dash:
        """Method to visualize pipeline runs via the Dash library.

        The layout puts every layer of the dag in a column.

        Args:
            object: The pipeline run to visualize.
            magic: If True, the visualization is rendered in a magic mode.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The Dash application.
        """
        external_stylesheets = [
            dbc.themes.BOOTSTRAP,
            dbc.icons.BOOTSTRAP,
        ]
        if magic:
            if Environment.in_notebook:
                # Only import jupyter_dash in this case
                from jupyter_dash import JupyterDash  # noqa

                JupyterDash.infer_jupyter_proxy_config()

                app = JupyterDash(
                    __name__,
                    external_stylesheets=external_stylesheets,
                )
                mode = "inline"
            else:
                cli_utils.warning(
                    "Cannot set magic flag in non-notebook environments."
                )
        else:
            app = dash.Dash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = None
        nodes, edges, first_step_id = [], [], None
        first_step_id = None
        for step in object.steps:
            step_output_artifacts = list(step.outputs.values())
            execution_id = (
                step_output_artifacts[0].producer_step.id
                if step_output_artifacts
                else step.id
            )
            step_id = self.STEP_PREFIX + str(step.id)
            if first_step_id is None:
                first_step_id = step_id
            nodes.append(
                {
                    "data": {
                        "id": step_id,
                        "execution_id": execution_id,
                        "label": f"{execution_id} / {step.entrypoint_name}",
                        "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                        "name": step.name,  # redundant for consistency
                        "type": "step",
                        "parameters": step.parameters,
                        "inputs": {k: v.uri for k, v in step.inputs.items()},
                        "outputs": {k: v.uri for k, v in step.outputs.items()},
                    },
                    "classes": self.STATUS_CLASS_MAPPING[step.status],
                }
            )

            for artifact_name, artifact in step.outputs.items():
                nodes.append(
                    {
                        "data": {
                            "id": self.ARTIFACT_PREFIX + str(artifact.id),
                            "execution_id": artifact.id,
                            "label": f"{artifact.id} / {artifact_name} ("
                            f"{artifact.data_type})",
                            "type": "artifact",
                            "name": artifact_name,
                            "is_cached": artifact.is_cached,
                            "artifact_type": artifact.type,
                            "artifact_data_type": artifact.data_type,
                            "parent_step_id": artifact.parent_step_id,
                            "producer_step_id": artifact.producer_step.id,
                            "uri": artifact.uri,
                        },
                        "classes": f"rectangle "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}",
                    }
                )
                edges.append(
                    {
                        "data": {
                            "source": self.STEP_PREFIX + str(step.id),
                            "target": self.ARTIFACT_PREFIX + str(artifact.id),
                        },
                        "classes": f"edge-arrow "
                        f"{self.STATUS_CLASS_MAPPING[step.status]}"
                        + (" dashed" if artifact.is_cached else " solid"),
                    }
                )

            for artifact_name, artifact in step.inputs.items():
                edges.append(
                    {
                        "data": {
                            "source": self.ARTIFACT_PREFIX + str(artifact.id),
                            "target": self.STEP_PREFIX + str(step.id),
                        },
                        "classes": "edge-arrow "
                        + (
                            f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                            if artifact.is_cached
                            else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                        ),
                    }
                )

        app.layout = dbc.Row(
            [
                dbc.Container(f"Run: {object.name}", class_name="h1"),
                dbc.Row(
                    [
                        dbc.Col(
                            [
                                dbc.Row(
                                    [
                                        html.Span(
                                            [
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-circle-fill me-1"
                                                        ),
                                                        "Step",
                                                    ],
                                                    className="me-2",
                                                ),
                                                html.Span(
                                                    [
                                                        html.I(
                                                            className="bi bi-square-fill me-1"
                                                        ),
                                                        "Artifact",
                                                    ],
                                                    className="me-4",
                                                ),
                                                dbc.Badge(
                                                    "Completed",
                                                    color=COLOR_BLUE,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Cached",
                                                    color=COLOR_GREEN,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Running",
                                                    color=COLOR_YELLOW,
                                                    className="me-1",
                                                ),
                                                dbc.Badge(
                                                    "Failed",
                                                    color=COLOR_RED,
                                                    className="me-1",
                                                ),
                                            ]
                                        ),
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        cyto.Cytoscape(
                                            id="cytoscape",
                                            layout={
                                                "name": "breadthfirst",
                                                "roots": f'[id = "{first_step_id}"]',
                                            },
                                            elements=edges + nodes,
                                            stylesheet=STYLESHEET,
                                            style={
                                                "width": "100%",
                                                "height": "800px",
                                            },
                                            zoom=1,
                                        )
                                    ]
                                ),
                                dbc.Row(
                                    [
                                        dbc.Button(
                                            "Reset",
                                            id="bt-reset",
                                            color="primary",
                                            className="me-1",
                                        )
                                    ]
                                ),
                            ]
                        ),
                        dbc.Col(
                            [
                                dcc.Markdown(id="markdown-selected-node-data"),
                            ]
                        ),
                    ]
                ),
            ],
            className="p-5",
        )

        @app.callback(  # type: ignore[misc]
            Output("markdown-selected-node-data", "children"),
            Input("cytoscape", "selectedNodeData"),
        )
        def display_data(data_list: List[Dict[str, Any]]) -> str:
            """Callback for the text area below the graph.

            Args:
                data_list: The selected node data.

            Returns:
                str: The selected node data.
            """
            if data_list is None:
                return "Click on a node in the diagram."

            text = ""
            for data in data_list:
                text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
                if data["type"] == "artifact":
                    for item in [
                        "artifact_data_type",
                        "is_cached",
                        "producer_step_id",
                        "parent_step_id",
                        "uri",
                    ]:
                        text += f"**{item}**: {data[item]}" + "\n\n"
                elif data["type"] == "step":
                    text += "### Inputs:" + "\n\n"
                    for k, v in data["inputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Outputs:" + "\n\n"
                    for k, v in data["outputs"].items():
                        text += f"**{k}**: {v}" + "\n\n"
                    text += "### Params:"
                    for k, v in data["parameters"].items():
                        text += f"**{k}**: {v}" + "\n\n"
            return text

        @app.callback(  # type: ignore[misc]
            [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
            [Input("bt-reset", "n_clicks")],
        )
        def reset_layout(
            n_clicks: int,
        ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
            """Resets the layout.

            Args:
                n_clicks: The number of clicks on the reset button.

            Returns:
                The zoom and the elements.
            """
            logger.debug(n_clicks, "clicked in reset button.")
            return [1, edges + nodes]

        if mode is not None:
            app.run_server(mode=mode)
        app.run_server()
        return app
visualize(self, object, magic=False, *args, **kwargs)

Method to visualize pipeline runs via the Dash library.

The layout puts every layer of the dag in a column.

Parameters:

Name Type Description Default
object PipelineRunView

The pipeline run to visualize.

required
magic bool

If True, the visualization is rendered in a magic mode.

False
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Dash

The Dash application.

Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
    self,
    object: PipelineRunView,
    magic: bool = False,
    *args: Any,
    **kwargs: Any,
) -> dash.Dash:
    """Method to visualize pipeline runs via the Dash library.

    The layout puts every layer of the dag in a column.

    Args:
        object: The pipeline run to visualize.
        magic: If True, the visualization is rendered in a magic mode.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The Dash application.
    """
    external_stylesheets = [
        dbc.themes.BOOTSTRAP,
        dbc.icons.BOOTSTRAP,
    ]
    if magic:
        if Environment.in_notebook:
            # Only import jupyter_dash in this case
            from jupyter_dash import JupyterDash  # noqa

            JupyterDash.infer_jupyter_proxy_config()

            app = JupyterDash(
                __name__,
                external_stylesheets=external_stylesheets,
            )
            mode = "inline"
        else:
            cli_utils.warning(
                "Cannot set magic flag in non-notebook environments."
            )
    else:
        app = dash.Dash(
            __name__,
            external_stylesheets=external_stylesheets,
        )
        mode = None
    nodes, edges, first_step_id = [], [], None
    first_step_id = None
    for step in object.steps:
        step_output_artifacts = list(step.outputs.values())
        execution_id = (
            step_output_artifacts[0].producer_step.id
            if step_output_artifacts
            else step.id
        )
        step_id = self.STEP_PREFIX + str(step.id)
        if first_step_id is None:
            first_step_id = step_id
        nodes.append(
            {
                "data": {
                    "id": step_id,
                    "execution_id": execution_id,
                    "label": f"{execution_id} / {step.entrypoint_name}",
                    "entrypoint_name": step.entrypoint_name,  # redundant for consistency
                    "name": step.name,  # redundant for consistency
                    "type": "step",
                    "parameters": step.parameters,
                    "inputs": {k: v.uri for k, v in step.inputs.items()},
                    "outputs": {k: v.uri for k, v in step.outputs.items()},
                },
                "classes": self.STATUS_CLASS_MAPPING[step.status],
            }
        )

        for artifact_name, artifact in step.outputs.items():
            nodes.append(
                {
                    "data": {
                        "id": self.ARTIFACT_PREFIX + str(artifact.id),
                        "execution_id": artifact.id,
                        "label": f"{artifact.id} / {artifact_name} ("
                        f"{artifact.data_type})",
                        "type": "artifact",
                        "name": artifact_name,
                        "is_cached": artifact.is_cached,
                        "artifact_type": artifact.type,
                        "artifact_data_type": artifact.data_type,
                        "parent_step_id": artifact.parent_step_id,
                        "producer_step_id": artifact.producer_step.id,
                        "uri": artifact.uri,
                    },
                    "classes": f"rectangle "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}",
                }
            )
            edges.append(
                {
                    "data": {
                        "source": self.STEP_PREFIX + str(step.id),
                        "target": self.ARTIFACT_PREFIX + str(artifact.id),
                    },
                    "classes": f"edge-arrow "
                    f"{self.STATUS_CLASS_MAPPING[step.status]}"
                    + (" dashed" if artifact.is_cached else " solid"),
                }
            )

        for artifact_name, artifact in step.inputs.items():
            edges.append(
                {
                    "data": {
                        "source": self.ARTIFACT_PREFIX + str(artifact.id),
                        "target": self.STEP_PREFIX + str(step.id),
                    },
                    "classes": "edge-arrow "
                    + (
                        f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
                        if artifact.is_cached
                        else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
                    ),
                }
            )

    app.layout = dbc.Row(
        [
            dbc.Container(f"Run: {object.name}", class_name="h1"),
            dbc.Row(
                [
                    dbc.Col(
                        [
                            dbc.Row(
                                [
                                    html.Span(
                                        [
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-circle-fill me-1"
                                                    ),
                                                    "Step",
                                                ],
                                                className="me-2",
                                            ),
                                            html.Span(
                                                [
                                                    html.I(
                                                        className="bi bi-square-fill me-1"
                                                    ),
                                                    "Artifact",
                                                ],
                                                className="me-4",
                                            ),
                                            dbc.Badge(
                                                "Completed",
                                                color=COLOR_BLUE,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Cached",
                                                color=COLOR_GREEN,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Running",
                                                color=COLOR_YELLOW,
                                                className="me-1",
                                            ),
                                            dbc.Badge(
                                                "Failed",
                                                color=COLOR_RED,
                                                className="me-1",
                                            ),
                                        ]
                                    ),
                                ]
                            ),
                            dbc.Row(
                                [
                                    cyto.Cytoscape(
                                        id="cytoscape",
                                        layout={
                                            "name": "breadthfirst",
                                            "roots": f'[id = "{first_step_id}"]',
                                        },
                                        elements=edges + nodes,
                                        stylesheet=STYLESHEET,
                                        style={
                                            "width": "100%",
                                            "height": "800px",
                                        },
                                        zoom=1,
                                    )
                                ]
                            ),
                            dbc.Row(
                                [
                                    dbc.Button(
                                        "Reset",
                                        id="bt-reset",
                                        color="primary",
                                        className="me-1",
                                    )
                                ]
                            ),
                        ]
                    ),
                    dbc.Col(
                        [
                            dcc.Markdown(id="markdown-selected-node-data"),
                        ]
                    ),
                ]
            ),
        ],
        className="p-5",
    )

    @app.callback(  # type: ignore[misc]
        Output("markdown-selected-node-data", "children"),
        Input("cytoscape", "selectedNodeData"),
    )
    def display_data(data_list: List[Dict[str, Any]]) -> str:
        """Callback for the text area below the graph.

        Args:
            data_list: The selected node data.

        Returns:
            str: The selected node data.
        """
        if data_list is None:
            return "Click on a node in the diagram."

        text = ""
        for data in data_list:
            text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
            if data["type"] == "artifact":
                for item in [
                    "artifact_data_type",
                    "is_cached",
                    "producer_step_id",
                    "parent_step_id",
                    "uri",
                ]:
                    text += f"**{item}**: {data[item]}" + "\n\n"
            elif data["type"] == "step":
                text += "### Inputs:" + "\n\n"
                for k, v in data["inputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Outputs:" + "\n\n"
                for k, v in data["outputs"].items():
                    text += f"**{k}**: {v}" + "\n\n"
                text += "### Params:"
                for k, v in data["parameters"].items():
                    text += f"**{k}**: {v}" + "\n\n"
        return text

    @app.callback(  # type: ignore[misc]
        [Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
        [Input("bt-reset", "n_clicks")],
    )
    def reset_layout(
        n_clicks: int,
    ) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
        """Resets the layout.

        Args:
            n_clicks: The number of clicks on the reset button.

        Returns:
            The zoom and the elements.
        """
        logger.debug(n_clicks, "clicked in reset button.")
        return [1, edges + nodes]

    if mode is not None:
        app.run_server(mode=mode)
    app.run_server()
    return app

deepchecks special

Deepchecks integration for ZenML.

The Deepchecks integration provides a way to validate your data in your pipelines. It includes a way to detect data anomalies and define checks to ensure quality of data.

The integration includes custom materializers to store Deepchecks SuiteResults and a visualizer to visualize the results in an easy way on a notebook and in your browser.

DeepchecksIntegration (Integration)

Definition of Deepchecks integration for ZenML.

Source code in zenml/integrations/deepchecks/__init__.py
class DeepchecksIntegration(Integration):
    """Definition of [Deepchecks](https://github.com/deepchecks/deepchecks) integration for ZenML."""

    NAME = DEEPCHECKS
    REQUIREMENTS = ["deepchecks[vision]==0.8.0", "torchvision==0.11.2"]

    @staticmethod
    def activate() -> None:
        """Activate the Deepchecks integration."""
        from zenml.integrations.deepchecks import materializers  # noqa
        from zenml.integrations.deepchecks import visualizers  # noqa

    @classmethod
    def flavors(cls) -> List[FlavorWrapper]:
        """Declare the stack component flavors for the Deepchecks integration.

        Returns:
            List of stack component flavors for this integration.
        """
        return [
            FlavorWrapper(
                name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
                source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
                type=StackComponentType.DATA_VALIDATOR,
                integration=cls.NAME,
            ),
        ]
activate() staticmethod

Activate the Deepchecks integration.

Source code in zenml/integrations/deepchecks/__init__.py
@staticmethod
def activate() -> None:
    """Activate the Deepchecks integration."""
    from zenml.integrations.deepchecks import materializers  # noqa
    from zenml.integrations.deepchecks import visualizers  # noqa
flavors() classmethod

Declare the stack component flavors for the Deepchecks integration.

Returns:

Type Description
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper]

List of stack component flavors for this integration.

Source code in zenml/integrations/deepchecks/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
    """Declare the stack component flavors for the Deepchecks integration.

    Returns:
        List of stack component flavors for this integration.
    """
    return [
        FlavorWrapper(
            name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
            source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
            type=StackComponentType.DATA_VALIDATOR,
            integration=cls.NAME,
        ),
    ]

data_validators special

Initialization of the Deepchecks data validator for ZenML.

deepchecks_data_validator

Implementation of the Deepchecks data validator.

DeepchecksDataValidator (BaseDataValidator) pydantic-model

Deepchecks data validator stack component.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
class DeepchecksDataValidator(BaseDataValidator):
    """Deepchecks data validator stack component."""

    # Class Configuration
    FLAVOR: ClassVar[str] = DEEPCHECKS_DATA_VALIDATOR_FLAVOR
    NAME: ClassVar[str] = "Deepchecks"

    @staticmethod
    def _split_checks(
        check_list: Sequence[str],
    ) -> Tuple[Sequence[str], Sequence[str]]:
        """Split a list of check identifiers in two lists, one for tabular and one for computer vision checks.

        Args:
            check_list: A list of check identifiers.

        Returns:
            List of tabular check identifiers and list of computer vision
            check identifiers.
        """
        tabular_checks = list(
            filter(
                lambda check: DeepchecksValidationCheck.is_tabular_check(check),
                check_list,
            )
        )
        vision_checks = list(
            filter(
                lambda check: DeepchecksValidationCheck.is_vision_check(check),
                check_list,
            )
        )
        return tabular_checks, vision_checks

    # flake8: noqa: C901
    @classmethod
    def _create_and_run_check_suite(
        cls,
        check_enum: Type[DeepchecksValidationCheck],
        reference_dataset: Union[pd.DataFrame, DataLoader[Any]],
        comparison_dataset: Optional[
            Union[pd.DataFrame, DataLoader[Any]]
        ] = None,
        model: Optional[Union[ClassifierMixin, Module]] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
    ) -> SuiteResult:
        """Create and run a Deepchecks check suite corresponding to the input parameters.

        This method contains generic logic common to all Deepchecks data
        validator methods that validates the input arguments and uses them to
        generate and run a Deepchecks check suite.

        Args:
            check_enum: ZenML enum type grouping together Deepchecks checks with
                the same characteristics. This is used to generate a default
                list of checks, if a custom list isn't provided via the
                `check_list` argument.
            reference_dataset: Primary (reference) dataset argument used during
                validation.
            comparison_dataset: Optional secondary (comparison) dataset argument
                used during comparison checks.
            model: Optional model argument used during validation.
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the list of Deepchecks checks to be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks tabular.Dataset or vision.VisionData constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.

        Returns:
            Deepchecks SuiteResult object with the Suite run results.

        Raises:
            TypeError: If the datasets, model and check list arguments combine
                data types and/or checks from different categories (tabular and
                computer vision).
        """
        # Detect what type of check to perform (tabular or computer vision) from
        # the dataset/model datatypes and the check list. At the same time,
        # validate the combination of data types used for dataset and model
        # arguments and the check list.
        is_tabular = False
        is_vision = False
        for dataset in [reference_dataset, comparison_dataset]:
            if dataset is None:
                continue
            if isinstance(dataset, pd.DataFrame):
                is_tabular = True
            elif isinstance(dataset, DataLoader):
                is_vision = True
            else:
                raise TypeError(
                    f"Unsupported dataset data type found: {type(dataset)}. "
                    f"Supported data types are {str(pd.DataFrame)} for tabular "
                    f"data and {str(DataLoader)} for computer vision data."
                )

        if model:
            if isinstance(model, ClassifierMixin):
                is_tabular = True
            elif isinstance(model, Module):
                is_vision = True
            else:
                raise TypeError(
                    f"Unsupported model data type found: {type(model)}. "
                    f"Supported data types are {str(ClassifierMixin)} for "
                    f"tabular data and {str(Module)} for computer vision "
                    f"data."
                )

        if is_tabular and is_vision:
            raise TypeError(
                f"Tabular and computer vision data types used for datasets and "
                f"models cannot be mixed. They must all belong to the same "
                f"category. Supported data types for tabular data are "
                f"{str(pd.DataFrame)} for datasets and {str(ClassifierMixin)} "
                f"for models. Supported data types for computer vision data "
                f"are {str(pd.DataFrame)} for datasets and and {str(Module)} "
                f"for models."
            )

        if not check_list:
            # default to executing all the checks listed in the supplied
            # checks enum type if a custom check list is not supplied
            tabular_checks, vision_checks = cls._split_checks(
                check_enum.values()
            )
            if is_tabular:
                check_list = tabular_checks
                vision_checks = []
            else:
                check_list = vision_checks
                tabular_checks = []
        else:
            tabular_checks, vision_checks = cls._split_checks(check_list)

        if tabular_checks and vision_checks:
            raise TypeError(
                f"The check list cannot mix tabular checks "
                f"({tabular_checks}) and computer vision checks ("
                f"{vision_checks})."
            )

        if is_tabular and vision_checks:
            raise TypeError(
                f"Tabular data types used for datasets and models can only "
                f"be used with tabular validation checks. The following "
                f"computer vision checks included in the check list are "
                f"not valid: {vision_checks}."
            )

        if is_vision and tabular_checks:
            raise TypeError(
                f"Computer vision data types used for datasets and models "
                f"can only be used with computer vision validation checks. "
                f"The following tabular checks included in the check list "
                f"are not valid: {tabular_checks}."
            )

        check_classes = map(
            lambda check: (
                check,
                check_enum.get_check_class(check),
            ),
            check_list,
        )

        # use the pipeline name and the step name to generate a unique suite
        # name
        try:
            # get pipeline name and step name
            step_env = cast(
                StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
            )
            suite_name = f"{step_env.pipeline_name}_{step_env.step_name}"
        except KeyError:
            # if not running inside a pipeline step, use random values
            suite_name = f"suite_{random_str(5)}"

        if is_tabular:
            dataset_class = TabularData
            suite_class = TabularSuite
            full_suite = full_tabular_suite()
        else:
            dataset_class = VisionData
            suite_class = VisionSuite
            full_suite = full_vision_suite()

        train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
        test_dataset = None
        if comparison_dataset is not None:
            test_dataset = dataset_class(comparison_dataset, **dataset_kwargs)
        suite = suite_class(name=suite_name)

        # Some Deepchecks checks require a minimum configuration such as
        # conditions to be configured (see https://docs.deepchecks.com/stable/user-guide/general/customizations/examples/plot_configure_check_conditions.html#sphx-glr-user-guide-general-customizations-examples-plot-configure-check-conditions-py)
        # for their execution to have meaning. For checks that don't have
        # custom configuration attributes explicitly specified in the
        # `check_kwargs` input parameter, we use the default check
        # instances extracted from the full suite shipped with Deepchecks.
        default_checks = {
            check.__class__: check for check in full_suite.checks.values()
        }
        for check_name, check_class in check_classes:
            extra_kwargs = check_kwargs.get(check_name, {})
            default_check = default_checks.get(check_class)
            check: BaseCheck
            if extra_kwargs or not default_check:
                check = check_class(**check_kwargs)
            else:
                check = default_check

            # extract the condition kwargs from the check kwargs
            for arg_name, condition_kwargs in extra_kwargs.items():
                if not arg_name.startswith("condition_") or not isinstance(
                    condition_kwargs, dict
                ):
                    continue
                condition_method = getattr(check, f"add_{arg_name}", None)
                if not condition_method or not callable(condition_method):
                    logger.warning(
                        f"Deepchecks check type {check.__class__} has no "
                        f"condition named {arg_name}. Ignoring the check "
                        f"argument."
                    )
                    continue
                condition_method(**condition_kwargs)

            suite.add(check)
        return suite.run(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            model=model,
            **run_kwargs,
        )

    def data_validation(
        self,
        dataset: Union[pd.DataFrame, DataLoader[Any]],
        comparison_dataset: Optional[Any] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
        **kwargs: Any,
    ) -> SuiteResult:
        """Run one or more Deepchecks data validation checks on a dataset.

        Call this method to analyze and identify potential integrity problems
        with a single dataset (e.g. missing values, conflicting labels, mixed
        data types etc.) and dataset comparison checks (e.g. data drift
        checks). Dataset comparison checks require that a second dataset be
        supplied via the `comparison_dataset` argument.

        The `check_list` argument may be used to specify a custom set of
        Deepchecks data integrity checks to perform, identified by
        `DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
        values. If omitted:

        * if the `comparison_dataset` is omitted, a suite with all available
        data integrity checks will be performed on the input data. See
        `DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
        checks that are compatible with this method.

        * if the `comparison_dataset` is supplied, a suite with all
        available data drift checks will be performed on the input
        data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
        builtin checks that are compatible with this method.

        Args:
            dataset: Target dataset to be validated.
            comparison_dataset: Optional second dataset to be used for data
                comparison checks (e.g data drift checks).
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the data validation checks to be performed.
                `DeepchecksDataIntegrityCheck` enum values should be used for
                single data validation checks and `DeepchecksDataDriftCheck`
                enum values for data comparison checks. If not supplied, the
                entire set of checks applicable to the input dataset(s)
                will be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.
            kwargs: Additional keyword arguments (unused).

        Returns:
            A Deepchecks SuiteResult with the results of the validation.
        """
        check_enum: Type[DeepchecksValidationCheck]
        if comparison_dataset is None:
            check_enum = DeepchecksDataIntegrityCheck
        else:
            check_enum = DeepchecksDataDriftCheck

        return self._create_and_run_check_suite(
            check_enum=check_enum,
            reference_dataset=dataset,
            comparison_dataset=comparison_dataset,
            check_list=check_list,
            dataset_kwargs=dataset_kwargs,
            check_kwargs=check_kwargs,
            run_kwargs=run_kwargs,
        )

    def model_validation(
        self,
        dataset: Union[pd.DataFrame, DataLoader[Any]],
        model: Union[ClassifierMixin, Module],
        comparison_dataset: Optional[Any] = None,
        check_list: Optional[Sequence[str]] = None,
        dataset_kwargs: Dict[str, Any] = {},
        check_kwargs: Dict[str, Dict[str, Any]] = {},
        run_kwargs: Dict[str, Any] = {},
        **kwargs: Any,
    ) -> Any:
        """Run one or more Deepchecks model validation checks.

        Call this method to perform model validation checks (e.g. confusion
        matrix validation, performance reports, model error analyses, etc).
        A second dataset is required for model performance comparison tests
        (i.e. tests that identify changes in a model behavior by comparing how
        it performs on two different datasets).

        The `check_list` argument may be used to specify a custom set of
        Deepchecks model validation checks to perform, identified by
        `DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
        values. If omitted:

            * if the `comparison_dataset` is omitted, a suite with all available
            model validation checks will be performed on the input data. See
            `DeepchecksModelValidationCheck` for a list of Deepchecks builtin
            checks that are compatible with this method.

            * if the `comparison_dataset` is supplied, a suite with all
            available model comparison checks will be performed on the input
            data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
            builtin checks that are compatible with this method.

        Args:
            dataset: Target dataset to be validated.
            model: Target model to be validated.
            comparison_dataset: Optional second dataset to be used for model
                comparison checks.
            check_list: Optional list of ZenML Deepchecks check identifiers
                specifying the model validation checks to be performed.
                `DeepchecksModelValidationCheck` enum values should be used for
                model validation checks and `DeepchecksModelDriftCheck` enum
                values for model comparison checks. If not supplied, the
                entire set of checks applicable to the input dataset(s)
                will be performed.
            dataset_kwargs: Additional keyword arguments to be passed to the
                Deepchecks tabular.Dataset or vision.VisionData constructor.
            check_kwargs: Additional keyword arguments to be passed to the
                Deepchecks check object constructors. Arguments are grouped for
                each check and indexed using the full check class name or
                check enum value as dictionary keys.
            run_kwargs: Additional keyword arguments to be passed to the
                Deepchecks Suite `run` method.
            kwargs: Additional keyword arguments (unused).

        Returns:
            A Deepchecks SuiteResult with the results of the validation.
        """
        check_enum: Type[DeepchecksValidationCheck]
        if comparison_dataset is None:
            check_enum = DeepchecksModelValidationCheck
        else:
            check_enum = DeepchecksModelDriftCheck

        return self._create_and_run_check_suite(
            check_enum=check_enum,
            reference_dataset=dataset,
            comparison_dataset=comparison_dataset,
            model=model,
            check_list=check_list,
            dataset_kwargs=dataset_kwargs,
            check_kwargs=check_kwargs,
            run_kwargs=run_kwargs,
        )
data_validation(self, dataset, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)

Run one or more Deepchecks data validation checks on a dataset.

Call this method to analyze and identify potential integrity problems with a single dataset (e.g. missing values, conflicting labels, mixed data types etc.) and dataset comparison checks (e.g. data drift checks). Dataset comparison checks require that a second dataset be supplied via the comparison_dataset argument.

The check_list argument may be used to specify a custom set of Deepchecks data integrity checks to perform, identified by DeepchecksDataIntegrityCheck and DeepchecksDataDriftCheck enum values. If omitted:

  • if the comparison_dataset is omitted, a suite with all available data integrity checks will be performed on the input data. See DeepchecksDataIntegrityCheck for a list of Deepchecks builtin checks that are compatible with this method.

  • if the comparison_dataset is supplied, a suite with all available data drift checks will be performed on the input data. See DeepchecksDataDriftCheck for a list of Deepchecks builtin checks that are compatible with this method.

Parameters:

Name Type Description Default
dataset Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]]

Target dataset to be validated.

required
comparison_dataset Optional[Any]

Optional second dataset to be used for data comparison checks (e.g data drift checks).

None
check_list Optional[Sequence[str]]

Optional list of ZenML Deepchecks check identifiers specifying the data validation checks to be performed. DeepchecksDataIntegrityCheck enum values should be used for single data validation checks and DeepchecksDataDriftCheck enum values for data comparison checks. If not supplied, the entire set of checks applicable to the input dataset(s) will be performed.

None
dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

{}
check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

{}
run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

{}
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
SuiteResult

A Deepchecks SuiteResult with the results of the validation.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def data_validation(
    self,
    dataset: Union[pd.DataFrame, DataLoader[Any]],
    comparison_dataset: Optional[Any] = None,
    check_list: Optional[Sequence[str]] = None,
    dataset_kwargs: Dict[str, Any] = {},
    check_kwargs: Dict[str, Dict[str, Any]] = {},
    run_kwargs: Dict[str, Any] = {},
    **kwargs: Any,
) -> SuiteResult:
    """Run one or more Deepchecks data validation checks on a dataset.

    Call this method to analyze and identify potential integrity problems
    with a single dataset (e.g. missing values, conflicting labels, mixed
    data types etc.) and dataset comparison checks (e.g. data drift
    checks). Dataset comparison checks require that a second dataset be
    supplied via the `comparison_dataset` argument.

    The `check_list` argument may be used to specify a custom set of
    Deepchecks data integrity checks to perform, identified by
    `DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
    values. If omitted:

    * if the `comparison_dataset` is omitted, a suite with all available
    data integrity checks will be performed on the input data. See
    `DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
    checks that are compatible with this method.

    * if the `comparison_dataset` is supplied, a suite with all
    available data drift checks will be performed on the input
    data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
    builtin checks that are compatible with this method.

    Args:
        dataset: Target dataset to be validated.
        comparison_dataset: Optional second dataset to be used for data
            comparison checks (e.g data drift checks).
        check_list: Optional list of ZenML Deepchecks check identifiers
            specifying the data validation checks to be performed.
            `DeepchecksDataIntegrityCheck` enum values should be used for
            single data validation checks and `DeepchecksDataDriftCheck`
            enum values for data comparison checks. If not supplied, the
            entire set of checks applicable to the input dataset(s)
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
        kwargs: Additional keyword arguments (unused).

    Returns:
        A Deepchecks SuiteResult with the results of the validation.
    """
    check_enum: Type[DeepchecksValidationCheck]
    if comparison_dataset is None:
        check_enum = DeepchecksDataIntegrityCheck
    else:
        check_enum = DeepchecksDataDriftCheck

    return self._create_and_run_check_suite(
        check_enum=check_enum,
        reference_dataset=dataset,
        comparison_dataset=comparison_dataset,
        check_list=check_list,
        dataset_kwargs=dataset_kwargs,
        check_kwargs=check_kwargs,
        run_kwargs=run_kwargs,
    )
model_validation(self, dataset, model, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)

Run one or more Deepchecks model validation checks.

Call this method to perform model validation checks (e.g. confusion matrix validation, performance reports, model error analyses, etc). A second dataset is required for model performance comparison tests (i.e. tests that identify changes in a model behavior by comparing how it performs on two different datasets).

The check_list argument may be used to specify a custom set of Deepchecks model validation checks to perform, identified by DeepchecksModelValidationCheck and DeepchecksModelDriftCheck enum values. If omitted:

* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.

* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.

Parameters:

Name Type Description Default
dataset Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]]

Target dataset to be validated.

required
model Union[sklearn.base.ClassifierMixin, torch.nn.modules.module.Module]

Target model to be validated.

required
comparison_dataset Optional[Any]

Optional second dataset to be used for model comparison checks.

None
check_list Optional[Sequence[str]]

Optional list of ZenML Deepchecks check identifiers specifying the model validation checks to be performed. DeepchecksModelValidationCheck enum values should be used for model validation checks and DeepchecksModelDriftCheck enum values for model comparison checks. If not supplied, the entire set of checks applicable to the input dataset(s) will be performed.

None
dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

{}
check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

{}
run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

{}
kwargs Any

Additional keyword arguments (unused).

{}

Returns:

Type Description
Any

A Deepchecks SuiteResult with the results of the validation.

Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def model_validation(
    self,
    dataset: Union[pd.DataFrame, DataLoader[Any]],
    model: Union[ClassifierMixin, Module],
    comparison_dataset: Optional[Any] = None,
    check_list: Optional[Sequence[str]] = None,
    dataset_kwargs: Dict[str, Any] = {},
    check_kwargs: Dict[str, Dict[str, Any]] = {},
    run_kwargs: Dict[str, Any] = {},
    **kwargs: Any,
) -> Any:
    """Run one or more Deepchecks model validation checks.

    Call this method to perform model validation checks (e.g. confusion
    matrix validation, performance reports, model error analyses, etc).
    A second dataset is required for model performance comparison tests
    (i.e. tests that identify changes in a model behavior by comparing how
    it performs on two different datasets).

    The `check_list` argument may be used to specify a custom set of
    Deepchecks model validation checks to perform, identified by
    `DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
    values. If omitted:

        * if the `comparison_dataset` is omitted, a suite with all available
        model validation checks will be performed on the input data. See
        `DeepchecksModelValidationCheck` for a list of Deepchecks builtin
        checks that are compatible with this method.

        * if the `comparison_dataset` is supplied, a suite with all
        available model comparison checks will be performed on the input
        data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
        builtin checks that are compatible with this method.

    Args:
        dataset: Target dataset to be validated.
        model: Target model to be validated.
        comparison_dataset: Optional second dataset to be used for model
            comparison checks.
        check_list: Optional list of ZenML Deepchecks check identifiers
            specifying the model validation checks to be performed.
            `DeepchecksModelValidationCheck` enum values should be used for
            model validation checks and `DeepchecksModelDriftCheck` enum
            values for model comparison checks. If not supplied, the
            entire set of checks applicable to the input dataset(s)
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks tabular.Dataset or vision.VisionData constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
        kwargs: Additional keyword arguments (unused).

    Returns:
        A Deepchecks SuiteResult with the results of the validation.
    """
    check_enum: Type[DeepchecksValidationCheck]
    if comparison_dataset is None:
        check_enum = DeepchecksModelValidationCheck
    else:
        check_enum = DeepchecksModelDriftCheck

    return self._create_and_run_check_suite(
        check_enum=check_enum,
        reference_dataset=dataset,
        comparison_dataset=comparison_dataset,
        model=model,
        check_list=check_list,
        dataset_kwargs=dataset_kwargs,
        check_kwargs=check_kwargs,
        run_kwargs=run_kwargs,
    )

materializers special

Deepchecks materializers.

deepchecks_dataset_materializer

Implementation of Deepchecks dataset materializer.

DeepchecksDatasetMaterializer (BaseMaterializer)

Materializer to read data to and from Deepchecks dataset.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
class DeepchecksDatasetMaterializer(BaseMaterializer):
    """Materializer to read data to and from Deepchecks dataset."""

    ASSOCIATED_TYPES = (Dataset,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Dataset:
        """Reads pandas dataframes and creates deepchecks.Dataset from it.

        Args:
            data_type: The type of the data to read.

        Returns:
            A Deepchecks Dataset.
        """
        super().handle_input(data_type)

        # Outsource to pandas
        pandas_materializer = PandasMaterializer(self.artifact)
        df = pandas_materializer.handle_input(data_type)

        # Recreate from pandas dataframe
        return Dataset(df)

    def handle_return(self, df: Dataset) -> None:
        """Serializes pandas dataframe within a Dataset object.

        Args:
            df: A deepchecks.Dataset object.
        """
        super().handle_return(df)

        # Outsource to pandas
        pandas_materializer = PandasMaterializer(self.artifact)
        pandas_materializer.handle_return(df.data)
handle_input(self, data_type)

Reads pandas dataframes and creates deepchecks.Dataset from it.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Dataset

A Deepchecks Dataset.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
    """Reads pandas dataframes and creates deepchecks.Dataset from it.

    Args:
        data_type: The type of the data to read.

    Returns:
        A Deepchecks Dataset.
    """
    super().handle_input(data_type)

    # Outsource to pandas
    pandas_materializer = PandasMaterializer(self.artifact)
    df = pandas_materializer.handle_input(data_type)

    # Recreate from pandas dataframe
    return Dataset(df)
handle_return(self, df)

Serializes pandas dataframe within a Dataset object.

Parameters:

Name Type Description Default
df Dataset

A deepchecks.Dataset object.

required
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_return(self, df: Dataset) -> None:
    """Serializes pandas dataframe within a Dataset object.

    Args:
        df: A deepchecks.Dataset object.
    """
    super().handle_return(df)

    # Outsource to pandas
    pandas_materializer = PandasMaterializer(self.artifact)
    pandas_materializer.handle_return(df.data)
deepchecks_results_materializer

Implementation of Deepchecks suite results materializer.

DeepchecksResultMaterializer (BaseMaterializer)

Materializer to read data to and from CheckResult and SuiteResult objects.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
class DeepchecksResultMaterializer(BaseMaterializer):
    """Materializer to read data to and from CheckResult and SuiteResult objects."""

    ASSOCIATED_TYPES = (
        CheckResult,
        SuiteResult,
    )
    ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)

    def handle_input(
        self, data_type: Type[Any]
    ) -> Union[CheckResult, SuiteResult]:
        """Reads a Deepchecks check or suite result from a serialized JSON file.

        Args:
            data_type: The type of the data to read.

        Returns:
            A Deepchecks CheckResult or SuiteResult.

        Raises:
            RuntimeError: if the input data type is not supported.
        """
        super().handle_input(data_type)
        filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

        json_res = io_utils.read_file_contents_as_string(filepath)
        if data_type == SuiteResult:
            res = SuiteResult.from_json(json_res)
        elif data_type == CheckResult:
            res = CheckResult.from_json(json_res)
        else:
            raise RuntimeError(f"Unknown data type: {data_type}")
        return res

    def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
        """Creates a JSON serialization for a CheckResult or SuiteResult.

        Args:
            result: A Deepchecks CheckResult or SuiteResult.
        """
        super().handle_return(result)

        filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

        serialized_json = result.to_json(True)
        io_utils.write_file_contents_as_string(filepath, serialized_json)
handle_input(self, data_type)

Reads a Deepchecks check or suite result from a serialized JSON file.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the data to read.

required

Returns:

Type Description
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult]

A Deepchecks CheckResult or SuiteResult.

Exceptions:

Type Description
RuntimeError

if the input data type is not supported.

Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_input(
    self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
    """Reads a Deepchecks check or suite result from a serialized JSON file.

    Args:
        data_type: The type of the data to read.

    Returns:
        A Deepchecks CheckResult or SuiteResult.

    Raises:
        RuntimeError: if the input data type is not supported.
    """
    super().handle_input(data_type)
    filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

    json_res = io_utils.read_file_contents_as_string(filepath)
    if data_type == SuiteResult:
        res = SuiteResult.from_json(json_res)
    elif data_type == CheckResult:
        res = CheckResult.from_json(json_res)
    else:
        raise RuntimeError(f"Unknown data type: {data_type}")
    return res
handle_return(self, result)

Creates a JSON serialization for a CheckResult or SuiteResult.

Parameters:

Name Type Description Default
result Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult]

A Deepchecks CheckResult or SuiteResult.

required
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
    """Creates a JSON serialization for a CheckResult or SuiteResult.

    Args:
        result: A Deepchecks CheckResult or SuiteResult.
    """
    super().handle_return(result)

    filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)

    serialized_json = result.to_json(True)
    io_utils.write_file_contents_as_string(filepath, serialized_json)

steps special

Initialization of the Deepchecks Standard Steps.

deepchecks_data_drift

Implementation of the Deepchecks data drift validation step.

DeepchecksDataDriftCheckStep (BaseStep)

Deepchecks data drift validator step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStep(BaseStep):
    """Deepchecks data drift validator step."""

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        target_dataset: pd.DataFrame,
        config: DeepchecksDataDriftCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks data drift validator step.

        Args:
            reference_dataset: Reference dataset for the data drift check.
            target_dataset: Target dataset to be used for the data drift check.
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.data_validation(
            dataset=reference_dataset,
            comparison_dataset=target_dataset,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks data drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]]

Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataDriftCheck identifiers
            specifying the subset of Deepchecks data drift checks to be
            performed. If not supplied, the entire set of data drift checks will
            be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, config)

Main entrypoint for the Deepchecks data drift validator step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

Reference dataset for the data drift check.

required
target_dataset DataFrame

Target dataset to be used for the data drift check.

required
config DeepchecksDataDriftCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    target_dataset: pd.DataFrame,
    config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks data drift validator step.

    Args:
        reference_dataset: Reference dataset for the data drift check.
        target_dataset: Target dataset to be used for the data drift check.
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.data_validation(
        dataset=reference_dataset,
        comparison_dataset=target_dataset,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksDataDriftCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks data drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]]

Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataDriftCheck identifiers
            specifying the subset of Deepchecks data drift checks to be
            performed. If not supplied, the entire set of data drift checks will
            be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_drift_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.

The returned DeepchecksDataDriftCheckStep can be used in a pipeline to run data drift checks on two input pd.DataFrame and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksDataDriftCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksDataDriftCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def deepchecks_data_drift_check_step(
    step_name: str,
    config: DeepchecksDataDriftCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.

    The returned DeepchecksDataDriftCheckStep can be used in a pipeline to
    run data drift checks on two input pd.DataFrame and return the results
    as a Deepchecks SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksDataDriftCheckStep step instance
    """
    return clone_step(DeepchecksDataDriftCheckStep, step_name)(config=config)
deepchecks_data_integrity

Implementation of the Deepchecks data integrity validation step.

DeepchecksDataIntegrityCheckStep (BaseStep)

Deepchecks data integrity validator step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStep(BaseStep):
    """Deepchecks data integrity validator step."""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        config: DeepchecksDataIntegrityCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks data integrity validator step.

        Args:
            dataset: a Pandas DataFrame to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.data_validation(
            dataset=dataset,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks data integrity validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]]

Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data integrity validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
            specifying the subset of Deepchecks data integrity checks to be
            performed. If not supplied, the entire set of data integrity checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, config)

Main entrypoint for the Deepchecks data integrity validator step.

Parameters:

Name Type Description Default
dataset DataFrame

a Pandas DataFrame to validate

required
config DeepchecksDataIntegrityCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks data integrity validator step.

    Args:
        dataset: a Pandas DataFrame to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.data_validation(
        dataset=dataset,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksDataIntegrityCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks data integrity validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]]

Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks data integrity validator step.

    Attributes:
        check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
            specifying the subset of Deepchecks data integrity checks to be
            performed. If not supplied, the entire set of data integrity checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_integrity_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.

The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to run data integrity checks on an input pd.DataFrame and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksDataIntegrityCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksDataIntegrityCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def deepchecks_data_integrity_check_step(
    step_name: str,
    config: DeepchecksDataIntegrityCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.

    The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to
    run data integrity checks on an input pd.DataFrame and return the results
    as a Deepchecks SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksDataIntegrityCheckStep step instance
    """
    return clone_step(DeepchecksDataIntegrityCheckStep, step_name)(
        config=config
    )
deepchecks_model_drift

Implementation of the Deepchecks model drift validation step.

DeepchecksModelDriftCheckStep (BaseStep)

Deepchecks model drift step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStep(BaseStep):
    """Deepchecks model drift step."""

    def entrypoint(  # type: ignore[override]
        self,
        reference_dataset: pd.DataFrame,
        target_dataset: pd.DataFrame,
        model: ClassifierMixin,
        config: DeepchecksModelDriftCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks model drift step.

        Args:
            reference_dataset: Reference dataset for the model drift check.
            target_dataset: Target dataset to be used for the model drift check.
            model: a scikit-learn model to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.model_validation(
            dataset=reference_dataset,
            comparison_dataset=target_dataset,
            model=model,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks model drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]]

Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelDriftCheck identifiers
            specifying the subset of Deepchecks model drift checks to be
            performed. If not supplied, the entire set of model drift checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, model, config)

Main entrypoint for the Deepchecks model drift step.

Parameters:

Name Type Description Default
reference_dataset DataFrame

Reference dataset for the model drift check.

required
target_dataset DataFrame

Target dataset to be used for the model drift check.

required
model ClassifierMixin

a scikit-learn model to validate

required
config DeepchecksModelDriftCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def entrypoint(  # type: ignore[override]
    self,
    reference_dataset: pd.DataFrame,
    target_dataset: pd.DataFrame,
    model: ClassifierMixin,
    config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks model drift step.

    Args:
        reference_dataset: Reference dataset for the model drift check.
        target_dataset: Target dataset to be used for the model drift check.
        model: a scikit-learn model to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.model_validation(
        dataset=reference_dataset,
        comparison_dataset=target_dataset,
        model=model,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksModelDriftCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks model drift validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]]

Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model drift validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelDriftCheck identifiers
            specifying the subset of Deepchecks model drift checks to be
            performed. If not supplied, the entire set of model drift checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_drift_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.

The returned DeepchecksModelDriftCheckStep can be used in a pipeline to run model drift checks on two input pd.DataFrame datasets and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksModelDriftCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksModelDriftCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def deepchecks_model_drift_check_step(
    step_name: str,
    config: DeepchecksModelDriftCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.

    The returned DeepchecksModelDriftCheckStep can be used in a pipeline to
    run model drift checks on two input pd.DataFrame datasets and an input
    scikit-learn ClassifierMixin model and return the results as a Deepchecks
    SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksModelDriftCheckStep step instance
    """
    return clone_step(DeepchecksModelDriftCheckStep, step_name)(config=config)
deepchecks_model_validation

Implementation of the Deepchecks model validation validation step.

DeepchecksModelValidationCheckStep (BaseStep)

Deepchecks model validation step.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStep(BaseStep):
    """Deepchecks model validation step."""

    def entrypoint(  # type: ignore[override]
        self,
        dataset: pd.DataFrame,
        model: ClassifierMixin,
        config: DeepchecksModelValidationCheckStepConfig,
    ) -> SuiteResult:
        """Main entrypoint for the Deepchecks model validation step.

        Args:
            dataset: a Pandas DataFrame to use for the validation
            model: a scikit-learn model to validate
            config: the configuration for the step

        Returns:
            A Deepchecks suite result with the validation results.
        """
        data_validator = cast(
            DeepchecksDataValidator,
            DeepchecksDataValidator.get_active_data_validator(),
        )

        return data_validator.model_validation(
            dataset=dataset,
            model=model,
            check_list=cast(Optional[Sequence[str]], config.check_list),
            dataset_kwargs=config.dataset_kwargs,
            check_kwargs=config.check_kwargs,
            run_kwargs=config.run_kwargs,
        )
CONFIG_CLASS (BaseStepConfig) pydantic-model

Config class for the Deepchecks model validation validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]]

Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model validation validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelValidationCheck identifiers
            specifying the subset of Deepchecks model validation checks to be
            performed. If not supplied, the entire set of model validation checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, model, config)

Main entrypoint for the Deepchecks model validation step.

Parameters:

Name Type Description Default
dataset DataFrame

a Pandas DataFrame to use for the validation

required
model ClassifierMixin

a scikit-learn model to validate

required
config DeepchecksModelValidationCheckStepConfig

the configuration for the step

required

Returns:

Type Description
SuiteResult

A Deepchecks suite result with the validation results.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def entrypoint(  # type: ignore[override]
    self,
    dataset: pd.DataFrame,
    model: ClassifierMixin,
    config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
    """Main entrypoint for the Deepchecks model validation step.

    Args:
        dataset: a Pandas DataFrame to use for the validation
        model: a scikit-learn model to validate
        config: the configuration for the step

    Returns:
        A Deepchecks suite result with the validation results.
    """
    data_validator = cast(
        DeepchecksDataValidator,
        DeepchecksDataValidator.get_active_data_validator(),
    )

    return data_validator.model_validation(
        dataset=dataset,
        model=model,
        check_list=cast(Optional[Sequence[str]], config.check_list),
        dataset_kwargs=config.dataset_kwargs,
        check_kwargs=config.check_kwargs,
        run_kwargs=config.run_kwargs,
    )
DeepchecksModelValidationCheckStepConfig (BaseStepConfig) pydantic-model

Config class for the Deepchecks model validation validator step.

Attributes:

Name Type Description
check_list Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]]

Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed.

dataset_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor.

check_kwargs Dict[str, Dict[str, Any]]

Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys.

run_kwargs Dict[str, Any]

Additional keyword arguments to be passed to the Deepchecks Suite run method.

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
    """Config class for the Deepchecks model validation validator step.

    Attributes:
        check_list: Optional list of DeepchecksModelValidationCheck identifiers
            specifying the subset of Deepchecks model validation checks to be
            performed. If not supplied, the entire set of model validation checks
            will be performed.
        dataset_kwargs: Additional keyword arguments to be passed to the
            Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
        check_kwargs: Additional keyword arguments to be passed to the
            Deepchecks check object constructors. Arguments are grouped for
            each check and indexed using the full check class name or
            check enum value as dictionary keys.
        run_kwargs: Additional keyword arguments to be passed to the
            Deepchecks Suite `run` method.
    """

    check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
    dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
    check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
    run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_validation_check_step(step_name, config)

Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.

The returned DeepchecksModelValidationCheckStep can be used in a pipeline to run model validation checks on an input pd.DataFrame dataset and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.

Parameters:

Name Type Description Default
step_name str

The name of the step

required
config DeepchecksModelValidationCheckStepConfig

The configuration for the step

required

Returns:

Type Description
BaseStep

a DeepchecksModelValidationCheckStep step instance

Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def deepchecks_model_validation_check_step(
    step_name: str,
    config: DeepchecksModelValidationCheckStepConfig,
) -> BaseStep:
    """Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.

    The returned DeepchecksModelValidationCheckStep can be used in a pipeline to
    run model validation checks on an input pd.DataFrame dataset and an input
    scikit-learn ClassifierMixin model and return the results as a Deepchecks
    SuiteResult object.

    Args:
        step_name: The name of the step
        config: The configuration for the step

    Returns:
        a DeepchecksModelValidationCheckStep step instance
    """
    return clone_step(DeepchecksModelValidationCheckStep, step_name)(
        config=config
    )

validation_checks

Definition of the Deepchecks validation check types.

DeepchecksDataDriftCheck (DeepchecksValidationCheck)

Categories of Deepchecks data drift checks.

This list reflects the set of train-test validation checks provided by Deepchecks:

All these checks inherit from deepchecks.tabular.TrainTestCheck or deepchecks.vision.TrainTestCheck and require two datasets as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks data drift checks.

    This list reflects the set of train-test validation checks provided by
    Deepchecks:

      * [for tabular data](https://docs.deepchecks.com/stable/checks_gallery/tabular.html#train-test-validation)
      * [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#train-test-validation)

    All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
    `deepchecks.vision.TrainTestCheck` and require two datasets as input.
    """

    TABULAR_CATEGORY_MISMATCH_TRAIN_TEST = resolve_class(
        tabular_checks.CategoryMismatchTrainTest
    )
    TABULAR_DATASET_SIZE_COMPARISON = resolve_class(
        tabular_checks.DatasetsSizeComparison
    )
    TABULAR_DATE_TRAIN_TEST_LEAKAGE_DUPLICATES = resolve_class(
        tabular_checks.DateTrainTestLeakageDuplicates
    )
    TABULAR_DATE_TRAIN_TEST_LEAKAGE_OVERLAP = resolve_class(
        tabular_checks.DateTrainTestLeakageOverlap
    )
    TABULAR_DOMINANT_FREQUENCY_CHANGE = resolve_class(
        tabular_checks.DominantFrequencyChange
    )
    TABULAR_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
        tabular_checks.FeatureLabelCorrelationChange
    )
    TABULAR_INDEX_LEAKAGE = resolve_class(tabular_checks.IndexTrainTestLeakage)
    TABULAR_NEW_LABEL_TRAIN_TEST = resolve_class(
        tabular_checks.NewLabelTrainTest
    )
    TABULAR_STRING_MISMATCH_COMPARISON = resolve_class(
        tabular_checks.StringMismatchComparison
    )
    TABULAR_TRAIN_TEST_FEATURE_DRIFT = resolve_class(
        tabular_checks.TrainTestFeatureDrift
    )
    TABULAR_TRAIN_TEST_LABEL_DRIFT = resolve_class(
        tabular_checks.TrainTestLabelDrift
    )
    TABULAR_TRAIN_TEST_SAMPLES_MIX = resolve_class(
        tabular_checks.TrainTestSamplesMix
    )
    TABULAR_WHOLE_DATASET_DRIFT = resolve_class(
        tabular_checks.WholeDatasetDrift
    )

    VISION_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
        vision_checks.FeatureLabelCorrelationChange
    )
    VISION_HEATMAP_COMPARISON = resolve_class(vision_checks.HeatmapComparison)
    VISION_IMAGE_DATASET_DRIFT = resolve_class(vision_checks.ImageDatasetDrift)
    VISION_IMAGE_PROPERTY_DRIFT = resolve_class(
        vision_checks.ImagePropertyDrift
    )
    VISION_NEW_LABELS = resolve_class(vision_checks.NewLabels)
    VISION_SIMILAR_IMAGE_LEAKAGE = resolve_class(
        vision_checks.SimilarImageLeakage
    )
    VISION_TRAIN_TEST_LABEL_DRIFT = resolve_class(
        vision_checks.TrainTestLabelDrift
    )
DeepchecksDataIntegrityCheck (DeepchecksValidationCheck)

Categories of Deepchecks data integrity checks.

This list reflects the set of data integrity checks provided by Deepchecks:

All these checks inherit from deepchecks.tabular.SingleDatasetCheck or deepchecks.vision.SingleDatasetCheck and require a single dataset as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks data integrity checks.

    This list reflects the set of data integrity checks provided by Deepchecks:

      * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
      * [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)

    All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
    `deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
    """

    TABULAR_COLUMNS_INFO = resolve_class(tabular_checks.ColumnsInfo)
    TABULAR_CONFLICTING_LABELS = resolve_class(tabular_checks.ConflictingLabels)
    TABULAR_DATA_DUPLICATES = resolve_class(tabular_checks.DataDuplicates)
    TABULAR_FEATURE_FEATURE_CORRELATION = resolve_class(
        FeatureFeatureCorrelation
    )
    TABULAR_FEATURE_LABEL_CORRELATION = resolve_class(
        tabular_checks.FeatureLabelCorrelation
    )
    TABULAR_IDENTIFIER_LEAKAGE = resolve_class(tabular_checks.IdentifierLeakage)
    TABULAR_IS_SINGLE_VALUE = resolve_class(tabular_checks.IsSingleValue)
    TABULAR_MIXED_DATA_TYPES = resolve_class(tabular_checks.MixedDataTypes)
    TABULAR_MIXED_NULLS = resolve_class(tabular_checks.MixedNulls)
    TABULAR_OUTLIER_SAMPLE_DETECTION = resolve_class(
        tabular_checks.OutlierSampleDetection
    )
    TABULAR_SPECIAL_CHARS = resolve_class(tabular_checks.SpecialCharacters)
    TABULAR_STRING_LENGTH_OUT_OF_BOUNDS = resolve_class(
        tabular_checks.StringLengthOutOfBounds
    )
    TABULAR_STRING_MISMATCH = resolve_class(tabular_checks.StringMismatch)

    VISION_IMAGE_PROPERTY_OUTLIERS = resolve_class(
        vision_checks.ImagePropertyOutliers
    )
    VISION_LABEL_PROPERTY_OUTLIERS = resolve_class(
        vision_checks.LabelPropertyOutliers
    )
DeepchecksModelDriftCheck (DeepchecksValidationCheck)

Categories of Deepchecks model drift checks.

This list includes a subset of the model evaluation checks provided by Deepchecks that require two datasets and a mandatory model as input:

All these checks inherit from deepchecks.tabular.TrainTestCheck or deepchecks.vision.TrainTestCheck and require two datasets and a mandatory model as input.

Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
    """Categories of Deepchecks model drift checks.

    This list includes a subset of the model evaluation checks provided by
    Deepchecks that require two datasets and a mandatory model as input:

      * [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
      * [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)

    All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
    `deepchecks.vision.TrainTestCheck` and require two datasets and a mandatory
    model as input.
    """

    TABULAR_BOOSTING_OVERFIT = resolve_class(tabular_checks.BoostingOverfit)
    TABULAR_MODEL_ERROR_ANALYSIS = resolve_class(
        tabular_checks.ModelErrorAnalysis
    )
    TABULAR_PERFORMANCE_REPORT = resolve_class(tabular_checks.PerformanceReport)
    TABULAR_SIMPLE_MODEL_COMPARISON = resolve_class(
        tabular_checks.SimpleModelComparison
    )
    TABULAR_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
        tabular_checks.TrainTestPredictionDrift
    )
    TABULAR_UNUSED_FEATURES = resolve_class(tabular_checks.UnusedFeatures)

    VISION_CLASS_PERFORMANCE = resolve_class(vision_checks.ClassPerformance)
    VISION_MODEL_ERROR_ANALYSIS = resolve_class(
        vision_checks.ModelErrorAnalysis
    )
    VISION_SIMPLE_MODEL_COMPARISON = resolve_class(
        vision_checks.SimpleModelComparison
    )
    VISION_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
        vision_checks.TrainTestPredictionDrift
    )
DeepchecksModelValidationCheck (DeepchecksValidationCheck)