Client
zenml.client
Client implementation.
Client
ZenML client class.
The ZenML client manages configuration options for ZenML stacks as well as their components.
Source code in zenml/client.py
class Client(metaclass=ClientMetaClass):
"""ZenML client class.
The ZenML client manages configuration options for ZenML stacks as well
as their components.
"""
def __init__(
self,
root: Optional[Path] = None,
) -> None:
"""Initializes the global client instance.
Client is a singleton class: only one instance can exist. Calling
this constructor multiple times will always yield the same instance (see
the exception below).
The `root` argument is only meant for internal use and testing purposes.
User code must never pass them to the constructor.
When a custom `root` value is passed, an anonymous Client instance
is created and returned independently of the Client singleton and
that will have no effect as far as the rest of the ZenML core code is
concerned.
Instead of creating a new Client instance to reflect a different
repository root, to change the active root in the global Client,
call `Client().activate_root(<new-root>)`.
Args:
root: (internal use) custom root directory for the client. If
no path is given, the repository root is determined using the
environment variable `ZENML_REPOSITORY_PATH` (if set) and by
recursively searching in the parent directories of the
current working directory. Only used to initialize new
clients internally.
"""
self._root: Optional[Path] = None
self._config: Optional[ClientConfiguration] = None
self._set_active_root(root)
@classmethod
def get_instance(cls) -> Optional["Client"]:
"""Return the Client singleton instance.
Returns:
The Client singleton instance or None, if the Client hasn't
been initialized yet.
"""
return cls._global_client
@classmethod
def _reset_instance(cls, client: Optional["Client"] = None) -> None:
"""Reset the Client singleton instance.
This method is only meant for internal use and testing purposes.
Args:
client: The Client instance to set as the global singleton.
If None, the global Client singleton is reset to an empty
value.
"""
cls._global_client = client
def _set_active_root(self, root: Optional[Path] = None) -> None:
"""Set the supplied path as the repository root.
If a client configuration is found at the given path or the
path, it is loaded and used to initialize the client.
If no client configuration is found, the global configuration is
used instead to manage the active stack, project etc.
Args:
root: The path to set as the active repository root. If not set,
the repository root is determined using the environment
variable `ZENML_REPOSITORY_PATH` (if set) and by recursively
searching in the parent directories of the current working
directory.
"""
enable_warnings = handle_bool_env_var(
ENV_ZENML_ENABLE_REPO_INIT_WARNINGS, True
)
self._root = self.find_repository(root, enable_warnings=enable_warnings)
if not self._root:
if enable_warnings:
logger.info("Running without an active repository root.")
else:
logger.debug("Using repository root %s.", self._root)
self._config = self._load_config()
# Sanitize the client configuration to reflect the current
# settings
self._sanitize_config()
def _config_path(self) -> Optional[str]:
"""Path to the client configuration file.
Returns:
Path to the client configuration file or None if the client
root has not been initialized yet.
"""
if not self.config_directory:
return None
return str(self.config_directory / "config.yaml")
def _sanitize_config(self) -> None:
"""Sanitize and save the client configuration.
This method is called to ensure that the client configuration
doesn't contain outdated information, such as an active stack or
project that no longer exists.
"""
if not self._config:
return
active_project, active_stack = self.zen_store.validate_active_config(
self._config.active_project_name,
self._config.active_stack_id,
config_name="repo",
)
self._config.active_stack_id = active_stack.id
self._config.set_active_project(active_project)
def _load_config(self) -> Optional[ClientConfiguration]:
"""Loads the client configuration from disk.
This happens if the client has an active root and the configuration
file exists. If the configuration file doesn't exist, an empty
configuration is returned.
Returns:
Loaded client configuration or None if the client does not
have an active root.
"""
config_path = self._config_path()
if not config_path:
return None
# load the client configuration file if it exists, otherwise use
# an empty configuration as default
if fileio.exists(config_path):
logger.debug(f"Loading client configuration from {config_path}.")
else:
logger.debug(
"No client configuration file found, creating default "
"configuration."
)
return ClientConfiguration(config_path)
@staticmethod
@track(event=AnalyticsEvent.INITIALIZE_REPO)
def initialize(
root: Optional[Path] = None,
) -> None:
"""Initializes a new ZenML repository at the given path.
Args:
root: The root directory where the repository should be created.
If None, the current working directory is used.
Raises:
InitializationException: If the root directory already contains a
ZenML repository.
"""
root = root or Path.cwd()
logger.debug("Initializing new repository at path %s.", root)
if Client.is_repository_directory(root):
raise InitializationException(
f"Found existing ZenML repository at path '{root}'."
)
config_directory = str(root / REPOSITORY_DIRECTORY_NAME)
io_utils.create_dir_recursive_if_not_exists(config_directory)
# Initialize the repository configuration at the custom path
Client(root=root)
@property
def uses_local_configuration(self) -> bool:
"""Check if the client is using a local configuration.
Returns:
True if the client is using a local configuration,
False otherwise.
"""
return self._config is not None
@staticmethod
def is_repository_directory(path: Path) -> bool:
"""Checks whether a ZenML client exists at the given path.
Args:
path: The path to check.
Returns:
True if a ZenML client exists at the given path,
False otherwise.
"""
config_dir = path / REPOSITORY_DIRECTORY_NAME
return fileio.isdir(str(config_dir))
@staticmethod
def find_repository(
path: Optional[Path] = None, enable_warnings: bool = False
) -> Optional[Path]:
"""Search for a ZenML repository directory.
Args:
path: Optional path to look for the repository. If no path is
given, this function tries to find the repository using the
environment variable `ZENML_REPOSITORY_PATH` (if set) and
recursively searching in the parent directories of the current
working directory.
enable_warnings: If `True`, warnings are printed if the repository
root cannot be found.
Returns:
Absolute path to a ZenML repository directory or None if no
repository directory was found.
"""
if not path:
# try to get path from the environment variable
env_var_path = os.getenv(ENV_ZENML_REPOSITORY_PATH)
if env_var_path:
path = Path(env_var_path)
if path:
# explicit path via parameter or environment variable, don't search
# parent directories
search_parent_directories = False
warning_message = (
f"Unable to find ZenML repository at path '{path}'. Make sure "
f"to create a ZenML repository by calling `zenml init` when "
f"specifying an explicit repository path in code or via the "
f"environment variable '{ENV_ZENML_REPOSITORY_PATH}'."
)
else:
# try to find the repository in the parent directories of the
# current working directory
path = Path.cwd()
search_parent_directories = True
warning_message = (
f"Unable to find ZenML repository in your current working "
f"directory ({path}) or any parent directories. If you "
f"want to use an existing repository which is in a different "
f"location, set the environment variable "
f"'{ENV_ZENML_REPOSITORY_PATH}'. If you want to create a new "
f"repository, run `zenml init`."
)
def _find_repository_helper(path_: Path) -> Optional[Path]:
"""Recursively search parent directories for a ZenML repository.
Args:
path_: The path to search.
Returns:
Absolute path to a ZenML repository directory or None if no
repository directory was found.
"""
if Client.is_repository_directory(path_):
return path_
if not search_parent_directories or io_utils.is_root(str(path_)):
return None
return _find_repository_helper(path_.parent)
repository_path = _find_repository_helper(path)
if repository_path:
return repository_path.resolve()
if enable_warnings:
logger.warning(warning_message)
return None
@property
def zen_store(self) -> "BaseZenStore":
"""Shortcut to return the global zen store.
Returns:
The global zen store.
"""
return GlobalConfiguration().zen_store
@property
def root(self) -> Optional[Path]:
"""The root directory of this client.
Returns:
The root directory of this client, or None, if the client
has not been initialized.
"""
return self._root
@property
def config_directory(self) -> Optional[Path]:
"""The configuration directory of this client.
Returns:
The configuration directory of this client, or None, if the
client doesn't have an active root.
"""
if not self.root:
return None
return self.root / REPOSITORY_DIRECTORY_NAME
def activate_root(self, root: Optional[Path] = None) -> None:
"""Set the active repository root directory.
Args:
root: The path to set as the active repository root. If not set,
the repository root is determined using the environment
variable `ZENML_REPOSITORY_PATH` (if set) and by recursively
searching in the parent directories of the current working
directory.
"""
self._set_active_root(root)
@track(event=AnalyticsEvent.SET_PROJECT)
def set_active_project(
self, project_name_or_id: Union[str, UUID]
) -> "ProjectModel":
"""Set the project for the local client.
Args:
project_name_or_id: The name or ID of the project to set active.
Returns:
The model of the active project.
"""
project = self.zen_store.get_project(
project_name_or_id=project_name_or_id
) # raises KeyError
if self._config:
self._config.set_active_project(project)
else:
# set the active project globally only if the client doesn't use
# a local configuration
GlobalConfiguration().set_active_project(project)
return project
@property
def active_project_name(self) -> str:
"""The name of the active project for this client.
If no active project is configured locally for the client, the
active project in the global configuration is used instead.
Returns:
The name of the active project.
"""
return self.active_project.name
@property
def active_project(self) -> "ProjectModel":
"""Get the currently active project of the local client.
If no active project is configured locally for the client, the
active project in the global configuration is used instead.
Returns:
The active project.
Raises:
RuntimeError: If the active project is not set.
"""
project: Optional["ProjectModel"] = None
if self._config:
project = self._config._active_project
if not project:
project = GlobalConfiguration().active_project
if not project:
raise RuntimeError(
"No active project is configured. Run "
"`zenml project set PROJECT_NAME` to set the active "
"project."
)
from zenml.zen_stores.base_zen_store import DEFAULT_PROJECT_NAME
if project.name != DEFAULT_PROJECT_NAME:
logger.warning(
f"You are running with a non-default project "
f"'{project.name}'. Any stacks, components, "
f"pipelines and pipeline runs produced in this "
f"project will currently not be accessible through "
f"the dashboard. However, this will be possible "
f"in the near future."
)
return project
@property
def active_user(self) -> "UserModel":
"""Get the user that is currently in use.
Returns:
The active user.
"""
return self.zen_store.active_user
@property
def stacks(self) -> List["HydratedStackModel"]:
"""All stack models in the active project, owned by the user or shared.
This property is intended as a quick way to get information about the
components of the registered stacks without loading all installed
integrations.
Returns:
A list of all stacks available in the current project and owned by
the current user.
"""
owned_stacks = cast(
List["HydratedStackModel"],
self.zen_store.list_stacks(
project_name_or_id=self.active_project_name,
user_name_or_id=self.active_user.id,
is_shared=False,
hydrated=True,
),
)
shared_stacks = cast(
List["HydratedStackModel"],
self.zen_store.list_stacks(
project_name_or_id=self.active_project_name,
is_shared=True,
hydrated=True,
),
)
return owned_stacks + shared_stacks
@property
def active_stack_model(self) -> "HydratedStackModel":
"""The model of the active stack for this client.
If no active stack is configured locally for the client, the active
stack in the global configuration is used instead.
Returns:
The model of the active stack for this client.
Raises:
RuntimeError: If the active stack is not set.
"""
stack_id = None
if self._config:
stack_id = self._config.active_stack_id
if not stack_id:
stack_id = GlobalConfiguration().active_stack_id
if not stack_id:
raise RuntimeError(
"No active stack is configured. Run "
"`zenml stack set STACK_NAME` to set the active stack."
)
return self.zen_store.get_stack(stack_id=stack_id).to_hydrated_model()
@property
def active_stack(self) -> "Stack":
"""The active stack for this client.
Returns:
The active stack for this client.
"""
from zenml.stack.stack import Stack
return Stack.from_model(self.active_stack_model)
@track(event=AnalyticsEvent.SET_STACK)
def activate_stack(self, stack: "StackModel") -> None:
"""Sets the stack as active.
Args:
stack: Model of the stack to activate.
Raises:
KeyError: If the stack is not registered.
"""
# Make sure the stack is registered
try:
self.zen_store.get_stack(stack_id=stack.id)
except KeyError:
raise KeyError(
f"Stack '{stack.name}' cannot be activated since it is "
"not registered yet. Please register it first."
)
if self._config:
self._config.active_stack_id = stack.id
else:
# set the active stack globally only if the client doesn't use
# a local configuration
GlobalConfiguration().active_stack_id = stack.id
def _validate_stack_configuration(self, stack: "StackModel") -> None:
"""Validates the configuration of a stack.
Args:
stack: The stack to validate.
Raises:
KeyError: If the stack references missing components.
ValidationError: If the stack configuration is invalid.
"""
local_components: List[str] = []
remote_components: List[str] = []
for component_type, component_ids in stack.components.items():
for component_id in component_ids:
try:
component = self.get_stack_component_by_id(
component_id=component_id
)
except KeyError:
raise KeyError(
f"Cannot register stack '{stack.name}' since it has an "
f"unregistered {component_type} with id "
f"'{component_id}'."
)
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor, component_type=component.type
)
# Create and validate the configuration
from zenml.stack import Flavor
flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
if configuration.is_local:
local_components.append(
f"{component.type.value}: {component.name}"
)
elif configuration.is_remote:
remote_components.append(
f"{component.type.value}: {component.name}"
)
if local_components and remote_components:
logger.warning(
f"You are configuring a stack that is composed of components "
f"that are relying on local resources "
f"({', '.join(local_components)}) as well as "
f"components that are running remotely "
f"({', '.join(remote_components)}). This is not recommended as "
f"it can lead to unexpected behavior, especially if the remote "
f"components need to access the local resources. Please make "
f"sure that your stack is configured correctly, or try to use "
f"component flavors or configurations that do not require "
f"local resources."
)
if not stack.is_valid:
raise ValidationError(
"Stack configuration is invalid. A valid"
"stack must contain an Artifact Store and "
"an Orchestrator."
)
def register_stack(self, stack: "StackModel") -> "StackModel":
"""Registers a stack and its components.
Args:
stack: The stack to register.
Returns:
The model of the registered stack.
"""
self._validate_stack_configuration(stack=stack)
created_stack = self.zen_store.create_stack(
stack=stack,
)
return created_stack
def update_stack(self, stack: "StackModel") -> None:
"""Updates a stack and its components.
Args:
stack: The new stack to use as the updated version.
"""
self._validate_stack_configuration(stack=stack)
self.zen_store.update_stack(stack=stack)
def deregister_stack(self, stack: "StackModel") -> None:
"""Deregisters a stack.
Args:
stack: The model of the stack to deregister.
Raises:
ValueError: If the stack is the currently active stack for this
client.
"""
if stack.id == self.active_stack_model.id:
raise ValueError(
f"Unable to deregister active stack "
f"'{stack.name}'. Make "
f"sure to designate a new active stack before deleting this "
f"one."
)
try:
self.zen_store.delete_stack(stack_id=stack.id)
logger.info("Deregistered stack with name '%s'.", stack.name)
except KeyError:
logger.warning(
"Unable to deregister stack with name '%s': No stack "
"with this name could be found.",
stack.name,
)
# .------------.
# | COMPONENTS |
# '------------'
def _validate_stack_component_configuration(
self,
component_type: "StackComponentType",
configuration: "StackComponentConfig",
) -> None:
"""Validates the configuration of a stack component.
Args:
component_type: The type of the component.
configuration: The component configuration to validate.
"""
from zenml.enums import StackComponentType, StoreType
if configuration.is_remote and self.zen_store.is_local_store():
if self.zen_store.type == StoreType.REST:
logger.warning(
"You are configuring a stack component that is running "
"remotely while using a local database. The component "
"may not be able to reach the local database and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
else:
logger.warning(
"You are configuring a stack component that is running "
"remotely while using a local ZenML server. The component "
"may not be able to reach the local ZenML server and will "
"therefore not be functional. Please consider deploying "
"and/or using a remote ZenML server instead."
)
elif configuration.is_local and not self.zen_store.is_local_store():
logger.warning(
"You are configuring a stack component that is using "
"local resources while connected to a remote ZenML server. The "
"stack component may not be usable from other hosts or by "
"other users. You should consider using a non-local stack "
"component alternative instead."
)
if component_type in [
StackComponentType.ORCHESTRATOR,
StackComponentType.STEP_OPERATOR,
]:
logger.warning(
"You are configuring a stack component that is running "
"pipeline code on your local host while connected to a "
"remote ZenML server. This will significantly affect the "
"performance of your pipelines. You will likely encounter "
"long running times caused by network latency. You should "
"consider using a non-local stack component alternative "
"instead."
)
def register_stack_component(
self,
component: "ComponentModel",
) -> "ComponentModel":
"""Registers a stack component.
Args:
component: The component to register.
Returns:
The model of the registered component.
"""
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor, component_type=component.type
)
# Create and validate the configuration
from zenml.stack import Flavor
flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
# Update the configuration in the model
component.configuration = configuration.dict()
self._validate_stack_component_configuration(
component.type, configuration=configuration
)
# Register the new model
return self.zen_store.create_stack_component(component=component)
def update_stack_component(
self,
component: "ComponentModel",
) -> "ComponentModel":
"""Updates a stack component.
Args:
component: The new component to update with.
Returns:
The updated component.
"""
# Get the existing component model
existing_component_model = self.get_stack_component_by_id(
component.id,
)
# Get the flavor model of the existing component
flavor_model = self.get_flavor_by_name_and_type(
name=existing_component_model.flavor,
component_type=existing_component_model.type,
)
# Use the flavor class to validate the new configuration
from zenml.stack import Flavor
flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
# Update the configuration in the model
component.configuration = configuration.dict()
self._validate_stack_component_configuration(
component.type, configuration=configuration
)
# Send the updated component to the ZenStore
return self.zen_store.update_stack_component(component=component)
def deregister_stack_component(self, component: "ComponentModel") -> None:
"""Deletes a registered stack component.
Args:
component: The model of the component to delete.
"""
try:
self.zen_store.delete_stack_component(component_id=component.id)
logger.info(
"Deregistered stack component (type: %s) with name '%s'.",
component.type,
component.name,
)
except KeyError:
logger.warning(
"Unable to deregister stack component (type: %s) with name "
"'%s': No stack component with this name could be found.",
component.type,
component.name,
)
def get_stack_component_by_id(self, component_id: UUID) -> "ComponentModel":
"""Fetches a registered stack component.
Args:
component_id: The id of the component to fetch.
Returns:
The registered stack component.
"""
logger.debug(
"Fetching stack component with id '%s'.",
id,
)
return self.zen_store.get_stack_component(component_id=component_id)
def list_stack_components_by_type(
self, type_: "StackComponentType", is_shared: bool = False
) -> List["ComponentModel"]:
"""Fetches all registered stack components of a given type.
Args:
type_: The type of the components to fetch.
is_shared: Whether to fetch shared components or not.
Returns:
The registered stack components.
"""
owned_stack_components = self.zen_store.list_stack_components(
project_name_or_id=self.active_project_name,
user_name_or_id=self.active_user.id,
type=type_,
is_shared=False,
)
shared_stack_components = self.zen_store.list_stack_components(
project_name_or_id=self.active_project_name,
is_shared=True,
type=type_,
)
return owned_stack_components + shared_stack_components
# .---------.
# | FLAVORS |
# '---------'
@property
def flavors(self) -> List["FlavorModel"]:
"""Fetches all the flavor models.
Returns:
The list of flavor models.
"""
return self.get_flavors()
def create_flavor(self, flavor: "FlavorModel") -> "FlavorModel":
"""Creates a new flavor.
Args:
flavor: The flavor to create.
Returns:
The created flavor (in model form).
"""
from zenml.utils.source_utils import validate_flavor_source
flavor_class = validate_flavor_source(
source=flavor.source,
component_type=flavor.type,
)
flavor_model = flavor_class().to_model()
flavor_model.project = self.active_project.id
flavor_model.user = self.active_user.id
flavor_model.name = flavor_class().name
flavor_model.config_schema = flavor_class().config_schema
return self.zen_store.create_flavor(flavor=flavor_model)
def delete_flavor(self, flavor: "FlavorModel") -> None:
"""Deletes a flavor.
Args:
flavor: The flavor to delete.
"""
try:
self.zen_store.delete_flavor(flavor_id=flavor.id)
logger.info(
f"Deleted flavor '{flavor.name}' of type '{flavor.type}'.",
)
except KeyError:
logger.warning(
f"Unable to delete flavor '{flavor.name}' of type "
f"'{flavor.type}': No flavor with this name could be found.",
)
def get_flavors(self) -> List["FlavorModel"]:
"""Fetches all the flavor models.
Returns:
A list of all the flavor models.
"""
from zenml.stack.flavor_registry import flavor_registry
zenml_flavors = flavor_registry.flavors
custom_flavors = self.zen_store.list_flavors(
user_name_or_id=self.active_user.id,
project_name_or_id=self.active_project.id,
)
return zenml_flavors + custom_flavors
def get_flavors_by_type(
self, component_type: "StackComponentType"
) -> List["FlavorModel"]:
"""Fetches the list of flavor for a stack component type.
Args:
component_type: The type of the component to fetch.
Returns:
The list of flavors.
"""
logger.debug(f"Fetching the flavors of type {component_type}.")
from zenml.stack.flavor_registry import flavor_registry
zenml_flavors = flavor_registry.get_flavors_by_type(
component_type=component_type
)
custom_flavors = self.zen_store.list_flavors(
project_name_or_id=self.active_project.id,
component_type=component_type,
)
return zenml_flavors + custom_flavors
def get_flavor_by_name_and_type(
self, name: str, component_type: "StackComponentType"
) -> "FlavorModel":
"""Fetches a registered flavor.
Args:
component_type: The type of the component to fetch.
name: The name of the flavor to fetch.
Returns:
The registered flavor.
Raises:
KeyError: If no flavor exists for the given type and name.
"""
logger.debug(
f"Fetching the flavor of type {component_type} with name {name}."
)
from zenml.stack.flavor_registry import flavor_registry
try:
zenml_flavor = flavor_registry.get_flavor_by_name_and_type(
component_type=component_type,
name=name,
)
except KeyError:
zenml_flavor = None
custom_flavors = self.zen_store.list_flavors(
project_name_or_id=self.active_project.id,
component_type=component_type,
name=name,
)
if custom_flavors:
if len(custom_flavors) > 1:
raise KeyError(
f"More than one flavor with name {name} and type "
f"{component_type} exists."
)
if zenml_flavor:
# If there is one, check whether the same flavor exists as
# a ZenML flavor to give out a warning
logger.warning(
f"There is a custom implementation for the flavor "
f"'{name}' of a {component_type}, which is currently "
f"overwriting the same flavor provided by ZenML."
)
return custom_flavors[0]
else:
if zenml_flavor:
return zenml_flavor
else:
raise KeyError(
f"No flavor with name '{name}' and type '{component_type}' "
"exists."
)
# .------------------.
# | Pipelines & Runs |
# '------------------'
def get_pipeline_by_name(self, name: str) -> "PipelineModel":
"""Fetches a pipeline by name.
Args:
name: The name of the pipeline to fetch.
Returns:
The pipeline model.
"""
return self.zen_store.get_pipeline_in_project(
pipeline_name=name, project_name_or_id=self.active_project_name
)
def register_pipeline(
self,
pipeline_name: str,
pipeline_spec: "PipelineSpec",
pipeline_docstring: Optional[str],
) -> UUID:
"""Registers a pipeline in the ZenStore within the active project.
This will do one of the following three things:
A) If there is no pipeline with this name, register a new pipeline.
B) If a pipeline exists that has the same config, use that pipeline.
C) If a pipeline with different config exists, raise an error.
Args:
pipeline_name: The name of the pipeline to register.
pipeline_spec: The spec of the pipeline.
pipeline_docstring: The docstring of the pipeline.
Returns:
The id of the existing or newly registered pipeline.
Raises:
AlreadyExistsException: If there is an existing pipeline in the
project with the same name but a different configuration.
"""
try:
existing_pipeline = self.get_pipeline_by_name(pipeline_name)
# A) If there is no pipeline with this name, register a new pipeline.
except KeyError:
from zenml.models import PipelineModel
pipeline = PipelineModel(
project=self.active_project.id,
user=self.active_user.id,
name=pipeline_name,
spec=pipeline_spec,
docstring=pipeline_docstring,
)
pipeline = self.zen_store.create_pipeline(pipeline=pipeline)
logger.info(f"Registered new pipeline with name {pipeline.name}.")
return pipeline.id
# B) If a pipeline exists that has the same config, use that pipeline.
if pipeline_spec == existing_pipeline.spec:
logger.debug("Did not register pipeline since it already exists.")
return existing_pipeline.id
# C) If a pipeline with different config exists, raise an error.
error_msg = (
f"Cannot run pipeline '{pipeline_name}' since this name has "
"already been registered with a different pipeline "
"configuration. You have three options to resolve this issue:\n"
"1) You can register a new pipeline by changing the name "
"of your pipeline, e.g., via `@pipeline(name='new_pipeline_name')."
"\n2) You can execute the current run without linking it to any "
"pipeline by setting the 'unlisted' argument to `True`, e.g., "
"via `my_pipeline_instance.run(unlisted=True)`. "
"Unlisted runs are not linked to any pipeline, but are still "
"tracked by ZenML and can be accessed via the 'All Runs' tab. \n"
"3) You can delete the existing pipeline via "
f"`zenml pipeline delete {pipeline_name}`. This will then "
"change all existing runs of this pipeline to become unlisted."
)
raise AlreadyExistsException(error_msg)
def export_pipeline_runs(self, filename: str) -> None:
"""Export all pipeline runs to a YAML file.
Args:
filename: The filename to export the pipeline runs to.
"""
import json
from zenml.utils.yaml_utils import write_yaml
pipeline_runs = self.zen_store.list_runs(
project_name_or_id=self.active_project.id
)
if not pipeline_runs:
logger.warning("No pipeline runs found. Nothing to export.")
return
yaml_data = []
for pipeline_run in pipeline_runs:
run_dict = json.loads(pipeline_run.json())
run_dict["steps"] = []
steps = self.zen_store.list_run_steps(run_id=pipeline_run.id)
for step in steps:
step_dict = json.loads(step.json())
step_dict["output_artifacts"] = []
artifacts = self.zen_store.list_artifacts(
parent_step_id=step.id
)
for artifact in sorted(artifacts, key=lambda x: x.created):
artifact_dict = json.loads(artifact.json())
step_dict["output_artifacts"].append(artifact_dict)
run_dict["steps"].append(step_dict)
yaml_data.append(run_dict)
write_yaml(filename, yaml_data)
logger.info(f"Exported {len(yaml_data)} pipeline runs to {filename}.")
def import_pipeline_runs(self, filename: str) -> None:
"""Import pipeline runs from a YAML file.
Args:
filename: The filename from which to import the pipeline runs.
"""
from datetime import datetime
from zenml.models.pipeline_models import (
ArtifactModel,
PipelineRunModel,
StepRunModel,
)
from zenml.utils.yaml_utils import read_yaml
step_id_mapping: Dict[str, UUID] = {}
artifact_id_mapping: Dict[str, UUID] = {}
yaml_data = read_yaml(filename)
for pipeline_run_dict in yaml_data:
steps = pipeline_run_dict.pop("steps")
pipeline_run_dict.pop("id")
pipeline_run = PipelineRunModel.parse_obj(pipeline_run_dict)
pipeline_run.updated = datetime.now()
pipeline_run.user = self.active_user.id
pipeline_run.project = self.active_project.id
pipeline_run.stack_id = None
pipeline_run.pipeline_id = None
pipeline_run.mlmd_id = None
pipeline_run = self.zen_store.create_run(pipeline_run)
for step_dict in steps:
artifacts = step_dict.pop("output_artifacts")
step_id = step_dict.pop("id")
step = StepRunModel.parse_obj(step_dict)
step.pipeline_run_id = pipeline_run.id
step.parent_step_ids = [
step_id_mapping[str(parent_step_id)]
for parent_step_id in step.parent_step_ids
]
step.input_artifacts = {
input_name: artifact_id_mapping[str(artifact_id)]
for input_name, artifact_id in step.input_artifacts.items()
}
step.updated = datetime.now()
step.mlmd_id = None
step.mlmd_parent_step_ids = []
step = self.zen_store.create_run_step(step)
step_id_mapping[str(step_id)] = step.id
for artifact_dict in artifacts:
artifact_id = artifact_dict.pop("id")
artifact = ArtifactModel.parse_obj(artifact_dict)
artifact.parent_step_id = step.id
artifact.producer_step_id = step_id_mapping[
str(artifact.producer_step_id)
]
artifact.updated = datetime.now()
artifact.mlmd_id = None
artifact.mlmd_parent_step_id = None
artifact.mlmd_producer_step_id = None
self.zen_store.create_artifact(artifact)
artifact_id_mapping[str(artifact_id)] = artifact.id
logger.info(f"Imported {len(yaml_data)} pipeline runs from {filename}.")
def migrate_pipeline_runs(
self,
database: str,
database_type: str = "sqlite",
mysql_host: Optional[str] = None,
mysql_port: int = 3306,
mysql_username: Optional[str] = None,
mysql_password: Optional[str] = None,
) -> None:
"""Migrate pipeline runs from a metadata store of ZenML < 0.20.0.
Args:
database: The metadata store database from which to migrate the
pipeline runs. Either a path to a SQLite database or a database
name for a MySQL database.
database_type: The type of the metadata store database
("sqlite" | "mysql"). Defaults to "sqlite".
mysql_host: The host of the MySQL database.
mysql_port: The port of the MySQL database. Defaults to 3306.
mysql_username: The username of the MySQL database.
mysql_password: The password of the MySQL database.
Raises:
NotImplementedError: If the database type is not supported.
RuntimeError: If no pipeline runs exist.
ValueError: If the database type is "mysql" but the MySQL host,
username or password are not provided.
"""
from tfx.orchestration import metadata
from zenml.enums import ExecutionStatus
from zenml.models.pipeline_models import (
ArtifactModel,
PipelineRunModel,
StepRunModel,
)
from zenml.zen_stores.metadata_store import MetadataStore
# Define MLMD connection config based on the database type.
if database_type == "sqlite":
mlmd_config = metadata.sqlite_metadata_connection_config(database)
elif database_type == "mysql":
if not mysql_host or not mysql_username or mysql_password is None:
raise ValueError(
"Migration from MySQL requires username, password and host "
"to be set."
)
mlmd_config = metadata.mysql_metadata_connection_config(
database=database,
host=mysql_host,
port=mysql_port,
username=mysql_username,
password=mysql_password,
)
else:
raise NotImplementedError(
"Migrating pipeline runs is only supported for SQLite and MySQL."
)
metadata_store = MetadataStore(config=mlmd_config)
# Dicts to keep tracks of MLMD IDs, which we need to resolve later.
step_mlmd_id_mapping: Dict[int, UUID] = {}
artifact_mlmd_id_mapping: Dict[int, UUID] = {}
# Get all pipeline runs from the metadata store.
pipeline_runs = metadata_store.get_all_runs()
if not pipeline_runs:
raise RuntimeError("No pipeline runs found in the metadata store.")
# For each run, first store the pipeline run, then all steps, then all
# output artifacts of each step.
# Runs, steps, and artifacts need to be sorted chronologically ensure
# that the MLMD IDs of producer steps and parent steps can be resolved.
for mlmd_run in sorted(pipeline_runs, key=lambda x: x.mlmd_id):
steps = metadata_store.get_pipeline_run_steps(
mlmd_run.mlmd_id
).values()
# Mark all steps that haven't finished yet as failed.
step_statuses = []
for step in steps:
status = metadata_store.get_step_status(step.mlmd_id)
if status == ExecutionStatus.RUNNING:
status = ExecutionStatus.FAILED
step_statuses.append(status)
pipeline_run = PipelineRunModel(
user=self.active_user.id, # Old user might not exist.
project=self.active_project.id, # Old project might not exist.
name=mlmd_run.name,
stack_id=None, # Stack might not exist in new DB.
pipeline_id=None, # Pipeline might not exist in new DB.
status=ExecutionStatus.run_status(step_statuses),
pipeline_configuration=mlmd_run.pipeline_configuration,
num_steps=len(steps),
mlmd_id=None, # Run might not exist in new MLMD.
)
new_run = self.zen_store.create_run(pipeline_run)
for step, step_status in sorted(
zip(steps, step_statuses), key=lambda x: x[0].mlmd_id
):
parent_step_ids = [
step_mlmd_id_mapping[mlmd_parent_step_id]
for mlmd_parent_step_id in step.mlmd_parent_step_ids
]
inputs = metadata_store.get_step_input_artifacts(
step_id=step.mlmd_id,
step_parent_step_ids=step.mlmd_parent_step_ids,
)
outputs = metadata_store.get_step_output_artifacts(
step_id=step.mlmd_id
)
input_artifacts = {
input_name: artifact_mlmd_id_mapping[mlmd_artifact.mlmd_id]
for input_name, mlmd_artifact in inputs.items()
}
step_run = StepRunModel(
name=step.name,
pipeline_run_id=new_run.id,
parent_step_ids=parent_step_ids,
input_artifacts=input_artifacts,
status=step_status,
entrypoint_name=step.entrypoint_name,
parameters=step.parameters,
step_configuration={},
mlmd_parent_step_ids=[],
num_outputs=len(outputs),
)
new_step = self.zen_store.create_run_step(step_run)
step_mlmd_id_mapping[step.mlmd_id] = new_step.id
for output_name, mlmd_artifact in sorted(
outputs.items(), key=lambda x: x[1].mlmd_id
):
producer_step_id = step_mlmd_id_mapping[
mlmd_artifact.mlmd_producer_step_id
]
artifact = ArtifactModel(
name=output_name,
parent_step_id=new_step.id,
producer_step_id=producer_step_id,
type=mlmd_artifact.type,
uri=mlmd_artifact.uri,
materializer=mlmd_artifact.materializer,
data_type=mlmd_artifact.data_type,
is_cached=mlmd_artifact.is_cached,
)
new_artifact = self.zen_store.create_artifact(artifact)
artifact_mlmd_id_mapping[
mlmd_artifact.mlmd_id
] = new_artifact.id
logger.info(f"Migrated {len(pipeline_runs)} pipeline runs.")
def delete_user(self, user_name_or_id: str) -> None:
"""Delete a user.
Args:
user_name_or_id: The name or ID of the user to delete.
"""
Client().zen_store.delete_user(user_name_or_id=user_name_or_id)
def delete_project(self, project_name_or_id: str) -> None:
"""Delete a project.
Args:
project_name_or_id: The name or ID of the project to delete.
Raises:
IllegalOperationError: If the project to delete is the active
project.
"""
project = self.zen_store.get_project(project_name_or_id)
if self.active_project_name == project.name:
raise IllegalOperationError(
f"Project '{project_name_or_id}' cannot be deleted since it is "
"currently active. Please set another project as active first."
)
Client().zen_store.delete_project(project_name_or_id=project_name_or_id)
active_project: ProjectModel
property
readonly
Get the currently active project of the local client.
If no active project is configured locally for the client, the active project in the global configuration is used instead.
Returns:
Type | Description |
---|---|
ProjectModel |
The active project. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the active project is not set. |
active_project_name: str
property
readonly
The name of the active project for this client.
If no active project is configured locally for the client, the active project in the global configuration is used instead.
Returns:
Type | Description |
---|---|
str |
The name of the active project. |
active_stack: Stack
property
readonly
The active stack for this client.
Returns:
Type | Description |
---|---|
Stack |
The active stack for this client. |
active_stack_model: HydratedStackModel
property
readonly
The model of the active stack for this client.
If no active stack is configured locally for the client, the active stack in the global configuration is used instead.
Returns:
Type | Description |
---|---|
HydratedStackModel |
The model of the active stack for this client. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the active stack is not set. |
active_user: UserModel
property
readonly
Get the user that is currently in use.
Returns:
Type | Description |
---|---|
UserModel |
The active user. |
config_directory: Optional[pathlib.Path]
property
readonly
The configuration directory of this client.
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
The configuration directory of this client, or None, if the client doesn't have an active root. |
flavors: List[FlavorModel]
property
readonly
Fetches all the flavor models.
Returns:
Type | Description |
---|---|
List[FlavorModel] |
The list of flavor models. |
root: Optional[pathlib.Path]
property
readonly
The root directory of this client.
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
The root directory of this client, or None, if the client has not been initialized. |
stacks: List[HydratedStackModel]
property
readonly
All stack models in the active project, owned by the user or shared.
This property is intended as a quick way to get information about the components of the registered stacks without loading all installed integrations.
Returns:
Type | Description |
---|---|
List[HydratedStackModel] |
A list of all stacks available in the current project and owned by the current user. |
uses_local_configuration: bool
property
readonly
Check if the client is using a local configuration.
Returns:
Type | Description |
---|---|
bool |
True if the client is using a local configuration, False otherwise. |
zen_store: BaseZenStore
property
readonly
Shortcut to return the global zen store.
Returns:
Type | Description |
---|---|
BaseZenStore |
The global zen store. |
__init__(self, root=None)
special
Initializes the global client instance.
Client is a singleton class: only one instance can exist. Calling this constructor multiple times will always yield the same instance (see the exception below).
The root
argument is only meant for internal use and testing purposes.
User code must never pass them to the constructor.
When a custom root
value is passed, an anonymous Client instance
is created and returned independently of the Client singleton and
that will have no effect as far as the rest of the ZenML core code is
concerned.
Instead of creating a new Client instance to reflect a different
repository root, to change the active root in the global Client,
call Client().activate_root(<new-root>)
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root |
Optional[pathlib.Path] |
(internal use) custom root directory for the client. If
no path is given, the repository root is determined using the
environment variable |
None |
Source code in zenml/client.py
def __init__(
self,
root: Optional[Path] = None,
) -> None:
"""Initializes the global client instance.
Client is a singleton class: only one instance can exist. Calling
this constructor multiple times will always yield the same instance (see
the exception below).
The `root` argument is only meant for internal use and testing purposes.
User code must never pass them to the constructor.
When a custom `root` value is passed, an anonymous Client instance
is created and returned independently of the Client singleton and
that will have no effect as far as the rest of the ZenML core code is
concerned.
Instead of creating a new Client instance to reflect a different
repository root, to change the active root in the global Client,
call `Client().activate_root(<new-root>)`.
Args:
root: (internal use) custom root directory for the client. If
no path is given, the repository root is determined using the
environment variable `ZENML_REPOSITORY_PATH` (if set) and by
recursively searching in the parent directories of the
current working directory. Only used to initialize new
clients internally.
"""
self._root: Optional[Path] = None
self._config: Optional[ClientConfiguration] = None
self._set_active_root(root)
activate_root(self, root=None)
Set the active repository root directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root |
Optional[pathlib.Path] |
The path to set as the active repository root. If not set,
the repository root is determined using the environment
variable |
None |
Source code in zenml/client.py
def activate_root(self, root: Optional[Path] = None) -> None:
"""Set the active repository root directory.
Args:
root: The path to set as the active repository root. If not set,
the repository root is determined using the environment
variable `ZENML_REPOSITORY_PATH` (if set) and by recursively
searching in the parent directories of the current working
directory.
"""
self._set_active_root(root)
activate_stack(*args, **kwargs)
Inner decorator function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Arguments to be passed to the function. |
() |
**kwargs |
Any |
Keyword arguments to be passed to the function. |
{} |
Returns:
Type | Description |
---|---|
Any |
Result of the function. |
Source code in zenml/client.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function.
Args:
*args: Arguments to be passed to the function.
**kwargs: Keyword arguments to be passed to the function.
Returns:
Result of the function.
"""
result = func(*args, **kwargs)
try:
tracker: Optional[AnalyticsTrackerMixin] = None
if len(args) and isinstance(args[0], AnalyticsTrackerMixin):
tracker = args[0]
for obj in [result] + list(args) + list(kwargs.values()):
if isinstance(obj, AnalyticsTrackedModelMixin):
obj.track_event(event_name, tracker=tracker)
break
else:
if tracker:
tracker.track_event(event_name, metadata)
else:
track_event(event_name, metadata)
except Exception as e:
logger.debug(f"Analytics tracking failure for {func}: {e}")
return result
create_flavor(self, flavor)
Creates a new flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
FlavorModel |
The flavor to create. |
required |
Returns:
Type | Description |
---|---|
FlavorModel |
The created flavor (in model form). |
Source code in zenml/client.py
def create_flavor(self, flavor: "FlavorModel") -> "FlavorModel":
"""Creates a new flavor.
Args:
flavor: The flavor to create.
Returns:
The created flavor (in model form).
"""
from zenml.utils.source_utils import validate_flavor_source
flavor_class = validate_flavor_source(
source=flavor.source,
component_type=flavor.type,
)
flavor_model = flavor_class().to_model()
flavor_model.project = self.active_project.id
flavor_model.user = self.active_user.id
flavor_model.name = flavor_class().name
flavor_model.config_schema = flavor_class().config_schema
return self.zen_store.create_flavor(flavor=flavor_model)
delete_flavor(self, flavor)
Deletes a flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
FlavorModel |
The flavor to delete. |
required |
Source code in zenml/client.py
def delete_flavor(self, flavor: "FlavorModel") -> None:
"""Deletes a flavor.
Args:
flavor: The flavor to delete.
"""
try:
self.zen_store.delete_flavor(flavor_id=flavor.id)
logger.info(
f"Deleted flavor '{flavor.name}' of type '{flavor.type}'.",
)
except KeyError:
logger.warning(
f"Unable to delete flavor '{flavor.name}' of type "
f"'{flavor.type}': No flavor with this name could be found.",
)
delete_project(self, project_name_or_id)
Delete a project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project_name_or_id |
str |
The name or ID of the project to delete. |
required |
Exceptions:
Type | Description |
---|---|
IllegalOperationError |
If the project to delete is the active project. |
Source code in zenml/client.py
def delete_project(self, project_name_or_id: str) -> None:
"""Delete a project.
Args:
project_name_or_id: The name or ID of the project to delete.
Raises:
IllegalOperationError: If the project to delete is the active
project.
"""
project = self.zen_store.get_project(project_name_or_id)
if self.active_project_name == project.name:
raise IllegalOperationError(
f"Project '{project_name_or_id}' cannot be deleted since it is "
"currently active. Please set another project as active first."
)
Client().zen_store.delete_project(project_name_or_id=project_name_or_id)
delete_user(self, user_name_or_id)
Delete a user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_name_or_id |
str |
The name or ID of the user to delete. |
required |
Source code in zenml/client.py
def delete_user(self, user_name_or_id: str) -> None:
"""Delete a user.
Args:
user_name_or_id: The name or ID of the user to delete.
"""
Client().zen_store.delete_user(user_name_or_id=user_name_or_id)
deregister_stack(self, stack)
Deregisters a stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack |
StackModel |
The model of the stack to deregister. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the stack is the currently active stack for this client. |
Source code in zenml/client.py
def deregister_stack(self, stack: "StackModel") -> None:
"""Deregisters a stack.
Args:
stack: The model of the stack to deregister.
Raises:
ValueError: If the stack is the currently active stack for this
client.
"""
if stack.id == self.active_stack_model.id:
raise ValueError(
f"Unable to deregister active stack "
f"'{stack.name}'. Make "
f"sure to designate a new active stack before deleting this "
f"one."
)
try:
self.zen_store.delete_stack(stack_id=stack.id)
logger.info("Deregistered stack with name '%s'.", stack.name)
except KeyError:
logger.warning(
"Unable to deregister stack with name '%s': No stack "
"with this name could be found.",
stack.name,
)
deregister_stack_component(self, component)
Deletes a registered stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
ComponentModel |
The model of the component to delete. |
required |
Source code in zenml/client.py
def deregister_stack_component(self, component: "ComponentModel") -> None:
"""Deletes a registered stack component.
Args:
component: The model of the component to delete.
"""
try:
self.zen_store.delete_stack_component(component_id=component.id)
logger.info(
"Deregistered stack component (type: %s) with name '%s'.",
component.type,
component.name,
)
except KeyError:
logger.warning(
"Unable to deregister stack component (type: %s) with name "
"'%s': No stack component with this name could be found.",
component.type,
component.name,
)
export_pipeline_runs(self, filename)
Export all pipeline runs to a YAML file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filename |
str |
The filename to export the pipeline runs to. |
required |
Source code in zenml/client.py
def export_pipeline_runs(self, filename: str) -> None:
"""Export all pipeline runs to a YAML file.
Args:
filename: The filename to export the pipeline runs to.
"""
import json
from zenml.utils.yaml_utils import write_yaml
pipeline_runs = self.zen_store.list_runs(
project_name_or_id=self.active_project.id
)
if not pipeline_runs:
logger.warning("No pipeline runs found. Nothing to export.")
return
yaml_data = []
for pipeline_run in pipeline_runs:
run_dict = json.loads(pipeline_run.json())
run_dict["steps"] = []
steps = self.zen_store.list_run_steps(run_id=pipeline_run.id)
for step in steps:
step_dict = json.loads(step.json())
step_dict["output_artifacts"] = []
artifacts = self.zen_store.list_artifacts(
parent_step_id=step.id
)
for artifact in sorted(artifacts, key=lambda x: x.created):
artifact_dict = json.loads(artifact.json())
step_dict["output_artifacts"].append(artifact_dict)
run_dict["steps"].append(step_dict)
yaml_data.append(run_dict)
write_yaml(filename, yaml_data)
logger.info(f"Exported {len(yaml_data)} pipeline runs to {filename}.")
find_repository(path=None, enable_warnings=False)
staticmethod
Search for a ZenML repository directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Optional[pathlib.Path] |
Optional path to look for the repository. If no path is
given, this function tries to find the repository using the
environment variable |
None |
enable_warnings |
bool |
If |
False |
Returns:
Type | Description |
---|---|
Optional[pathlib.Path] |
Absolute path to a ZenML repository directory or None if no repository directory was found. |
Source code in zenml/client.py
@staticmethod
def find_repository(
path: Optional[Path] = None, enable_warnings: bool = False
) -> Optional[Path]:
"""Search for a ZenML repository directory.
Args:
path: Optional path to look for the repository. If no path is
given, this function tries to find the repository using the
environment variable `ZENML_REPOSITORY_PATH` (if set) and
recursively searching in the parent directories of the current
working directory.
enable_warnings: If `True`, warnings are printed if the repository
root cannot be found.
Returns:
Absolute path to a ZenML repository directory or None if no
repository directory was found.
"""
if not path:
# try to get path from the environment variable
env_var_path = os.getenv(ENV_ZENML_REPOSITORY_PATH)
if env_var_path:
path = Path(env_var_path)
if path:
# explicit path via parameter or environment variable, don't search
# parent directories
search_parent_directories = False
warning_message = (
f"Unable to find ZenML repository at path '{path}'. Make sure "
f"to create a ZenML repository by calling `zenml init` when "
f"specifying an explicit repository path in code or via the "
f"environment variable '{ENV_ZENML_REPOSITORY_PATH}'."
)
else:
# try to find the repository in the parent directories of the
# current working directory
path = Path.cwd()
search_parent_directories = True
warning_message = (
f"Unable to find ZenML repository in your current working "
f"directory ({path}) or any parent directories. If you "
f"want to use an existing repository which is in a different "
f"location, set the environment variable "
f"'{ENV_ZENML_REPOSITORY_PATH}'. If you want to create a new "
f"repository, run `zenml init`."
)
def _find_repository_helper(path_: Path) -> Optional[Path]:
"""Recursively search parent directories for a ZenML repository.
Args:
path_: The path to search.
Returns:
Absolute path to a ZenML repository directory or None if no
repository directory was found.
"""
if Client.is_repository_directory(path_):
return path_
if not search_parent_directories or io_utils.is_root(str(path_)):
return None
return _find_repository_helper(path_.parent)
repository_path = _find_repository_helper(path)
if repository_path:
return repository_path.resolve()
if enable_warnings:
logger.warning(warning_message)
return None
get_flavor_by_name_and_type(self, name, component_type)
Fetches a registered flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the component to fetch. |
required |
name |
str |
The name of the flavor to fetch. |
required |
Returns:
Type | Description |
---|---|
FlavorModel |
The registered flavor. |
Exceptions:
Type | Description |
---|---|
KeyError |
If no flavor exists for the given type and name. |
Source code in zenml/client.py
def get_flavor_by_name_and_type(
self, name: str, component_type: "StackComponentType"
) -> "FlavorModel":
"""Fetches a registered flavor.
Args:
component_type: The type of the component to fetch.
name: The name of the flavor to fetch.
Returns:
The registered flavor.
Raises:
KeyError: If no flavor exists for the given type and name.
"""
logger.debug(
f"Fetching the flavor of type {component_type} with name {name}."
)
from zenml.stack.flavor_registry import flavor_registry
try:
zenml_flavor = flavor_registry.get_flavor_by_name_and_type(
component_type=component_type,
name=name,
)
except KeyError:
zenml_flavor = None
custom_flavors = self.zen_store.list_flavors(
project_name_or_id=self.active_project.id,
component_type=component_type,
name=name,
)
if custom_flavors:
if len(custom_flavors) > 1:
raise KeyError(
f"More than one flavor with name {name} and type "
f"{component_type} exists."
)
if zenml_flavor:
# If there is one, check whether the same flavor exists as
# a ZenML flavor to give out a warning
logger.warning(
f"There is a custom implementation for the flavor "
f"'{name}' of a {component_type}, which is currently "
f"overwriting the same flavor provided by ZenML."
)
return custom_flavors[0]
else:
if zenml_flavor:
return zenml_flavor
else:
raise KeyError(
f"No flavor with name '{name}' and type '{component_type}' "
"exists."
)
get_flavors(self)
Fetches all the flavor models.
Returns:
Type | Description |
---|---|
List[FlavorModel] |
A list of all the flavor models. |
Source code in zenml/client.py
def get_flavors(self) -> List["FlavorModel"]:
"""Fetches all the flavor models.
Returns:
A list of all the flavor models.
"""
from zenml.stack.flavor_registry import flavor_registry
zenml_flavors = flavor_registry.flavors
custom_flavors = self.zen_store.list_flavors(
user_name_or_id=self.active_user.id,
project_name_or_id=self.active_project.id,
)
return zenml_flavors + custom_flavors
get_flavors_by_type(self, component_type)
Fetches the list of flavor for a stack component type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_type |
StackComponentType |
The type of the component to fetch. |
required |
Returns:
Type | Description |
---|---|
List[FlavorModel] |
The list of flavors. |
Source code in zenml/client.py
def get_flavors_by_type(
self, component_type: "StackComponentType"
) -> List["FlavorModel"]:
"""Fetches the list of flavor for a stack component type.
Args:
component_type: The type of the component to fetch.
Returns:
The list of flavors.
"""
logger.debug(f"Fetching the flavors of type {component_type}.")
from zenml.stack.flavor_registry import flavor_registry
zenml_flavors = flavor_registry.get_flavors_by_type(
component_type=component_type
)
custom_flavors = self.zen_store.list_flavors(
project_name_or_id=self.active_project.id,
component_type=component_type,
)
return zenml_flavors + custom_flavors
get_instance()
classmethod
Return the Client singleton instance.
Returns:
Type | Description |
---|---|
Optional[Client] |
The Client singleton instance or None, if the Client hasn't been initialized yet. |
Source code in zenml/client.py
@classmethod
def get_instance(cls) -> Optional["Client"]:
"""Return the Client singleton instance.
Returns:
The Client singleton instance or None, if the Client hasn't
been initialized yet.
"""
return cls._global_client
get_pipeline_by_name(self, name)
Fetches a pipeline by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the pipeline to fetch. |
required |
Returns:
Type | Description |
---|---|
PipelineModel |
The pipeline model. |
Source code in zenml/client.py
def get_pipeline_by_name(self, name: str) -> "PipelineModel":
"""Fetches a pipeline by name.
Args:
name: The name of the pipeline to fetch.
Returns:
The pipeline model.
"""
return self.zen_store.get_pipeline_in_project(
pipeline_name=name, project_name_or_id=self.active_project_name
)
get_stack_component_by_id(self, component_id)
Fetches a registered stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component_id |
UUID |
The id of the component to fetch. |
required |
Returns:
Type | Description |
---|---|
ComponentModel |
The registered stack component. |
Source code in zenml/client.py
def get_stack_component_by_id(self, component_id: UUID) -> "ComponentModel":
"""Fetches a registered stack component.
Args:
component_id: The id of the component to fetch.
Returns:
The registered stack component.
"""
logger.debug(
"Fetching stack component with id '%s'.",
id,
)
return self.zen_store.get_stack_component(component_id=component_id)
import_pipeline_runs(self, filename)
Import pipeline runs from a YAML file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filename |
str |
The filename from which to import the pipeline runs. |
required |
Source code in zenml/client.py
def import_pipeline_runs(self, filename: str) -> None:
"""Import pipeline runs from a YAML file.
Args:
filename: The filename from which to import the pipeline runs.
"""
from datetime import datetime
from zenml.models.pipeline_models import (
ArtifactModel,
PipelineRunModel,
StepRunModel,
)
from zenml.utils.yaml_utils import read_yaml
step_id_mapping: Dict[str, UUID] = {}
artifact_id_mapping: Dict[str, UUID] = {}
yaml_data = read_yaml(filename)
for pipeline_run_dict in yaml_data:
steps = pipeline_run_dict.pop("steps")
pipeline_run_dict.pop("id")
pipeline_run = PipelineRunModel.parse_obj(pipeline_run_dict)
pipeline_run.updated = datetime.now()
pipeline_run.user = self.active_user.id
pipeline_run.project = self.active_project.id
pipeline_run.stack_id = None
pipeline_run.pipeline_id = None
pipeline_run.mlmd_id = None
pipeline_run = self.zen_store.create_run(pipeline_run)
for step_dict in steps:
artifacts = step_dict.pop("output_artifacts")
step_id = step_dict.pop("id")
step = StepRunModel.parse_obj(step_dict)
step.pipeline_run_id = pipeline_run.id
step.parent_step_ids = [
step_id_mapping[str(parent_step_id)]
for parent_step_id in step.parent_step_ids
]
step.input_artifacts = {
input_name: artifact_id_mapping[str(artifact_id)]
for input_name, artifact_id in step.input_artifacts.items()
}
step.updated = datetime.now()
step.mlmd_id = None
step.mlmd_parent_step_ids = []
step = self.zen_store.create_run_step(step)
step_id_mapping[str(step_id)] = step.id
for artifact_dict in artifacts:
artifact_id = artifact_dict.pop("id")
artifact = ArtifactModel.parse_obj(artifact_dict)
artifact.parent_step_id = step.id
artifact.producer_step_id = step_id_mapping[
str(artifact.producer_step_id)
]
artifact.updated = datetime.now()
artifact.mlmd_id = None
artifact.mlmd_parent_step_id = None
artifact.mlmd_producer_step_id = None
self.zen_store.create_artifact(artifact)
artifact_id_mapping[str(artifact_id)] = artifact.id
logger.info(f"Imported {len(yaml_data)} pipeline runs from {filename}.")
initialize(*args, **kwargs)
staticmethod
Inner decorator function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Arguments to be passed to the function. |
() |
**kwargs |
Any |
Keyword arguments to be passed to the function. |
{} |
Returns:
Type | Description |
---|---|
Any |
Result of the function. |
Source code in zenml/client.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function.
Args:
*args: Arguments to be passed to the function.
**kwargs: Keyword arguments to be passed to the function.
Returns:
Result of the function.
"""
result = func(*args, **kwargs)
try:
tracker: Optional[AnalyticsTrackerMixin] = None
if len(args) and isinstance(args[0], AnalyticsTrackerMixin):
tracker = args[0]
for obj in [result] + list(args) + list(kwargs.values()):
if isinstance(obj, AnalyticsTrackedModelMixin):
obj.track_event(event_name, tracker=tracker)
break
else:
if tracker:
tracker.track_event(event_name, metadata)
else:
track_event(event_name, metadata)
except Exception as e:
logger.debug(f"Analytics tracking failure for {func}: {e}")
return result
is_repository_directory(path)
staticmethod
Checks whether a ZenML client exists at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Path |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if a ZenML client exists at the given path, False otherwise. |
Source code in zenml/client.py
@staticmethod
def is_repository_directory(path: Path) -> bool:
"""Checks whether a ZenML client exists at the given path.
Args:
path: The path to check.
Returns:
True if a ZenML client exists at the given path,
False otherwise.
"""
config_dir = path / REPOSITORY_DIRECTORY_NAME
return fileio.isdir(str(config_dir))
list_stack_components_by_type(self, type_, is_shared=False)
Fetches all registered stack components of a given type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type_ |
StackComponentType |
The type of the components to fetch. |
required |
is_shared |
bool |
Whether to fetch shared components or not. |
False |
Returns:
Type | Description |
---|---|
List[ComponentModel] |
The registered stack components. |
Source code in zenml/client.py
def list_stack_components_by_type(
self, type_: "StackComponentType", is_shared: bool = False
) -> List["ComponentModel"]:
"""Fetches all registered stack components of a given type.
Args:
type_: The type of the components to fetch.
is_shared: Whether to fetch shared components or not.
Returns:
The registered stack components.
"""
owned_stack_components = self.zen_store.list_stack_components(
project_name_or_id=self.active_project_name,
user_name_or_id=self.active_user.id,
type=type_,
is_shared=False,
)
shared_stack_components = self.zen_store.list_stack_components(
project_name_or_id=self.active_project_name,
is_shared=True,
type=type_,
)
return owned_stack_components + shared_stack_components
migrate_pipeline_runs(self, database, database_type='sqlite', mysql_host=None, mysql_port=3306, mysql_username=None, mysql_password=None)
Migrate pipeline runs from a metadata store of ZenML < 0.20.0.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
database |
str |
The metadata store database from which to migrate the pipeline runs. Either a path to a SQLite database or a database name for a MySQL database. |
required |
database_type |
str |
The type of the metadata store database ("sqlite" | "mysql"). Defaults to "sqlite". |
'sqlite' |
mysql_host |
Optional[str] |
The host of the MySQL database. |
None |
mysql_port |
int |
The port of the MySQL database. Defaults to 3306. |
3306 |
mysql_username |
Optional[str] |
The username of the MySQL database. |
None |
mysql_password |
Optional[str] |
The password of the MySQL database. |
None |
Exceptions:
Type | Description |
---|---|
NotImplementedError |
If the database type is not supported. |
RuntimeError |
If no pipeline runs exist. |
ValueError |
If the database type is "mysql" but the MySQL host, username or password are not provided. |
Source code in zenml/client.py
def migrate_pipeline_runs(
self,
database: str,
database_type: str = "sqlite",
mysql_host: Optional[str] = None,
mysql_port: int = 3306,
mysql_username: Optional[str] = None,
mysql_password: Optional[str] = None,
) -> None:
"""Migrate pipeline runs from a metadata store of ZenML < 0.20.0.
Args:
database: The metadata store database from which to migrate the
pipeline runs. Either a path to a SQLite database or a database
name for a MySQL database.
database_type: The type of the metadata store database
("sqlite" | "mysql"). Defaults to "sqlite".
mysql_host: The host of the MySQL database.
mysql_port: The port of the MySQL database. Defaults to 3306.
mysql_username: The username of the MySQL database.
mysql_password: The password of the MySQL database.
Raises:
NotImplementedError: If the database type is not supported.
RuntimeError: If no pipeline runs exist.
ValueError: If the database type is "mysql" but the MySQL host,
username or password are not provided.
"""
from tfx.orchestration import metadata
from zenml.enums import ExecutionStatus
from zenml.models.pipeline_models import (
ArtifactModel,
PipelineRunModel,
StepRunModel,
)
from zenml.zen_stores.metadata_store import MetadataStore
# Define MLMD connection config based on the database type.
if database_type == "sqlite":
mlmd_config = metadata.sqlite_metadata_connection_config(database)
elif database_type == "mysql":
if not mysql_host or not mysql_username or mysql_password is None:
raise ValueError(
"Migration from MySQL requires username, password and host "
"to be set."
)
mlmd_config = metadata.mysql_metadata_connection_config(
database=database,
host=mysql_host,
port=mysql_port,
username=mysql_username,
password=mysql_password,
)
else:
raise NotImplementedError(
"Migrating pipeline runs is only supported for SQLite and MySQL."
)
metadata_store = MetadataStore(config=mlmd_config)
# Dicts to keep tracks of MLMD IDs, which we need to resolve later.
step_mlmd_id_mapping: Dict[int, UUID] = {}
artifact_mlmd_id_mapping: Dict[int, UUID] = {}
# Get all pipeline runs from the metadata store.
pipeline_runs = metadata_store.get_all_runs()
if not pipeline_runs:
raise RuntimeError("No pipeline runs found in the metadata store.")
# For each run, first store the pipeline run, then all steps, then all
# output artifacts of each step.
# Runs, steps, and artifacts need to be sorted chronologically ensure
# that the MLMD IDs of producer steps and parent steps can be resolved.
for mlmd_run in sorted(pipeline_runs, key=lambda x: x.mlmd_id):
steps = metadata_store.get_pipeline_run_steps(
mlmd_run.mlmd_id
).values()
# Mark all steps that haven't finished yet as failed.
step_statuses = []
for step in steps:
status = metadata_store.get_step_status(step.mlmd_id)
if status == ExecutionStatus.RUNNING:
status = ExecutionStatus.FAILED
step_statuses.append(status)
pipeline_run = PipelineRunModel(
user=self.active_user.id, # Old user might not exist.
project=self.active_project.id, # Old project might not exist.
name=mlmd_run.name,
stack_id=None, # Stack might not exist in new DB.
pipeline_id=None, # Pipeline might not exist in new DB.
status=ExecutionStatus.run_status(step_statuses),
pipeline_configuration=mlmd_run.pipeline_configuration,
num_steps=len(steps),
mlmd_id=None, # Run might not exist in new MLMD.
)
new_run = self.zen_store.create_run(pipeline_run)
for step, step_status in sorted(
zip(steps, step_statuses), key=lambda x: x[0].mlmd_id
):
parent_step_ids = [
step_mlmd_id_mapping[mlmd_parent_step_id]
for mlmd_parent_step_id in step.mlmd_parent_step_ids
]
inputs = metadata_store.get_step_input_artifacts(
step_id=step.mlmd_id,
step_parent_step_ids=step.mlmd_parent_step_ids,
)
outputs = metadata_store.get_step_output_artifacts(
step_id=step.mlmd_id
)
input_artifacts = {
input_name: artifact_mlmd_id_mapping[mlmd_artifact.mlmd_id]
for input_name, mlmd_artifact in inputs.items()
}
step_run = StepRunModel(
name=step.name,
pipeline_run_id=new_run.id,
parent_step_ids=parent_step_ids,
input_artifacts=input_artifacts,
status=step_status,
entrypoint_name=step.entrypoint_name,
parameters=step.parameters,
step_configuration={},
mlmd_parent_step_ids=[],
num_outputs=len(outputs),
)
new_step = self.zen_store.create_run_step(step_run)
step_mlmd_id_mapping[step.mlmd_id] = new_step.id
for output_name, mlmd_artifact in sorted(
outputs.items(), key=lambda x: x[1].mlmd_id
):
producer_step_id = step_mlmd_id_mapping[
mlmd_artifact.mlmd_producer_step_id
]
artifact = ArtifactModel(
name=output_name,
parent_step_id=new_step.id,
producer_step_id=producer_step_id,
type=mlmd_artifact.type,
uri=mlmd_artifact.uri,
materializer=mlmd_artifact.materializer,
data_type=mlmd_artifact.data_type,
is_cached=mlmd_artifact.is_cached,
)
new_artifact = self.zen_store.create_artifact(artifact)
artifact_mlmd_id_mapping[
mlmd_artifact.mlmd_id
] = new_artifact.id
logger.info(f"Migrated {len(pipeline_runs)} pipeline runs.")
register_pipeline(self, pipeline_name, pipeline_spec, pipeline_docstring)
Registers a pipeline in the ZenStore within the active project.
This will do one of the following three things: A) If there is no pipeline with this name, register a new pipeline. B) If a pipeline exists that has the same config, use that pipeline. C) If a pipeline with different config exists, raise an error.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
The name of the pipeline to register. |
required |
pipeline_spec |
PipelineSpec |
The spec of the pipeline. |
required |
pipeline_docstring |
Optional[str] |
The docstring of the pipeline. |
required |
Returns:
Type | Description |
---|---|
UUID |
The id of the existing or newly registered pipeline. |
Exceptions:
Type | Description |
---|---|
AlreadyExistsException |
If there is an existing pipeline in the project with the same name but a different configuration. |
Source code in zenml/client.py
def register_pipeline(
self,
pipeline_name: str,
pipeline_spec: "PipelineSpec",
pipeline_docstring: Optional[str],
) -> UUID:
"""Registers a pipeline in the ZenStore within the active project.
This will do one of the following three things:
A) If there is no pipeline with this name, register a new pipeline.
B) If a pipeline exists that has the same config, use that pipeline.
C) If a pipeline with different config exists, raise an error.
Args:
pipeline_name: The name of the pipeline to register.
pipeline_spec: The spec of the pipeline.
pipeline_docstring: The docstring of the pipeline.
Returns:
The id of the existing or newly registered pipeline.
Raises:
AlreadyExistsException: If there is an existing pipeline in the
project with the same name but a different configuration.
"""
try:
existing_pipeline = self.get_pipeline_by_name(pipeline_name)
# A) If there is no pipeline with this name, register a new pipeline.
except KeyError:
from zenml.models import PipelineModel
pipeline = PipelineModel(
project=self.active_project.id,
user=self.active_user.id,
name=pipeline_name,
spec=pipeline_spec,
docstring=pipeline_docstring,
)
pipeline = self.zen_store.create_pipeline(pipeline=pipeline)
logger.info(f"Registered new pipeline with name {pipeline.name}.")
return pipeline.id
# B) If a pipeline exists that has the same config, use that pipeline.
if pipeline_spec == existing_pipeline.spec:
logger.debug("Did not register pipeline since it already exists.")
return existing_pipeline.id
# C) If a pipeline with different config exists, raise an error.
error_msg = (
f"Cannot run pipeline '{pipeline_name}' since this name has "
"already been registered with a different pipeline "
"configuration. You have three options to resolve this issue:\n"
"1) You can register a new pipeline by changing the name "
"of your pipeline, e.g., via `@pipeline(name='new_pipeline_name')."
"\n2) You can execute the current run without linking it to any "
"pipeline by setting the 'unlisted' argument to `True`, e.g., "
"via `my_pipeline_instance.run(unlisted=True)`. "
"Unlisted runs are not linked to any pipeline, but are still "
"tracked by ZenML and can be accessed via the 'All Runs' tab. \n"
"3) You can delete the existing pipeline via "
f"`zenml pipeline delete {pipeline_name}`. This will then "
"change all existing runs of this pipeline to become unlisted."
)
raise AlreadyExistsException(error_msg)
register_stack(self, stack)
Registers a stack and its components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack |
StackModel |
The stack to register. |
required |
Returns:
Type | Description |
---|---|
StackModel |
The model of the registered stack. |
Source code in zenml/client.py
def register_stack(self, stack: "StackModel") -> "StackModel":
"""Registers a stack and its components.
Args:
stack: The stack to register.
Returns:
The model of the registered stack.
"""
self._validate_stack_configuration(stack=stack)
created_stack = self.zen_store.create_stack(
stack=stack,
)
return created_stack
register_stack_component(self, component)
Registers a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
ComponentModel |
The component to register. |
required |
Returns:
Type | Description |
---|---|
ComponentModel |
The model of the registered component. |
Source code in zenml/client.py
def register_stack_component(
self,
component: "ComponentModel",
) -> "ComponentModel":
"""Registers a stack component.
Args:
component: The component to register.
Returns:
The model of the registered component.
"""
# Get the flavor model
flavor_model = self.get_flavor_by_name_and_type(
name=component.flavor, component_type=component.type
)
# Create and validate the configuration
from zenml.stack import Flavor
flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
# Update the configuration in the model
component.configuration = configuration.dict()
self._validate_stack_component_configuration(
component.type, configuration=configuration
)
# Register the new model
return self.zen_store.create_stack_component(component=component)
set_active_project(*args, **kwargs)
Inner decorator function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Arguments to be passed to the function. |
() |
**kwargs |
Any |
Keyword arguments to be passed to the function. |
{} |
Returns:
Type | Description |
---|---|
Any |
Result of the function. |
Source code in zenml/client.py
def inner_func(*args: Any, **kwargs: Any) -> Any:
"""Inner decorator function.
Args:
*args: Arguments to be passed to the function.
**kwargs: Keyword arguments to be passed to the function.
Returns:
Result of the function.
"""
result = func(*args, **kwargs)
try:
tracker: Optional[AnalyticsTrackerMixin] = None
if len(args) and isinstance(args[0], AnalyticsTrackerMixin):
tracker = args[0]
for obj in [result] + list(args) + list(kwargs.values()):
if isinstance(obj, AnalyticsTrackedModelMixin):
obj.track_event(event_name, tracker=tracker)
break
else:
if tracker:
tracker.track_event(event_name, metadata)
else:
track_event(event_name, metadata)
except Exception as e:
logger.debug(f"Analytics tracking failure for {func}: {e}")
return result
update_stack(self, stack)
Updates a stack and its components.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack |
StackModel |
The new stack to use as the updated version. |
required |
Source code in zenml/client.py
def update_stack(self, stack: "StackModel") -> None:
"""Updates a stack and its components.
Args:
stack: The new stack to use as the updated version.
"""
self._validate_stack_configuration(stack=stack)
self.zen_store.update_stack(stack=stack)
update_stack_component(self, component)
Updates a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
component |
ComponentModel |
The new component to update with. |
required |
Returns:
Type | Description |
---|---|
ComponentModel |
The updated component. |
Source code in zenml/client.py
def update_stack_component(
self,
component: "ComponentModel",
) -> "ComponentModel":
"""Updates a stack component.
Args:
component: The new component to update with.
Returns:
The updated component.
"""
# Get the existing component model
existing_component_model = self.get_stack_component_by_id(
component.id,
)
# Get the flavor model of the existing component
flavor_model = self.get_flavor_by_name_and_type(
name=existing_component_model.flavor,
component_type=existing_component_model.type,
)
# Use the flavor class to validate the new configuration
from zenml.stack import Flavor
flavor = Flavor.from_model(flavor_model)
configuration = flavor.config_class(**component.configuration)
# Update the configuration in the model
component.configuration = configuration.dict()
self._validate_stack_component_configuration(
component.type, configuration=configuration
)
# Send the updated component to the ZenStore
return self.zen_store.update_stack_component(component=component)
ClientConfiguration (FileSyncModel)
pydantic-model
Pydantic object used for serializing client configuration options.
Attributes:
Name | Type | Description |
---|---|---|
active_stack_id |
Optional name of the active stack. |
|
active_project_name |
Optional name of the active project. |
Source code in zenml/client.py
class ClientConfiguration(FileSyncModel):
"""Pydantic object used for serializing client configuration options.
Attributes:
active_stack_id: Optional name of the active stack.
active_project_name: Optional name of the active project.
"""
active_stack_id: Optional[UUID]
active_project_name: Optional[str]
_active_project: Optional["ProjectModel"] = None
def set_active_project(self, project: "ProjectModel") -> None:
"""Set the project for the local client.
Args:
project: The project to set active.
"""
self.active_project_name = project.name
self._active_project = project
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra = "allow"
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
Config
Pydantic configuration class.
Source code in zenml/client.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them. We need to set this in order
# to have a mix of mutable and immutable attributes
validate_assignment = True
# Allow extra attributes from configs of previous ZenML versions to
# permit downgrading
extra = "allow"
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
set_active_project(self, project)
Set the project for the local client.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
project |
ProjectModel |
The project to set active. |
required |
Source code in zenml/client.py
def set_active_project(self, project: "ProjectModel") -> None:
"""Set the project for the local client.
Args:
project: The project to set active.
"""
self.active_project_name = project.name
self._active_project = project
ClientMetaClass (ABCMeta)
Client singleton metaclass.
This metaclass is used to enforce a singleton instance of the Client class with the following additional properties:
- the singleton Client instance is created on first access to reflect the global configuration and local client configuration.
- the Client shouldn't be accessed from within pipeline steps (a warning is logged if this is attempted).
Source code in zenml/client.py
class ClientMetaClass(ABCMeta):
"""Client singleton metaclass.
This metaclass is used to enforce a singleton instance of the Client
class with the following additional properties:
* the singleton Client instance is created on first access to reflect
the global configuration and local client configuration.
* the Client shouldn't be accessed from within pipeline steps (a warning
is logged if this is attempted).
"""
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize the Client class.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
"""
super().__init__(*args, **kwargs)
cls._global_client: Optional["Client"] = None
def __call__(cls, *args: Any, **kwargs: Any) -> "Client":
"""Create or return the global Client instance.
If the Client constructor is called with custom arguments,
the singleton functionality of the metaclass is bypassed: a new
Client instance is created and returned immediately and without
saving it as the global Client singleton.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Client: The global Client instance.
"""
if args or kwargs:
return cast("Client", super().__call__(*args, **kwargs))
if not cls._global_client:
cls._global_client = cast(
"Client", super().__call__(*args, **kwargs)
)
return cls._global_client
__call__(cls, *args, **kwargs)
special
Create or return the global Client instance.
If the Client constructor is called with custom arguments, the singleton functionality of the metaclass is bypassed: a new Client instance is created and returned immediately and without saving it as the global Client singleton.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments. |
() |
**kwargs |
Any |
Keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Client |
The global Client instance. |
Source code in zenml/client.py
def __call__(cls, *args: Any, **kwargs: Any) -> "Client":
"""Create or return the global Client instance.
If the Client constructor is called with custom arguments,
the singleton functionality of the metaclass is bypassed: a new
Client instance is created and returned immediately and without
saving it as the global Client singleton.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Client: The global Client instance.
"""
if args or kwargs:
return cast("Client", super().__call__(*args, **kwargs))
if not cls._global_client:
cls._global_client = cast(
"Client", super().__call__(*args, **kwargs)
)
return cls._global_client
__init__(cls, *args, **kwargs)
special
Initialize the Client class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments. |
() |
**kwargs |
Any |
Keyword arguments. |
{} |
Source code in zenml/client.py
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize the Client class.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
"""
super().__init__(*args, **kwargs)
cls._global_client: Optional["Client"] = None