Skip to content

Metadata Stores

zenml.metadata_stores special

Initialization of ZenML's metadata stores.

The configuration of each pipeline, step, backend, and produced artifacts are all tracked within the metadata store. The metadata store is an SQL database, and can be sqlite or mysql.

Metadata are the pieces of information tracked about the pipelines, experiments and configurations that you are running with ZenML. Metadata are stored inside the metadata store.

base_metadata_store

Base implementation of a metadata store.

BaseMetadataStore (StackComponent, ABC) pydantic-model

Base class for all ZenML metadata stores.

Source code in zenml/metadata_stores/base_metadata_store.py
class BaseMetadataStore(StackComponent, ABC):
    """Base class for all ZenML metadata stores."""

    # Class Configuration
    TYPE: ClassVar[StackComponentType] = StackComponentType.METADATA_STORE

    upgrade_migration_enabled: bool = True
    _store: Optional[metadata_store.MetadataStore] = None

    @property
    def store(self) -> metadata_store.MetadataStore:
        """General property that hooks into TFX metadata store.

        Returns:
            metadata_store.MetadataStore: TFX metadata store.
        """
        if self._store is None:
            config = self.get_tfx_metadata_config()
            self._store = metadata_store.MetadataStore(
                config,
                enable_upgrade_migration=self.upgrade_migration_enabled
                and isinstance(config, metadata_store_pb2.ConnectionConfig),
            )
        return self._store

    @abstractmethod
    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config.

        Returns:
            tfx metadata config.
        """
        raise NotImplementedError

    @property
    def step_type_mapping(self) -> Dict[int, str]:
        """Maps type_ids to step names.

        Returns:
            Dict[int, str]: a mapping from type_ids to step names.
        """
        return {
            type_.id: type_.name for type_ in self.store.get_execution_types()
        }

    def _check_if_executions_belong_to_pipeline(
        self,
        executions: List[proto.Execution],
        pipeline: PipelineView,
    ) -> bool:
        """Returns `True` if the executions are associated with the pipeline context.

        Args:
            executions: List of executions.
            pipeline: Pipeline to check.

        Returns:
            `True` if the executions are associated with the pipeline context.
        """
        for execution in executions:
            associated_contexts = self.store.get_contexts_by_execution(
                execution.id
            )
            for context in associated_contexts:
                if context.id == pipeline._id:  # noqa
                    return True
        return False

    def _get_step_view_from_execution(
        self, execution: proto.Execution
    ) -> StepView:
        """Get original StepView from an execution.

        Args:
            execution: proto.Execution object from mlmd store.

        Returns:
            Original `StepView` derived from the proto.Execution.

        Raises:
            KeyError: If the execution is not associated with a step.
        """
        impl_name = self.step_type_mapping[execution.type_id].split(".")[-1]

        step_name_property = execution.custom_properties.get(
            INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME,
            None,
        )
        if step_name_property:
            step_name = json.loads(step_name_property.string_value)
        else:
            raise KeyError(
                f"Step name missing for execution with ID {execution.id}. "
                f"This error probably occurs because you're using ZenML "
                f"version 0.5.4 or newer but your metadata store contains "
                f"data from previous versions."
            )

        step_parameters = {}
        for k, v in execution.custom_properties.items():
            if not k.startswith(INTERNAL_EXECUTION_PARAMETER_PREFIX):
                try:
                    step_parameters[k] = json.loads(v.string_value)
                except JSONDecodeError:
                    # this means there is a property in there that is neither
                    # an internal one or one created by zenml. Therefore, we can
                    # ignore it
                    pass

        # TODO [ENG-222]: This is a lot of querying to the metadata store. We
        #  should refactor and make it nicer. Probably it makes more sense
        #  to first get `executions_ids_for_current_run` and then filter on
        #  `event.execution_id in execution_ids_for_current_run`.
        # Core logic here is that we get the event of this particular execution
        # id that gives us the artifacts of this execution. We then go through
        # all `input` artifacts of this execution and get all events related to
        # that artifact. This in turn gives us other events for which this
        # artifact was an `output` artifact. Then we simply need to sort by
        # time to get the most recent execution (i.e. step) that produced that
        # particular artifact.
        events_for_execution = self.store.get_events_by_execution_ids(
            [execution.id]
        )

        parents_step_ids = set()
        for current_event in events_for_execution:
            if current_event.type == current_event.INPUT:
                # this means the artifact is an input artifact
                events_for_input_artifact = [
                    e
                    for e in self.store.get_events_by_artifact_ids(
                        [current_event.artifact_id]
                    )
                    # should be output type and should NOT be the same id as
                    # the execution we are querying and it should be BEFORE
                    # the time of the current event.
                    if e.type == e.OUTPUT
                    and e.execution_id != current_event.execution_id
                    and e.milliseconds_since_epoch
                    < current_event.milliseconds_since_epoch
                ]

                # sort by time
                events_for_input_artifact.sort(
                    key=lambda x: x.milliseconds_since_epoch  # type: ignore[no-any-return] # noqa
                )
                # take the latest one and add execution to the parents.
                parents_step_ids.add(events_for_input_artifact[-1].execution_id)

        return StepView(
            id_=execution.id,
            parents_step_ids=list(parents_step_ids),
            entrypoint_name=impl_name,
            name=step_name,
            parameters=step_parameters,
            metadata_store=self,
        )

    def get_pipelines(self) -> List[PipelineView]:
        """Returns a list of all pipelines stored in this metadata store.

        Returns:
            List[PipelineView]: a list of all pipelines stored in this metadata store.
        """
        pipelines = []
        for pipeline_context in self.store.get_contexts_by_type(
            PIPELINE_CONTEXT_TYPE_NAME
        ):
            pipeline = PipelineView(
                id_=pipeline_context.id,
                name=pipeline_context.name,
                metadata_store=self,
            )
            pipelines.append(pipeline)

        logger.debug("Fetched %d pipelines.", len(pipelines))
        return pipelines

    def get_pipeline(self, pipeline_name: str) -> Optional[PipelineView]:
        """Returns a pipeline for the given name.

        Args:
            pipeline_name: Name of the pipeline.

        Returns:
            PipelineView if found, None otherwise.
        """
        pipeline_context = self.store.get_context_by_type_and_name(
            PIPELINE_CONTEXT_TYPE_NAME, pipeline_name
        )
        if pipeline_context:
            logger.debug("Fetched pipeline with name '%s'", pipeline_name)
            return PipelineView(
                id_=pipeline_context.id,
                name=pipeline_context.name,
                metadata_store=self,
            )
        else:
            logger.info("No pipelines found for name '%s'", pipeline_name)
            return None

    def get_pipeline_runs(
        self, pipeline: PipelineView
    ) -> Dict[str, PipelineRunView]:
        """Gets all runs for the given pipeline.

        Args:
            pipeline: a Pipeline object for which you want the runs.

        Returns:
            A dictionary of pipeline run names to PipelineRunView.
        """
        all_pipeline_runs = self.store.get_contexts_by_type(
            PIPELINE_RUN_CONTEXT_TYPE_NAME
        )
        runs: Dict[str, PipelineRunView] = OrderedDict()

        for run in all_pipeline_runs:
            executions = self.store.get_executions_by_context(run.id)
            if self._check_if_executions_belong_to_pipeline(
                executions, pipeline
            ):
                run_view = PipelineRunView(
                    id_=run.id,
                    name=run.name,
                    executions=executions,
                    metadata_store=self,
                )
                runs[run.name] = run_view

        logger.debug(
            "Fetched %d pipeline runs for pipeline named '%s'.",
            len(runs),
            pipeline.name,
        )

        return runs

    def get_pipeline_run(
        self, pipeline: PipelineView, run_name: str
    ) -> Optional[PipelineRunView]:
        """Gets a specific run for the given pipeline.

        Args:
            pipeline: The pipeline for which to get the run.
            run_name: The name of the run to get.

        Returns:
            The pipeline run with the given name.
        """
        run = self.store.get_context_by_type_and_name(
            PIPELINE_RUN_CONTEXT_TYPE_NAME, run_name
        )

        if not run:
            # No context found for the given run name
            return None

        executions = self.store.get_executions_by_context(run.id)
        if self._check_if_executions_belong_to_pipeline(executions, pipeline):
            logger.debug("Fetched pipeline run with name '%s'", run_name)
            return PipelineRunView(
                id_=run.id,
                name=run.name,
                executions=executions,
                metadata_store=self,
            )

        logger.info("No pipeline run found for name '%s'", run_name)
        return None

    def get_pipeline_run_steps(
        self, pipeline_run: PipelineRunView
    ) -> Dict[str, StepView]:
        """Gets all steps for the given pipeline run.

        Args:
            pipeline_run: The pipeline run to get the steps for.

        Returns:
            A dictionary of step names to step views.
        """
        steps: Dict[str, StepView] = OrderedDict()
        # reverse the executions as they get returned in reverse chronological
        # order from the metadata store
        for execution in reversed(pipeline_run._executions):  # noqa
            step = self._get_step_view_from_execution(execution)
            steps[step.name] = step

        logger.debug(
            "Fetched %d steps for pipeline run '%s'.",
            len(steps),
            pipeline_run.name,
        )

        return steps

    def get_step_by_id(self, step_id: int) -> StepView:
        """Gets a `StepView` by its ID.

        Args:
            step_id (int): The ID of the step to get.

        Returns:
            StepView: The `StepView` with the given ID.
        """
        execution = self.store.get_executions_by_id([step_id])[0]
        return self._get_step_view_from_execution(execution)

    def get_step_status(self, step: StepView) -> ExecutionStatus:
        """Gets the execution status of a single step.

        Args:
            step (StepView): The step to get the status for.

        Returns:
            ExecutionStatus: The status of the step.
        """
        proto = self.store.get_executions_by_id([step._id])[0]  # noqa
        state = proto.last_known_state

        if state == proto.COMPLETE:
            return ExecutionStatus.COMPLETED
        elif state == proto.RUNNING:
            return ExecutionStatus.RUNNING
        elif state == proto.CACHED:
            return ExecutionStatus.CACHED
        else:
            return ExecutionStatus.FAILED

    def get_step_artifacts(
        self, step: StepView
    ) -> Tuple[Dict[str, ArtifactView], Dict[str, ArtifactView]]:
        """Returns input and output artifacts for the given step.

        Args:
            step: The step for which to get the artifacts.

        Returns:
            A tuple (inputs, outputs) where inputs and outputs
            are both Dicts mapping artifact names
            to the input and output artifacts respectively.
        """
        # maps artifact types to their string representation
        artifact_type_mapping = {
            type_.id: type_.name for type_ in self.store.get_artifact_types()
        }

        events = self.store.get_events_by_execution_ids([step._id])  # noqa
        artifacts = self.store.get_artifacts_by_id(
            [event.artifact_id for event in events]
        )

        inputs: Dict[str, ArtifactView] = {}
        outputs: Dict[str, ArtifactView] = {}

        # sort them according to artifact_id's so that the zip works.
        events.sort(key=lambda x: x.artifact_id)
        artifacts.sort(key=lambda x: x.id)

        for event_proto, artifact_proto in zip(events, artifacts):
            artifact_type = artifact_type_mapping[artifact_proto.type_id]
            artifact_name = event_proto.path.steps[0].key

            materializer = artifact_proto.properties[
                MATERIALIZER_PROPERTY_KEY
            ].string_value

            data_type = artifact_proto.properties[
                DATATYPE_PROPERTY_KEY
            ].string_value

            parent_step_id = step.id
            if event_proto.type == event_proto.INPUT:
                # In the case that this is an input event, we actually need
                # to resolve it via its parents outputs.
                for parent in step.parent_steps:
                    for a in parent.outputs.values():
                        if artifact_proto.id == a.id:
                            parent_step_id = parent.id

            artifact = ArtifactView(
                id_=event_proto.artifact_id,
                type_=artifact_type,
                uri=artifact_proto.uri,
                materializer=materializer,
                data_type=data_type,
                metadata_store=self,
                parent_step_id=parent_step_id,
            )

            if event_proto.type == event_proto.INPUT:
                inputs[artifact_name] = artifact
            elif event_proto.type == event_proto.OUTPUT:
                outputs[artifact_name] = artifact

        logger.debug(
            "Fetched %d inputs and %d outputs for step '%s'.",
            len(inputs),
            len(outputs),
            step.entrypoint_name,
        )

        return inputs, outputs

    def get_producer_step_from_artifact(
        self, artifact: ArtifactView
    ) -> StepView:
        """Returns original StepView from an ArtifactView.

        Args:
            artifact: ArtifactView to be queried.

        Returns:
            Original StepView that produced the artifact.
        """
        executions_ids = set(
            event.execution_id
            for event in self.store.get_events_by_artifact_ids([artifact.id])
            if event.type == event.OUTPUT
        )
        execution = self.store.get_executions_by_id(executions_ids)[0]
        return self._get_step_view_from_execution(execution)
step_type_mapping: Dict[int, str] property readonly

Maps type_ids to step names.

Returns:

Type Description
Dict[int, str]

a mapping from type_ids to step names.

store: MetadataStore property readonly

General property that hooks into TFX metadata store.

Returns:

Type Description
metadata_store.MetadataStore

TFX metadata store.

get_pipeline(self, pipeline_name)

Returns a pipeline for the given name.

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline.

required

Returns:

Type Description
Optional[zenml.post_execution.pipeline.PipelineView]

PipelineView if found, None otherwise.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline(self, pipeline_name: str) -> Optional[PipelineView]:
    """Returns a pipeline for the given name.

    Args:
        pipeline_name: Name of the pipeline.

    Returns:
        PipelineView if found, None otherwise.
    """
    pipeline_context = self.store.get_context_by_type_and_name(
        PIPELINE_CONTEXT_TYPE_NAME, pipeline_name
    )
    if pipeline_context:
        logger.debug("Fetched pipeline with name '%s'", pipeline_name)
        return PipelineView(
            id_=pipeline_context.id,
            name=pipeline_context.name,
            metadata_store=self,
        )
    else:
        logger.info("No pipelines found for name '%s'", pipeline_name)
        return None
get_pipeline_run(self, pipeline, run_name)

Gets a specific run for the given pipeline.

Parameters:

Name Type Description Default
pipeline PipelineView

The pipeline for which to get the run.

required
run_name str

The name of the run to get.

required

Returns:

Type Description
Optional[zenml.post_execution.pipeline_run.PipelineRunView]

The pipeline run with the given name.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_run(
    self, pipeline: PipelineView, run_name: str
) -> Optional[PipelineRunView]:
    """Gets a specific run for the given pipeline.

    Args:
        pipeline: The pipeline for which to get the run.
        run_name: The name of the run to get.

    Returns:
        The pipeline run with the given name.
    """
    run = self.store.get_context_by_type_and_name(
        PIPELINE_RUN_CONTEXT_TYPE_NAME, run_name
    )

    if not run:
        # No context found for the given run name
        return None

    executions = self.store.get_executions_by_context(run.id)
    if self._check_if_executions_belong_to_pipeline(executions, pipeline):
        logger.debug("Fetched pipeline run with name '%s'", run_name)
        return PipelineRunView(
            id_=run.id,
            name=run.name,
            executions=executions,
            metadata_store=self,
        )

    logger.info("No pipeline run found for name '%s'", run_name)
    return None
get_pipeline_run_steps(self, pipeline_run)

Gets all steps for the given pipeline run.

Parameters:

Name Type Description Default
pipeline_run PipelineRunView

The pipeline run to get the steps for.

required

Returns:

Type Description
Dict[str, zenml.post_execution.step.StepView]

A dictionary of step names to step views.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_run_steps(
    self, pipeline_run: PipelineRunView
) -> Dict[str, StepView]:
    """Gets all steps for the given pipeline run.

    Args:
        pipeline_run: The pipeline run to get the steps for.

    Returns:
        A dictionary of step names to step views.
    """
    steps: Dict[str, StepView] = OrderedDict()
    # reverse the executions as they get returned in reverse chronological
    # order from the metadata store
    for execution in reversed(pipeline_run._executions):  # noqa
        step = self._get_step_view_from_execution(execution)
        steps[step.name] = step

    logger.debug(
        "Fetched %d steps for pipeline run '%s'.",
        len(steps),
        pipeline_run.name,
    )

    return steps
get_pipeline_runs(self, pipeline)

Gets all runs for the given pipeline.

Parameters:

Name Type Description Default
pipeline PipelineView

a Pipeline object for which you want the runs.

required

Returns:

Type Description
Dict[str, zenml.post_execution.pipeline_run.PipelineRunView]

A dictionary of pipeline run names to PipelineRunView.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipeline_runs(
    self, pipeline: PipelineView
) -> Dict[str, PipelineRunView]:
    """Gets all runs for the given pipeline.

    Args:
        pipeline: a Pipeline object for which you want the runs.

    Returns:
        A dictionary of pipeline run names to PipelineRunView.
    """
    all_pipeline_runs = self.store.get_contexts_by_type(
        PIPELINE_RUN_CONTEXT_TYPE_NAME
    )
    runs: Dict[str, PipelineRunView] = OrderedDict()

    for run in all_pipeline_runs:
        executions = self.store.get_executions_by_context(run.id)
        if self._check_if_executions_belong_to_pipeline(
            executions, pipeline
        ):
            run_view = PipelineRunView(
                id_=run.id,
                name=run.name,
                executions=executions,
                metadata_store=self,
            )
            runs[run.name] = run_view

    logger.debug(
        "Fetched %d pipeline runs for pipeline named '%s'.",
        len(runs),
        pipeline.name,
    )

    return runs
get_pipelines(self)

Returns a list of all pipelines stored in this metadata store.

Returns:

Type Description
List[PipelineView]

a list of all pipelines stored in this metadata store.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_pipelines(self) -> List[PipelineView]:
    """Returns a list of all pipelines stored in this metadata store.

    Returns:
        List[PipelineView]: a list of all pipelines stored in this metadata store.
    """
    pipelines = []
    for pipeline_context in self.store.get_contexts_by_type(
        PIPELINE_CONTEXT_TYPE_NAME
    ):
        pipeline = PipelineView(
            id_=pipeline_context.id,
            name=pipeline_context.name,
            metadata_store=self,
        )
        pipelines.append(pipeline)

    logger.debug("Fetched %d pipelines.", len(pipelines))
    return pipelines
get_producer_step_from_artifact(self, artifact)

Returns original StepView from an ArtifactView.

Parameters:

Name Type Description Default
artifact ArtifactView

ArtifactView to be queried.

required

Returns:

Type Description
StepView

Original StepView that produced the artifact.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_producer_step_from_artifact(
    self, artifact: ArtifactView
) -> StepView:
    """Returns original StepView from an ArtifactView.

    Args:
        artifact: ArtifactView to be queried.

    Returns:
        Original StepView that produced the artifact.
    """
    executions_ids = set(
        event.execution_id
        for event in self.store.get_events_by_artifact_ids([artifact.id])
        if event.type == event.OUTPUT
    )
    execution = self.store.get_executions_by_id(executions_ids)[0]
    return self._get_step_view_from_execution(execution)
get_step_artifacts(self, step)

Returns input and output artifacts for the given step.

Parameters:

Name Type Description Default
step StepView

The step for which to get the artifacts.

required

Returns:

Type Description
Tuple[Dict[str, zenml.post_execution.artifact.ArtifactView], Dict[str, zenml.post_execution.artifact.ArtifactView]]

A tuple (inputs, outputs) where inputs and outputs are both Dicts mapping artifact names to the input and output artifacts respectively.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_artifacts(
    self, step: StepView
) -> Tuple[Dict[str, ArtifactView], Dict[str, ArtifactView]]:
    """Returns input and output artifacts for the given step.

    Args:
        step: The step for which to get the artifacts.

    Returns:
        A tuple (inputs, outputs) where inputs and outputs
        are both Dicts mapping artifact names
        to the input and output artifacts respectively.
    """
    # maps artifact types to their string representation
    artifact_type_mapping = {
        type_.id: type_.name for type_ in self.store.get_artifact_types()
    }

    events = self.store.get_events_by_execution_ids([step._id])  # noqa
    artifacts = self.store.get_artifacts_by_id(
        [event.artifact_id for event in events]
    )

    inputs: Dict[str, ArtifactView] = {}
    outputs: Dict[str, ArtifactView] = {}

    # sort them according to artifact_id's so that the zip works.
    events.sort(key=lambda x: x.artifact_id)
    artifacts.sort(key=lambda x: x.id)

    for event_proto, artifact_proto in zip(events, artifacts):
        artifact_type = artifact_type_mapping[artifact_proto.type_id]
        artifact_name = event_proto.path.steps[0].key

        materializer = artifact_proto.properties[
            MATERIALIZER_PROPERTY_KEY
        ].string_value

        data_type = artifact_proto.properties[
            DATATYPE_PROPERTY_KEY
        ].string_value

        parent_step_id = step.id
        if event_proto.type == event_proto.INPUT:
            # In the case that this is an input event, we actually need
            # to resolve it via its parents outputs.
            for parent in step.parent_steps:
                for a in parent.outputs.values():
                    if artifact_proto.id == a.id:
                        parent_step_id = parent.id

        artifact = ArtifactView(
            id_=event_proto.artifact_id,
            type_=artifact_type,
            uri=artifact_proto.uri,
            materializer=materializer,
            data_type=data_type,
            metadata_store=self,
            parent_step_id=parent_step_id,
        )

        if event_proto.type == event_proto.INPUT:
            inputs[artifact_name] = artifact
        elif event_proto.type == event_proto.OUTPUT:
            outputs[artifact_name] = artifact

    logger.debug(
        "Fetched %d inputs and %d outputs for step '%s'.",
        len(inputs),
        len(outputs),
        step.entrypoint_name,
    )

    return inputs, outputs
get_step_by_id(self, step_id)

Gets a StepView by its ID.

Parameters:

Name Type Description Default
step_id int

The ID of the step to get.

required

Returns:

Type Description
StepView

The StepView with the given ID.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_by_id(self, step_id: int) -> StepView:
    """Gets a `StepView` by its ID.

    Args:
        step_id (int): The ID of the step to get.

    Returns:
        StepView: The `StepView` with the given ID.
    """
    execution = self.store.get_executions_by_id([step_id])[0]
    return self._get_step_view_from_execution(execution)
get_step_status(self, step)

Gets the execution status of a single step.

Parameters:

Name Type Description Default
step StepView

The step to get the status for.

required

Returns:

Type Description
ExecutionStatus

The status of the step.

Source code in zenml/metadata_stores/base_metadata_store.py
def get_step_status(self, step: StepView) -> ExecutionStatus:
    """Gets the execution status of a single step.

    Args:
        step (StepView): The step to get the status for.

    Returns:
        ExecutionStatus: The status of the step.
    """
    proto = self.store.get_executions_by_id([step._id])[0]  # noqa
    state = proto.last_known_state

    if state == proto.COMPLETE:
        return ExecutionStatus.COMPLETED
    elif state == proto.RUNNING:
        return ExecutionStatus.RUNNING
    elif state == proto.CACHED:
        return ExecutionStatus.CACHED
    else:
        return ExecutionStatus.FAILED
get_tfx_metadata_config(self)

Return tfx metadata config.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

tfx metadata config.

Source code in zenml/metadata_stores/base_metadata_store.py
@abstractmethod
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config.

    Returns:
        tfx metadata config.
    """
    raise NotImplementedError

