Skip to content

Mlflow

zenml.integrations.mlflow special

Initialization for the ZenML MLflow integration.

The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.

MlflowIntegration (Integration)

Definition of MLflow integration for ZenML.

Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
    """Definition of MLflow integration for ZenML."""

    NAME = MLFLOW
    # We need to pin protobuf to a version <=4 here, as this mlflow release
    # does not pin it. They fixed this in a later version, so we can probably
    # remove this once we update the mlflow version.
    REQUIREMENTS = [
        "mlflow>=1.24.0,<=2.2.0",
        "mlserver>=0.5.3",
        "mlserver-mlflow>=0.5.3",
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the MLflow integration."""
        from zenml.integrations.mlflow import services  # noqa

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the MLflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.mlflow.flavors import (
            MLFlowExperimentTrackerFlavor,
            MLFlowModelDeployerFlavor,
            MLFlowModelRegistryFlavor,
        )

        return [
            MLFlowModelDeployerFlavor,
            MLFlowExperimentTrackerFlavor,
            MLFlowModelRegistryFlavor,
        ]

activate() classmethod

Activate the MLflow integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the MLflow integration."""
    from zenml.integrations.mlflow import services  # noqa

flavors() classmethod

Declare the stack component flavors for the MLflow integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the MLflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.mlflow.flavors import (
        MLFlowExperimentTrackerFlavor,
        MLFlowModelDeployerFlavor,
        MLFlowModelRegistryFlavor,
    )

    return [
        MLFlowModelDeployerFlavor,
        MLFlowExperimentTrackerFlavor,
        MLFlowModelRegistryFlavor,
    ]

experiment_trackers special

Initialization of the MLflow experiment tracker.

mlflow_experiment_tracker

Implementation of the MLflow experiment tracker for ZenML.

MLFlowExperimentTracker (BaseExperimentTracker)

Track experiments using MLflow.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
    """Track experiments using MLflow."""

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initialize the experiment tracker and validate the tracking uri.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self._ensure_valid_tracking_uri()

    def _ensure_valid_tracking_uri(self) -> None:
        """Ensures that the tracking uri is a valid mlflow tracking uri.

        Raises:
            ValueError: If the tracking uri is not valid.
        """
        tracking_uri = self.config.tracking_uri
        if tracking_uri:
            valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
            if not any(
                tracking_uri.startswith(scheme) for scheme in valid_schemes
            ) and not is_databricks_tracking_uri(tracking_uri):
                raise ValueError(
                    f"MLflow tracking uri does not start with one of the valid "
                    f"schemes {valid_schemes} or its value is not set to "
                    f"'databricks'. See "
                    f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
                    f"for more information."
                )

    @property
    def config(self) -> MLFlowExperimentTrackerConfig:
        """Returns the `MLFlowExperimentTrackerConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowExperimentTrackerConfig, self._config)

    @property
    def local_path(self) -> Optional[str]:
        """Path to the local directory where the MLflow artifacts are stored.

        Returns:
            None if configured with a remote tracking URI, otherwise the
            path to the local MLflow artifact store directory.
        """
        tracking_uri = self.get_tracking_uri()
        if is_remote_mlflow_tracking_uri(tracking_uri):
            return None
        else:
            assert tracking_uri.startswith("file:")
            return tracking_uri[5:]

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.

        Returns:
            An optional `StackValidator`.
        """
        if self.config.tracking_uri:
            # user specified a tracking uri, do nothing
            return None
        else:
            # try to fall back to a tracking uri inside the zenml artifact
            # store. this only works in case of a local artifact store, so we
            # make sure to prevent stack with other artifact stores for now
            return StackValidator(
                custom_validation_function=lambda stack: (
                    isinstance(stack.artifact_store, LocalArtifactStore),
                    "MLflow experiment tracker without a specified tracking "
                    "uri only works with a local artifact store.",
                )
            )

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Mlflow experiment tracker.

        Returns:
            The settings class.
        """
        return MLFlowExperimentTrackerSettings

    @staticmethod
    def _local_mlflow_backend() -> str:
        """Gets the local MLflow backend inside the ZenML artifact repository directory.

        Returns:
            The MLflow tracking URI for the local MLflow backend.
        """
        client = Client()
        artifact_store = client.active_stack.artifact_store
        local_mlflow_tracking_uri = os.path.join(artifact_store.path, "mlruns")
        if not os.path.exists(local_mlflow_tracking_uri):
            os.makedirs(local_mlflow_tracking_uri)
        return "file:" + local_mlflow_tracking_uri

    def get_tracking_uri(self) -> str:
        """Returns the configured tracking URI or a local fallback.

        Returns:
            The tracking URI.
        """
        return self.config.tracking_uri or self._local_mlflow_backend()

    def prepare_step_run(self, info: "StepRunInfo") -> None:
        """Sets the MLflow tracking uri and credentials.

        Args:
            info: Info about the step that will be executed.
        """
        self.configure_mlflow()
        settings = cast(
            MLFlowExperimentTrackerSettings,
            self.get_settings(info),
        )

        experiment_name = settings.experiment_name or info.pipeline.name
        experiment = self._set_active_experiment(experiment_name)
        run_id = self.get_run_id(
            experiment_name=experiment_name, run_name=info.run_name
        )

        tags = settings.tags.copy()
        tags.update(self._get_internal_tags())

        mlflow.start_run(
            run_id=run_id,
            run_name=info.run_name,
            experiment_id=experiment.experiment_id,
            tags=tags,
        )

        if settings.nested:
            mlflow.start_run(run_name=info.config.name, nested=True, tags=tags)

    def get_step_run_metadata(
        self, info: "StepRunInfo"
    ) -> Dict[str, "MetadataType"]:
        """Get component- and step-specific metadata after a step ran.

        Args:
            info: Info about the step that was executed.

        Returns:
            A dictionary of metadata.
        """
        return {
            METADATA_EXPERIMENT_TRACKER_URL: Uri(self.get_tracking_uri()),
            "mlflow_run_id": mlflow.active_run().info.run_id,
            "mlflow_experiment_id": mlflow.active_run().info.experiment_id,
        }

    def disable_autologging(self) -> None:
        """Disables MLflow autologging."""
        from mlflow import (
            fastai,
            gluon,
            lightgbm,
            pytorch,
            sklearn,
            spark,
            statsmodels,
            tensorflow,
            xgboost,
        )

        # There is no way to disable auto-logging for all frameworks at once.
        # If auto-logging is explicitly enabled for a framework by calling its
        # autolog() method, it cannot be disabled by calling
        # `mlflow.autolog(disable=True)`. Therefore, we need to disable
        # auto-logging for all frameworks explicitly.

        tensorflow.autolog(disable=True)
        gluon.autolog(disable=True)
        xgboost.autolog(disable=True)
        lightgbm.autolog(disable=True)
        statsmodels.autolog(disable=True)
        spark.autolog(disable=True)
        sklearn.autolog(disable=True)
        fastai.autolog(disable=True)
        pytorch.autolog(disable=True)

    def cleanup_step_run(
        self,
        info: "StepRunInfo",
        step_failed: bool,
    ) -> None:
        """Stops active MLflow runs and resets the MLflow tracking uri.

        Args:
            info: Info about the step that was executed.
            step_failed: Whether the step failed or not.
        """
        status = "FAILED" if step_failed else "FINISHED"
        self.disable_autologging()
        mlflow_utils.stop_zenml_mlflow_runs(status)
        mlflow.set_tracking_uri("")

    def configure_mlflow(self) -> None:
        """Configures the MLflow tracking URI and any additional credentials."""
        tracking_uri = self.get_tracking_uri()
        mlflow.set_tracking_uri(tracking_uri)
        mlflow.set_registry_uri(tracking_uri)

        if is_databricks_tracking_uri(tracking_uri):
            if self.config.databricks_host:
                os.environ[DATABRICKS_HOST] = self.config.databricks_host
            if self.config.tracking_username:
                os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
            if self.config.tracking_password:
                os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
            if self.config.tracking_token:
                os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
        else:
            if self.config.tracking_username:
                os.environ[
                    MLFLOW_TRACKING_USERNAME
                ] = self.config.tracking_username
            if self.config.tracking_password:
                os.environ[
                    MLFLOW_TRACKING_PASSWORD
                ] = self.config.tracking_password
            if self.config.tracking_token:
                os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

        os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
            "true" if self.config.tracking_insecure_tls else "false"
        )

    def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
        """Gets the if of a run with the given name and experiment.

        Args:
            experiment_name: Name of the experiment in which to search for the
                run.
            run_name: Name of the run to search.

        Returns:
            The id of the run if it exists.
        """
        self.configure_mlflow()
        experiment_name = self._adjust_experiment_name(experiment_name)

        runs = mlflow.search_runs(
            experiment_names=[experiment_name],
            filter_string=f'tags.mlflow.runName = "{run_name}"',
            run_view_type=3,
            output_format="list",
        )
        if not runs:
            return None

        run: Run = runs[0]
        if mlflow_utils.is_zenml_run(run):
            return cast(str, run.info.run_id)
        else:
            return None

    def _set_active_experiment(self, experiment_name: str) -> Experiment:
        """Sets the active MLflow experiment.

        If no experiment with this name exists, it is created and then
        activated.

        Args:
            experiment_name: Name of the experiment to activate.

        Raises:
            RuntimeError: If the experiment creation or activation failed.

        Returns:
            The experiment.
        """
        experiment_name = self._adjust_experiment_name(experiment_name)

        mlflow.set_experiment(experiment_name=experiment_name)
        experiment = mlflow.get_experiment_by_name(experiment_name)
        if not experiment:
            raise RuntimeError("Failed to set active mlflow experiment.")
        return experiment

    def _adjust_experiment_name(self, experiment_name: str) -> str:
        """Prepends a slash to the experiment name if using Databricks.

        Databricks requires the experiment name to be an absolute path within
        the Databricks workspace.

        Args:
            experiment_name: The experiment name.

        Returns:
            The potentially adjusted experiment name.
        """
        tracking_uri = self.get_tracking_uri()

        if (
            tracking_uri
            and is_databricks_tracking_uri(tracking_uri)
            and not experiment_name.startswith("/")
        ):
            return f"/{experiment_name}"
        else:
            return experiment_name

    @staticmethod
    def _get_internal_tags() -> Dict[str, Any]:
        """Gets ZenML internal tags for MLflow runs.

        Returns:
            Internal tags.
        """
        return {mlflow_utils.ZENML_TAG_KEY: zenml.__version__}
config: MLFlowExperimentTrackerConfig property readonly

Returns the MLFlowExperimentTrackerConfig config.

Returns:

Type Description
MLFlowExperimentTrackerConfig

The configuration.

local_path: Optional[str] property readonly

Path to the local directory where the MLflow artifacts are stored.

Returns:

Type Description
Optional[str]

None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Mlflow experiment tracker.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property readonly

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

__init__(self, *args, **kwargs) special

Initialize the experiment tracker and validate the tracking uri.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the experiment tracker and validate the tracking uri.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self._ensure_valid_tracking_uri()
cleanup_step_run(self, info, step_failed)

Stops active MLflow runs and resets the MLflow tracking uri.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required
step_failed bool

Whether the step failed or not.

required
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(
    self,
    info: "StepRunInfo",
    step_failed: bool,
) -> None:
    """Stops active MLflow runs and resets the MLflow tracking uri.

    Args:
        info: Info about the step that was executed.
        step_failed: Whether the step failed or not.
    """
    status = "FAILED" if step_failed else "FINISHED"
    self.disable_autologging()
    mlflow_utils.stop_zenml_mlflow_runs(status)
    mlflow.set_tracking_uri("")
configure_mlflow(self)

Configures the MLflow tracking URI and any additional credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    tracking_uri = self.get_tracking_uri()
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_registry_uri(tracking_uri)

    if is_databricks_tracking_uri(tracking_uri):
        if self.config.databricks_host:
            os.environ[DATABRICKS_HOST] = self.config.databricks_host
        if self.config.tracking_username:
            os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
        if self.config.tracking_password:
            os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
        if self.config.tracking_token:
            os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
    else:
        if self.config.tracking_username:
            os.environ[
                MLFLOW_TRACKING_USERNAME
            ] = self.config.tracking_username
        if self.config.tracking_password:
            os.environ[
                MLFLOW_TRACKING_PASSWORD
            ] = self.config.tracking_password
        if self.config.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.config.tracking_insecure_tls else "false"
    )
disable_autologging(self)

Disables MLflow autologging.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def disable_autologging(self) -> None:
    """Disables MLflow autologging."""
    from mlflow import (
        fastai,
        gluon,
        lightgbm,
        pytorch,
        sklearn,
        spark,
        statsmodels,
        tensorflow,
        xgboost,
    )

    # There is no way to disable auto-logging for all frameworks at once.
    # If auto-logging is explicitly enabled for a framework by calling its
    # autolog() method, it cannot be disabled by calling
    # `mlflow.autolog(disable=True)`. Therefore, we need to disable
    # auto-logging for all frameworks explicitly.

    tensorflow.autolog(disable=True)
    gluon.autolog(disable=True)
    xgboost.autolog(disable=True)
    lightgbm.autolog(disable=True)
    statsmodels.autolog(disable=True)
    spark.autolog(disable=True)
    sklearn.autolog(disable=True)
    fastai.autolog(disable=True)
    pytorch.autolog(disable=True)
get_run_id(self, experiment_name, run_name)

Gets the if of a run with the given name and experiment.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment in which to search for the run.

required
run_name str

Name of the run to search.

required

Returns:

Type Description
Optional[str]

The id of the run if it exists.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
    """Gets the if of a run with the given name and experiment.

    Args:
        experiment_name: Name of the experiment in which to search for the
            run.
        run_name: Name of the run to search.

    Returns:
        The id of the run if it exists.
    """
    self.configure_mlflow()
    experiment_name = self._adjust_experiment_name(experiment_name)

    runs = mlflow.search_runs(
        experiment_names=[experiment_name],
        filter_string=f'tags.mlflow.runName = "{run_name}"',
        run_view_type=3,
        output_format="list",
    )
    if not runs:
        return None

    run: Run = runs[0]
    if mlflow_utils.is_zenml_run(run):
        return cast(str, run.info.run_id)
    else:
        return None
get_step_run_metadata(self, info)

Get component- and step-specific metadata after a step ran.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_step_run_metadata(
    self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
    """Get component- and step-specific metadata after a step ran.

    Args:
        info: Info about the step that was executed.

    Returns:
        A dictionary of metadata.
    """
    return {
        METADATA_EXPERIMENT_TRACKER_URL: Uri(self.get_tracking_uri()),
        "mlflow_run_id": mlflow.active_run().info.run_id,
        "mlflow_experiment_id": mlflow.active_run().info.experiment_id,
    }
get_tracking_uri(self)

Returns the configured tracking URI or a local fallback.

Returns:

Type Description
str

The tracking URI.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self) -> str:
    """Returns the configured tracking URI or a local fallback.

    Returns:
        The tracking URI.
    """
    return self.config.tracking_uri or self._local_mlflow_backend()
prepare_step_run(self, info)

Sets the MLflow tracking uri and credentials.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that will be executed.

required
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
    """Sets the MLflow tracking uri and credentials.

    Args:
        info: Info about the step that will be executed.
    """
    self.configure_mlflow()
    settings = cast(
        MLFlowExperimentTrackerSettings,
        self.get_settings(info),
    )

    experiment_name = settings.experiment_name or info.pipeline.name
    experiment = self._set_active_experiment(experiment_name)
    run_id = self.get_run_id(
        experiment_name=experiment_name, run_name=info.run_name
    )

    tags = settings.tags.copy()
    tags.update(self._get_internal_tags())

    mlflow.start_run(
        run_id=run_id,
        run_name=info.run_name,
        experiment_id=experiment.experiment_id,
        tags=tags,
    )

    if settings.nested:
        mlflow.start_run(run_name=info.config.name, nested=True, tags=tags)

flavors special

MLFlow integration flavors.

mlflow_experiment_tracker_flavor

MLflow experiment tracker flavor.

MLFlowExperimentTrackerConfig (BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings) pydantic-model

Config for the MLflow experiment tracker.

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

databricks_host Optional[str]

The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if tracking_uri value is set to "databricks".

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerConfig(  # type: ignore[misc] # https://github.com/pydantic/pydantic/issues/4173
    BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings
):
    """Config for the MLflow experiment tracker.

    Attributes:
        tracking_uri: The uri of the mlflow tracking server. If no uri is set,
            your stack must contain a `LocalArtifactStore` and ZenML will
            point MLflow to a subdirectory of your artifact store instead.
        tracking_username: Username for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_password: Password for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_token: Token for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_insecure_tls: Skips verification of TLS connection to the
            MLflow tracking server if set to `True`.
        databricks_host: The host of the Databricks workspace with the MLflow
            managed server to connect to. This is only required if
            `tracking_uri` value is set to `"databricks"`.
    """

    tracking_uri: Optional[str] = None
    tracking_username: Optional[str] = SecretField()
    tracking_password: Optional[str] = SecretField()
    tracking_token: Optional[str] = SecretField()
    tracking_insecure_tls: bool = False
    databricks_host: Optional[str] = None

    @root_validator(skip_on_failure=True)
    def _ensure_authentication_if_necessary(
        cls, values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Ensures that credentials or a token for authentication exist.

        We make this check when running MLflow tracking with a remote backend.

        Args:
            values: The values to validate.

        Returns:
            The validated values.

        Raises:
            ValueError: If neither credentials nor a token are provided.
        """
        tracking_uri = values.get("tracking_uri")

        if tracking_uri:
            if is_databricks_tracking_uri(tracking_uri):
                # If the tracking uri is "databricks", then we need the databricks
                # host to be set.
                databricks_host = values.get("databricks_host")

                if not databricks_host:
                    raise ValueError(
                        "MLflow experiment tracking with a Databricks MLflow "
                        "managed tracking server requires the `databricks_host` "
                        "to be set in your stack component. To update your "
                        "component, run `zenml experiment-tracker update "
                        "<NAME> --databricks_host=DATABRICKS_HOST` "
                        "and specify the hostname of your Databricks workspace."
                    )

            if is_remote_mlflow_tracking_uri(tracking_uri):
                # we need either username + password or a token to authenticate to
                # the remote backend
                basic_auth = values.get("tracking_username") and values.get(
                    "tracking_password"
                )
                token_auth = values.get("tracking_token")

                if not (basic_auth or token_auth):
                    raise ValueError(
                        f"MLflow experiment tracking with a remote backend "
                        f"{tracking_uri} is only possible when specifying either "
                        f"username and password or an authentication token in your "
                        f"stack component. To update your component, run the "
                        f"following command: `zenml experiment-tracker update "
                        f"<NAME> --tracking_username=MY_USERNAME "
                        f"--tracking_password=MY_PASSWORD "
                        f"--tracking_token=MY_TOKEN` and specify either your "
                        f"username and password or token."
                    )

        return values

    @property
    def is_local(self) -> bool:
        """Checks if this stack component is running locally.

        This designation is used to determine if the stack component can be
        shared with other users or if it is only usable on the local host.

        Returns:
            True if this config is for a local component, False otherwise.
        """
        if not self.tracking_uri or not is_remote_mlflow_tracking_uri(
            self.tracking_uri
        ):
            return True
        return False
is_local: bool property readonly

Checks if this stack component is running locally.

This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowExperimentTrackerFlavor (BaseExperimentTrackerFlavor)

Class for the MLFlowExperimentTrackerFlavor.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
    """Class for the `MLFlowExperimentTrackerFlavor`."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowExperimentTrackerConfig]:
        """Returns `MLFlowExperimentTrackerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowExperimentTrackerConfig

    @property
    def implementation_class(self) -> Type["MLFlowExperimentTracker"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.experiment_trackers import (
            MLFlowExperimentTracker,
        )

        return MLFlowExperimentTracker
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig] property readonly

Returns MLFlowExperimentTrackerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowExperimentTracker] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowExperimentTracker]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

MLFlowExperimentTrackerSettings (BaseSettings) pydantic-model

Settings for the MLflow experiment tracker.

Attributes:

Name Type Description
experiment_name Optional[str]

The MLflow experiment name.

nested bool

If True, will create a nested sub-run for the step.

tags Dict[str, Any]

Tags for the Mlflow run.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerSettings(BaseSettings):
    """Settings for the MLflow experiment tracker.

    Attributes:
        experiment_name: The MLflow experiment name.
        nested: If `True`, will create a nested sub-run for the step.
        tags: Tags for the Mlflow run.
    """

    experiment_name: Optional[str] = None
    nested: bool = False
    tags: Dict[str, Any] = {}
is_databricks_tracking_uri(tracking_uri)

Checks whether the given tracking uri is a Databricks tracking uri.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is a Databricks tracking uri, False otherwise.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_databricks_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is a Databricks tracking uri.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is a Databricks tracking uri, `False`
        otherwise.
    """
    return tracking_uri == "databricks"
is_remote_mlflow_tracking_uri(tracking_uri)

Checks whether the given tracking uri is remote or not.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is remote, False otherwise.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is remote or not.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is remote, `False` otherwise.
    """
    return any(
        tracking_uri.startswith(prefix) for prefix in ["http://", "https://"]
    ) or is_databricks_tracking_uri(tracking_uri)

mlflow_model_deployer_flavor

MLflow model deployer flavor.

MLFlowModelDeployerConfig (BaseModelDeployerConfig) pydantic-model

Configuration for the MLflow model deployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service configuration, PID and log files are stored.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerConfig(BaseModelDeployerConfig):
    """Configuration for the MLflow model deployer.

    Attributes:
        service_path: the path where the local MLflow deployment service
            configuration, PID and log files are stored.
    """

    service_path: str = ""

    @property
    def is_local(self) -> bool:
        """Checks if this stack component is running locally.

        This designation is used to determine if the stack component can be
        shared with other users or if it is only usable on the local host.

        Returns:
            True if this config is for a local component, False otherwise.
        """
        return True
is_local: bool property readonly

Checks if this stack component is running locally.

This designation is used to determine if the stack component can be shared with other users or if it is only usable on the local host.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowModelDeployerFlavor (BaseModelDeployerFlavor)

Model deployer flavor for MLflow models.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
    """Model deployer flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_DEPLOYER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelDeployerConfig]:
        """Returns `MLFlowModelDeployerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelDeployerConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelDeployer"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_deployers import (
            MLFlowModelDeployer,
        )

        return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] property readonly

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_model_registry_flavor

