Models
zenml.models
special
Initialization for ZenML models submodule.
base_models
Base domain model definitions.
DomainModel (BaseModel)
pydantic-model
Base domain model.
Used as a base class for all domain models that have the following common characteristics:
- are uniquely identified by an UUID
- have a creation timestamp and a last modified timestamp
Source code in zenml/models/base_models.py
class DomainModel(BaseModel):
"""Base domain model.
Used as a base class for all domain models that have the following common
characteristics:
* are uniquely identified by an UUID
* have a creation timestamp and a last modified timestamp
"""
def __hash__(self) -> int:
"""Implementation of hash magic method.
Returns:
Hash of the UUID.
"""
return hash((type(self),) + tuple([self.id]))
def __eq__(self, other: Any) -> bool:
"""Implementation of equality magic method.
Args:
other: The other object to compare to.
Returns:
True if the other object is of the same type and has the same UUID.
"""
return self.id == other.id if isinstance(other, DomainModel) else False
id: UUID = Field(default_factory=uuid4, title="The unique resource id.")
created: datetime = Field(
default_factory=datetime.now,
title="Time when this resource was created.",
)
updated: datetime = Field(
default_factory=datetime.now,
title="Time when this resource was last updated.",
)
__eq__(self, other)
special
Implementation of equality magic method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
other |
Any |
The other object to compare to. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the other object is of the same type and has the same UUID. |
Source code in zenml/models/base_models.py
def __eq__(self, other: Any) -> bool:
"""Implementation of equality magic method.
Args:
other: The other object to compare to.
Returns:
True if the other object is of the same type and has the same UUID.
"""
return self.id == other.id if isinstance(other, DomainModel) else False
__hash__(self)
special
Implementation of hash magic method.
Returns:
Type | Description |
---|---|
int |
Hash of the UUID. |
Source code in zenml/models/base_models.py
def __hash__(self) -> int:
"""Implementation of hash magic method.
Returns:
Hash of the UUID.
"""
return hash((type(self),) + tuple([self.id]))
ProjectScopedDomainModel (UserOwnedDomainModel)
pydantic-model
Base project-scoped domain model.
Used as a base class for all domain models that are project-scoped.
Source code in zenml/models/base_models.py
class ProjectScopedDomainModel(UserOwnedDomainModel):
"""Base project-scoped domain model.
Used as a base class for all domain models that are project-scoped.
"""
project: UUID = Field(title="The project to which this resource belongs.")
ShareableProjectScopedDomainModel (ProjectScopedDomainModel)
pydantic-model
Base shareable project-scoped domain model.
Used as a base class for all domain models that are project-scoped and are shareable.
Source code in zenml/models/base_models.py
class ShareableProjectScopedDomainModel(ProjectScopedDomainModel):
"""Base shareable project-scoped domain model.
Used as a base class for all domain models that are project-scoped and are
shareable.
"""
is_shared: bool = Field(
default=False,
title=(
"Flag describing if this resource is shared with other users in "
"the same project."
),
)
UserOwnedDomainModel (DomainModel)
pydantic-model
Base user-owned domain model.
Used as a base class for all domain models that are "owned" by a user.
Source code in zenml/models/base_models.py
class UserOwnedDomainModel(DomainModel):
"""Base user-owned domain model.
Used as a base class for all domain models that are "owned" by a user.
"""
user: UUID = Field(
title="The id of the user that created this resource.",
)
component_model
Model definition for stack components.
ComponentModel (ShareableProjectScopedDomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain Model describing the Stack Component.
Source code in zenml/models/component_model.py
class ComponentModel(
ShareableProjectScopedDomainModel, AnalyticsTrackedModelMixin
):
"""Domain Model describing the Stack Component."""
ANALYTICS_FIELDS: ClassVar[List[str]] = [
"id",
"type",
"flavor",
"project",
"user",
"is_shared",
]
id: UUID = Field(
default_factory=uuid4, title="The unique id of the component."
)
name: str = Field(
title="The name of the stack component.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
type: StackComponentType = Field(
title="The type of the stack component.",
)
flavor: str = Field(
title="The flavor of the stack component.",
)
configuration: Dict[
str, Any
] = Field( # Json representation of the configuration
title="The stack component configuration.",
)
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "5e4286b5-51f4-4286-b1f8-b0143e9a27ce",
"name": "vertex_prd_orchestrator",
"type": "orchestrator",
"flavor": "vertex",
"configuration": {"location": "europe-west3"},
"project": "da63ad01-9117-4082-8a99-557ca5a7d324",
"user": "43d73159-04fe-418b-b604-b769dd5b771b",
"created": "2022-08-12T07:12:44.931Z",
"updated": "2022-08-12T07:12:44.931Z",
}
}
def to_hydrated_model(self) -> "HydratedComponentModel":
"""Converts the `ComponentModel` into a `HydratedComponentModel`.
Returns:
The hydrated component model.
"""
zen_store = GlobalConfiguration().zen_store
project = zen_store.get_project(self.project)
user = zen_store.get_user(self.user)
return HydratedComponentModel(
id=self.id,
name=self.name,
type=self.type,
flavor=self.flavor,
configuration=self.configuration,
project=project,
user=user,
is_shared=self.is_shared,
created=self.created,
updated=self.updated,
)
Config
Example of a json-serialized instance.
Source code in zenml/models/component_model.py
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "5e4286b5-51f4-4286-b1f8-b0143e9a27ce",
"name": "vertex_prd_orchestrator",
"type": "orchestrator",
"flavor": "vertex",
"configuration": {"location": "europe-west3"},
"project": "da63ad01-9117-4082-8a99-557ca5a7d324",
"user": "43d73159-04fe-418b-b604-b769dd5b771b",
"created": "2022-08-12T07:12:44.931Z",
"updated": "2022-08-12T07:12:44.931Z",
}
}
to_hydrated_model(self)
Converts the ComponentModel
into a HydratedComponentModel
.
Returns:
Type | Description |
---|---|
HydratedComponentModel |
The hydrated component model. |
Source code in zenml/models/component_model.py
def to_hydrated_model(self) -> "HydratedComponentModel":
"""Converts the `ComponentModel` into a `HydratedComponentModel`.
Returns:
The hydrated component model.
"""
zen_store = GlobalConfiguration().zen_store
project = zen_store.get_project(self.project)
user = zen_store.get_user(self.user)
return HydratedComponentModel(
id=self.id,
name=self.name,
type=self.type,
flavor=self.flavor,
configuration=self.configuration,
project=project,
user=user,
is_shared=self.is_shared,
created=self.created,
updated=self.updated,
)
HydratedComponentModel (ComponentModel)
pydantic-model
Component model with User and Project fully hydrated.
Source code in zenml/models/component_model.py
class HydratedComponentModel(ComponentModel):
"""Component model with User and Project fully hydrated."""
# TODO: before ignoring the typing error, think of a better way to do this
project: ProjectModel = Field(title="The project that contains this stack.") # type: ignore[assignment]
user: UserModel = Field( # type: ignore[assignment]
title="The user that created this stack.",
)
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "5e4286b5-51f4-4286-b1f8-b0143e9a27ce",
"name": "vertex_prd_orchestrator",
"type": "orchestrator",
"flavor": "vertex",
"configuration": {"location": "europe-west3"},
"project": {
"id": "da63ad01-9117-4082-8a99-557ca5a7d324",
"name": "default",
"description": "Best project.",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
},
"user": {
"id": "43d73159-04fe-418b-b604-b769dd5b771b",
"name": "default",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
},
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
}
}
Config
Example of a json-serialized instance.
Source code in zenml/models/component_model.py
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "5e4286b5-51f4-4286-b1f8-b0143e9a27ce",
"name": "vertex_prd_orchestrator",
"type": "orchestrator",
"flavor": "vertex",
"configuration": {"location": "europe-west3"},
"project": "da63ad01-9117-4082-8a99-557ca5a7d324",
"user": "43d73159-04fe-418b-b604-b769dd5b771b",
"created": "2022-08-12T07:12:44.931Z",
"updated": "2022-08-12T07:12:44.931Z",
}
}
constants
Constants used by ZenML domain models.
flavor_models
Model definitions for stack component flavors.
FlavorModel (ProjectScopedDomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model representing the custom implementation of a flavor.
Source code in zenml/models/flavor_models.py
class FlavorModel(ProjectScopedDomainModel, AnalyticsTrackedModelMixin):
"""Domain model representing the custom implementation of a flavor."""
ANALYTICS_FIELDS: ClassVar[List[str]] = [
"id",
"type",
"integration",
"project",
"user",
]
name: str = Field(
title="The name of the Flavor.",
)
type: StackComponentType = Field(
title="The type of the Flavor.",
)
config_schema: str = Field(
title="The JSON schema of this flavor's corresponding configuration."
)
source: str = Field(
title="The path to the module which contains this Flavor."
)
integration: Optional[str] = Field(
title="The name of the integration that the Flavor belongs to."
)
pipeline_models
Model definitions for pipelines, runs, steps, and artifacts.
ArtifactModel (DomainModel)
pydantic-model
Domain Model representing an artifact.
Source code in zenml/models/pipeline_models.py
class ArtifactModel(DomainModel):
"""Domain Model representing an artifact."""
name: str # Name of the output in the parent step
parent_step_id: UUID
producer_step_id: UUID
type: ArtifactType
uri: str
materializer: str
data_type: str
is_cached: bool
# IDs in MLMD - needed for some metadata store methods
mlmd_id: Optional[int]
mlmd_parent_step_id: Optional[int]
mlmd_producer_step_id: Optional[int]
PipelineModel (ProjectScopedDomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model representing a pipeline.
Source code in zenml/models/pipeline_models.py
class PipelineModel(ProjectScopedDomainModel, AnalyticsTrackedModelMixin):
"""Domain model representing a pipeline."""
ANALYTICS_FIELDS: ClassVar[List[str]] = ["id", "project", "user"]
name: str = Field(
title="The name of the pipeline.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
docstring: Optional[str]
spec: PipelineSpec
PipelineRunModel (ProjectScopedDomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain Model representing a pipeline run.
Source code in zenml/models/pipeline_models.py
class PipelineRunModel(ProjectScopedDomainModel, AnalyticsTrackedModelMixin):
"""Domain Model representing a pipeline run."""
name: str = Field(
title="The name of the pipeline run.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
orchestrator_run_id: Optional[str] = None
stack_id: Optional[UUID] # Might become None if the stack is deleted.
pipeline_id: Optional[UUID] # Unlisted runs have this as None.
status: ExecutionStatus
pipeline_configuration: Dict[str, Any]
num_steps: Optional[int]
zenml_version: Optional[str] = current_zenml_version
git_sha: Optional[str] = Field(default_factory=get_git_sha)
# ID in MLMD - needed for some metadata store methods.
mlmd_id: Optional[int] # Modeled as Optional, so we can remove it later.
StepRunModel (DomainModel)
pydantic-model
Domain Model representing a step in a pipeline run.
Source code in zenml/models/pipeline_models.py
class StepRunModel(DomainModel):
"""Domain Model representing a step in a pipeline run."""
name: str = Field(
title="The name of the pipeline run step.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
pipeline_run_id: UUID
parent_step_ids: List[UUID]
input_artifacts: Dict[str, UUID] # mapping from input name to artifact ID
status: ExecutionStatus
entrypoint_name: str
parameters: Dict[str, str]
step_configuration: Dict[str, Any]
docstring: Optional[str]
num_outputs: Optional[int]
# IDs in MLMD - needed for some metadata store methods
mlmd_id: Optional[int]
mlmd_parent_step_ids: List[int]
get_git_sha(clean=True)
Returns the current git HEAD SHA.
If the current working directory is not inside a git repo, this will return
None
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clean |
bool |
If |
True |
Returns:
Type | Description |
---|---|
Optional[str] |
The current git HEAD SHA or |
Source code in zenml/models/pipeline_models.py
def get_git_sha(clean: bool = True) -> Optional[str]:
"""Returns the current git HEAD SHA.
If the current working directory is not inside a git repo, this will return
`None`.
Args:
clean: If `True` and there any untracked files or files in the index or
working tree, this function will return `None`.
Returns:
The current git HEAD SHA or `None` if the current working directory is
not inside a git repo.
"""
try:
from git.exc import InvalidGitRepositoryError
from git.repo.base import Repo
except ImportError:
return None
try:
repo = Repo(search_parent_directories=True)
except InvalidGitRepositoryError:
return None
if clean and repo.is_dirty(untracked_files=True):
return None
return cast(str, repo.head.object.hexsha)
project_models
Model definitions for code projects.
ProjectModel (DomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model for projects.
Source code in zenml/models/project_models.py
class ProjectModel(DomainModel, AnalyticsTrackedModelMixin):
"""Domain model for projects."""
ANALYTICS_FIELDS: ClassVar[List[str]] = [
"id",
]
name: str = Field(
title="The unique name of the project.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
description: str = Field(
default="",
title="The description of the project.",
max_length=MODEL_DESCRIPTIVE_FIELD_MAX_LENGTH,
)
server_models
Model definitions for code projects.
ServerDatabaseType (StrEnum)
Enum for server database types.
Source code in zenml/models/server_models.py
class ServerDatabaseType(StrEnum):
"""Enum for server database types."""
SQLITE = "sqlite"
MYSQL = "mysql"
OTHER = "other"
ServerDeploymentType (StrEnum)
Enum for server deployment types.
Source code in zenml/models/server_models.py
class ServerDeploymentType(StrEnum):
"""Enum for server deployment types."""
LOCAL = "local"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS = "aws"
GCP = "gcp"
AZURE = "azure"
OTHER = "other"
ServerModel (BaseModel)
pydantic-model
Domain model for ZenML servers.
Source code in zenml/models/server_models.py
class ServerModel(BaseModel):
"""Domain model for ZenML servers."""
id: UUID = Field(default_factory=uuid4, title="The unique server id.")
version: str = Field(
title="The ZenML version that the server is running.",
)
deployment_type: ServerDeploymentType = Field(
ServerDeploymentType.OTHER,
title="The ZenML server deployment type.",
)
database_type: ServerDatabaseType = Field(
ServerDatabaseType.OTHER,
title="The database type that the server is using.",
)
def is_local(self) -> bool:
"""Return whether the server is running locally.
Returns:
True if the server is running locally, False otherwise.
"""
from zenml.config.global_config import GlobalConfiguration
# Local ZenML servers are identifiable by the fact that their
# server ID is the same as the local client (user) ID.
return self.id == GlobalConfiguration().user_id
is_local(self)
Return whether the server is running locally.
Returns:
Type | Description |
---|---|
bool |
True if the server is running locally, False otherwise. |
Source code in zenml/models/server_models.py
def is_local(self) -> bool:
"""Return whether the server is running locally.
Returns:
True if the server is running locally, False otherwise.
"""
from zenml.config.global_config import GlobalConfiguration
# Local ZenML servers are identifiable by the fact that their
# server ID is the same as the local client (user) ID.
return self.id == GlobalConfiguration().user_id
stack_models
Model definitions for stack.
HydratedStackModel (StackModel)
pydantic-model
Stack model with Components, User and Project fully hydrated.
Source code in zenml/models/stack_models.py
class HydratedStackModel(StackModel):
"""Stack model with Components, User and Project fully hydrated."""
components: Dict[StackComponentType, List[ComponentModel]] = Field( # type: ignore[assignment]
title="A mapping of stack component types to the actual"
"instances of components of this type."
)
project: ProjectModel = Field(title="The project that contains this stack.") # type: ignore[assignment]
user: UserModel = Field( # type: ignore[assignment]
title="The user that created this stack.",
)
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "cbc7d4fd-8c88-49dd-ab12-d998e4fafe22",
"name": "default",
"description": "",
"components": {
"artifact_store": [
{
"id": "55a32b96-7995-4622-8474-12e7c94f3054",
"name": "default",
"type": "artifact_store",
"flavor": "local",
"configuration": {
"path": "../zenml/local_stores/default_local_store"
},
"user": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"is_shared": "False",
"project": "c5600721-8432-436d-ac59-a47aec6dec0f",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
}
],
"orchestrator": [
{
"id": "67441c8b-e4e7-439b-bad3-e5883659d387",
"name": "default",
"type": "orchestrator",
"flavor": "local",
"configuration": {},
"user": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"is_shared": "False",
"project": "c5600721-8432-436d-ac59-a47aec6dec0f",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
}
],
},
"is_shared": "False",
"project": {
"id": "c5600721-8432-436d-ac59-a47aec6dec0f",
"name": "default",
"description": "",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
},
"user": {
"id": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"name": "default",
"full_name": "",
"email": "",
"active": "True",
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
},
"created": "2022-09-15T11:43:29.987627",
"updated": "2022-09-15T11:43:29.987627",
}
}
def to_yaml(self) -> Dict[str, Any]:
"""Create yaml representation of the Stack Model.
Returns:
The yaml representation of the Stack Model.
"""
component_data = {}
for component_type, components_list in self.components.items():
component_dict = json.loads(components_list[0].json())
component_dict.pop("project") # Not needed in the yaml repr
component_dict.pop("created") # Not needed in the yaml repr
component_dict.pop("updated") # Not needed in the yaml repr
component_data[component_type.value] = component_dict
# write zenml version and stack dict to YAML
yaml_data = {
"stack_name": self.name,
"components": component_data,
}
return yaml_data
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Add the stack components to the stack analytics metadata.
Returns:
Dict of analytics metadata.
"""
metadata = super().get_analytics_metadata()
metadata.update({ct: c[0].flavor for ct, c in self.components.items()})
return metadata
Config
Example of a json-serialized instance.
Source code in zenml/models/stack_models.py
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "cbc7d4fd-8c88-49dd-ab12-d998e4fafe22",
"name": "default",
"description": "",
"components": {
"artifact_store": ["55a32b96-7995-4622-8474-12e7c94f3054"],
"orchestrator": ["67441c8b-e4e7-439b-bad3-e5883659d387"],
},
"is_shared": "False",
"project": "c5600721-8432-436d-ac59-a47aec6dec0f",
"user": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"created": "2022-09-15T11:43:29.994722",
"updated": "2022-09-15T11:43:29.994722",
}
}
get_analytics_metadata(self)
Add the stack components to the stack analytics metadata.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dict of analytics metadata. |
Source code in zenml/models/stack_models.py
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Add the stack components to the stack analytics metadata.
Returns:
Dict of analytics metadata.
"""
metadata = super().get_analytics_metadata()
metadata.update({ct: c[0].flavor for ct, c in self.components.items()})
return metadata
to_yaml(self)
Create yaml representation of the Stack Model.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The yaml representation of the Stack Model. |
Source code in zenml/models/stack_models.py
def to_yaml(self) -> Dict[str, Any]:
"""Create yaml representation of the Stack Model.
Returns:
The yaml representation of the Stack Model.
"""
component_data = {}
for component_type, components_list in self.components.items():
component_dict = json.loads(components_list[0].json())
component_dict.pop("project") # Not needed in the yaml repr
component_dict.pop("created") # Not needed in the yaml repr
component_dict.pop("updated") # Not needed in the yaml repr
component_data[component_type.value] = component_dict
# write zenml version and stack dict to YAML
yaml_data = {
"stack_name": self.name,
"components": component_data,
}
return yaml_data
StackModel (ShareableProjectScopedDomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain Model describing the Stack.
Source code in zenml/models/stack_models.py
class StackModel(ShareableProjectScopedDomainModel, AnalyticsTrackedModelMixin):
"""Domain Model describing the Stack."""
ANALYTICS_FIELDS: ClassVar[List[str]] = [
"id",
"project",
"user",
"is_shared",
]
name: str = Field(
title="The name of the stack.", max_length=MODEL_NAME_FIELD_MAX_LENGTH
)
description: str = Field(
default="",
title="The description of the stack",
max_length=MODEL_DESCRIPTIVE_FIELD_MAX_LENGTH,
)
components: Dict[StackComponentType, List[UUID]] = Field(
title=(
"A mapping of stack component types to the id's of"
"instances of components of this type."
)
)
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "cbc7d4fd-8c88-49dd-ab12-d998e4fafe22",
"name": "default",
"description": "",
"components": {
"artifact_store": ["55a32b96-7995-4622-8474-12e7c94f3054"],
"orchestrator": ["67441c8b-e4e7-439b-bad3-e5883659d387"],
},
"is_shared": "False",
"project": "c5600721-8432-436d-ac59-a47aec6dec0f",
"user": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"created": "2022-09-15T11:43:29.994722",
"updated": "2022-09-15T11:43:29.994722",
}
}
@property
def is_valid(self) -> bool:
"""Check if the stack is valid.
Returns:
True if the stack is valid, False otherwise.
"""
if (
StackComponentType.ARTIFACT_STORE
and StackComponentType.ORCHESTRATOR in self.components.keys()
):
return True
else:
return False
def to_hydrated_model(self) -> "HydratedStackModel":
"""Create a hydrated version of the stack model.
Returns:
A hydrated version of the stack model.
"""
zen_store = GlobalConfiguration().zen_store
components = {}
for comp_type, comp_id_list in self.components.items():
components[comp_type] = [
zen_store.get_stack_component(c_id) for c_id in comp_id_list
]
project = zen_store.get_project(self.project)
user = zen_store.get_user(self.user)
return HydratedStackModel(
id=self.id,
name=self.name,
description=self.description,
components=components,
project=project,
user=user,
is_shared=self.is_shared,
created=self.created,
updated=self.updated,
)
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Add the stack components to the stack analytics metadata.
Returns:
Dict of analytics metadata.
"""
metadata = super().get_analytics_metadata()
metadata.update({ct: c[0] for ct, c in self.components.items()})
return metadata
is_valid: bool
property
readonly
Check if the stack is valid.
Returns:
Type | Description |
---|---|
bool |
True if the stack is valid, False otherwise. |
Config
Example of a json-serialized instance.
Source code in zenml/models/stack_models.py
class Config:
"""Example of a json-serialized instance."""
schema_extra = {
"example": {
"id": "cbc7d4fd-8c88-49dd-ab12-d998e4fafe22",
"name": "default",
"description": "",
"components": {
"artifact_store": ["55a32b96-7995-4622-8474-12e7c94f3054"],
"orchestrator": ["67441c8b-e4e7-439b-bad3-e5883659d387"],
},
"is_shared": "False",
"project": "c5600721-8432-436d-ac59-a47aec6dec0f",
"user": "ae1fd828-fb3b-48e8-a31a-f3ecb3cdb294",
"created": "2022-09-15T11:43:29.994722",
"updated": "2022-09-15T11:43:29.994722",
}
}
get_analytics_metadata(self)
Add the stack components to the stack analytics metadata.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dict of analytics metadata. |
Source code in zenml/models/stack_models.py
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Add the stack components to the stack analytics metadata.
Returns:
Dict of analytics metadata.
"""
metadata = super().get_analytics_metadata()
metadata.update({ct: c[0] for ct, c in self.components.items()})
return metadata
to_hydrated_model(self)
Create a hydrated version of the stack model.
Returns:
Type | Description |
---|---|
HydratedStackModel |
A hydrated version of the stack model. |
Source code in zenml/models/stack_models.py
def to_hydrated_model(self) -> "HydratedStackModel":
"""Create a hydrated version of the stack model.
Returns:
A hydrated version of the stack model.
"""
zen_store = GlobalConfiguration().zen_store
components = {}
for comp_type, comp_id_list in self.components.items():
components[comp_type] = [
zen_store.get_stack_component(c_id) for c_id in comp_id_list
]
project = zen_store.get_project(self.project)
user = zen_store.get_user(self.user)
return HydratedStackModel(
id=self.id,
name=self.name,
description=self.description,
components=components,
project=project,
user=user,
is_shared=self.is_shared,
created=self.created,
updated=self.updated,
)
user_management_models
Model definitions for users, teams, and roles.
JWTToken (BaseModel)
pydantic-model
Pydantic object representing a JWT token.
Attributes:
Name | Type | Description |
---|---|---|
token |
The JWT token. |
|
token_type |
JWTTokenType |
The type of token. |
Source code in zenml/models/user_management_models.py
class JWTToken(BaseModel):
"""Pydantic object representing a JWT token.
Attributes:
token: The JWT token.
token_type: The type of token.
"""
JWT_ALGORITHM: ClassVar[str] = "HS256"
token_type: JWTTokenType
user_id: UUID
permissions: List[str]
@classmethod
def decode(cls, token_type: JWTTokenType, token: str) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token_type: The type of token.
token: The encoded JWT token.
Returns:
The decoded JWT access token.
Raises:
AuthorizationException: If the token is invalid.
"""
# import here to keep these dependencies out of the client
from jose import JWTError, jwt # type: ignore[import]
try:
payload = jwt.decode(
token,
GlobalConfiguration().jwt_secret_key,
algorithms=[cls.JWT_ALGORITHM],
)
except JWTError as e:
raise AuthorizationException(f"Invalid JWT token: {e}") from e
subject: str = payload.get("sub")
if subject is None:
raise AuthorizationException(
"Invalid JWT token: the subject claim is missing"
)
permissions: List[str] = payload.get("permissions")
if permissions is None:
raise AuthorizationException(
"Invalid JWT token: the permissions scope is missing"
)
try:
return cls(
token_type=token_type,
user_id=UUID(subject),
permissions=set(permissions),
)
except ValueError as e:
raise AuthorizationException(
f"Invalid JWT token: could not decode subject claim: {e}"
) from e
def encode(self, expire_minutes: Optional[int] = None) -> str:
"""Creates a JWT access token.
Generates and returns a JWT access token with the subject claim set to
contain the information in this Pydantic object.
Args:
expire_minutes: Number of minutes the token should be valid. If not
provided, the token will not be set to expire.
Returns:
The generated access token.
"""
# import here to keep these dependencies out of the client
from jose import jwt
claims: Dict[str, Any] = {
"sub": str(self.user_id),
"permissions": list(self.permissions),
}
if expire_minutes:
expire = datetime.utcnow() + timedelta(minutes=expire_minutes)
claims["exp"] = expire
token: str = jwt.encode(
claims,
GlobalConfiguration().jwt_secret_key,
algorithm=self.JWT_ALGORITHM,
)
return token
decode(token_type, token)
classmethod
Decodes a JWT access token.
Decodes a JWT access token and returns a JWTToken
object with the
information retrieved from its subject claim.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token_type |
JWTTokenType |
The type of token. |
required |
token |
str |
The encoded JWT token. |
required |
Returns:
Type | Description |
---|---|
JWTToken |
The decoded JWT access token. |
Exceptions:
Type | Description |
---|---|
AuthorizationException |
If the token is invalid. |
Source code in zenml/models/user_management_models.py
@classmethod
def decode(cls, token_type: JWTTokenType, token: str) -> "JWTToken":
"""Decodes a JWT access token.
Decodes a JWT access token and returns a `JWTToken` object with the
information retrieved from its subject claim.
Args:
token_type: The type of token.
token: The encoded JWT token.
Returns:
The decoded JWT access token.
Raises:
AuthorizationException: If the token is invalid.
"""
# import here to keep these dependencies out of the client
from jose import JWTError, jwt # type: ignore[import]
try:
payload = jwt.decode(
token,
GlobalConfiguration().jwt_secret_key,
algorithms=[cls.JWT_ALGORITHM],
)
except JWTError as e:
raise AuthorizationException(f"Invalid JWT token: {e}") from e
subject: str = payload.get("sub")
if subject is None:
raise AuthorizationException(
"Invalid JWT token: the subject claim is missing"
)
permissions: List[str] = payload.get("permissions")
if permissions is None:
raise AuthorizationException(
"Invalid JWT token: the permissions scope is missing"
)
try:
return cls(
token_type=token_type,
user_id=UUID(subject),
permissions=set(permissions),
)
except ValueError as e:
raise AuthorizationException(
f"Invalid JWT token: could not decode subject claim: {e}"
) from e
encode(self, expire_minutes=None)
Creates a JWT access token.
Generates and returns a JWT access token with the subject claim set to contain the information in this Pydantic object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
expire_minutes |
Optional[int] |
Number of minutes the token should be valid. If not provided, the token will not be set to expire. |
None |
Returns:
Type | Description |
---|---|
str |
The generated access token. |
Source code in zenml/models/user_management_models.py
def encode(self, expire_minutes: Optional[int] = None) -> str:
"""Creates a JWT access token.
Generates and returns a JWT access token with the subject claim set to
contain the information in this Pydantic object.
Args:
expire_minutes: Number of minutes the token should be valid. If not
provided, the token will not be set to expire.
Returns:
The generated access token.
"""
# import here to keep these dependencies out of the client
from jose import jwt
claims: Dict[str, Any] = {
"sub": str(self.user_id),
"permissions": list(self.permissions),
}
if expire_minutes:
expire = datetime.utcnow() + timedelta(minutes=expire_minutes)
claims["exp"] = expire
token: str = jwt.encode(
claims,
GlobalConfiguration().jwt_secret_key,
algorithm=self.JWT_ALGORITHM,
)
return token
JWTTokenType (StrEnum)
The type of JWT token.
Source code in zenml/models/user_management_models.py
class JWTTokenType(StrEnum):
"""The type of JWT token."""
ACCESS_TOKEN = "access_token"
PermissionModel (BaseModel)
pydantic-model
Domain model for roles.
Source code in zenml/models/user_management_models.py
class PermissionModel(BaseModel):
"""Domain model for roles."""
ANALYTICS_FIELDS: ClassVar[List[str]] = ["id"]
id: int = Field(title="Id of the specific permission")
name: str = Field(
title="The unique name of the permission.",
)
RoleAssignmentModel (DomainModel)
pydantic-model
Domain model for role assignments.
Source code in zenml/models/user_management_models.py
class RoleAssignmentModel(DomainModel):
"""Domain model for role assignments."""
role: UUID = Field(title="The role.")
project: Optional[UUID] = Field(
None, title="The project that the role is limited to."
)
team: Optional[UUID] = Field(
None, title="The team that the role is assigned to."
)
user: Optional[UUID] = Field(
None, title="The user that the role is assigned to."
)
@root_validator
def ensure_single_entity(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validates that either `user` or `team` is set.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If neither `user` nor `team` is set.
"""
user = values.get("user", None)
team = values.get("team", None)
if user and team:
raise ValueError("Only `user` or `team` is allowed.")
if not (user or team):
raise ValueError("Missing `user` or `team` for role assignment.")
return values
ensure_single_entity(values)
classmethod
Validates that either user
or team
is set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
The values to validate. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The validated values. |
Exceptions:
Type | Description |
---|---|
ValueError |
If neither |
Source code in zenml/models/user_management_models.py
@root_validator
def ensure_single_entity(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validates that either `user` or `team` is set.
Args:
values: The values to validate.
Returns:
The validated values.
Raises:
ValueError: If neither `user` nor `team` is set.
"""
user = values.get("user", None)
team = values.get("team", None)
if user and team:
raise ValueError("Only `user` or `team` is allowed.")
if not (user or team):
raise ValueError("Missing `user` or `team` for role assignment.")
return values
RoleModel (DomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model for roles.
Source code in zenml/models/user_management_models.py
class RoleModel(DomainModel, AnalyticsTrackedModelMixin):
"""Domain model for roles."""
ANALYTICS_FIELDS: ClassVar[List[str]] = ["id"]
name: str = Field(
title="The unique name of the role.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
permissions: Set[PermissionType]
TeamModel (DomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model for teams.
Source code in zenml/models/user_management_models.py
class TeamModel(DomainModel, AnalyticsTrackedModelMixin):
"""Domain model for teams."""
ANALYTICS_FIELDS: ClassVar[List[str]] = ["id"]
name: str = Field(
title="The unique name of the team.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
UserModel (DomainModel, AnalyticsTrackedModelMixin)
pydantic-model
Domain model for user accounts.
Source code in zenml/models/user_management_models.py
class UserModel(DomainModel, AnalyticsTrackedModelMixin):
"""Domain model for user accounts."""
ANALYTICS_FIELDS: ClassVar[List[str]] = [
"id",
"name",
"full_name",
"active",
"email_opted_in",
]
name: str = Field(
default="",
title="The unique username for the account.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
full_name: str = Field(
default="",
title="The full name for the account owner.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
email: Optional[str] = Field(
default="",
title="The email address associated with the account.",
max_length=MODEL_NAME_FIELD_MAX_LENGTH,
)
email_opted_in: Optional[bool] = Field(
title="Whether the user agreed to share their email.",
description="`null` if not answered, `true` if agreed, "
"`false` if skipped.",
)
active: bool = Field(default=False, title="Active account.")
password: Optional[SecretStr] = Field(default=None, exclude=True)
activation_token: Optional[SecretStr] = Field(default=None, exclude=True)
@classmethod
def _get_crypt_context(cls) -> "CryptContext":
"""Returns the password encryption context.
Returns:
The password encryption context.
"""
from passlib.context import CryptContext
return CryptContext(schemes=["bcrypt"], deprecated="auto")
@classmethod
def verify_password(
cls, plain_password: str, user: Optional["UserModel"] = None
) -> bool:
"""Verifies a given plain password against the stored password.
Args:
plain_password: Input password to be verified.
user: User for which the password is to be verified.
Returns:
True if the passwords match.
"""
# even when the user or password is not set, we still want to execute
# the password hash verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
hash: Optional[str] = None
if user is not None and user.password is not None and user.active:
hash = user.get_hashed_password()
pwd_context = cls._get_crypt_context()
return cast(bool, pwd_context.verify(plain_password, hash))
def get_password(self) -> Optional[str]:
"""Get the password.
Returns:
The password as a plain string, if it exists.
"""
if self.password is None:
return None
return self.password.get_secret_value()
@classmethod
def _is_hashed_secret(cls, secret: SecretStr) -> bool:
"""Checks if a secret value is already hashed.
Args:
secret: The secret value to check.
Returns:
True if the secret value is hashed, otherwise False.
"""
return (
re.match(r"^\$2[ayb]\$.{56}$", secret.get_secret_value())
is not None
)
@classmethod
def _get_hashed_secret(cls, secret: Optional[SecretStr]) -> Optional[str]:
"""Hashes the input secret and returns the hash value, if supplied and if not already hashed.
Args:
secret: The secret value to hash.
Returns:
The secret hash value, or None if no secret was supplied.
"""
if secret is None:
return None
if cls._is_hashed_secret(secret):
return secret.get_secret_value()
pwd_context = cls._get_crypt_context()
return cast(str, pwd_context.hash(secret.get_secret_value()))
def get_hashed_password(self) -> Optional[str]:
"""Returns the hashed password, if configured.
Returns:
The hashed password.
"""
return self._get_hashed_secret(self.password)
@classmethod
def verify_access_token(cls, token: str) -> Optional["UserModel"]:
"""Verifies an access token.
Verifies an access token and returns the user that was used to generate
it if the token is valid and None otherwise.
Args:
token: The access token to verify.
Returns:
The user that generated the token if valid, None otherwise.
"""
try:
access_token = JWTToken.decode(
token_type=JWTTokenType.ACCESS_TOKEN, token=token
)
except AuthorizationException:
return None
zen_store = GlobalConfiguration().zen_store
try:
user = zen_store.get_user(user_name_or_id=access_token.user_id)
except KeyError:
return None
if access_token.user_id == user.id and user.active:
return user
return None
def generate_access_token(self, permissions: List[str]) -> str:
"""Generates an access token.
Generates an access token and returns it.
Args:
permissions: Permissions to add to the token
Returns:
The generated access token.
"""
return JWTToken(
token_type=JWTTokenType.ACCESS_TOKEN,
user_id=self.id,
permissions=permissions,
).encode()
def get_activation_token(self) -> Optional[str]:
"""Get the activation token.
Returns:
The activation token as a plain string, if it exists.
"""
if self.activation_token is None:
return None
return self.activation_token.get_secret_value()
def get_hashed_activation_token(self) -> Optional[str]:
"""Returns the hashed activation token, if configured.
Returns:
The hashed activation token.
"""
return self._get_hashed_secret(self.activation_token)
@classmethod
def verify_activation_token(
cls, activation_token: str, user: Optional["UserModel"] = None
) -> bool:
"""Verifies a given activation token against the stored activation token.
Args:
activation_token: Input activation token to be verified.
user: User for which the activation token is to be verified.
Returns:
True if the token is valid.
"""
# even when the user or token is not set, we still want to execute the
# token hash verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
hash: Optional[str] = None
if (
user is not None
and user.activation_token is not None
and not user.active
):
hash = user.get_hashed_activation_token()
pwd_context = cls._get_crypt_context()
return cast(bool, pwd_context.verify(activation_token, hash))
def generate_activation_token(self) -> SecretStr:
"""Generates and stores a new activation token.
Returns:
The generated activation token.
"""
self.activation_token = SecretStr(token_hex(32))
return self.activation_token
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them
validate_assignment = True
# Forbid extra attributes to prevent unexpected behavior
extra = "forbid"
underscore_attrs_are_private = True
email_opted_in: bool
pydantic-field
null
if not answered, true
if agreed, false
if skipped.
Config
Pydantic configuration class.
Source code in zenml/models/user_management_models.py
class Config:
"""Pydantic configuration class."""
# Validate attributes when assigning them
validate_assignment = True
# Forbid extra attributes to prevent unexpected behavior
extra = "forbid"
underscore_attrs_are_private = True
generate_access_token(self, permissions)
Generates an access token.
Generates an access token and returns it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
permissions |
List[str] |
Permissions to add to the token |
required |
Returns:
Type | Description |
---|---|
str |
The generated access token. |
Source code in zenml/models/user_management_models.py
def generate_access_token(self, permissions: List[str]) -> str:
"""Generates an access token.
Generates an access token and returns it.
Args:
permissions: Permissions to add to the token
Returns:
The generated access token.
"""
return JWTToken(
token_type=JWTTokenType.ACCESS_TOKEN,
user_id=self.id,
permissions=permissions,
).encode()
generate_activation_token(self)
Generates and stores a new activation token.
Returns:
Type | Description |
---|---|
SecretStr |
The generated activation token. |
Source code in zenml/models/user_management_models.py
def generate_activation_token(self) -> SecretStr:
"""Generates and stores a new activation token.
Returns:
The generated activation token.
"""
self.activation_token = SecretStr(token_hex(32))
return self.activation_token
get_activation_token(self)
Get the activation token.
Returns:
Type | Description |
---|---|
Optional[str] |
The activation token as a plain string, if it exists. |
Source code in zenml/models/user_management_models.py
def get_activation_token(self) -> Optional[str]:
"""Get the activation token.
Returns:
The activation token as a plain string, if it exists.
"""
if self.activation_token is None:
return None
return self.activation_token.get_secret_value()
get_hashed_activation_token(self)
Returns the hashed activation token, if configured.
Returns:
Type | Description |
---|---|
Optional[str] |
The hashed activation token. |
Source code in zenml/models/user_management_models.py
def get_hashed_activation_token(self) -> Optional[str]:
"""Returns the hashed activation token, if configured.
Returns:
The hashed activation token.
"""
return self._get_hashed_secret(self.activation_token)
get_hashed_password(self)
Returns the hashed password, if configured.
Returns:
Type | Description |
---|---|
Optional[str] |
The hashed password. |
Source code in zenml/models/user_management_models.py
def get_hashed_password(self) -> Optional[str]:
"""Returns the hashed password, if configured.
Returns:
The hashed password.
"""
return self._get_hashed_secret(self.password)
get_password(self)
Get the password.
Returns:
Type | Description |
---|---|
Optional[str] |
The password as a plain string, if it exists. |
Source code in zenml/models/user_management_models.py
def get_password(self) -> Optional[str]:
"""Get the password.
Returns:
The password as a plain string, if it exists.
"""
if self.password is None:
return None
return self.password.get_secret_value()
verify_access_token(token)
classmethod
Verifies an access token.
Verifies an access token and returns the user that was used to generate it if the token is valid and None otherwise.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
token |
str |
The access token to verify. |
required |
Returns:
Type | Description |
---|---|
Optional[UserModel] |
The user that generated the token if valid, None otherwise. |
Source code in zenml/models/user_management_models.py
@classmethod
def verify_access_token(cls, token: str) -> Optional["UserModel"]:
"""Verifies an access token.
Verifies an access token and returns the user that was used to generate
it if the token is valid and None otherwise.
Args:
token: The access token to verify.
Returns:
The user that generated the token if valid, None otherwise.
"""
try:
access_token = JWTToken.decode(
token_type=JWTTokenType.ACCESS_TOKEN, token=token
)
except AuthorizationException:
return None
zen_store = GlobalConfiguration().zen_store
try:
user = zen_store.get_user(user_name_or_id=access_token.user_id)
except KeyError:
return None
if access_token.user_id == user.id and user.active:
return user
return None
verify_activation_token(activation_token, user=None)
classmethod
Verifies a given activation token against the stored activation token.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activation_token |
str |
Input activation token to be verified. |
required |
user |
Optional[UserModel] |
User for which the activation token is to be verified. |
None |
Returns:
Type | Description |
---|---|
bool |
True if the token is valid. |
Source code in zenml/models/user_management_models.py
@classmethod
def verify_activation_token(
cls, activation_token: str, user: Optional["UserModel"] = None
) -> bool:
"""Verifies a given activation token against the stored activation token.
Args:
activation_token: Input activation token to be verified.
user: User for which the activation token is to be verified.
Returns:
True if the token is valid.
"""
# even when the user or token is not set, we still want to execute the
# token hash verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
hash: Optional[str] = None
if (
user is not None
and user.activation_token is not None
and not user.active
):
hash = user.get_hashed_activation_token()
pwd_context = cls._get_crypt_context()
return cast(bool, pwd_context.verify(activation_token, hash))
verify_password(plain_password, user=None)
classmethod
Verifies a given plain password against the stored password.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
plain_password |
str |
Input password to be verified. |
required |
user |
Optional[UserModel] |
User for which the password is to be verified. |
None |
Returns:
Type | Description |
---|---|
bool |
True if the passwords match. |
Source code in zenml/models/user_management_models.py
@classmethod
def verify_password(
cls, plain_password: str, user: Optional["UserModel"] = None
) -> bool:
"""Verifies a given plain password against the stored password.
Args:
plain_password: Input password to be verified.
user: User for which the password is to be verified.
Returns:
True if the passwords match.
"""
# even when the user or password is not set, we still want to execute
# the password hash verification to protect against response discrepancy
# attacks (https://cwe.mitre.org/data/definitions/204.html)
hash: Optional[str] = None
if user is not None and user.password is not None and user.active:
hash = user.get_hashed_password()
pwd_context = cls._get_crypt_context()
return cast(bool, pwd_context.verify(plain_password, hash))