mysql_metadata_store

Implementation of a MySQL metadata store.

MySQLMetadataStore (BaseMetadataStore) pydantic-model

MySQL backend for ZenML metadata store.

Source code in zenml/metadata_stores/mysql_metadata_store.py
class MySQLMetadataStore(BaseMetadataStore):
    """MySQL backend for ZenML metadata store."""

    port: int = 3306
    host: str
    database: str
    secret: Optional[str] = None
    username: Optional[str] = None
    password: Optional[str] = None

    # Class Configuration
    FLAVOR: ClassVar[str] = "mysql"

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for MySQL metadata store.

        Returns:
            The tfx metadata config.

        Raises:
            RuntimeError: If you have configured your metadata store incorrectly.
        """
        config = MySQLDatabaseConfig(
            host=self.host,
            port=self.port,
            database=self.database,
        )

        secret = self._get_mysql_secret()

        # Set the user
        if self.username:
            if secret and secret.user:
                raise RuntimeError(
                    f"Both the metadata store {self.name} and the secret "
                    f"{self.secret} within your secrets manager define "
                    f"a username `{self.username}` and `{secret.user}`. Please "
                    f"make sure that you only use one."
                )
            else:
                config.user = self.username
        else:
            if secret and secret.user:
                config.user = secret.user
            else:
                raise RuntimeError(
                    "Your metadata store does not have a username. Please "
                    "provide it either by defining it upon registration or "
                    "through a MySQL secret."
                )

        # Set the password
        if self.password:
            if secret and secret.password:
                raise RuntimeError(
                    f"Both the metadata store {self.name} and the secret "
                    f"{self.secret} within your secrets manager define "
                    f"a password. Please make sure that you only use one."
                )
            else:
                config.password = self.password
        else:
            if secret and secret.password:
                config.password = secret.password

        # Set the SSL configuration if there is one
        if secret:
            secret_folder = Path(
                GlobalConfiguration().config_directory,
                "mysql-metadata",
                str(self.uuid),
            )

            ssl_options = {}
            # Handle the files
            for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
                content = getattr(secret, key)
                if content:
                    fileio.makedirs(str(secret_folder))
                    file_path = Path(secret_folder, f"{key}.pem")
                    ssl_options[key.lstrip("ssl_")] = str(file_path)
                    with open(file_path, "w") as f:
                        f.write(content)
                    file_path.chmod(0o600)

            # Handle additional params
            ssl_options["verify_server_cert"] = secret.ssl_verify_server_cert

            ssl_options = MySQLDatabaseConfig.SSLOptions(**ssl_options)
            config.ssl_options.CopyFrom(ssl_options)

        return metadata_store_pb2.ConnectionConfig(mysql=config)

    def _get_mysql_secret(self) -> Any:
        """Method which returns a MySQL secret from the secrets manager.

        Returns:
            Any: The MySQL secret.

        Raises:
            RuntimeError: If you don't have a secrets manager as part of your stack.
        """
        if self.secret:
            active_stack = Repository().active_stack
            secret_manager = active_stack.secrets_manager
            if secret_manager is None:
                raise RuntimeError(
                    f"The metadata store `{self.name}` that you are using "
                    f"requires a secret. However, your stack "
                    f"`{active_stack.name}` does not have a secrets manager."
                )
            try:
                secret = secret_manager.get_secret(self.secret)

                from zenml.metadata_stores.mysql_secret_schema import (
                    MYSQLSecretSchema,
                )

                if not isinstance(secret, MYSQLSecretSchema):
                    raise RuntimeError(
                        f"If you are using a secret with a MySQL Metadata "
                        f"Store, please make sure to use the schema: "
                        f"{MYSQLSecretSchema.TYPE}"
                    )
                return secret

            except KeyError:
                raise RuntimeError(
                    f"The secret `{self.secret}` used for your MySQL metadata "
                    f"store `{self.name}` does not exist in your secrets "
                    f"manager `{secret_manager.name}`."
                )
        return None
get_tfx_metadata_config(self)

Return tfx metadata config for MySQL metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config.

Exceptions:

Type Description
RuntimeError

If you have configured your metadata store incorrectly.

Source code in zenml/metadata_stores/mysql_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for MySQL metadata store.

    Returns:
        The tfx metadata config.

    Raises:
        RuntimeError: If you have configured your metadata store incorrectly.
    """
    config = MySQLDatabaseConfig(
        host=self.host,
        port=self.port,
        database=self.database,
    )

    secret = self._get_mysql_secret()

    # Set the user
    if self.username:
        if secret and secret.user:
            raise RuntimeError(
                f"Both the metadata store {self.name} and the secret "
                f"{self.secret} within your secrets manager define "
                f"a username `{self.username}` and `{secret.user}`. Please "
                f"make sure that you only use one."
            )
        else:
            config.user = self.username
    else:
        if secret and secret.user:
            config.user = secret.user
        else:
            raise RuntimeError(
                "Your metadata store does not have a username. Please "
                "provide it either by defining it upon registration or "
                "through a MySQL secret."
            )

    # Set the password
    if self.password:
        if secret and secret.password:
            raise RuntimeError(
                f"Both the metadata store {self.name} and the secret "
                f"{self.secret} within your secrets manager define "
                f"a password. Please make sure that you only use one."
            )
        else:
            config.password = self.password
    else:
        if secret and secret.password:
            config.password = secret.password

    # Set the SSL configuration if there is one
    if secret:
        secret_folder = Path(
            GlobalConfiguration().config_directory,
            "mysql-metadata",
            str(self.uuid),
        )

        ssl_options = {}
        # Handle the files
        for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
            content = getattr(secret, key)
            if content:
                fileio.makedirs(str(secret_folder))
                file_path = Path(secret_folder, f"{key}.pem")
                ssl_options[key.lstrip("ssl_")] = str(file_path)
                with open(file_path, "w") as f:
                    f.write(content)
                file_path.chmod(0o600)

        # Handle additional params
        ssl_options["verify_server_cert"] = secret.ssl_verify_server_cert

        ssl_options = MySQLDatabaseConfig.SSLOptions(**ssl_options)
        config.ssl_options.CopyFrom(ssl_options)

    return metadata_store_pb2.ConnectionConfig(mysql=config)