MLflow model registry flavor.

MLFlowModelRegistryConfig (BaseModelRegistryConfig) pydantic-model

Configuration for the MLflow model registry.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryConfig(BaseModelRegistryConfig):
    """Configuration for the MLflow model registry."""
MLFlowModelRegistryFlavor (BaseModelRegistryFlavor)

Model registry flavor for MLflow models.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryFlavor(BaseModelRegistryFlavor):
    """Model registry flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_REGISTRY_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelRegistryConfig]:
        """Returns `MLFlowModelRegistryConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelRegistryConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelRegistry"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_registries import (
            MLFlowModelRegistry,
        )

        return MLFlowModelRegistry
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig] property readonly

Returns MLFlowModelRegistryConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelRegistry] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelRegistry]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_utils

Implementation of utils specific to the MLflow integration.

get_missing_mlflow_experiment_tracker_error()

Returns description of how to add an MLflow experiment tracker to your stack.

Returns:

Type Description
ValueError

If no MLflow experiment tracker is registered in the active stack.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
    """Returns description of how to add an MLflow experiment tracker to your stack.

    Returns:
        ValueError: If no MLflow experiment tracker is registered in the active stack.
    """
    return ValueError(
        "The active stack needs to have a MLflow experiment tracker "
        "component registered to be able to track experiments using "
        "MLflow. You can create a new stack with a MLflow experiment "
        "tracker component or update your existing stack to add this "
        "component, e.g.:\n\n"
        "  'zenml experiment-tracker register mlflow_tracker "
        "--type=mlflow'\n"
        "  'zenml stack register stack-name -e mlflow_tracker ...'\n"
    )

get_tracking_uri()

Gets the MLflow tracking URI from the active experiment tracking stack component.

noqa: DAR401

Returns:

Type Description
str

MLflow tracking URI.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
    """Gets the MLflow tracking URI from the active experiment tracking stack component.

    # noqa: DAR401

    Returns:
        MLflow tracking URI.
    """
    from zenml.integrations.mlflow.experiment_trackers.mlflow_experiment_tracker import (
        MLFlowExperimentTracker,
    )

    tracker = Client().active_stack.experiment_tracker
    if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    return tracker.get_tracking_uri()

is_zenml_run(run)

Checks if a MLflow run is a ZenML run or not.

Parameters:

Name Type Description Default
run Run

The run to check.

required

Returns:

Type Description
bool

If the run is a ZenML run.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def is_zenml_run(run: Run) -> bool:
    """Checks if a MLflow run is a ZenML run or not.

    Args:
        run: The run to check.

    Returns:
        If the run is a ZenML run.
    """
    return ZENML_TAG_KEY in run.data.tags

stop_zenml_mlflow_runs(status)

Stops active ZenML Mlflow runs.

This function stops all MLflow active runs until no active run exists or a non-ZenML run is active.

Parameters:

Name Type Description Default
status str

The status to set the run to.

required
Source code in zenml/integrations/mlflow/mlflow_utils.py
def stop_zenml_mlflow_runs(status: str) -> None:
    """Stops active ZenML Mlflow runs.

    This function stops all MLflow active runs until no active run exists or
    a non-ZenML run is active.

    Args:
        status: The status to set the run to.
    """
    active_run = mlflow.active_run()
    while active_run:
        if is_zenml_run(active_run):
            logger.debug("Stopping mlflow run %s.", active_run.info.run_id)
            mlflow.end_run(status=status)
            active_run = mlflow.active_run()
        else:
            break

model_deployers special

Initialization of the MLflow model deployers.

mlflow_model_deployer

Implementation of the MLflow model deployer.

MLFlowModelDeployer (BaseModelDeployer)

MLflow implementation of the BaseModelDeployer.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
    """MLflow implementation of the BaseModelDeployer."""

    NAME: ClassVar[str] = "MLflow"
    FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = MLFlowModelDeployerFlavor

    _service_path: Optional[str] = None

    @property
    def config(self) -> MLFlowModelDeployerConfig:
        """Returns the `MLFlowModelDeployerConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowModelDeployerConfig, self._config)

    @staticmethod
    def get_service_path(id_: UUID) -> str:
        """Get the path where local MLflow service information is stored.

        This includes the deployment service configuration, PID and log files
        are stored.

        Args:
            id_: The ID of the MLflow model deployer.

        Returns:
            The service path.
        """
        service_path = os.path.join(
            GlobalConfiguration().local_stores_path,
            str(id_),
        )
        create_dir_recursive_if_not_exists(service_path)
        return service_path

    @property
    def local_path(self) -> str:
        """Returns the path to the root directory.

        This is where all configurations for MLflow deployment daemon processes
        are stored.

        If the service path is not set in the config by the user, the path is
        set to a local default path according to the component ID.

        Returns:
            The path to the local service root directory.
        """
        if self._service_path is not None:
            return self._service_path

        if self.config.service_path:
            self._service_path = self.config.service_path
        else:
            self._service_path = self.get_service_path(self.id)

        create_dir_recursive_if_not_exists(self._service_path)
        return self._service_path

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "MLFlowDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            A dictionary containing the information.
        """
        return {
            "PREDICTION_URL": service_instance.endpoint.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
            "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
            "SERVICE_PATH": service_instance.status.runtime_path,
            "DAEMON_PID": str(service_instance.status.pid),
        }

    def deploy_model(
        self,
        config: ServiceConfig,
        replace: bool = False,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new MLflow deployment service or update an existing one.

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new MLflow
            deployment server to reflect the model and other configuration
            parameters specified in the supplied MLflow service `config`.

          * if `replace` is True, this method will first attempt to find an
            existing MLflow deployment service that is *equivalent* to the
            supplied configuration parameters. Two or more MLflow deployment
            services are considered equivalent if they have the same
            `pipeline_name`, `pipeline_step_name` and `model_name` configuration
            parameters. To put it differently, two MLflow deployment services
            are equivalent if they serve versions of the same model deployed by
            the same pipeline step. If an equivalent MLflow deployment is found,
            it will be updated in place to reflect the new configuration
            parameters.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new MLflow deployment
        server for each new model version. If multiple equivalent MLflow
        deployment servers are found, one is selected at random to be updated
        and the others are deleted.

        Args:
            config: the configuration of the model to be deployed with MLflow.
            replace: set this flag to True to find and update an equivalent
                MLflow deployment server with the new model instead of
                creating and starting a new deployment server.
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the MLflow
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML MLflow deployment service object that can be used to
            interact with the MLflow model server.
        """
        config = cast(MLFlowDeploymentConfig, config)
        service = None

        # if replace is True, remove all existing services
        if replace is True:
            existing_services = self.find_model_server(
                pipeline_name=config.pipeline_name,
                pipeline_step_name=config.pipeline_step_name,
                model_name=config.model_name,
            )

            for existing_service in existing_services:
                if service is None:
                    # keep the most recently created service
                    service = cast(MLFlowDeploymentService, existing_service)
                try:
                    # delete the older services and don't wait for them to
                    # be deprovisioned
                    self._clean_up_existing_service(
                        existing_service=cast(
                            MLFlowDeploymentService, existing_service
                        ),
                        timeout=timeout,
                        force=True,
                    )
                except RuntimeError:
                    # ignore errors encountered while stopping old services
                    pass
        if service:
            logger.info(
                f"Updating an existing MLflow deployment service: {service}"
            )

            # set the root runtime path with the stack component's UUID
            config.root_runtime_path = self.local_path
            service.stop(timeout=timeout, force=True)
            service.update(config)
            service.start(timeout=timeout)
        else:
            # create a new MLFlowDeploymentService instance
            service = self._create_new_service(timeout, config)
            logger.info(f"Created a new MLflow deployment service: {service}")

        return cast(BaseService, service)

    def _clean_up_existing_service(
        self,
        timeout: int,
        force: bool,
        existing_service: MLFlowDeploymentService,
    ) -> None:
        # stop the older service
        existing_service.stop(timeout=timeout, force=force)

        # delete the old configuration file
        if existing_service.status.runtime_path:
            shutil.rmtree(existing_service.status.runtime_path)

    # the step will receive a config from the user that mentions the number
    # of workers etc.the step implementation will create a new config using
    # all values from the user and add values like pipeline name, model_uri
    def _create_new_service(
        self, timeout: int, config: MLFlowDeploymentConfig
    ) -> MLFlowDeploymentService:
        """Creates a new MLFlowDeploymentService.

        Args:
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated.
            config: the configuration of the model to be deployed with MLflow.

        Returns:
            The MLFlowDeploymentService object that can be used to interact
            with the MLflow model server.
        """
        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        # create a new service for the new model
        service = MLFlowDeploymentService(config)
        service.start(timeout=timeout)

        return service

    def find_model_server(
        self,
        running: bool = False,
        service_uuid: Optional[UUID] = None,
        pipeline_name: Optional[str] = None,
        pipeline_run_id: Optional[str] = None,
        pipeline_step_name: Optional[str] = None,
        model_name: Optional[str] = None,
        model_uri: Optional[str] = None,
        model_type: Optional[str] = None,
        registry_model_name: Optional[str] = None,
        registry_model_version: Optional[str] = None,
    ) -> List[BaseService]:
        """Finds one or more model servers that match the given criteria.

        Args:
            running: If true, only running services will be returned.
            service_uuid: The UUID of the service that was originally used
                to deploy the model.
            pipeline_name: Name of the pipeline that the deployed model was part
                of.
            pipeline_run_id: ID of the pipeline run which the deployed model
                was part of.
            pipeline_step_name: The name of the pipeline model deployment step
                that deployed the model.
            model_name: Name of the deployed model.
            model_uri: URI of the deployed model.
            model_type: Type/format of the deployed model. Not used in this
                MLflow case.
            registry_model_name: Name of the registered model that the
                deployed model belongs to.
            registry_model_version: Version of the registered model that
                the deployed model belongs to.

        Returns:
            One or more Service objects representing model servers that match
            the input search criteria.

        Raises:
            TypeError: if any of the input arguments are of an invalid type.
        """
        services = []
        config = MLFlowDeploymentConfig(
            model_name=model_name or "",
            model_uri=model_uri or "",
            pipeline_name=pipeline_name or "",
            pipeline_run_id=pipeline_run_id or "",
            pipeline_step_name=pipeline_step_name or "",
            registry_model_name=registry_model_name,
            registry_model_version=registry_model_version,
        )

        # find all services that match the input criteria
        for root, _, files in os.walk(self.local_path):
            if service_uuid and Path(root).name != str(service_uuid):
                continue
            for file in files:
                if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                    service_config_path = os.path.join(root, file)
                    logger.debug(
                        "Loading service daemon configuration from %s",
                        service_config_path,
                    )
                    existing_service_config = None
                    with open(service_config_path, "r") as f:
                        existing_service_config = f.read()
                    existing_service = (
                        ServiceRegistry().load_service_from_json(
                            existing_service_config
                        )
                    )
                    if not isinstance(
                        existing_service, MLFlowDeploymentService
                    ):
                        raise TypeError(
                            f"Expected service type MLFlowDeploymentService but got "
                            f"{type(existing_service)} instead"
                        )
                    existing_service.update_status()
                    if self._matches_search_criteria(existing_service, config):
                        if not running or existing_service.is_running:
                            services.append(
                                cast(BaseService, existing_service)
                            )

        return services

    def _matches_search_criteria(
        self,
        existing_service: MLFlowDeploymentService,
        config: MLFlowDeploymentConfig,
    ) -> bool:
        """Returns true if a service matches the input criteria.

        If any of the values in the input criteria are None, they are ignored.
        This allows listing services just by common pipeline names or step
        names, etc.

        Args:
            existing_service: The materialized Service instance derived from
                the config of the older (existing) service
            config: The MLFlowDeploymentConfig object passed to the
                deploy_model function holding parameters of the new service
                to be created.

        Returns:
            True if the service matches the input criteria.
        """
        existing_service_config = existing_service.config
        # check if the existing service matches the input criteria
        if (
            (
                not config.pipeline_name
                or existing_service_config.pipeline_name
                == config.pipeline_name
            )
            and (
                not config.model_name
                or existing_service_config.model_name == config.model_name
            )
            and (
                not config.pipeline_step_name
                or existing_service_config.pipeline_step_name
                == config.pipeline_step_name
            )
            and (
                not config.pipeline_run_id
                or existing_service_config.pipeline_run_id
                == config.pipeline_run_id
            )
            and (
                (
                    not config.registry_model_name
                    and not config.registry_model_version
                )
                or (
                    existing_service_config.registry_model_name
                    == config.registry_model_name
                    and existing_service_config.registry_model_version
                    == config.registry_model_version
                )
            )
        ):
            return True

        return False

    def stop_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to stop a model server.

        Args:
            uuid: UUID of the model server to stop.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, stop it
        if existing_services:
            existing_services[0].stop(timeout=timeout, force=force)

    def start_model_server(
        self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
    ) -> None:
        """Method to start a model server.

        Args:
            uuid: UUID of the model server to start.
            timeout: Timeout in seconds to wait for the service to start.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, start it
        if existing_services:
            existing_services[0].start(timeout=timeout)

    def delete_model_server(
        self,
        uuid: UUID,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to delete all configuration of a model server.

        Args:
            uuid: UUID of the model server to delete.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        # get list of all services
        existing_services = self.find_model_server(service_uuid=uuid)

        # if the service exists, clean it up
        if existing_services:
            service = cast(MLFlowDeploymentService, existing_services[0])
            self._clean_up_existing_service(
                existing_service=service, timeout=timeout, force=force
            )
config: MLFlowModelDeployerConfig property readonly

Returns the MLFlowModelDeployerConfig config.

Returns:

Type Description
MLFlowModelDeployerConfig

The configuration.

local_path: str property readonly

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.

Returns:

Type Description
str

The path to the local service root directory.

FLAVOR (BaseModelDeployerFlavor)

Model deployer flavor for MLflow models.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
    """Model deployer flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_DEPLOYER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelDeployerConfig]:
        """Returns `MLFlowModelDeployerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelDeployerConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelDeployer"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_deployers import (
            MLFlowModelDeployer,
        )

        return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] property readonly

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

delete_model_server(self, uuid, timeout=60, force=False)

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

60
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def delete_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        uuid: UUID of the model server to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, clean it up
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])
        self._clean_up_existing_service(
            existing_service=service, timeout=timeout, force=force
        )
deploy_model(self, config, replace=False, timeout=60)

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
replace bool

set this flag to True to find and update an equivalent MLflow deployment server with the new model instead of creating and starting a new deployment server.

False
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

60

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to interact with the MLflow model server.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def deploy_model(
    self,
    config: ServiceConfig,
    replace: bool = False,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        config: the configuration of the model to be deployed with MLflow.
        replace: set this flag to True to find and update an equivalent
            MLflow deployment server with the new model instead of
            creating and starting a new deployment server.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = None

    # if replace is True, remove all existing services
    if replace is True:
        existing_services = self.find_model_server(
            pipeline_name=config.pipeline_name,
            pipeline_step_name=config.pipeline_step_name,
            model_name=config.model_name,
        )

        for existing_service in existing_services:
            if service is None:
                # keep the most recently created service
                service = cast(MLFlowDeploymentService, existing_service)
            try:
                # delete the older services and don't wait for them to
                # be deprovisioned
                self._clean_up_existing_service(
                    existing_service=cast(
                        MLFlowDeploymentService, existing_service
                    ),
                    timeout=timeout,
                    force=True,
                )
            except RuntimeError:
                # ignore errors encountered while stopping old services
                pass
    if service:
        logger.info(
            f"Updating an existing MLflow deployment service: {service}"
        )

        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        service.stop(timeout=timeout, force=True)
        service.update(config)
        service.start(timeout=timeout)
    else:
        # create a new MLFlowDeploymentService instance
        service = self._create_new_service(timeout, config)
        logger.info(f"Created a new MLflow deployment service: {service}")

    return cast(BaseService, service)
find_model_server(self, running=False, service_uuid=None, pipeline_name=None, pipeline_run_id=None, pipeline_step_name=None, model_name=None, model_uri=None, model_type=None, registry_model_name=None, registry_model_version=None)

Finds one or more model servers that match the given criteria.

Parameters:

Name Type Description Default
running bool

If true, only running services will be returned.

False
service_uuid Optional[uuid.UUID]

The UUID of the service that was originally used to deploy the model.

None
pipeline_name Optional[str]

Name of the pipeline that the deployed model was part of.

None
pipeline_run_id Optional[str]

ID of the pipeline run which the deployed model was part of.

None
pipeline_step_name Optional[str]

The name of the pipeline model deployment step that deployed the model.

None
model_name Optional[str]

Name of the deployed model.

None
model_uri Optional[str]

URI of the deployed model.

None
model_type Optional[str]

Type/format of the deployed model. Not used in this MLflow case.

None
registry_model_name Optional[str]

Name of the registered model that the deployed model belongs to.

None
registry_model_version Optional[str]

Version of the registered model that the deployed model belongs to.

None

Returns:

Type Description
List[zenml.services.service.BaseService]

One or more Service objects representing model servers that match the input search criteria.

Exceptions:

Type Description
TypeError

if any of the input arguments are of an invalid type.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def find_model_server(
    self,
    running: bool = False,
    service_uuid: Optional[UUID] = None,
    pipeline_name: Optional[str] = None,
    pipeline_run_id: Optional[str] = None,
    pipeline_step_name: Optional[str] = None,
    model_name: Optional[str] = None,
    model_uri: Optional[str] = None,
    model_type: Optional[str] = None,
    registry_model_name: Optional[str] = None,
    registry_model_version: Optional[str] = None,
) -> List[BaseService]:
    """Finds one or more model servers that match the given criteria.

    Args:
        running: If true, only running services will be returned.
        service_uuid: The UUID of the service that was originally used
            to deploy the model.
        pipeline_name: Name of the pipeline that the deployed model was part
            of.
        pipeline_run_id: ID of the pipeline run which the deployed model
            was part of.
        pipeline_step_name: The name of the pipeline model deployment step
            that deployed the model.
        model_name: Name of the deployed model.
        model_uri: URI of the deployed model.
        model_type: Type/format of the deployed model. Not used in this
            MLflow case.
        registry_model_name: Name of the registered model that the
            deployed model belongs to.
        registry_model_version: Version of the registered model that
            the deployed model belongs to.

    Returns:
        One or more Service objects representing model servers that match
        the input search criteria.

    Raises:
        TypeError: if any of the input arguments are of an invalid type.
    """
    services = []
    config = MLFlowDeploymentConfig(
        model_name=model_name or "",
        model_uri=model_uri or "",
        pipeline_name=pipeline_name or "",
        pipeline_run_id=pipeline_run_id or "",
        pipeline_step_name=pipeline_step_name or "",
        registry_model_name=registry_model_name,
        registry_model_version=registry_model_version,
    )

    # find all services that match the input criteria
    for root, _, files in os.walk(self.local_path):
        if service_uuid and Path(root).name != str(service_uuid):
            continue
        for file in files:
            if file == SERVICE_DAEMON_CONFIG_FILE_NAME:
                service_config_path = os.path.join(root, file)
                logger.debug(
                    "Loading service daemon configuration from %s",
                    service_config_path,
                )
                existing_service_config = None
                with open(service_config_path, "r") as f:
                    existing_service_config = f.read()
                existing_service = (
                    ServiceRegistry().load_service_from_json(
                        existing_service_config
                    )
                )
                if not isinstance(
                    existing_service, MLFlowDeploymentService
                ):
                    raise TypeError(
                        f"Expected service type MLFlowDeploymentService but got "
                        f"{type(existing_service)} instead"
                    )
                existing_service.update_status()
                if self._matches_search_criteria(existing_service, config):
                    if not running or existing_service.is_running:
                        services.append(
                            cast(BaseService, existing_service)
                        )

    return services
get_model_server_info(service_instance) staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
        "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
    }
get_service_path(id_) staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
id_ UUID

The ID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(id_: UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files
    are stored.

    Args:
        id_: The ID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        GlobalConfiguration().local_stores_path,
        str(id_),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
start_model_server(self, uuid, timeout=60)

Method to start a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to start.

required
timeout int

Timeout in seconds to wait for the service to start.

60
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def start_model_server(
    self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
) -> None:
    """Method to start a model server.

    Args:
        uuid: UUID of the model server to start.
        timeout: Timeout in seconds to wait for the service to start.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, start it
    if existing_services:
        existing_services[0].start(timeout=timeout)
stop_model_server(self, uuid, timeout=60, force=False)

Method to stop a model server.

Parameters:

Name Type Description Default
uuid UUID

UUID of the model server to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

60
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def stop_model_server(
    self,
    uuid: UUID,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to stop a model server.

    Args:
        uuid: UUID of the model server to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    # get list of all services
    existing_services = self.find_model_server(service_uuid=uuid)

    # if the service exists, stop it
    if existing_services:
        existing_services[0].stop(timeout=timeout, force=force)

model_registries special

Initialization of the MLflow model registry.

mlflow_model_registry

Implementation of the MLflow model registry for ZenML.

MLFlowModelRegistry (BaseModelRegistry)

Register models using MLflow.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
class MLFlowModelRegistry(BaseModelRegistry):
    """Register models using MLflow."""

    _client: Optional[MlflowClient] = None

    @property
    def config(self) -> MLFlowModelRegistryConfig:
        """Returns the `MLFlowModelRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowModelRegistryConfig, self._config)

    def configure_mlflow(self) -> None:
        """Configures the MLflow Client with the experiment tracker config."""
        experiment_tracker = Client().active_stack.experiment_tracker
        assert isinstance(experiment_tracker, MLFlowExperimentTracker)
        experiment_tracker.configure_mlflow()

    @property
    def mlflow_client(self) -> MlflowClient:
        """Get the MLflow client.

        Returns:
            The MLFlowClient.
        """
        if not self._client:
            self.configure_mlflow()
            self._client = mlflow.tracking.MlflowClient()
        return self._client

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains an mlflow experiment tracker.

        Returns:
            A StackValidator instance.
        """

        def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that all the requirements are met for the stack.

            Args:
                stack: The stack to validate.

            Returns:
                A tuple of (is_valid, error_message).
            """
            # Validate that the experiment tracker is an mlflow experiment tracker.
            experiment_tracker = stack.experiment_tracker
            assert experiment_tracker is not None
            if experiment_tracker.flavor != "mlflow":
                return False, (
                    "The MLflow model registry requires a MLflow experiment "
                    "tracker. You should register a MLflow experiment "
                    "tracker to the stack using the following command: "
                    "`zenml stack update model_registry -e mlflow_tracker"
                )
            mlflow_version = mlflow.version.VERSION
            if (
                not mlflow_version >= "2.1.1"
                and experiment_tracker.config.is_local
            ):
                return False, (
                    "The MLflow model registry requires MLflow version "
                    f"2.1.1 or higher to use a local MLflow registry. "
                    f"Your current MLflow version is {mlflow_version}."
                    "You can upgrade MLflow using the following command: "
                    "`pip install --upgrade mlflow`"
                )
            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.EXPERIMENT_TRACKER,
            },
            custom_validation_function=_validate_stack_requirements,
        )

    # ---------
    # Model Registration Methods
    # ---------

    def register_model(
        self,
        name: str,
        description: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> RegisteredModel:
        """Register a model to the MLflow model registry.

        Args:
            name: The name of the model.
            description: The description of the model.
            metadata: The metadata of the model.

        Raises:
            RuntimeError: If the model already exists.

        Returns:
            The registered model.
        """
        # Check if model already exists.
        try:
            self.get_model(name)
            raise KeyError(
                f"Model with name {name} already exists in the MLflow model "
                f"registry. Please use a different name.",
            )
        except KeyError:
            pass
        # Register model.
        try:
            registered_model = self.mlflow_client.create_registered_model(
                name=name,
                description=description,
                tags=metadata,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to register model with name {name} to the MLflow "
                f"model registry: {str(e)}",
            )

        # Return the registered model.
        return RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )

    def delete_model(
        self,
        name: str,
    ) -> None:
        """Delete a model from the MLflow model registry.

        Args:
            name: The name of the model.

        Raises:
            RuntimeError: If the model does not exist.
        """
        # Check if model exists.
        self.get_model(name=name)
        # Delete the registered model.
        try:
            self.mlflow_client.delete_registered_model(
                name=name,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to delete model with name {name} from MLflow model "
                f"registry: {str(e)}",
            )

    def update_model(
        self,
        name: str,
        description: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
        remove_metadata: Optional[List[str]] = None,
    ) -> RegisteredModel:
        """Update a model in the MLflow model registry.

        Args:
            name: The name of the model.
            description: The description of the model.
            metadata: The metadata of the model.
            remove_metadata: The metadata to remove from the model.

        Raises:
            RuntimeError: If mlflow fails to update the model.

        Returns:
            The updated model.
        """
        # Check if model exists.
        self.get_model(name=name)
        # Update the registered model description.
        if description:
            try:
                self.mlflow_client.update_registered_model(
                    name=name,
                    description=description,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update description for the model {name} in MLflow "
                    f"model registry: {str(e)}",
                )
        # Update the registered model tags.
        if metadata:
            try:
                for tag, value in metadata.items():
                    self.mlflow_client.set_registered_model_tag(
                        name=name,
                        key=tag,
                        value=value,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update tags for the model {name} in MLflow model "
                    f"registry: {str(e)}",
                )
        # Remove tags from the registered model.
        if remove_metadata:
            try:
                for tag in remove_metadata:
                    self.mlflow_client.delete_registered_model_tag(
                        name=name,
                        key=tag,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to remove tags for the model {name} in MLflow model "
                    f"registry: {str(e)}",
                )
        # Return the updated registered model.
        return self.get_model(name)

    def get_model(self, name: str) -> RegisteredModel:
        """Get a model from the MLflow model registry.

        Args:
            name: The name of the model.

        Returns:
            The model.

        Raises:
            KeyError: If mlflow fails to get the model.
        """
        # Get the registered model.
        try:
            registered_model = self.mlflow_client.get_registered_model(
                name=name,
            )
        except MlflowException as e:
            raise KeyError(
                f"Failed to get model with name {name} from the MLflow model "
                f"registry: {str(e)}",
            )
        # Return the registered model.
        return RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )

    def list_models(
        self,
        name: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> List[RegisteredModel]:
        """List models in the MLflow model registry.

        Args:
            name: A name to filter the models by.
            metadata: The metadata to filter the models by.

        Returns:
            A list of models (RegisteredModel)
        """
        # Set the filter string.
        filter_string = ""
        if name:
            filter_string += f"name='{name}'"
        if metadata:
            for tag, value in metadata.items():
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"

        # Get the registered models.
        registered_models = self.mlflow_client.search_registered_models(
            filter_string=filter_string,
            max_results=100,
        )
        # Return the registered models.
        return [
            RegisteredModel(
                name=registered_model.name,
                description=registered_model.description,
                metadata=registered_model.tags,
            )
            for registered_model in registered_models
        ]

    # ---------
    # Model Version Methods
    # ---------

    def register_model_version(
        self,
        name: str,
        version: Optional[str] = None,
        model_source_uri: Optional[str] = None,
        description: Optional[str] = None,
        metadata: ModelRegistryModelMetadata = Field(
            default_factory=ModelRegistryModelMetadata
        ),
        **kwargs: Any,
    ) -> ModelVersion:
        """Register a model version to the MLflow model registry.

        Args:
            name: The name of the model.
            model_source_uri: The source URI of the model.
            version: The version of the model.
            description: The description of the model version.
            metadata: The registry metadata of the model version.
            **kwargs: Additional keyword arguments.

        Raises:
            RuntimeError: If the registered model does not exist.

        Returns:
            The registered model version.
        """
        # Check if the model exists, if not create it.
        try:
            self.get_model(name=name)
        except KeyError:
            logger.info(
                f"No registered model with name {name} found. Creating a new"
                "registered model."
            )
            self.register_model(
                name=name,
            )
        try:
            # Inform the user that the version is ignored.
            if version:
                logger.info(
                    f"MLflow model registry does not take a version as an argument. "
                    f"Registering a new version for the model `'{name}'` "
                    f"a version will be assigned automatically."
                )
            # Set the run ID and link.
            run_id = metadata.dict().get("mlflow_run_id", None)
            run_link = metadata.dict().get("mlflow_run_link", None)
            # Register the model version.
            registered_model_version = self.mlflow_client.create_model_version(
                name=name,
                source=model_source_uri,
                run_id=run_id,
                run_link=run_link,
                description=description,
                tags=metadata.dict(),
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to register model version with name '{name}' and "
                f"version '{version}' to the MLflow model registry."
                f"Error: {e}"
            )
        # Return the registered model version.
        return self._cast_mlflow_version_to_model_version(
            registered_model_version
        )

    def delete_model_version(
        self,
        name: str,
        version: str,
    ) -> None:
        """Delete a model version from the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.

        Raises:
            RuntimeError: If mlflow fails to delete the model version.
        """
        self.get_model_version(name=name, version=version)
        try:
            self.mlflow_client.delete_model_version(
                name=name,
                version=version,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to delete model version '{version}' of model '{name}'."
                f"From the MLflow model registry: {str(e)}",
            )

    def update_model_version(
        self,
        name: str,
        version: str,
        description: Optional[str] = None,
        metadata: ModelRegistryModelMetadata = Field(
            default_factory=ModelRegistryModelMetadata
        ),
        remove_metadata: Optional[List[str]] = None,
        stage: Optional[ModelVersionStage] = None,
    ) -> ModelVersion:
        """Update a model version in the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.
            description: The description of the model version.
            metadata: The metadata of the model version.
            remove_metadata: The metadata to remove from the model version.
            stage: The stage of the model version.

        Raises:
            RuntimeError: If mlflow fails to update the model version.

        Returns:
            The updated model version.
        """
        self.get_model_version(name=name, version=version)
        # Update the model description.
        if description:
            try:
                self.mlflow_client.update_model_version(
                    name=name,
                    version=version,
                    description=description,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the description of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Update the model tags.
        if metadata:
            try:
                for key, value in metadata.dict().items():
                    self.mlflow_client.set_model_version_tag(
                        name=name,
                        version=version,
                        key=key,
                        value=value,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the tags of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Remove the model tags.
        if remove_metadata:
            try:
                for key in remove_metadata:
                    self.mlflow_client.delete_model_version_tag(
                        name=name,
                        version=version,
                        key=key,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to remove the tags of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Update the model stage.
        if stage:
            try:
                self.mlflow_client.transition_model_stage(
                    name=name,
                    version=version,
                    stage=stage.value,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the current stage of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        return self.get_model_version(name, version)

    def get_model_version(
        self,
        name: str,
        version: str,
    ) -> ModelVersion:
        """Get a model version from the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.

        Raises:
            KeyError: If the model version does not exist.

        Returns:
            The model version.
        """
        # Get the model version from the MLflow model registry.
        try:
            mlflow_model_version = self.mlflow_client.get_model_version(
                name=name,
                version=version,
            )
        except MlflowException as e:
            raise KeyError(
                f"Failed to get model version '{name}:{version}' from the "
                f"MLflow model registry: {str(e)}"
            )
        # Return the model version.
        return self._cast_mlflow_version_to_model_version(
            mlflow_model_version=mlflow_model_version,
        )

    def list_model_versions(
        self,
        name: Optional[str] = None,
        model_source_uri: Optional[str] = None,
        metadata: ModelRegistryModelMetadata = Field(
            default_factory=ModelRegistryModelMetadata
        ),
        stage: Optional[ModelVersionStage] = None,
        count: Optional[int] = None,
        created_after: Optional[datetime] = None,
        created_before: Optional[datetime] = None,
        order_by_date: Optional[str] = None,
        **kwargs: Any,
    ) -> List[ModelVersion]:
        """List model versions from the MLflow model registry.

        Args:
            name: The name of the model.
            model_source_uri: The model source URI.
            metadata: The metadata of the model version.
            stage: The stage of the model version.
            count: The maximum number of model versions to return.
            created_after: The minimum creation time of the model versions.
            created_before: The maximum creation time of the model versions.
            order_by_date: The order of the model versions by creation time,
                either ascending or descending.
            kwargs: Additional keyword arguments.

        Returns:
            The model versions.
        """
        # Set the filter string.
        filter_string = ""
        if name:
            filter_string += f"name='{name}'"
        if model_source_uri:
            if filter_string:
                filter_string += " AND "
            filter_string += f"source='{model_source_uri}'"
        if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
            if filter_string:
                filter_string += " AND "
            filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
        if metadata:
            for tag, value in metadata.dict().items():
                if value:
                    if filter_string:
                        filter_string += " AND "
                    filter_string += f"tags.{tag}='{value}'"
        # Get the model versions.
        mlflow_model_versions = self.mlflow_client.search_model_versions(
            filter_string=filter_string,
        )
        # Cast the MLflow model versions to the ZenML model version class.
        model_versions = [
            self._cast_mlflow_version_to_model_version(
                mlflow_model_version=mlflow_model_version,
            )
            for mlflow_model_version in mlflow_model_versions
        ]
        # Filter the model versions by stage.
        if stage:
            model_versions = [
                model_version
                for model_version in model_versions
                if model_version.stage == stage
            ]
        # Filter the model versions by creation time.
        if created_after:
            model_versions = [
                model_version
                for model_version in model_versions
                if model_version.created_at
                and model_version.created_at >= created_after
            ]
        if created_before:
            model_versions = [
                model_version
                for model_version in model_versions
                if model_version.created_at
                and model_version.created_at <= created_before
            ]
        # Sort the model versions by creation time.
        if order_by_date == "asc":
            model_versions = sorted(
                model_versions,
                key=lambda model_version: model_version.created_at
                if model_version.created_at is not None
                else float("-inf"),
            )
        elif order_by_date == "desc":
            model_versions = sorted(
                model_versions,
                key=lambda model_version: model_version.created_at
                if model_version.created_at is not None
                else float("inf"),
                reverse=True,
            )
        # Return the model versions.
        if count:
            return model_versions[:count]
        return model_versions

    def load_model_version(
        self,
        name: str,
        version: str,
        **kwargs: Any,
    ) -> Any:
        """Load a model version from the MLflow model registry.

        This method loads the model version from the MLflow model registry
        and returns the model. The model is loaded using the `mlflow.pyfunc`
        module which takes care of loading the model from the model source
        URI for the right framework.

        Args:
            name: The name of the model.
            version: The version of the model.
            kwargs: Additional keyword arguments.

        Returns:
            The model version.

        Raises:
            KeyError: If the model version does not exist.
        """
        try:
            self.get_model_version(name=name, version=version)
        except KeyError:
            raise KeyError(
                f"Failed to load model version '{name}:{version}' from the "
                f"MLflow model registry: Model version does not exist."
            )
        # Load the model version.
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
        return load_model(
            model_uri=mlflow_model_version.source,
            **kwargs,
        )

    def get_model_uri_artifact_store(
        self,
        model_version: ModelVersion,
    ) -> str:
        """Get the model URI artifact store.

        Args:
            model_version: The model version.

        Returns:
            The model URI artifact store.
        """
        artifact_store_path = (
            f"{Client().active_stack.artifact_store.path}/mlflow"
        )
        model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
        return artifact_store_path + model_source_uri

    def _cast_mlflow_version_to_model_version(
        self,
        mlflow_model_version: MLflowModelVersion,
    ) -> ModelVersion:
        """Cast an MLflow model version to a model version.

        Args:
            mlflow_model_version: The MLflow model version.

        Returns:
            The model version.
        """
        metadata = mlflow_model_version.tags or {}
        if mlflow_model_version.run_id:
            metadata["mlflow_run_id"] = mlflow_model_version.run_id
        if mlflow_model_version.run_link:
            metadata["mlflow_run_link"] = mlflow_model_version.run_link

        try:
            from mlflow.models import get_model_info

            model_library = (
                get_model_info(model_uri=mlflow_model_version.source)
                .flavors.get("python_function", {})
                .get("loader_module")
            )
        except ImportError:
            model_library = None
        return ModelVersion(
            registered_model=RegisteredModel(name=mlflow_model_version.name),
            model_format=MLFLOW_MODEL_FORMAT,
            model_library=model_library,
            version=mlflow_model_version.version,
            created_at=datetime.fromtimestamp(
                int(mlflow_model_version.creation_timestamp) / 1e3
            ),
            stage=ModelVersionStage(mlflow_model_version.current_stage),
            description=mlflow_model_version.description,
            last_updated_at=datetime.fromtimestamp(
                int(mlflow_model_version.last_updated_timestamp) / 1e3
            ),
            metadata=ModelRegistryModelMetadata(**metadata),
            model_source_uri=mlflow_model_version.source,
        )
config: MLFlowModelRegistryConfig property readonly

Returns the MLFlowModelRegistryConfig config.

Returns:

Type Description
MLFlowModelRegistryConfig

The configuration.

mlflow_client: MlflowClient property readonly

Get the MLflow client.

Returns:

Type Description
MlflowClient

The MLFlowClient.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains an mlflow experiment tracker.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

configure_mlflow(self)

Configures the MLflow Client with the experiment tracker config.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def configure_mlflow(self) -> None:
    """Configures the MLflow Client with the experiment tracker config."""
    experiment_tracker = Client().active_stack.experiment_tracker
    assert isinstance(experiment_tracker, MLFlowExperimentTracker)
    experiment_tracker.configure_mlflow()
delete_model(self, name)

Delete a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Exceptions:

Type Description
RuntimeError

If the model does not exist.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model(
    self,
    name: str,
) -> None:
    """Delete a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Raises:
        RuntimeError: If the model does not exist.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Delete the registered model.
    try:
        self.mlflow_client.delete_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model with name {name} from MLflow model "
            f"registry: {str(e)}",
        )