mysql_secret_schema

Secret schema for MySQL metadata store.

MYSQLSecretSchema (BaseSecretSchema) pydantic-model

MySQL secret schema.

Source code in zenml/metadata_stores/mysql_secret_schema.py
class MYSQLSecretSchema(BaseSecretSchema):
    """MySQL secret schema."""

    TYPE: ClassVar[str] = MYSQL_METADATA_STORE_SCHEMA_TYPE

    user: Optional[str]
    password: Optional[str]
    ssl_ca: Optional[str]
    ssl_cert: Optional[str]
    ssl_key: Optional[str]
    ssl_verify_server_cert: Optional[bool] = False

sqlite_metadata_store

Metadata store for SQLite.

SQLiteMetadataStore (BaseMetadataStore) pydantic-model

SQLite backend for ZenML metadata store.

Source code in zenml/metadata_stores/sqlite_metadata_store.py
class SQLiteMetadataStore(BaseMetadataStore):
    """SQLite backend for ZenML metadata store."""

    uri: str

    # Class Configuration
    FLAVOR: ClassVar[str] = "sqlite"

    @property
    def local_path(self) -> str:
        """Path to the local directory where the SQLite DB is stored.

        Returns:
            The path to the local directory where the SQLite DB is stored.
        """
        return str(Path(self.uri).parent)

    def get_tfx_metadata_config(
        self,
    ) -> Union[
        metadata_store_pb2.ConnectionConfig,
        metadata_store_pb2.MetadataStoreClientConfig,
    ]:
        """Return tfx metadata config for sqlite metadata store.

        Returns:
            The tfx metadata config.
        """
        return metadata.sqlite_metadata_connection_config(self.uri)

    @validator("uri")
    def ensure_uri_is_local(cls, uri: str) -> str:
        """Ensures that the metadata store uri is local.

        Args:
            uri: The metadata store uri.

        Returns:
            The metadata store uri.

        Raises:
            ValueError: If the uri is not local.
        """
        if io_utils.is_remote(uri):
            raise ValueError(
                f"Uri '{uri}' specified for SQLiteMetadataStore is not a "
                f"local uri."
            )

        return uri