delete_model_version(self, name, version)

Delete a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Exceptions:

Type Description
RuntimeError

If mlflow fails to delete the model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model_version(
    self,
    name: str,
    version: str,
) -> None:
    """Delete a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        RuntimeError: If mlflow fails to delete the model version.
    """
    self.get_model_version(name=name, version=version)
    try:
        self.mlflow_client.delete_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model version '{version}' of model '{name}'."
            f"From the MLflow model registry: {str(e)}",
        )
get_model(self, name)

Get a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Returns:

Type Description
RegisteredModel

The model.

Exceptions:

Type Description
KeyError

If mlflow fails to get the model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model(self, name: str) -> RegisteredModel:
    """Get a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Returns:
        The model.

    Raises:
        KeyError: If mlflow fails to get the model.
    """
    # Get the registered model.
    try:
        registered_model = self.mlflow_client.get_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model with name {name} from the MLflow model "
            f"registry: {str(e)}",
        )
    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
get_model_uri_artifact_store(self, model_version)

Get the model URI artifact store.

Parameters:

Name Type Description Default
model_version ModelVersion

The model version.

required

Returns:

Type Description
str

The model URI artifact store.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_uri_artifact_store(
    self,
    model_version: ModelVersion,
) -> str:
    """Get the model URI artifact store.

    Args:
        model_version: The model version.

    Returns:
        The model URI artifact store.
    """
    artifact_store_path = (
        f"{Client().active_stack.artifact_store.path}/mlflow"
    )
    model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
    return artifact_store_path + model_source_uri
get_model_version(self, name, version)

Get a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Exceptions:

Type Description
KeyError

If the model version does not exist.

Returns:

Type Description
ModelVersion

The model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_version(
    self,
    name: str,
    version: str,
) -> ModelVersion:
    """Get a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        KeyError: If the model version does not exist.

    Returns:
        The model version.
    """
    # Get the model version from the MLflow model registry.
    try:
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model version '{name}:{version}' from the "
            f"MLflow model registry: {str(e)}"
        )
    # Return the model version.
    return self._cast_mlflow_version_to_model_version(
        mlflow_model_version=mlflow_model_version,
    )
list_model_versions(self, name=None, model_source_uri=None, metadata=FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={}), stage=None, count=None, created_after=None, created_before=None, order_by_date=None, **kwargs)

List model versions from the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

The name of the model.

None
model_source_uri Optional[str]

The model source URI.

None
metadata ModelRegistryModelMetadata

The metadata of the model version.

FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={})
stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

The stage of the model version.

None
count Optional[int]

The maximum number of model versions to return.

None
created_after Optional[datetime.datetime]

The minimum creation time of the model versions.

None
created_before Optional[datetime.datetime]

The maximum creation time of the model versions.

None
order_by_date Optional[str]

The order of the model versions by creation time, either ascending or descending.

None
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[zenml.model_registries.base_model_registry.ModelVersion]

The model versions.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_model_versions(
    self,
    name: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    metadata: ModelRegistryModelMetadata = Field(
        default_factory=ModelRegistryModelMetadata
    ),
    stage: Optional[ModelVersionStage] = None,
    count: Optional[int] = None,
    created_after: Optional[datetime] = None,
    created_before: Optional[datetime] = None,
    order_by_date: Optional[str] = None,
    **kwargs: Any,
) -> List[ModelVersion]:
    """List model versions from the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The model source URI.
        metadata: The metadata of the model version.
        stage: The stage of the model version.
        count: The maximum number of model versions to return.
        created_after: The minimum creation time of the model versions.
        created_before: The maximum creation time of the model versions.
        order_by_date: The order of the model versions by creation time,
            either ascending or descending.
        kwargs: Additional keyword arguments.

    Returns:
        The model versions.
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if model_source_uri:
        if filter_string:
            filter_string += " AND "
        filter_string += f"source='{model_source_uri}'"
    if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
        if filter_string:
            filter_string += " AND "
        filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
    if metadata:
        for tag, value in metadata.dict().items():
            if value:
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"
    # Get the model versions.
    mlflow_model_versions = self.mlflow_client.search_model_versions(
        filter_string=filter_string,
    )
    # Cast the MLflow model versions to the ZenML model version class.
    model_versions = [
        self._cast_mlflow_version_to_model_version(
            mlflow_model_version=mlflow_model_version,
        )
        for mlflow_model_version in mlflow_model_versions
    ]
    # Filter the model versions by stage.
    if stage:
        model_versions = [
            model_version
            for model_version in model_versions
            if model_version.stage == stage
        ]
    # Filter the model versions by creation time.
    if created_after:
        model_versions = [
            model_version
            for model_version in model_versions
            if model_version.created_at
            and model_version.created_at >= created_after
        ]
    if created_before:
        model_versions = [
            model_version
            for model_version in model_versions
            if model_version.created_at
            and model_version.created_at <= created_before
        ]
    # Sort the model versions by creation time.
    if order_by_date == "asc":
        model_versions = sorted(
            model_versions,
            key=lambda model_version: model_version.created_at
            if model_version.created_at is not None
            else float("-inf"),
        )
    elif order_by_date == "desc":
        model_versions = sorted(
            model_versions,
            key=lambda model_version: model_version.created_at
            if model_version.created_at is not None
            else float("inf"),
            reverse=True,
        )
    # Return the model versions.
    if count:
        return model_versions[:count]
    return model_versions
list_models(self, name=None, metadata=None)

List models in the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

A name to filter the models by.

None
metadata Optional[Dict[str, str]]

The metadata to filter the models by.

None

Returns:

Type Description
List[zenml.model_registries.base_model_registry.RegisteredModel]

A list of models (RegisteredModel)

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_models(
    self,
    name: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
    """List models in the MLflow model registry.

    Args:
        name: A name to filter the models by.
        metadata: The metadata to filter the models by.

    Returns:
        A list of models (RegisteredModel)
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if metadata:
        for tag, value in metadata.items():
            if filter_string:
                filter_string += " AND "
            filter_string += f"tags.{tag}='{value}'"

    # Get the registered models.
    registered_models = self.mlflow_client.search_registered_models(
        filter_string=filter_string,
        max_results=100,
    )
    # Return the registered models.
    return [
        RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )
        for registered_model in registered_models
    ]