local_path: str property readonly

Path to the local directory where the SQLite DB is stored.

Returns:

Type Description
str

The path to the local directory where the SQLite DB is stored.

ensure_uri_is_local(uri) classmethod

Ensures that the metadata store uri is local.

Parameters:

Name Type Description Default
uri str

The metadata store uri.

required

Returns:

Type Description
str

The metadata store uri.

Exceptions:

Type Description
ValueError

If the uri is not local.

Source code in zenml/metadata_stores/sqlite_metadata_store.py
@validator("uri")
def ensure_uri_is_local(cls, uri: str) -> str:
    """Ensures that the metadata store uri is local.

    Args:
        uri: The metadata store uri.

    Returns:
        The metadata store uri.

    Raises:
        ValueError: If the uri is not local.
    """
    if io_utils.is_remote(uri):
        raise ValueError(
            f"Uri '{uri}' specified for SQLiteMetadataStore is not a "
            f"local uri."
        )

    return uri
get_tfx_metadata_config(self)

Return tfx metadata config for sqlite metadata store.

Returns:

Type Description
Union[ml_metadata.proto.metadata_store_pb2.ConnectionConfig, ml_metadata.proto.metadata_store_pb2.MetadataStoreClientConfig]

The tfx metadata config.

Source code in zenml/metadata_stores/sqlite_metadata_store.py
def get_tfx_metadata_config(
    self,
) -> Union[
    metadata_store_pb2.ConnectionConfig,
    metadata_store_pb2.MetadataStoreClientConfig,
]:
    """Return tfx metadata config for sqlite metadata store.

    Returns:
        The tfx metadata config.
    """
    return metadata.sqlite_metadata_connection_config(self.uri)