load_model_version(self, name, version, **kwargs)

Load a model version from the MLflow model registry.

This method loads the model version from the MLflow model registry and returns the model. The model is loaded using the mlflow.pyfunc module which takes care of loading the model from the model source URI for the right framework.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Any

The model version.

Exceptions:

Type Description
KeyError

If the model version does not exist.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def load_model_version(
    self,
    name: str,
    version: str,
    **kwargs: Any,
) -> Any:
    """Load a model version from the MLflow model registry.

    This method loads the model version from the MLflow model registry
    and returns the model. The model is loaded using the `mlflow.pyfunc`
    module which takes care of loading the model from the model source
    URI for the right framework.

    Args:
        name: The name of the model.
        version: The version of the model.
        kwargs: Additional keyword arguments.

    Returns:
        The model version.

    Raises:
        KeyError: If the model version does not exist.
    """
    try:
        self.get_model_version(name=name, version=version)
    except KeyError:
        raise KeyError(
            f"Failed to load model version '{name}:{version}' from the "
            f"MLflow model registry: Model version does not exist."
        )
    # Load the model version.
    mlflow_model_version = self.mlflow_client.get_model_version(
        name=name,
        version=version,
    )
    return load_model(
        model_uri=mlflow_model_version.source,
        **kwargs,
    )
register_model(self, name, description=None, metadata=None)

Register a model to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None

Exceptions:

Type Description
RuntimeError

If the model already exists.

Returns:

Type Description
RegisteredModel

The registered model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
    """Register a model to the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.

    Raises:
        RuntimeError: If the model already exists.

    Returns:
        The registered model.
    """
    # Check if model already exists.
    try:
        self.get_model(name)
        raise KeyError(
            f"Model with name {name} already exists in the MLflow model "
            f"registry. Please use a different name.",
        )
    except KeyError:
        pass
    # Register model.
    try:
        registered_model = self.mlflow_client.create_registered_model(
            name=name,
            description=description,
            tags=metadata,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model with name {name} to the MLflow "
            f"model registry: {str(e)}",
        )

    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
register_model_version(self, name, version=None, model_source_uri=None, description=None, metadata=FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={}), **kwargs)

Register a model version to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
model_source_uri Optional[str]

The source URI of the model.

None
version Optional[str]

The version of the model.

None
description Optional[str]

The description of the model version.

None
metadata ModelRegistryModelMetadata

The registry metadata of the model version.

FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={})
**kwargs Any

Additional keyword arguments.

{}

Exceptions:

Type Description
RuntimeError

If the registered model does not exist.

Returns:

Type Description
ModelVersion

The registered model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model_version(
    self,
    name: str,
    version: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    description: Optional[str] = None,
    metadata: ModelRegistryModelMetadata = Field(
        default_factory=ModelRegistryModelMetadata
    ),
    **kwargs: Any,
) -> ModelVersion:
    """Register a model version to the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The source URI of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The registry metadata of the model version.
        **kwargs: Additional keyword arguments.

    Raises:
        RuntimeError: If the registered model does not exist.

    Returns:
        The registered model version.
    """
    # Check if the model exists, if not create it.
    try:
        self.get_model(name=name)
    except KeyError:
        logger.info(
            f"No registered model with name {name} found. Creating a new"
            "registered model."
        )
        self.register_model(
            name=name,
        )
    try:
        # Inform the user that the version is ignored.
        if version:
            logger.info(
                f"MLflow model registry does not take a version as an argument. "
                f"Registering a new version for the model `'{name}'` "
                f"a version will be assigned automatically."
            )
        # Set the run ID and link.
        run_id = metadata.dict().get("mlflow_run_id", None)
        run_link = metadata.dict().get("mlflow_run_link", None)
        # Register the model version.
        registered_model_version = self.mlflow_client.create_model_version(
            name=name,
            source=model_source_uri,
            run_id=run_id,
            run_link=run_link,
            description=description,
            tags=metadata.dict(),
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model version with name '{name}' and "
            f"version '{version}' to the MLflow model registry."
            f"Error: {e}"
        )
    # Return the registered model version.
    return self._cast_mlflow_version_to_model_version(
        registered_model_version
    )
update_model(self, name, description=None, metadata=None, remove_metadata=None)

Update a model in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model.

None

Exceptions:

Type Description
RuntimeError

If mlflow fails to update the model.

Returns:

Type Description
RegisteredModel

The updated model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
    remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
    """Update a model in the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.
        remove_metadata: The metadata to remove from the model.

    Raises:
        RuntimeError: If mlflow fails to update the model.

    Returns:
        The updated model.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Update the registered model description.
    if description:
        try:
            self.mlflow_client.update_registered_model(
                name=name,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update description for the model {name} in MLflow "
                f"model registry: {str(e)}",
            )
    # Update the registered model tags.
    if metadata:
        try:
            for tag, value in metadata.items():
                self.mlflow_client.set_registered_model_tag(
                    name=name,
                    key=tag,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Remove tags from the registered model.
    if remove_metadata:
        try:
            for tag in remove_metadata:
                self.mlflow_client.delete_registered_model_tag(
                    name=name,
                    key=tag,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Return the updated registered model.
    return self.get_model(name)
update_model_version(self, name, version, description=None, metadata=FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={}), remove_metadata=None, stage=None)

Update a model version in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
description Optional[str]

The description of the model version.

None
metadata ModelRegistryModelMetadata

The metadata of the model version.

FieldInfo(default=PydanticUndefined, default_factory=<class 'zenml.model_registries.base_model_registry.ModelRegistryModelMetadata'>, extra={})
remove_metadata Optional[List[str]]

The metadata to remove from the model version.

None
stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

The stage of the model version.

None

Exceptions:

Type Description
RuntimeError

If mlflow fails to update the model version.

Returns:

Type Description
ModelVersion

The updated model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model_version(
    self,
    name: str,
    version: str,
    description: Optional[str] = None,
    metadata: ModelRegistryModelMetadata = Field(
        default_factory=ModelRegistryModelMetadata
    ),
    remove_metadata: Optional[List[str]] = None,
    stage: Optional[ModelVersionStage] = None,
) -> ModelVersion:
    """Update a model version in the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The metadata of the model version.
        remove_metadata: The metadata to remove from the model version.
        stage: The stage of the model version.

    Raises:
        RuntimeError: If mlflow fails to update the model version.

    Returns:
        The updated model version.
    """
    self.get_model_version(name=name, version=version)
    # Update the model description.
    if description:
        try:
            self.mlflow_client.update_model_version(
                name=name,
                version=version,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the description of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model tags.
    if metadata:
        try:
            for key, value in metadata.dict().items():
                self.mlflow_client.set_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Remove the model tags.
    if remove_metadata:
        try:
            for key in remove_metadata:
                self.mlflow_client.delete_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model stage.
    if stage:
        try:
            self.mlflow_client.transition_model_stage(
                name=name,
                version=version,
                stage=stage.value,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the current stage of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    return self.get_model_version(name, version)

services special

Initialization of the MLflow Service.

mlflow_deployment

Implementation of the MLflow deployment functionality.

MLFlowDeploymentConfig (LocalDaemonServiceConfig) pydantic-model

MLflow model deployment configuration.

Attributes:

Name Type Description
model_uri str

URI of the MLflow model to serve

model_name str

the name of the model

workers int

number of workers to use for the prediction service

registry_model_name Optional[str]

the name of the model in the registry

registry_model_version Optional[str]

the version of the model in the registry

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

timeout in seconds for starting and stopping the service

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
    """MLflow model deployment configuration.

    Attributes:
        model_uri: URI of the MLflow model to serve
        model_name: the name of the model
        workers: number of workers to use for the prediction service
        registry_model_name: the name of the model in the registry
        registry_model_version: the version of the model in the registry
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: timeout in seconds for starting and stopping the service
    """

    # TODO: ServiceConfig should have additional fields such as "pipeline_run_uuid"
    #  and "pipeline_uuid" to allow for better tracking of the service.
    model_uri: str
    model_name: str
    registry_model_name: Optional[str] = None
    registry_model_version: Optional[str] = None
    registry_model_stage: Optional[str] = None
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint) pydantic-model

A service endpoint exposed by the MLflow deployment daemon.

Attributes:

Name Type Description
config MLFlowDeploymentEndpointConfig

service endpoint configuration

monitor HTTPEndpointHealthMonitor

optional service endpoint health monitor

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
    """A service endpoint exposed by the MLflow deployment daemon.

    Attributes:
        config: service endpoint configuration
        monitor: optional service endpoint health monitor
    """

    config: MLFlowDeploymentEndpointConfig
    monitor: HTTPEndpointHealthMonitor

    @property
    def prediction_url(self) -> Optional[str]:
        """Gets the prediction URL for the endpoint.

        Returns:
            the prediction URL for the endpoint
        """
        uri = self.status.uri
        if not uri:
            return None
        return os.path.join(uri, self.config.prediction_url_path)
prediction_url: Optional[str] property readonly

Gets the prediction URL for the endpoint.

Returns:

Type Description
Optional[str]

the prediction URL for the endpoint

MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig) pydantic-model

MLflow daemon service endpoint configuration.

Attributes:

Name Type Description
prediction_url_path str

URI subpath for prediction requests

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
    """MLflow daemon service endpoint configuration.

    Attributes:
        prediction_url_path: URI subpath for prediction requests
    """

    prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService, BaseDeploymentService) pydantic-model

MLflow deployment service used to start a local prediction server for MLflow models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the MLflow deployment service class

config MLFlowDeploymentConfig

service configuration

endpoint MLFlowDeploymentEndpoint

optional service endpoint

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService):
    """MLflow deployment service used to start a local prediction server for MLflow models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the MLflow deployment service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="mlflow-deployment",
        type="model-serving",
        flavor="mlflow",
        description="MLflow prediction service",
    )

    config: MLFlowDeploymentConfig
    endpoint: MLFlowDeploymentEndpoint

    def __init__(
        self,
        config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialize the MLflow deployment service.

        Args:
            config: service configuration
            attrs: additional attributes to set on the service
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-700]: implement a service factory or builder for MLflow
        #   deployment services
        if (
            isinstance(config, MLFlowDeploymentConfig)
            and "endpoint" not in attrs
        ):
            if config.mlserver:
                prediction_url_path = MLSERVER_PREDICTION_URL_PATH
                healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
                use_head_request = False
            else:
                prediction_url_path = MLFLOW_PREDICTION_URL_PATH
                healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
                use_head_request = True

            endpoint = MLFlowDeploymentEndpoint(
                config=MLFlowDeploymentEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                    prediction_url_path=prediction_url_path,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path=healthcheck_uri_path,
                        use_head_request=use_head_request,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Start the service."""
        logger.info(
            "Starting MLflow prediction service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()
        try:
            backend_kwargs: Dict[str, Any] = {}
            serve_kwargs: Dict[str, Any] = {}
            mlflow_version = MLFLOW_VERSION.split(".")
            # MLflow version 1.26 introduces an additional mandatory
            # `timeout` argument to the `PyFuncBackend.serve` function
            if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
                serve_kwargs["timeout"] = None
            # Mlflow 2.0+ requires the env_manager to be set to "local"
            # to run the deploy the model on the local running environment
            if int(mlflow_version[0]) >= 2:
                backend_kwargs["env_manager"] = "local"
            backend = PyFuncBackend(
                config={},
                no_conda=True,
                workers=self.config.workers,
                install_mlflow=False,
                **backend_kwargs,
            )
            backend.serve(
                model_uri=self.config.model_uri,
                port=self.endpoint.status.port,
                host="localhost",
                enable_mlserver=self.config.mlserver,
                **serve_kwargs,
            )
        except KeyboardInterrupt:
            logger.info(
                "MLflow prediction service stopped. Resuming normal execution."
            )

    @property
    def prediction_url(self) -> Optional[str]:
        """Get the URI where the prediction service is answering requests.

        Returns:
            The URI where the prediction service can be contacted to process
            HTTP/REST inference requests, or None, if the service isn't running.
        """
        if not self.is_running:
            return None
        return self.endpoint.prediction_url

    def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a numpy array representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not running
            ValueError: if the prediction endpoint is unknown.
        """
        if not self.is_running:
            raise Exception(
                "MLflow prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.endpoint.prediction_url is not None:
            response = requests.post(
                self.endpoint.prediction_url,
                json={"instances": request.tolist()},
            )
        else:
            raise ValueError("No endpoint known for prediction.")
        response.raise_for_status()
        if int(MLFLOW_VERSION.split(".")[0]) <= 1:
            return np.array(response.json())
        else:
            # Mlflow 2.0+ returns a dictionary with the predictions
            # under the "predictions" key
            return np.array(response.json()["predictions"])
prediction_url: Optional[str] property readonly

Get the URI where the prediction service is answering requests.

Returns:

Type Description
Optional[str]

The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running.

__init__(self, config, **attrs) special

Initialize the MLflow deployment service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]]

service configuration

required
attrs Any

additional attributes to set on the service

{}
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
    self,
    config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialize the MLflow deployment service.

    Args:
        config: service configuration
        attrs: additional attributes to set on the service
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-700]: implement a service factory or builder for MLflow
    #   deployment services
    if (
        isinstance(config, MLFlowDeploymentConfig)
        and "endpoint" not in attrs
    ):
        if config.mlserver:
            prediction_url_path = MLSERVER_PREDICTION_URL_PATH
            healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
            use_head_request = False
        else:
            prediction_url_path = MLFLOW_PREDICTION_URL_PATH
            healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
            use_head_request = True

        endpoint = MLFlowDeploymentEndpoint(
            config=MLFlowDeploymentEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                prediction_url_path=prediction_url_path,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path=healthcheck_uri_path,
                    use_head_request=use_head_request,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request NDArray[Any]

a numpy array representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not running

ValueError

if the prediction endpoint is unknown.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(self, request: "NDArray[Any]") -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a numpy array representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not running
        ValueError: if the prediction endpoint is unknown.
    """
    if not self.is_running:
        raise Exception(
            "MLflow prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.endpoint.prediction_url is not None:
        response = requests.post(
            self.endpoint.prediction_url,
            json={"instances": request.tolist()},
        )
    else:
        raise ValueError("No endpoint known for prediction.")
    response.raise_for_status()
    if int(MLFLOW_VERSION.split(".")[0]) <= 1:
        return np.array(response.json())
    else:
        # Mlflow 2.0+ returns a dictionary with the predictions
        # under the "predictions" key
        return np.array(response.json()["predictions"])
run(self)

Start the service.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
    """Start the service."""
    logger.info(
        "Starting MLflow prediction service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()
    try:
        backend_kwargs: Dict[str, Any] = {}
        serve_kwargs: Dict[str, Any] = {}
        mlflow_version = MLFLOW_VERSION.split(".")
        # MLflow version 1.26 introduces an additional mandatory
        # `timeout` argument to the `PyFuncBackend.serve` function
        if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
            serve_kwargs["timeout"] = None
        # Mlflow 2.0+ requires the env_manager to be set to "local"
        # to run the deploy the model on the local running environment
        if int(mlflow_version[0]) >= 2:
            backend_kwargs["env_manager"] = "local"
        backend = PyFuncBackend(
            config={},
            no_conda=True,
            workers=self.config.workers,
            install_mlflow=False,
            **backend_kwargs,
        )
        backend.serve(
            model_uri=self.config.model_uri,
            port=self.endpoint.status.port,
            host="localhost",
            enable_mlserver=self.config.mlserver,
            **serve_kwargs,
        )
    except KeyboardInterrupt:
        logger.info(
            "MLflow prediction service stopped. Resuming normal execution."
        )

steps special

Initialization of the MLflow standard interface steps.

mlflow_deployer

Implementation of the MLflow model deployer pipeline step.

MLFlowDeployerParameters (BaseParameters) pydantic-model

Model deployer step parameters for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

experiment_name Optional[str]

Name of the MLflow experiment in which the model was logged.

run_name Optional[str]

Name of the MLflow run in which the model was logged.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

registry_model_name Optional[str]

the name of the model in the model registry

registry_model_version Optional[str]

the version of the model in the model registry

registry_model_stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

the stage of the model in the model registry

replace_existing bool

whether to create a new deployment service or not, this parameter is only used when trying to deploy a model that is registered in the MLflow model registry. Default is True.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerParameters(BaseParameters):
    """Model deployer step parameters for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        experiment_name: Name of the MLflow experiment in which the model was
            logged.
        run_name: Name of the MLflow run in which the model was logged.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        registry_model_name: the name of the model in the model registry
        registry_model_version: the version of the model in the model registry
        registry_model_stage: the stage of the model in the model registry
        replace_existing: whether to create a new deployment service or not,
            this parameter is only used when trying to deploy a model that
            is registered in the MLflow model registry. Default is True.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    registry_model_name: Optional[str] = None
    registry_model_version: Optional[str] = None
    registry_model_stage: Optional[ModelVersionStage] = None
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    replace_existing: bool = True
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
mlflow_model_deployer_step (BaseStep)

Model deployer pipeline step for MLflow.

This step deploys a model logged in the MLflow artifact store to a deployment service. The user would typically use this step in a pipeline that deploys a model that was already registered in the MLflow model registr either manually or by using the mlflow_model_registry_step.

Parameters:

Name Type Description Default
deploy_decision

whether to deploy the model or not

required
model

the model artifact to deploy

required
params

parameters for the deployer step

required

Returns:

Type Description

MLflow deployment service

Exceptions:

Type Description
ValueError

if the MLflow experiment tracker is not found

PARAMETERS_CLASS (BaseParameters) pydantic-model

Model deployer step parameters for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

experiment_name Optional[str]

Name of the MLflow experiment in which the model was logged.

run_name Optional[str]

Name of the MLflow run in which the model was logged.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

registry_model_name Optional[str]

the name of the model in the model registry

registry_model_version Optional[str]

the version of the model in the model registry

registry_model_stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

the stage of the model in the model registry

replace_existing bool

whether to create a new deployment service or not, this parameter is only used when trying to deploy a model that is registered in the MLflow model registry. Default is True.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerParameters(BaseParameters):
    """Model deployer step parameters for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        experiment_name: Name of the MLflow experiment in which the model was
            logged.
        run_name: Name of the MLflow run in which the model was logged.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        registry_model_name: the name of the model in the model registry
        registry_model_version: the version of the model in the model registry
        registry_model_stage: the stage of the model in the model registry
        replace_existing: whether to create a new deployment service or not,
            this parameter is only used when trying to deploy a model that
            is registered in the MLflow model registry. Default is True.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    registry_model_name: Optional[str] = None
    registry_model_version: Optional[str] = None
    registry_model_stage: Optional[ModelVersionStage] = None
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    replace_existing: bool = True
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(deploy_decision, model, params) staticmethod

Model deployer pipeline step for MLflow.

This step deploys a model logged in the MLflow artifact store to a deployment service. The user would typically use this step in a pipeline that deploys a model that was already registered in the MLflow model registr either manually or by using the mlflow_model_registry_step.

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

required
model UnmaterializedArtifact

the model artifact to deploy

required
params MLFlowDeployerParameters

parameters for the deployer step

required

Returns:

Type Description
MLFlowDeploymentService

MLflow deployment service

Exceptions:

Type Description
ValueError

if the MLflow experiment tracker is not found

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@step(enable_cache=False)
def mlflow_model_deployer_step(
    deploy_decision: bool,
    model: UnmaterializedArtifact,
    params: MLFlowDeployerParameters,
) -> MLFlowDeploymentService:
    """Model deployer pipeline step for MLflow.

    This step deploys a model logged in the MLflow artifact store to a
    deployment service. The user would typically use this step in a pipeline
    that deploys a model that was already registered in the MLflow model
    registr either manually or by using the `mlflow_model_registry_step`.

    Args:
        deploy_decision: whether to deploy the model or not
        model: the model artifact to deploy
        params: parameters for the deployer step

    Returns:
        MLflow deployment service

    Raises:
        ValueError: if the MLflow experiment tracker is not found
    """
    model_deployer = cast(
        MLFlowModelDeployer, MLFlowModelDeployer.get_active_model_deployer()
    )

    experiment_tracker = Client().active_stack.experiment_tracker

    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise ValueError(
            "MLflow model deployer step requires an MLflow experiment "
            "tracker. Please add an MLflow experiment tracker to your "
            "stack."
        )

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    run_name = step_env.run_name
    step_name = step_env.step_name

    # Configure Mlflow so the client points to the correct store
    experiment_tracker.configure_mlflow()
    client = MlflowClient()
    mlflow_run_id = experiment_tracker.get_run_id(
        experiment_name=params.experiment_name or pipeline_name,
        run_name=params.run_name or run_name,
    )

    model_uri = ""
    if mlflow_run_id and client.list_artifacts(
        mlflow_run_id, params.model_name
    ):
        model_uri = artifact_utils.get_artifact_uri(
            run_id=mlflow_run_id, artifact_path=params.model_name
        )

    # fetch existing services with same pipeline name, step name and model name
    existing_services = model_deployer.find_model_server(
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        model_name=params.model_name,
    )

    # create a config for the new model service
    predictor_cfg = MLFlowDeploymentConfig(
        model_name=params.model_name or "",
        model_uri=model_uri,
        workers=params.workers,
        mlserver=params.mlserver,
        registry_model_name=params.registry_model_name or "",
        registry_model_version=params.registry_model_version or "",
        pipeline_name=pipeline_name,
        pipeline_run_id=run_name,
        pipeline_step_name=step_name,
        timeout=params.timeout,
    )

    # Creating a new service with inactive state and status by default
    service = MLFlowDeploymentService(predictor_cfg)
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])

    # check for conditions to deploy the model
    if not model_uri:
        # an MLflow model was not trained in the current run, so we simply reuse
        # the currently running service created for the same model, if any
        if not existing_services:
            logger.warning(
                f"An MLflow model with name `{params.model_name}` was not "
                f"logged in the current pipeline run and no running MLflow "
                f"model server was found. Please ensure that your pipeline "
                f"includes a step with a MLflow experiment configured that "
                "trains a model and logs it to MLflow. This could also happen "
                "if the current pipeline run did not log an MLflow model  "
                f"because the training step was cached."
            )
            # return an inactive service just because we have to return
            # something
            return service
        logger.info(
            f"An MLflow model with name `{params.model_name}` was not "
            f"trained in the current pipeline run. Reusing the existing "
            f"MLflow model server."
        )
        if not service.is_running:
            service.start(params.timeout)

        # return the existing service
        return service

    # even when the deploy decision is negative, if an existing model server
    # is not running for this pipeline/step, we still have to serve the
    # current model, to ensure that a model server is available at all times
    if not deploy_decision and existing_services:
        logger.info(
            f"Skipping model deployment because the model quality does not "
            f"meet the criteria. Reusing last model server deployed by step "
            f"'{step_name}' and pipeline '{pipeline_name}' for model "
            f"'{params.model_name}'..."
        )
        # even when the deploy decision is negative, we still need to start
        # the previous model server if it is no longer running, to ensure
        # that a model server is available at all times
        if not service.is_running:
            service.start(params.timeout)
        return service

    # create a new model deployment and replace an old one if it exists
    new_service = cast(
        MLFlowDeploymentService,
        model_deployer.deploy_model(
            replace=True,
            config=predictor_cfg,
            timeout=params.timeout,
        ),
    )

    logger.info(
        f"MLflow deployment service started and reachable at:\n"
        f"    {new_service.prediction_url}\n"
    )

    return new_service
mlflow_model_registry_deployer_step (BaseStep)

Model deployer pipeline step for MLflow.

Parameters:

Name Type Description Default
params

parameters for the deployer step

required

Returns:

Type Description

MLflow deployment service

Exceptions:

Type Description
ValueError

if the registry_model_name is not provided

ValueError

if the registry_model_version or registry_model_stage is not provided

ValueError

if No MLflow experiment tracker is found in the current active stack

LookupError

if no model version is found in the MLflow model registry.

PARAMETERS_CLASS (BaseParameters) pydantic-model

Model deployer step parameters for MLflow.

Attributes:

Name Type Description
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

experiment_name Optional[str]

Name of the MLflow experiment in which the model was logged.

run_name Optional[str]

Name of the MLflow run in which the model was logged.

workers int

number of workers to use for the prediction service

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

registry_model_name Optional[str]

the name of the model in the model registry

registry_model_version Optional[str]

the version of the model in the model registry

registry_model_stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

the stage of the model in the model registry

replace_existing bool

whether to create a new deployment service or not, this parameter is only used when trying to deploy a model that is registered in the MLflow model registry. Default is True.

timeout int

the number of seconds to wait for the service to start/stop.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
class MLFlowDeployerParameters(BaseParameters):
    """Model deployer step parameters for MLflow.

    Attributes:
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        experiment_name: Name of the MLflow experiment in which the model was
            logged.
        run_name: Name of the MLflow run in which the model was logged.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        registry_model_name: the name of the model in the model registry
        registry_model_version: the version of the model in the model registry
        registry_model_stage: the stage of the model in the model registry
        replace_existing: whether to create a new deployment service or not,
            this parameter is only used when trying to deploy a model that
            is registered in the MLflow model registry. Default is True.
        timeout: the number of seconds to wait for the service to start/stop.
    """

    model_name: str = "model"
    registry_model_name: Optional[str] = None
    registry_model_version: Optional[str] = None
    registry_model_stage: Optional[ModelVersionStage] = None
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    replace_existing: bool = True
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
entrypoint(params) staticmethod

Model deployer pipeline step for MLflow.

Parameters:

Name Type Description Default
params MLFlowDeployerParameters

parameters for the deployer step

required

Returns:

Type Description
MLFlowDeploymentService

MLflow deployment service

Exceptions:

Type Description
ValueError

if the registry_model_name is not provided

ValueError

if the registry_model_version or registry_model_stage is not provided

ValueError

if No MLflow experiment tracker is found in the current active stack

LookupError

if no model version is found in the MLflow model registry.

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
@step(enable_cache=False)
def mlflow_model_registry_deployer_step(
    params: MLFlowDeployerParameters,
) -> MLFlowDeploymentService:
    """Model deployer pipeline step for MLflow.

    Args:
        params: parameters for the deployer step

    Returns:
        MLflow deployment service

    Raises:
        ValueError: if the registry_model_name is not provided
        ValueError: if the registry_model_version or registry_model_stage is not provided
        ValueError: if No MLflow experiment tracker is found in the current active stack
        LookupError: if no model version is found in the MLflow model registry.
    """
    if not params.registry_model_name:
        raise ValueError(
            "registry_model_name must be provided to the MLflow"
            "model registry deployer step."
        )
    elif not params.registry_model_version and not params.registry_model_stage:
        raise ValueError(
            "Either registry_model_version or registry_model_stage must"
            "be provided in addition to registry_model_name to the MLflow"
            "model registry deployer step. Since the"
            "mlflow_model_registry_deployer_step is used in conjunction with"
            "the mlflow_model_registry."
        )

    model_deployer = cast(
        MLFlowModelDeployer, MLFlowModelDeployer.get_active_model_deployer()
    )

    # fetch the MLflow model registry
    model_registry = Client().active_stack.model_registry
    if not isinstance(model_registry, MLFlowModelRegistry):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow model registry."
        )

    # fetch the model version
    if params.registry_model_version:
        try:
            model_version = model_registry.get_model_version(
                name=params.registry_model_name,
                version=params.registry_model_version,
            )
        except KeyError:
            model_version = None
    elif params.registry_model_stage:
        model_version = model_registry.get_latest_model_version(
            name=params.registry_model_name,
            stage=params.registry_model_stage,
        )
    if not model_version:
        raise LookupError(
            f"No Model Version found for model name "
            f"{params.registry_model_name} and version "
            f"{params.registry_model_version} or stage "
            f"{params.registry_model_stage}"
        )
    if model_version.model_format != MLFLOW_MODEL_FORMAT:
        raise ValueError(
            f"Model version {model_version.version} of model "
            f"{model_version.registered_model.name} is not an MLflow model."
            f"Only MLflow models can be deployed with the MLflow deployer "
            f"using this step."
        )
    # fetch existing services with same pipeline name, step name and model name
    existing_services = (
        model_deployer.find_model_server(
            registry_model_name=model_version.registered_model.name,
        )
        if params.replace_existing
        else None
    )

    # create a config for the new model service
    metadata = model_version.metadata or ModelRegistryModelMetadata()
    predictor_cfg = MLFlowDeploymentConfig(
        model_name=params.model_name or "",
        model_uri=model_version.model_source_uri,
        registry_model_name=model_version.registered_model.name,
        registry_model_version=model_version.version,
        registry_model_stage=model_version.stage.value,
        workers=params.workers,
        mlserver=params.mlserver,
        pipeline_name=metadata.zenml_pipeline_name or "",
        pipeline_run_id=metadata.zenml_pipeline_run_id or "",
        pipeline_step_name=metadata.zenml_step_name or "",
        timeout=params.timeout,
    )

    # Creating a new service with inactive state and status by default
    service = MLFlowDeploymentService(predictor_cfg)
    if existing_services:
        service = cast(MLFlowDeploymentService, existing_services[0])

    # check if the model is already deployed but not running
    if existing_services and not service.is_running:
        service.start(params.timeout)
        return service

    # create a new model deployment and replace an old one if it exists
    new_service = cast(
        MLFlowDeploymentService,
        model_deployer.deploy_model(
            replace=True,
            config=predictor_cfg,
            timeout=params.timeout,
        ),
    )

    logger.info(
        f"MLflow deployment service started and reachable at:\n"
        f"    {new_service.prediction_url}\n"
    )

    return new_service
mlflow_deployer_step(enable_cache=True, name=None)

Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

The returned step can be used in a pipeline to implement continuous deployment for an MLflow model.

Parameters:

Name Type Description Default
enable_cache bool

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default

True
name Optional[str]

Name of the step.

None

Returns:

Type Description
Type[zenml.steps.base_step.BaseStep]

an MLflow model deployer pipeline step

Source code in zenml/integrations/mlflow/steps/mlflow_deployer.py
def mlflow_deployer_step(
    enable_cache: bool = True,
    name: Optional[str] = None,
) -> Type[BaseStep]:
    """Creates a pipeline step to deploy a given ML model with a local MLflow prediction server.

    The returned step can be used in a pipeline to implement continuous
    deployment for an MLflow model.

    Args:
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default
        name: Name of the step.

    Returns:
        an MLflow model deployer pipeline step
    """
    logger.warning(
        "The `mlflow_deployer_step` function is deprecated. Please "
        "use the built-in `mlflow_model_deployer_step` step instead."
    )
    return mlflow_model_deployer_step

mlflow_registry

Implementation of the MLflow model registration pipeline step.

MLFlowRegistryParameters (BaseParameters) pydantic-model

Model registry step parameters for MLflow.

Parameters:

Name Type Description Default
name

Name of the registered model.

required
version

Version of the registered model.

required
trained_model_name

Name of the model to be deployed.

required
experiment_name

Name of the experiment to be used for the run.

required
run_name

Name of the run to be created.

required
run_id

ID of the run to be used.

required
model_source_uri

URI of the model source. If not provided, the model will be fetched from the MLflow tracking server.

required
description

Description of the model.

required
metadata

Metadata of the model version to be added to the model registry.

required
Source code in zenml/integrations/mlflow/steps/mlflow_registry.py
class MLFlowRegistryParameters(BaseParameters):
    """Model registry step parameters for MLflow.

    Args:
        name: Name of the registered model.
        version: Version of the registered model.
        trained_model_name: Name of the model to be deployed.
        experiment_name: Name of the experiment to be used for the run.
        run_name: Name of the run to be created.
        run_id: ID of the run to be used.
        model_source_uri: URI of the model source. If not provided, the model
            will be fetched from the MLflow tracking server.
        description: Description of the model.
        metadata: Metadata of the model version to be added to the model registry.
    """

    name: str
    version: Optional[str] = None
    trained_model_name: Optional[str] = "model"
    model_source_uri: Optional[str] = None
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    run_id: Optional[str] = None
    description: Optional[str] = None
    metadata: ModelRegistryModelMetadata = Field(
        default_factory=ModelRegistryModelMetadata
    )
mlflow_register_model_step (BaseStep)

MLflow model registry step.

Parameters:

Name Type Description Default
model

Model to be registered, This is not used in the step, but is required to trigger the step when the model is trained.

required
params

Parameters for the step.

required

Exceptions:

Type Description
ValueError

If the model registry is not an MLflow model registry.

ValueError

If the experiment tracker is not an MLflow experiment tracker.

RuntimeError

If no model source URI is provided and no model is found.

RuntimeError

If no run ID is provided and no run is found.

PARAMETERS_CLASS (BaseParameters) pydantic-model

Model registry step parameters for MLflow.

Parameters:

Name Type Description Default
name

Name of the registered model.

required
version

Version of the registered model.

required
trained_model_name

Name of the model to be deployed.

required
experiment_name

Name of the experiment to be used for the run.

required
run_name

Name of the run to be created.

required
run_id

ID of the run to be used.

required
model_source_uri

URI of the model source. If not provided, the model will be fetched from the MLflow tracking server.

required
description

Description of the model.

required
metadata

Metadata of the model version to be added to the model registry.

required
Source code in zenml/integrations/mlflow/steps/mlflow_registry.py
class MLFlowRegistryParameters(BaseParameters):
    """Model registry step parameters for MLflow.

    Args:
        name: Name of the registered model.
        version: Version of the registered model.
        trained_model_name: Name of the model to be deployed.
        experiment_name: Name of the experiment to be used for the run.
        run_name: Name of the run to be created.
        run_id: ID of the run to be used.
        model_source_uri: URI of the model source. If not provided, the model
            will be fetched from the MLflow tracking server.
        description: Description of the model.
        metadata: Metadata of the model version to be added to the model registry.
    """

    name: str
    version: Optional[str] = None
    trained_model_name: Optional[str] = "model"
    model_source_uri: Optional[str] = None
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    run_id: Optional[str] = None
    description: Optional[str] = None
    metadata: ModelRegistryModelMetadata = Field(
        default_factory=ModelRegistryModelMetadata
    )
entrypoint(model, params) staticmethod

MLflow model registry step.

Parameters:

Name Type Description Default
model UnmaterializedArtifact

Model to be registered, This is not used in the step, but is required to trigger the step when the model is trained.

required
params MLFlowRegistryParameters

Parameters for the step.

required

Exceptions:

Type Description
ValueError

If the model registry is not an MLflow model registry.

ValueError

If the experiment tracker is not an MLflow experiment tracker.

RuntimeError

If no model source URI is provided and no model is found.

RuntimeError

If no run ID is provided and no run is found.

Source code in zenml/integrations/mlflow/steps/mlflow_registry.py
@step(enable_cache=True)
def mlflow_register_model_step(
    model: UnmaterializedArtifact,
    params: MLFlowRegistryParameters,
) -> None:
    """MLflow model registry step.

    Args:
        model: Model to be registered, This is not used in the step, but is
            required to trigger the step when the model is trained.
        params: Parameters for the step.

    Raises:
        ValueError: If the model registry is not an MLflow model registry.
        ValueError: If the experiment tracker is not an MLflow experiment tracker.
        RuntimeError: If no model source URI is provided and no model is found.
        RuntimeError: If no run ID is provided and no run is found.
    """
    # get the experiment tracker and check if it is an MLflow experiment tracker.
    experiment_tracker = Client().active_stack.experiment_tracker
    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow experiment tracker. Please add an MLflow experiment "
            "tracker to your stack."
        )

    # fetch the MLflow model registry
    model_registry = Client().active_stack.model_registry
    if not isinstance(model_registry, MLFlowModelRegistry):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow model registry."
        )

    # get pipeline name, step name and run id
    step_env = cast(StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME])
    pipeline_name = step_env.pipeline_name
    pipeline_run_id = step_env.run_name
    pipeline_run_uuid = str(step_env.step_run_info.run_id)
    zenml_workspace = str(model_registry.workspace)

    # Get MLflow run ID either from params or from experiment tracker using
    # pipeline name and run name
    mlflow_run_id = params.run_id or experiment_tracker.get_run_id(
        experiment_name=params.experiment_name or pipeline_name,
        run_name=params.run_name or pipeline_run_id,
    )
    # If no value was set at all, raise an error
    if not mlflow_run_id:
        raise RuntimeError(
            f"Could not find MLflow run for experiment {pipeline_name} "
            f"and run {pipeline_run_id}."
        )

    # Get MLflow client
    client = model_registry.mlflow_client
    # Lastly, check if the run ID is valid
    try:
        client.get_run(run_id=mlflow_run_id).info.run_id
    except Exception:
        raise RuntimeError(
            f"Could not find MLflow run with ID {mlflow_run_id}."
        )

    # Set model source URI
    model_source_uri = params.model_source_uri or None

    # Check if the run ID have a model artifact if no model source URI is set.
    if not params.model_source_uri and client.list_artifacts(
        mlflow_run_id, params.trained_model_name
    ):
        model_source_uri = artifact_utils.get_artifact_uri(
            run_id=mlflow_run_id, artifact_path=params.trained_model_name
        )
    if not model_source_uri:
        raise RuntimeError(
            "No model source URI provided or no model found in the "
            "MLflow tracking server for the given inputs."
        )

    # Check metadata
    if params.metadata.zenml_version is None:
        params.metadata.zenml_version = __version__
    if params.metadata.zenml_pipeline_name is None:
        params.metadata.zenml_pipeline_name = pipeline_name
    if params.metadata.zenml_pipeline_run_id is None:
        params.metadata.zenml_pipeline_run_id = pipeline_run_id
    if params.metadata.zenml_pipeline_run_uuid is None:
        params.metadata.zenml_pipeline_run_uuid = pipeline_run_uuid
    if params.metadata.zenml_workspace is None:
        params.metadata.zenml_workspace = zenml_workspace

    # Register model version
    model_version = model_registry.register_model_version(
        name=params.name,
        version=params.version or "1",
        model_source_uri=model_source_uri,
        description=params.description,
        metadata=params.metadata,
    )

    logger.info(
        f"Registered model {params.name} "
        f"with version {model_version.version} "
        f"from source {model_source_uri}."
    )