Utils
zenml.utils
special
Initialization of the utils module.
The utils
module contains utility functions handling analytics, reading and
writing YAML data as well as other general purpose functions.
analytics_utils
Analytics code for ZenML.
AnalyticsContext
Context manager for analytics.
Source code in zenml/utils/analytics_utils.py
class AnalyticsContext:
"""Context manager for analytics."""
def __init__(self) -> None:
"""Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized
properly, only tracked when configured to do so and that any errors
are handled gracefully.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
try:
gc = GlobalConfiguration()
self.analytics_opt_in = gc.analytics_opt_in
self.user_id = str(gc.user_id)
# That means user opted out of analytics
if not gc.analytics_opt_in:
return
if analytics.write_key is None:
analytics.write_key = get_segment_key()
assert (
analytics.write_key is not None
), "Analytics key not set but trying to make telemetry call."
# Set this to 1 to avoid backoff loop
analytics.max_retries = 1
except Exception as e:
self.analytics_opt_in = False
logger.debug(f"Analytics initialization failed: {e}")
def __enter__(self) -> "AnalyticsContext":
"""Enter context manager.
Returns:
Self.
"""
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exit context manager.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Returns:
True if exception was handled, False otherwise.
"""
if exc_val is not None:
logger.debug("Sending telemetry data failed: {exc_val}")
# We should never fail main thread
return True
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
"""Identify the user.
Args:
traits: Traits of the user.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.identify(self.user_id, traits)
logger.debug(f"User data sent: User: {self.user_id},{traits}")
return True
def group(
self,
group: Union[str, AnalyticsGroup],
group_id: str,
traits: Optional[Dict[str, Any]] = None,
) -> bool:
"""Group the user.
Args:
group: Group to which the user belongs.
group_id: Group ID.
traits: Traits of the group.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
if isinstance(group, AnalyticsGroup):
group = group.value
if traits is None:
traits = {}
traits.update(
{
"group_id": group_id,
}
)
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.group(self.user_id, group_id, traits=traits)
logger.debug(
f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
f"{group_id}, Metadata: {traits}"
)
return True
def track(
self,
event: Union[str, AnalyticsEvent],
properties: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track an event.
Args:
event: Event to track.
properties: Event properties.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
if isinstance(event, AnalyticsEvent):
event = event.value
if properties is None:
properties = {}
logger.debug(
f"Attempting analytics: User: {self.user_id}, "
f"Event: {event},"
f"Metadata: {properties}"
)
if not self.analytics_opt_in and event not in {
AnalyticsEvent.OPT_OUT_ANALYTICS,
AnalyticsEvent.OPT_IN_ANALYTICS,
}:
return False
# add basics
properties.update(Environment.get_system_info())
properties.update(
{
"environment": get_environment(),
"python_version": Environment.python_version(),
"version": __version__,
}
)
gc = GlobalConfiguration()
# avoid initializing the store in the analytics, to not create an
# infinite loop
if gc._zen_store is not None:
zen_store = gc.zen_store
if (
zen_store.type == StoreType.REST
and "server_id" not in properties
):
user = zen_store.active_user
server_info = zen_store.get_store_info()
properties.update(
{
"user_id": str(user.id),
"server_id": str(server_info.id),
"server_deployment": str(server_info.deployment_type),
"database_type": str(server_info.database_type),
}
)
analytics.track(self.user_id, event, properties)
logger.debug(
f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
f"{properties}"
)
return True
__enter__(self)
special
Enter context manager.
Returns:
Type | Description |
---|---|
AnalyticsContext |
Self. |
Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "AnalyticsContext":
"""Enter context manager.
Returns:
Self.
"""
return self
__exit__(self, exc_type, exc_val, exc_tb)
special
Exit context manager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
exc_type |
Optional[Type[BaseException]] |
Exception type. |
required |
exc_val |
Optional[BaseException] |
Exception value. |
required |
exc_tb |
Optional[traceback] |
Exception traceback. |
required |
Returns:
Type | Description |
---|---|
bool |
True if exception was handled, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
"""Exit context manager.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Returns:
True if exception was handled, False otherwise.
"""
if exc_val is not None:
logger.debug("Sending telemetry data failed: {exc_val}")
# We should never fail main thread
return True
__init__(self)
special
Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized properly, only tracked when configured to do so and that any errors are handled gracefully.
Source code in zenml/utils/analytics_utils.py
def __init__(self) -> None:
"""Context manager for analytics.
Use this as a context manager to ensure that analytics are initialized
properly, only tracked when configured to do so and that any errors
are handled gracefully.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
try:
gc = GlobalConfiguration()
self.analytics_opt_in = gc.analytics_opt_in
self.user_id = str(gc.user_id)
# That means user opted out of analytics
if not gc.analytics_opt_in:
return
if analytics.write_key is None:
analytics.write_key = get_segment_key()
assert (
analytics.write_key is not None
), "Analytics key not set but trying to make telemetry call."
# Set this to 1 to avoid backoff loop
analytics.max_retries = 1
except Exception as e:
self.analytics_opt_in = False
logger.debug(f"Analytics initialization failed: {e}")
group(self, group, group_id, traits=None)
Group the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
group |
Union[str, zenml.utils.analytics_utils.AnalyticsGroup] |
Group to which the user belongs. |
required |
group_id |
str |
Group ID. |
required |
traits |
Optional[Dict[str, Any]] |
Traits of the group. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def group(
self,
group: Union[str, AnalyticsGroup],
group_id: str,
traits: Optional[Dict[str, Any]] = None,
) -> bool:
"""Group the user.
Args:
group: Group to which the user belongs.
group_id: Group ID.
traits: Traits of the group.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
if isinstance(group, AnalyticsGroup):
group = group.value
if traits is None:
traits = {}
traits.update(
{
"group_id": group_id,
}
)
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.group(self.user_id, group_id, traits=traits)
logger.debug(
f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
f"{group_id}, Metadata: {traits}"
)
return True
identify(self, traits=None)
Identify the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
traits |
Optional[Dict[str, Any]] |
Traits of the user. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
"""Identify the user.
Args:
traits: Traits of the user.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
logger.debug(
f"Attempting to attach metadata to: User: {self.user_id}, "
f"Metadata: {traits}"
)
if not self.analytics_opt_in:
return False
analytics.identify(self.user_id, traits)
logger.debug(f"User data sent: User: {self.user_id},{traits}")
return True
track(self, event, properties=None)
Track an event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Event to track. |
required |
properties |
Optional[Dict[str, Any]] |
Event properties. |
None |
Returns:
Type | Description |
---|---|
bool |
True if tracking information was sent, False otherwise. |
Source code in zenml/utils/analytics_utils.py
def track(
self,
event: Union[str, AnalyticsEvent],
properties: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track an event.
Args:
event: Event to track.
properties: Event properties.
Returns:
True if tracking information was sent, False otherwise.
"""
import analytics
from zenml.config.global_config import GlobalConfiguration
if isinstance(event, AnalyticsEvent):
event = event.value
if properties is None:
properties = {}
logger.debug(
f"Attempting analytics: User: {self.user_id}, "
f"Event: {event},"
f"Metadata: {properties}"
)
if not self.analytics_opt_in and event not in {
AnalyticsEvent.OPT_OUT_ANALYTICS,
AnalyticsEvent.OPT_IN_ANALYTICS,
}:
return False
# add basics
properties.update(Environment.get_system_info())
properties.update(
{
"environment": get_environment(),
"python_version": Environment.python_version(),
"version": __version__,
}
)
gc = GlobalConfiguration()
# avoid initializing the store in the analytics, to not create an
# infinite loop
if gc._zen_store is not None:
zen_store = gc.zen_store
if (
zen_store.type == StoreType.REST
and "server_id" not in properties
):
user = zen_store.active_user
server_info = zen_store.get_store_info()
properties.update(
{
"user_id": str(user.id),
"server_id": str(server_info.id),
"server_deployment": str(server_info.deployment_type),
"database_type": str(server_info.database_type),
}
)
analytics.track(self.user_id, event, properties)
logger.debug(
f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
f"{properties}"
)
return True
AnalyticsEvent (str, Enum)
Enum of events to track in segment.
Source code in zenml/utils/analytics_utils.py
class AnalyticsEvent(str, Enum):
"""Enum of events to track in segment."""
# Pipelines
RUN_PIPELINE = "Pipeline run"
GET_PIPELINES = "Pipelines fetched"
GET_PIPELINE = "Pipeline fetched"
CREATE_PIPELINE = "Pipeline created"
UPDATE_PIPELINE = "Pipeline updated"
DELETE_PIPELINE = "Pipeline deleted"
# Repo
INITIALIZE_REPO = "ZenML initialized"
CONNECT_REPOSITORY = "Repository connected"
UPDATE_REPOSITORY = "Repository updated"
DELETE_REPOSITORY = "Repository deleted"
# Zen store
INITIALIZED_STORE = "Store initialized"
# Components
REGISTERED_STACK_COMPONENT = "Stack component registered"
UPDATED_STACK_COMPONENT = "Stack component updated"
COPIED_STACK_COMPONENT = "Stack component copied"
DELETED_STACK_COMPONENT = "Stack component copied"
# Stack
REGISTERED_STACK = "Stack registered"
REGISTERED_DEFAULT_STACK = "Default stack registered"
SET_STACK = "Stack set"
UPDATED_STACK = "Stack updated"
COPIED_STACK = "Stack copied"
IMPORT_STACK = "Stack imported"
EXPORT_STACK = "Stack exported"
DELETED_STACK = "Stack deleted"
# Model Deployment
MODEL_DEPLOYED = "Model deployed"
# Analytics opt in and out
OPT_IN_ANALYTICS = "Analytics opt-in"
OPT_OUT_ANALYTICS = "Analytics opt-out"
OPT_IN_OUT_EMAIL = "Response for Email prompt"
# Examples
RUN_ZENML_GO = "ZenML go"
RUN_EXAMPLE = "Example run"
PULL_EXAMPLE = "Example pull"
# Integrations
INSTALL_INTEGRATION = "Integration installed"
# Users
CREATED_USER = "User created"
CREATED_DEFAULT_USER = "Default user created"
UPDATED_USER = "User updated"
DELETED_USER = "User deleted"
# Teams
CREATED_TEAM = "Team created"
UPDATED_TEAM = "Team updated"
DELETED_TEAM = "Team deleted"
# Projects
CREATED_PROJECT = "Project created"
CREATED_DEFAULT_PROJECT = "Default project created"
UPDATED_PROJECT = "Project updated"
DELETED_PROJECT = "Project deleted"
SET_PROJECT = "Project set"
# Role
CREATED_ROLE = "Role created"
CREATED_DEFAULT_ROLES = "Default roles created"
UPDATED_ROLE = "Role updated"
DELETED_ROLE = "Role deleted"
# Flavor
CREATED_FLAVOR = "Flavor created"
UPDATED_FLAVOR = "Flavor updated"
DELETED_FLAVOR = "Flavor deleted"
# Test event
EVENT_TEST = "Test event"
# Stack recipes
PULL_STACK_RECIPE = "Stack recipes pulled"
RUN_STACK_RECIPE = "Stack recipe created"
DESTROY_STACK_RECIPE = "Stack recipe destroyed"
# ZenML server events
ZENML_SERVER_STARTED = "ZenML server started"
ZENML_SERVER_STOPPED = "ZenML server stopped"
ZENML_SERVER_CONNECTED = "ZenML server connected"
ZENML_SERVER_DEPLOYED = "ZenML server deployed"
ZENML_SERVER_DESTROYED = "ZenML server destroyed"
AnalyticsGroup (str, Enum)
Enum of event groups to track in segment.
Source code in zenml/utils/analytics_utils.py
class AnalyticsGroup(str, Enum):
"""Enum of event groups to track in segment."""
ZENML_SERVER_GROUP = "ZenML server group"
AnalyticsTrackedModelMixin (BaseModel)
pydantic-model
Mixin for models that are tracked through analytics events.
Classes that have information tracked in analytics events can inherit
from this mixin and implement the abstract methods. The @track
decorator
will detect function arguments and return values that inherit from this
class and will include the ANALYTICS_FIELDS
attributes as
tracking metadata.
Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackedModelMixin(BaseModel):
"""Mixin for models that are tracked through analytics events.
Classes that have information tracked in analytics events can inherit
from this mixin and implement the abstract methods. The `@track` decorator
will detect function arguments and return values that inherit from this
class and will include the `ANALYTICS_FIELDS` attributes as
tracking metadata.
"""
ANALYTICS_FIELDS: ClassVar[List[str]]
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Get the analytics metadata for the model.
Returns:
Dict of analytics metadata.
"""
metadata = {}
for field_name in self.ANALYTICS_FIELDS:
metadata[field_name] = getattr(self, field_name, None)
return metadata
def track_event(
self,
event: Union[str, AnalyticsEvent],
tracker: Optional[AnalyticsTrackerMixin] = None,
) -> None:
"""Track an event for the model.
Args:
event: Event to track.
tracker: Optional tracker to use for analytics.
"""
metadata = self.get_analytics_metadata()
if tracker:
tracker.track_event(event, metadata)
else:
track_event(event, metadata)
get_analytics_metadata(self)
Get the analytics metadata for the model.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dict of analytics metadata. |
Source code in zenml/utils/analytics_utils.py
def get_analytics_metadata(self) -> Dict[str, Any]:
"""Get the analytics metadata for the model.
Returns:
Dict of analytics metadata.
"""
metadata = {}
for field_name in self.ANALYTICS_FIELDS:
metadata[field_name] = getattr(self, field_name, None)
return metadata
track_event(self, event, tracker=None)
Track an event for the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Event to track. |
required |
tracker |
Optional[zenml.utils.analytics_utils.AnalyticsTrackerMixin] |
Optional tracker to use for analytics. |
None |
Source code in zenml/utils/analytics_utils.py
def track_event(
self,
event: Union[str, AnalyticsEvent],
tracker: Optional[AnalyticsTrackerMixin] = None,
) -> None:
"""Track an event for the model.
Args:
event: Event to track.
tracker: Optional tracker to use for analytics.
"""
metadata = self.get_analytics_metadata()
if tracker:
tracker.track_event(event, metadata)
else:
track_event(event, metadata)
AnalyticsTrackerMixin (ABC)
Abstract base class for analytics trackers.
Use this as a mixin for classes that have methods decorated with
@track
to add global control over how analytics are tracked. The decorator
will detect that the class has this mixin and will call the class
track_event
method.
Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackerMixin(ABC):
"""Abstract base class for analytics trackers.
Use this as a mixin for classes that have methods decorated with
`@track` to add global control over how analytics are tracked. The decorator
will detect that the class has this mixin and will call the class
`track_event` method.
"""
@abstractmethod
def track_event(
self,
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]],
) -> None:
"""Track an event.
Args:
event: Event to track.
metadata: Metadata to track.
"""
track_event(self, event, metadata)
Track an event.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Event to track. |
required |
metadata |
Optional[Dict[str, Any]] |
Metadata to track. |
required |
Source code in zenml/utils/analytics_utils.py
@abstractmethod
def track_event(
self,
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]],
) -> None:
"""Track an event.
Args:
event: Event to track.
metadata: Metadata to track.
"""
get_segment_key()
Get key for authorizing to Segment backend.
Returns:
Type | Description |
---|---|
str |
Segment key as a string. |
Source code in zenml/utils/analytics_utils.py
def get_segment_key() -> str:
"""Get key for authorizing to Segment backend.
Returns:
Segment key as a string.
"""
if IS_DEBUG_ENV:
return SEGMENT_KEY_DEV
else:
return SEGMENT_KEY_PROD
identify_group(group, group_id, group_metadata=None)
Attach metadata to a segment group.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
group |
Union[str, zenml.utils.analytics_utils.AnalyticsGroup] |
Group to track. |
required |
group_id |
str |
ID of the group. |
required |
group_metadata |
Optional[Dict[str, Any]] |
Metadata to attach to the group. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def identify_group(
group: Union[str, AnalyticsGroup],
group_id: str,
group_metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""Attach metadata to a segment group.
Args:
group: Group to track.
group_id: ID of the group.
group_metadata: Metadata to attach to the group.
Returns:
True if event is sent successfully, False is not.
"""
with AnalyticsContext() as analytics:
return analytics.group(group, group_id, traits=group_metadata)
return False
identify_user(user_metadata=None)
Attach metadata to user directly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
user_metadata |
Optional[Dict[str, Any]] |
Dict of metadata to attach to the user. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def identify_user(user_metadata: Optional[Dict[str, Any]] = None) -> bool:
"""Attach metadata to user directly.
Args:
user_metadata: Dict of metadata to attach to the user.
Returns:
True if event is sent successfully, False is not.
"""
with AnalyticsContext() as analytics:
if user_metadata is None:
return False
return analytics.identify(traits=user_metadata)
return False
parametrized(dec)
This is a meta-decorator, that is, a decorator for decorators.
As a decorator is a function, it actually works as a regular decorator with arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dec |
Callable[..., Callable[..., Any]] |
Decorator to be applied to the function. |
required |
Returns:
Type | Description |
---|---|
Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]] |
Decorator that applies the given decorator to the function. |
Source code in zenml/utils/analytics_utils.py
def parametrized(
dec: Callable[..., Callable[..., Any]]
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]:
"""This is a meta-decorator, that is, a decorator for decorators.
As a decorator is a function, it actually works as a regular decorator
with arguments.
Args:
dec: Decorator to be applied to the function.
Returns:
Decorator that applies the given decorator to the function.
"""
def layer(
*args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Internal layer.
Args:
*args: Arguments to be passed to the decorator.
**kwargs: Keyword arguments to be passed to the decorator.
Returns:
Decorator that applies the given decorator to the function.
"""
def repl(f: Callable[..., Any]) -> Callable[..., Any]:
"""Internal REPL.
Args:
f: Function to be decorated.
Returns:
Decorated function.
"""
return dec(f, *args, **kwargs)
return repl
return layer
track(*args, **kwargs)
Internal layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Arguments to be passed to the decorator. |
() |
**kwargs |
Any |
Keyword arguments to be passed to the decorator. |
{} |
Returns:
Type | Description |
---|---|
Callable[[Callable[..., Any]], Callable[..., Any]] |
Decorator that applies the given decorator to the function. |
Source code in zenml/utils/analytics_utils.py
def layer(
*args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Internal layer.
Args:
*args: Arguments to be passed to the decorator.
**kwargs: Keyword arguments to be passed to the decorator.
Returns:
Decorator that applies the given decorator to the function.
"""
def repl(f: Callable[..., Any]) -> Callable[..., Any]:
"""Internal REPL.
Args:
f: Function to be decorated.
Returns:
Decorated function.
"""
return dec(f, *args, **kwargs)
return repl
track_event(event, metadata=None)
Track segment event if user opted-in.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
event |
Union[str, zenml.utils.analytics_utils.AnalyticsEvent] |
Name of event to track in segment. |
required |
metadata |
Optional[Dict[str, Any]] |
Dict of metadata to track. |
None |
Returns:
Type | Description |
---|---|
bool |
True if event is sent successfully, False is not. |
Source code in zenml/utils/analytics_utils.py
def track_event(
event: Union[str, AnalyticsEvent],
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""Track segment event if user opted-in.
Args:
event: Name of event to track in segment.
metadata: Dict of metadata to track.
Returns:
True if event is sent successfully, False is not.
"""
with AnalyticsContext() as analytics:
return analytics.track(event, metadata)
return False
daemon
Utility functions to start/stop daemon processes.
This is only implemented for UNIX systems and therefore doesn't work on Windows. Based on https://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/
check_if_daemon_is_running(pid_file)
Checks whether a daemon process indicated by the PID file is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the daemon process is running, otherwise False. |
Source code in zenml/utils/daemon.py
def check_if_daemon_is_running(pid_file: str) -> bool:
"""Checks whether a daemon process indicated by the PID file is running.
Args:
pid_file: Path to file containing the PID of the daemon
process to check.
Returns:
True if the daemon process is running, otherwise False.
"""
return get_daemon_pid_if_running(pid_file) is not None
daemonize(pid_file, log_file=None, working_directory='/')
Decorator that executes the decorated function as a daemon process.
Use this decorator to easily transform any function into a daemon process.
Examples:
import time
from zenml.utils.daemonizer import daemonize
@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
print(f"I'm a daemon! I will sleep for {period} seconds.")
time.sleep(period)
print("Done sleeping, flying away.")
sleeping_daemon(period=30)
print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
an optional file where the PID of the daemon process will be stored. |
required |
log_file |
Optional[str] |
file where stdout and stderr are redirected for the daemon process. If not supplied, the daemon will be silenced (i.e. have its stdout/stderr redirected to /dev/null). |
None |
working_directory |
str |
working directory for the daemon process, defaults to the root directory. |
'/' |
Returns:
Type | Description |
---|---|
Callable[[~F], ~F] |
Decorated function that, when called, will detach from the current process and continue executing in the background, as a daemon process. |
Source code in zenml/utils/daemon.py
def daemonize(
pid_file: str,
log_file: Optional[str] = None,
working_directory: str = "/",
) -> Callable[[F], F]:
"""Decorator that executes the decorated function as a daemon process.
Use this decorator to easily transform any function into a daemon
process.
Example:
```python
import time
from zenml.utils.daemonizer import daemonize
@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
print(f"I'm a daemon! I will sleep for {period} seconds.")
time.sleep(period)
print("Done sleeping, flying away.")
sleeping_daemon(period=30)
print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel
```
Args:
pid_file: an optional file where the PID of the daemon process will
be stored.
log_file: file where stdout and stderr are redirected for the daemon
process. If not supplied, the daemon will be silenced (i.e. have
its stdout/stderr redirected to /dev/null).
working_directory: working directory for the daemon process,
defaults to the root directory.
Returns:
Decorated function that, when called, will detach from the current
process and continue executing in the background, as a daemon
process.
"""
def inner_decorator(_func: F) -> F:
def daemon(*args: Any, **kwargs: Any) -> None:
"""Standard daemonization of a process.
Args:
*args: Arguments to be passed to the decorated function.
**kwargs: Keyword arguments to be passed to the decorated
function.
"""
# flake8: noqa: C901
if sys.platform == "win32":
logger.error(
"Daemon functionality is currently not supported on Windows."
)
else:
run_as_daemon(
_func,
log_file=log_file,
pid_file=pid_file,
working_directory=working_directory,
*args,
**kwargs,
)
return cast(F, daemon)
return inner_decorator
get_daemon_pid_if_running(pid_file)
Read and return the PID value from a PID file.
It does this if the daemon process tracked by the PID file is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to check. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The PID of the daemon process if it is running, otherwise None. |
Source code in zenml/utils/daemon.py
def get_daemon_pid_if_running(pid_file: str) -> Optional[int]:
"""Read and return the PID value from a PID file.
It does this if the daemon process tracked by the PID file is running.
Args:
pid_file: Path to file containing the PID of the daemon
process to check.
Returns:
The PID of the daemon process if it is running, otherwise None.
"""
try:
with open(pid_file, "r") as f:
pid = int(f.read().strip())
except (IOError, FileNotFoundError):
logger.debug(
f"Daemon PID file '{pid_file}' does not exist or cannot be read."
)
return None
if not pid or not psutil.pid_exists(pid):
logger.debug(f"Daemon with PID '{pid}' is no longer running.")
return None
logger.debug(f"Daemon with PID '{pid}' is running.")
return pid
run_as_daemon(daemon_function, *args, *, pid_file, log_file=None, working_directory='/', **kwargs)
Runs a function as a daemon process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
daemon_function |
~F |
The function to run as a daemon. |
required |
pid_file |
str |
Path to file in which to store the PID of the daemon process. |
required |
log_file |
Optional[str] |
Optional file to which the daemons stdout/stderr will be redirected to. |
None |
working_directory |
str |
Working directory for the daemon process, defaults to the root directory. |
'/' |
args |
Any |
Positional arguments to pass to the daemon function. |
() |
kwargs |
Any |
Keyword arguments to pass to the daemon function. |
{} |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If the PID file already exists. |
Source code in zenml/utils/daemon.py
def run_as_daemon(
daemon_function: F,
*args: Any,
pid_file: str,
log_file: Optional[str] = None,
working_directory: str = "/",
**kwargs: Any,
) -> None:
"""Runs a function as a daemon process.
Args:
daemon_function: The function to run as a daemon.
pid_file: Path to file in which to store the PID of the daemon
process.
log_file: Optional file to which the daemons stdout/stderr will be
redirected to.
working_directory: Working directory for the daemon process,
defaults to the root directory.
args: Positional arguments to pass to the daemon function.
kwargs: Keyword arguments to pass to the daemon function.
Raises:
FileExistsError: If the PID file already exists.
"""
# convert to absolute path as we will change working directory later
if pid_file:
pid_file = os.path.abspath(pid_file)
if log_file:
log_file = os.path.abspath(log_file)
# create parent directory if necessary
dir_name = os.path.dirname(pid_file)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
# check if PID file exists
if pid_file and os.path.exists(pid_file):
pid = get_daemon_pid_if_running(pid_file)
if pid:
raise FileExistsError(
f"The PID file '{pid_file}' already exists and a daemon "
f"process with the same PID '{pid}' is already running."
f"Please remove the PID file or kill the daemon process "
f"before starting a new daemon."
)
logger.warning(
f"Removing left over PID file '{pid_file}' from a previous "
f"daemon process that didn't shut down correctly."
)
os.remove(pid_file)
# first fork
try:
pid = os.fork()
if pid > 0:
# this is the process that called `run_as_daemon` so we
# simply return so it can keep running
return
except OSError as e:
logger.error("Unable to fork (error code: %d)", e.errno)
sys.exit(1)
# decouple from parent environment
os.chdir(working_directory)
os.setsid()
os.umask(0o22)
# second fork
try:
pid = os.fork()
if pid > 0:
# this is the parent of the future daemon process, kill it
# so the daemon gets adopted by the init process
sys.exit(0)
except OSError as e:
sys.stderr.write(f"Unable to fork (error code: {e.errno})")
sys.exit(1)
# redirect standard file descriptors to devnull (or the given logfile)
devnull = "/dev/null"
if hasattr(os, "devnull"):
devnull = os.devnull
devnull_fd = os.open(devnull, os.O_RDWR)
log_fd = (
os.open(log_file, os.O_CREAT | os.O_RDWR | os.O_APPEND)
if log_file
else None
)
out_fd = log_fd or devnull_fd
os.dup2(devnull_fd, sys.stdin.fileno())
os.dup2(out_fd, sys.stdout.fileno())
os.dup2(out_fd, sys.stderr.fileno())
if pid_file:
# write the PID file
with open(pid_file, "w+") as f:
f.write(f"{os.getpid()}\n")
# register actions in case this process exits/gets killed
def cleanup() -> None:
"""Daemon cleanup."""
sys.stderr.write("Cleanup: terminating children processes...\n")
terminate_children()
if pid_file and os.path.exists(pid_file):
sys.stderr.write(f"Cleanup: removing PID file {pid_file}...\n")
os.remove(pid_file)
sys.stderr.flush()
def sighndl(signum: int, frame: Optional[types.FrameType]) -> None:
"""Daemon signal handler.
Args:
signum: Signal number.
frame: Frame object.
"""
sys.stderr.write(f"Handling signal {signum}...\n")
cleanup()
signal.signal(signal.SIGTERM, sighndl)
signal.signal(signal.SIGINT, sighndl)
atexit.register(cleanup)
# finally run the actual daemon code
daemon_function(*args, **kwargs)
sys.exit(0)
stop_daemon(pid_file)
Stops a daemon process.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pid_file |
str |
Path to file containing the PID of the daemon process to kill. |
required |
Source code in zenml/utils/daemon.py
def stop_daemon(pid_file: str) -> None:
"""Stops a daemon process.
Args:
pid_file: Path to file containing the PID of the daemon process to
kill.
"""
try:
with open(pid_file, "r") as f:
pid = int(f.read().strip())
except (IOError, FileNotFoundError):
logger.warning("Daemon PID file '%s' does not exist.", pid_file)
return
if psutil.pid_exists(pid):
process = psutil.Process(pid)
process.terminate()
else:
logger.warning("PID from '%s' does not exist.", pid_file)
terminate_children()
Terminate all processes that are children of the currently running process.
Source code in zenml/utils/daemon.py
def terminate_children() -> None:
"""Terminate all processes that are children of the currently running process."""
pid = os.getpid()
try:
parent = psutil.Process(pid)
except psutil.Error:
# could not find parent process id
return
children = parent.children(recursive=False)
for p in children:
sys.stderr.write(f"Terminating child process with PID {p.pid}...\n")
p.terminate()
_, alive = psutil.wait_procs(
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
)
for p in alive:
sys.stderr.write(f"Killing child process with PID {p.pid}...\n")
p.kill()
_, alive = psutil.wait_procs(
children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
)
dashboard_utils
Utility class to help with interacting with the dashboard.
get_run_url(run_name, pipeline_id=None)
Computes a dashboard url to directly view the run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the pipeline run. |
required |
pipeline_id |
Optional[uuid.UUID] |
Optional pipeline_id, to be sent when available. |
None |
Returns:
Type | Description |
---|---|
Optional[str] |
A direct url link to the pipeline run details page. If run does not exist, returns None. |
Source code in zenml/utils/dashboard_utils.py
def get_run_url(
run_name: str, pipeline_id: Optional[UUID] = None
) -> Optional[str]:
"""Computes a dashboard url to directly view the run.
Args:
run_name: Name of the pipeline run.
pipeline_id: Optional pipeline_id, to be sent when available.
Returns:
A direct url link to the pipeline run details page. If run does not exist,
returns None.
"""
# Connected to ZenML Server
client = Client()
if client.zen_store.type != StoreType.REST:
return ""
url = client.zen_store.url
runs = client.zen_store.list_runs(run_name=run_name)
if pipeline_id:
url += f"/pipelines/{str(pipeline_id)}/runs"
elif runs:
url += "/runs"
else:
url += "/pipelines/all-runs"
if runs:
url += f"/{runs[0].id}/dag"
return url
print_run_url(run_name, pipeline_id=None)
Logs a dashboard url to directly view the run.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run_name |
str |
Name of the pipeline run. |
required |
pipeline_id |
Optional[uuid.UUID] |
Optional pipeline_id, to be sent when available. |
None |
Source code in zenml/utils/dashboard_utils.py
def print_run_url(run_name: str, pipeline_id: Optional[UUID] = None) -> None:
"""Logs a dashboard url to directly view the run.
Args:
run_name: Name of the pipeline run.
pipeline_id: Optional pipeline_id, to be sent when available.
"""
client = Client()
if client.zen_store.type == StoreType.REST:
url = get_run_url(
run_name,
pipeline_id,
)
if url:
logger.info(f"Dashboard URL: {url}")
elif client.zen_store.type == StoreType.SQL:
# Connected to SQL Store Type, we're local
logger.info(
"Pipeline visualization can be seen in the ZenML Dashboard. "
"Run `zenml up` to see your pipeline!"
)
deprecation_utils
Deprecation utilities.
deprecate_pydantic_attributes(*attributes)
Utility function for deprecating and migrating pydantic attributes.
Usage: To use this, you can specify it on any pydantic BaseModel subclass like this (all the deprecated attributes need to be non-required):
from pydantic import BaseModel
from typing import Optional
class MyModel(BaseModel):
deprecated: Optional[int] = None
old_name: Optional[str] = None
new_name: str
_deprecation_validator = deprecate_pydantic_attributes(
"deprecated", ("old_name", "new_name")
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*attributes |
Union[str, Tuple[str, str]] |
List of attributes to deprecate. This is either the name of the attribute to deprecate, or a tuple containing the name of the deprecated attribute and it's replacement. |
() |
Returns:
Type | Description |
---|---|
AnyClassMethod |
Pydantic validator class method to be used on BaseModel subclasses to deprecate or migrate attributes. |
Source code in zenml/utils/deprecation_utils.py
def deprecate_pydantic_attributes(
*attributes: Union[str, Tuple[str, str]]
) -> "AnyClassMethod":
"""Utility function for deprecating and migrating pydantic attributes.
**Usage**:
To use this, you can specify it on any pydantic BaseModel subclass like
this (all the deprecated attributes need to be non-required):
```python
from pydantic import BaseModel
from typing import Optional
class MyModel(BaseModel):
deprecated: Optional[int] = None
old_name: Optional[str] = None
new_name: str
_deprecation_validator = deprecate_pydantic_attributes(
"deprecated", ("old_name", "new_name")
)
```
Args:
*attributes: List of attributes to deprecate. This is either the name
of the attribute to deprecate, or a tuple containing the name of
the deprecated attribute and it's replacement.
Returns:
Pydantic validator class method to be used on BaseModel subclasses
to deprecate or migrate attributes.
"""
@root_validator(pre=True, allow_reuse=True)
def _deprecation_validator(
cls: Type[BaseModel], values: Dict[str, Any]
) -> Dict[str, Any]:
"""Pydantic validator function for deprecating pydantic attributes.
Args:
cls: The class on which the attributes are defined.
values: All values passed at model initialization.
Raises:
AssertionError: If either the deprecated or replacement attribute
don't exist.
TypeError: If the deprecated attribute is a required attribute.
ValueError: If the deprecated attribute and replacement attribute
contain different values.
Returns:
Input values with potentially migrated values.
"""
previous_deprecation_warnings: Set[str] = getattr(
cls, PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE, set()
)
def _warn(message: str, attribute: str) -> None:
"""Logs and raises a warning for a deprecated attribute.
Args:
message: The warning message.
attribute: The name of the attribute.
"""
if attribute not in previous_deprecation_warnings:
logger.warning(message)
previous_deprecation_warnings.add(attribute)
warnings.warn(
message,
DeprecationWarning,
)
for attribute in attributes:
if isinstance(attribute, str):
deprecated_attribute = attribute
replacement_attribute = None
else:
deprecated_attribute, replacement_attribute = attribute
assert (
replacement_attribute in cls.__fields__
), f"Unable to find attribute {replacement_attribute}."
assert (
deprecated_attribute in cls.__fields__
), f"Unable to find attribute {deprecated_attribute}."
if cls.__fields__[deprecated_attribute].required:
raise TypeError(
f"Unable to deprecate attribute '{deprecated_attribute}' "
f"of class {cls.__name__}. In order to deprecate an "
"attribute, it needs to be a non-required attribute. "
"To do so, mark the attribute with an `Optional[...] type "
"annotation."
)
if values.get(deprecated_attribute, None) is None:
continue
if replacement_attribute is None:
_warn(
message=f"The attribute `{deprecated_attribute}` of class "
f"`{cls.__name__}` will be deprecated soon.",
attribute=deprecated_attribute,
)
continue
_warn(
message=f"The attribute `{deprecated_attribute}` of class "
f"`{cls.__name__}` will be deprecated soon. Use the "
f"attribute `{replacement_attribute}` instead.",
attribute=deprecated_attribute,
)
if values.get(replacement_attribute, None) is None:
logger.debug(
"Migrating value of deprecated attribute %s to "
"replacement attribute %s.",
deprecated_attribute,
replacement_attribute,
)
values[replacement_attribute] = values.pop(deprecated_attribute)
elif values[deprecated_attribute] != values[replacement_attribute]:
raise ValueError(
"Got different values for deprecated attribute "
f"{deprecated_attribute} and replacement "
f"attribute {replacement_attribute}."
)
else:
# Both values are identical, no need to do anything
pass
setattr(
cls,
PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE,
previous_deprecation_warnings,
)
return values
return _deprecation_validator
dict_utils
Util functions for dictionaries.
recursive_update(original, update)
Recursively updates a dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
original |
Dict[str, Any] |
The dictionary to update. |
required |
update |
Dict[str, Any] |
The dictionary containing the updated values. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If the value types of original and update don't match. |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The updated dictionary. |
Source code in zenml/utils/dict_utils.py
def recursive_update(
original: Dict[str, Any], update: Dict[str, Any]
) -> Dict[str, Any]:
"""Recursively updates a dictionary.
Args:
original: The dictionary to update.
update: The dictionary containing the updated values.
Raises:
TypeError: If the value types of original and update don't match.
Returns:
The updated dictionary.
"""
for key, value in update.items():
if isinstance(value, Dict):
original_value = original.get(key, None) or {}
if not isinstance(original_value, Dict):
raise TypeError(
f"Type of dictionary values for key {key} does not match "
"in original and update dict (original="
f"{type(original_value)}, update={type(value)})."
)
original[key] = recursive_update(original_value, value)
else:
original[key] = value
return original
remove_none_values(dict_, recursive=False)
Removes all key-value pairs with None
value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dict_ |
Dict[str, Any] |
The dict from which the key-value pairs should be removed. |
required |
recursive |
bool |
If |
False |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The updated dictionary. |
Source code in zenml/utils/dict_utils.py
def remove_none_values(
dict_: Dict[str, Any], recursive: bool = False
) -> Dict[str, Any]:
"""Removes all key-value pairs with `None` value.
Args:
dict_: The dict from which the key-value pairs should be removed.
recursive: If `True`, will recursively remove `None` values in all
child dicts.
Returns:
The updated dictionary.
"""
def _maybe_recurse(value: Any) -> Any:
"""Calls `remove_none_values` recursively if required.
Args:
value: A dictionary value.
Returns:
The updated dictionary value.
"""
if recursive and isinstance(value, Dict):
return remove_none_values(value, recursive=True)
else:
return value
return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None}
docker_utils
Utility functions relating to Docker.
build_image(image_name, dockerfile, build_context_root=None, dockerignore=None, extra_files=(), **custom_build_options)
Builds a docker image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The name to use for the built docker image. |
required |
dockerfile |
Union[str, List[str]] |
Path to a dockerfile or a list of strings representing the Dockerfile lines/commands. |
required |
build_context_root |
Optional[str] |
Optional path to a directory that will be sent to the Docker daemon as build context. If left empty, the Docker build context will be empty. |
None |
dockerignore |
Optional[str] |
Optional path to a dockerignore file. If no value is
given, the .dockerignore in the root of the build context will be
used if it exists. Otherwise, all files inside |
None |
extra_files |
Sequence[Tuple[str, str]] |
Additional files to include in the build context. The files should be passed as a tuple (filepath_inside_build_context, file_content) and will overwrite existing files in the build context if they share the same path. |
() |
**custom_build_options |
Any |
Additional options that will be passed unmodified to the Docker build call when building the image. You can use this to for example specify build args or a target stage. See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build for a full list of available options. |
{} |
Source code in zenml/utils/docker_utils.py
def build_image(
image_name: str,
dockerfile: Union[str, List[str]],
build_context_root: Optional[str] = None,
dockerignore: Optional[str] = None,
extra_files: Sequence[Tuple[str, str]] = (),
**custom_build_options: Any,
) -> None:
"""Builds a docker image.
Args:
image_name: The name to use for the built docker image.
dockerfile: Path to a dockerfile or a list of strings representing the
Dockerfile lines/commands.
build_context_root: Optional path to a directory that will be sent to
the Docker daemon as build context. If left empty, the Docker build
context will be empty.
dockerignore: Optional path to a dockerignore file. If no value is
given, the .dockerignore in the root of the build context will be
used if it exists. Otherwise, all files inside `build_context_root`
are included in the build context.
extra_files: Additional files to include in the build context. The
files should be passed as a tuple
(filepath_inside_build_context, file_content) and will overwrite
existing files in the build context if they share the same path.
**custom_build_options: Additional options that will be passed
unmodified to the Docker build call when building the image. You
can use this to for example specify build args or a target stage.
See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build
for a full list of available options.
"""
if isinstance(dockerfile, str):
dockerfile_contents = io_utils.read_file_contents_as_string(dockerfile)
logger.info("Using Dockerfile `%s`.", os.path.abspath(dockerfile))
else:
dockerfile_contents = "\n".join(dockerfile)
build_context = _create_custom_build_context(
dockerfile_contents=dockerfile_contents,
build_context_root=build_context_root,
dockerignore=dockerignore,
extra_files=extra_files,
)
build_options = {
"rm": False, # don't remove intermediate containers to improve caching
"pull": True, # always pull parent images
**custom_build_options,
}
logger.info("Building Docker image `%s`.", image_name)
logger.debug("Docker build options: %s", build_options)
logger.info("Building the image might take a while...")
docker_client = DockerClient.from_env()
# We use the client api directly here, so we can stream the logs
output_stream = docker_client.images.client.api.build(
fileobj=build_context,
custom_context=True,
tag=image_name,
**build_options,
)
_process_stream(output_stream)
logger.info("Finished building Docker image `%s`.", image_name)
check_docker()
Checks if Docker is installed and running.
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/docker_utils.py
def check_docker() -> bool:
"""Checks if Docker is installed and running.
Returns:
`True` if Docker is installed, `False` otherwise.
"""
# Try to ping Docker, to see if it's running
try:
docker_client = DockerClient.from_env()
docker_client.ping()
return True
except Exception:
logger.debug("Docker is not running.", exc_info=True)
return False
get_image_digest(image_name)
Gets the digest of an image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the image to get the digest for. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
Returns the repo digest for the given image if there exists exactly one.
If there are zero or multiple repo digests, returns |
Source code in zenml/utils/docker_utils.py
def get_image_digest(image_name: str) -> Optional[str]:
"""Gets the digest of an image.
Args:
image_name: Name of the image to get the digest for.
Returns:
Returns the repo digest for the given image if there exists exactly one.
If there are zero or multiple repo digests, returns `None`.
"""
docker_client = DockerClient.from_env()
image = docker_client.images.get(image_name)
repo_digests = image.attrs["RepoDigests"]
if len(repo_digests) == 1:
return cast(str, repo_digests[0])
else:
logger.debug(
"Found zero or more repo digests for docker image '%s': %s",
image_name,
repo_digests,
)
return None
is_local_image(image_name)
Returns whether an image was pulled from a registry or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the image to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/docker_utils.py
def is_local_image(image_name: str) -> bool:
"""Returns whether an image was pulled from a registry or not.
Args:
image_name: Name of the image to check.
Returns:
`True` if the image was pulled from a registry, `False` otherwise.
"""
docker_client = DockerClient.from_env()
images = docker_client.images.list(name=image_name)
if images:
# An image with this name is available locally -> now check whether it
# was pulled from a repo or built locally (in which case the repo
# digest is empty)
return get_image_digest(image_name) is None
else:
# no image with this name found locally
return False
push_image(image_name)
Pushes an image to a container registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The full name (including a tag) of the image to push. |
required |
Returns:
Type | Description |
---|---|
str |
The Docker repository digest of the pushed image. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If fetching the repository digest of the image failed. |
Source code in zenml/utils/docker_utils.py
def push_image(image_name: str) -> str:
"""Pushes an image to a container registry.
Args:
image_name: The full name (including a tag) of the image to push.
Returns:
The Docker repository digest of the pushed image.
Raises:
RuntimeError: If fetching the repository digest of the image failed.
"""
logger.info("Pushing Docker image `%s`.", image_name)
docker_client = DockerClient.from_env()
output_stream = docker_client.images.push(image_name, stream=True)
aux_info = _process_stream(output_stream)
logger.info("Finished pushing Docker image.")
image_name_without_tag, _ = image_name.rsplit(":", maxsplit=1)
for info in reversed(aux_info):
try:
repo_digest = info["Digest"]
return f"{image_name_without_tag}@{repo_digest}"
except KeyError:
pass
else:
raise RuntimeError(
f"Unable to find repo digest after pushing image {image_name}."
)
tag_image(image_name, target)
Tags an image.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
The name of the image to tag. |
required |
target |
str |
The full target name including a tag. |
required |
Source code in zenml/utils/docker_utils.py
def tag_image(image_name: str, target: str) -> None:
"""Tags an image.
Args:
image_name: The name of the image to tag.
target: The full target name including a tag.
"""
docker_client = DockerClient.from_env()
image = docker_client.images.get(image_name)
image.tag(target)
enum_utils
Util functions for enums.
StrEnum (str, Enum)
Base enum type for string enum values.
Source code in zenml/utils/enum_utils.py
class StrEnum(str, Enum):
"""Base enum type for string enum values."""
def __str__(self) -> str:
"""Returns the enum string value.
Returns:
The enum string value.
"""
return self.value # type: ignore
@classmethod
def names(cls) -> List[str]:
"""Get all enum names as a list of strings.
Returns:
A list of all enum names.
"""
return [c.name for c in cls]
@classmethod
def values(cls) -> List[str]:
"""Get all enum values as a list of strings.
Returns:
A list of all enum values.
"""
return [c.value for c in cls]
filesync_model
Filesync utils for ZenML.
FileSyncModel (BaseModel)
pydantic-model
Pydantic model synchronized with a configuration file.
Use this class as a base Pydantic model that is automatically synchronized with a configuration file on disk.
This class overrides the setattr and getattr magic methods to ensure that the FileSyncModel instance acts as an in-memory cache of the information stored in the associated configuration file.
Source code in zenml/utils/filesync_model.py
class FileSyncModel(BaseModel):
"""Pydantic model synchronized with a configuration file.
Use this class as a base Pydantic model that is automatically synchronized
with a configuration file on disk.
This class overrides the __setattr__ and __getattr__ magic methods to
ensure that the FileSyncModel instance acts as an in-memory cache of the
information stored in the associated configuration file.
"""
_config_file: str
_config_file_timestamp: Optional[float]
def __init__(self, config_file: str, **kwargs: Any) -> None:
"""Create a FileSyncModel instance synchronized with a configuration file on disk.
Args:
config_file: configuration file path. If the file exists, the model
will be initialized with the values from the file.
**kwargs: additional keyword arguments to pass to the Pydantic model
constructor. If supplied, these values will override those
loaded from the configuration file.
"""
config_dict = {}
if fileio.exists(config_file):
config_dict = yaml_utils.read_yaml(config_file)
self._config_file = config_file
self._config_file_timestamp = None
config_dict.update(kwargs)
super(FileSyncModel, self).__init__(**config_dict)
# write the configuration file to disk, to reflect new attributes
# and schema changes
self.write_config()
def __setattr__(self, key: str, value: Any) -> None:
"""Sets an attribute on the model and persists it in the configuration file.
Args:
key: attribute name.
value: attribute value.
"""
super(FileSyncModel, self).__setattr__(key, value)
if key.startswith("_"):
return
self.write_config()
def __getattribute__(self, key: str) -> Any:
"""Gets an attribute value for a specific key.
Args:
key: attribute name.
Returns:
attribute value.
"""
if not key.startswith("_") and key in self.__dict__:
self.load_config()
return super(FileSyncModel, self).__getattribute__(key)
def write_config(self) -> None:
"""Writes the model to the configuration file."""
config_dict = json.loads(self.json())
yaml_utils.write_yaml(self._config_file, config_dict)
self._config_file_timestamp = os.path.getmtime(self._config_file)
def load_config(self) -> None:
"""Loads the model from the configuration file on disk."""
if not fileio.exists(self._config_file):
return
# don't reload the configuration if the file hasn't
# been updated since the last load
file_timestamp = os.path.getmtime(self._config_file)
if file_timestamp == self._config_file_timestamp:
return
if self._config_file_timestamp is not None:
logger.info(f"Reloading configuration file {self._config_file}")
# refresh the model from the configuration file values
config_dict = yaml_utils.read_yaml(self._config_file)
for key, value in config_dict.items():
super(FileSyncModel, self).__setattr__(key, value)
self._config_file_timestamp = file_timestamp
class Config:
"""Pydantic configuration class."""
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
Config
Pydantic configuration class.
Source code in zenml/utils/filesync_model.py
class Config:
"""Pydantic configuration class."""
# all attributes with leading underscore are private and therefore
# are mutable and not included in serialization
underscore_attrs_are_private = True
__getattribute__(self, key)
special
Gets an attribute value for a specific key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
attribute name. |
required |
Returns:
Type | Description |
---|---|
Any |
attribute value. |
Source code in zenml/utils/filesync_model.py
def __getattribute__(self, key: str) -> Any:
"""Gets an attribute value for a specific key.
Args:
key: attribute name.
Returns:
attribute value.
"""
if not key.startswith("_") and key in self.__dict__:
self.load_config()
return super(FileSyncModel, self).__getattribute__(key)
__init__(self, config_file, **kwargs)
special
Create a FileSyncModel instance synchronized with a configuration file on disk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_file |
str |
configuration file path. If the file exists, the model will be initialized with the values from the file. |
required |
**kwargs |
Any |
additional keyword arguments to pass to the Pydantic model constructor. If supplied, these values will override those loaded from the configuration file. |
{} |
Source code in zenml/utils/filesync_model.py
def __init__(self, config_file: str, **kwargs: Any) -> None:
"""Create a FileSyncModel instance synchronized with a configuration file on disk.
Args:
config_file: configuration file path. If the file exists, the model
will be initialized with the values from the file.
**kwargs: additional keyword arguments to pass to the Pydantic model
constructor. If supplied, these values will override those
loaded from the configuration file.
"""
config_dict = {}
if fileio.exists(config_file):
config_dict = yaml_utils.read_yaml(config_file)
self._config_file = config_file
self._config_file_timestamp = None
config_dict.update(kwargs)
super(FileSyncModel, self).__init__(**config_dict)
# write the configuration file to disk, to reflect new attributes
# and schema changes
self.write_config()
__setattr__(self, key, value)
special
Sets an attribute on the model and persists it in the configuration file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
attribute name. |
required |
value |
Any |
attribute value. |
required |
Source code in zenml/utils/filesync_model.py
def __setattr__(self, key: str, value: Any) -> None:
"""Sets an attribute on the model and persists it in the configuration file.
Args:
key: attribute name.
value: attribute value.
"""
super(FileSyncModel, self).__setattr__(key, value)
if key.startswith("_"):
return
self.write_config()
load_config(self)
Loads the model from the configuration file on disk.
Source code in zenml/utils/filesync_model.py
def load_config(self) -> None:
"""Loads the model from the configuration file on disk."""
if not fileio.exists(self._config_file):
return
# don't reload the configuration if the file hasn't
# been updated since the last load
file_timestamp = os.path.getmtime(self._config_file)
if file_timestamp == self._config_file_timestamp:
return
if self._config_file_timestamp is not None:
logger.info(f"Reloading configuration file {self._config_file}")
# refresh the model from the configuration file values
config_dict = yaml_utils.read_yaml(self._config_file)
for key, value in config_dict.items():
super(FileSyncModel, self).__setattr__(key, value)
self._config_file_timestamp = file_timestamp
write_config(self)
Writes the model to the configuration file.
Source code in zenml/utils/filesync_model.py
def write_config(self) -> None:
"""Writes the model to the configuration file."""
config_dict = json.loads(self.json())
yaml_utils.write_yaml(self._config_file, config_dict)
self._config_file_timestamp = os.path.getmtime(self._config_file)
io_utils
Various utility functions for the io module.
convert_to_str(path)
Converts a PathType to a str using UTF-8.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
PathType |
Path to convert. |
required |
Returns:
Type | Description |
---|---|
str |
Converted path. |
Source code in zenml/utils/io_utils.py
def convert_to_str(path: "PathType") -> str:
"""Converts a PathType to a str using UTF-8.
Args:
path: Path to convert.
Returns:
Converted path.
"""
if isinstance(path, str):
return path
else:
return path.decode("utf-8")
copy_dir(source_dir, destination_dir, overwrite=False)
Copies dir from source to destination.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source_dir |
str |
Path to copy from. |
required |
destination_dir |
str |
Path to copy to. |
required |
overwrite |
bool |
Boolean. If false, function throws an error before overwrite. |
False |
Source code in zenml/utils/io_utils.py
def copy_dir(
source_dir: str, destination_dir: str, overwrite: bool = False
) -> None:
"""Copies dir from source to destination.
Args:
source_dir: Path to copy from.
destination_dir: Path to copy to.
overwrite: Boolean. If false, function throws an error before overwrite.
"""
for source_file in listdir(source_dir):
source_path = os.path.join(source_dir, convert_to_str(source_file))
destination_path = os.path.join(
destination_dir, convert_to_str(source_file)
)
if isdir(source_path):
if source_path == destination_dir:
# if the destination is a subdirectory of the source, we skip
# copying it to avoid an infinite loop.
return
copy_dir(source_path, destination_path, overwrite)
else:
create_dir_recursive_if_not_exists(
os.path.dirname(destination_path)
)
copy(str(source_path), str(destination_path), overwrite)
create_dir_if_not_exists(dir_path)
Creates directory if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Local path in filesystem. |
required |
Source code in zenml/utils/io_utils.py
def create_dir_if_not_exists(dir_path: str) -> None:
"""Creates directory if it does not exist.
Args:
dir_path: Local path in filesystem.
"""
if not isdir(dir_path):
mkdir(dir_path)
create_dir_recursive_if_not_exists(dir_path)
Creates directory recursively if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Local path in filesystem. |
required |
Source code in zenml/utils/io_utils.py
def create_dir_recursive_if_not_exists(dir_path: str) -> None:
"""Creates directory recursively if it does not exist.
Args:
dir_path: Local path in filesystem.
"""
if not isdir(dir_path):
makedirs(dir_path)
create_file_if_not_exists(file_path, file_contents='{}')
Creates file if it does not exist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Local path in filesystem. |
required |
file_contents |
str |
Contents of file. |
'{}' |
Source code in zenml/utils/io_utils.py
def create_file_if_not_exists(
file_path: str, file_contents: str = "{}"
) -> None:
"""Creates file if it does not exist.
Args:
file_path: Local path in filesystem.
file_contents: Contents of file.
"""
full_path = Path(file_path)
if not exists(file_path):
create_dir_recursive_if_not_exists(str(full_path.parent))
with open(str(full_path), "w") as f:
f.write(file_contents)
find_files(dir_path, pattern)
Find files in a directory that match pattern.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
PathType |
Path to directory. |
required |
pattern |
str |
pattern like *.png. |
required |
Yields:
Type | Description |
---|---|
Iterable[str] |
All matching filenames if found. |
Source code in zenml/utils/io_utils.py
def find_files(dir_path: "PathType", pattern: str) -> Iterable[str]:
"""Find files in a directory that match pattern.
Args:
dir_path: Path to directory.
pattern: pattern like *.png.
Yields:
All matching filenames if found.
"""
for root, dirs, files in walk(dir_path):
for basename in files:
if fnmatch.fnmatch(convert_to_str(basename), pattern):
filename = os.path.join(
convert_to_str(root), convert_to_str(basename)
)
yield filename
get_global_config_directory()
Gets the global config directory for ZenML.
Returns:
Type | Description |
---|---|
str |
The global config directory for ZenML. |
Source code in zenml/utils/io_utils.py
def get_global_config_directory() -> str:
"""Gets the global config directory for ZenML.
Returns:
The global config directory for ZenML.
"""
env_var_path = os.getenv(ENV_ZENML_CONFIG_PATH)
if env_var_path:
return str(Path(env_var_path).resolve())
return click.get_app_dir(APP_NAME)
get_grandparent(dir_path)
Get grandparent of dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Path to directory. |
required |
Returns:
Type | Description |
---|---|
str |
The input path's parent's parent. |
Source code in zenml/utils/io_utils.py
def get_grandparent(dir_path: str) -> str:
"""Get grandparent of dir.
Args:
dir_path: Path to directory.
Returns:
The input path's parent's parent.
"""
return Path(dir_path).parent.parent.stem
get_parent(dir_path)
Get parent of dir.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str |
Path to directory. |
required |
Returns:
Type | Description |
---|---|
str |
Parent (stem) of the dir as a string. |
Source code in zenml/utils/io_utils.py
def get_parent(dir_path: str) -> str:
"""Get parent of dir.
Args:
dir_path: Path to directory.
Returns:
Parent (stem) of the dir as a string.
"""
return Path(dir_path).parent.stem
is_remote(path)
Returns True if path exists remotely.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Any path as a string. |
required |
Returns:
Type | Description |
---|---|
bool |
True if remote path, else False. |
Source code in zenml/utils/io_utils.py
def is_remote(path: str) -> bool:
"""Returns True if path exists remotely.
Args:
path: Any path as a string.
Returns:
True if remote path, else False.
"""
return any(path.startswith(prefix) for prefix in REMOTE_FS_PREFIX)
is_root(path)
Returns true if path has no parent in local filesystem.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Local path in filesystem. |
required |
Returns:
Type | Description |
---|---|
bool |
True if root, else False. |
Source code in zenml/utils/io_utils.py
def is_root(path: str) -> bool:
"""Returns true if path has no parent in local filesystem.
Args:
path: Local path in filesystem.
Returns:
True if root, else False.
"""
return Path(path).parent == Path(path)
read_file_contents_as_string(file_path)
Reads contents of file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to file. |
required |
Returns:
Type | Description |
---|---|
str |
Contents of file. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If file does not exist. |
Source code in zenml/utils/io_utils.py
def read_file_contents_as_string(file_path: str) -> str:
"""Reads contents of file.
Args:
file_path: Path to file.
Returns:
Contents of file.
Raises:
FileNotFoundError: If file does not exist.
"""
if not exists(file_path):
raise FileNotFoundError(f"{file_path} does not exist!")
with open(file_path) as f:
return f.read() # type: ignore[no-any-return]
resolve_relative_path(path)
Takes relative path and resolves it absolutely.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Local path in filesystem. |
required |
Returns:
Type | Description |
---|---|
str |
Resolved path. |
Source code in zenml/utils/io_utils.py
def resolve_relative_path(path: str) -> str:
"""Takes relative path and resolves it absolutely.
Args:
path: Local path in filesystem.
Returns:
Resolved path.
"""
if is_remote(path):
return path
return str(Path(path).resolve())
write_file_contents_as_string(file_path, content)
Writes contents of file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to file. |
required |
content |
str |
Contents of file. |
required |
Source code in zenml/utils/io_utils.py
def write_file_contents_as_string(file_path: str, content: str) -> None:
"""Writes contents of file.
Args:
file_path: Path to file.
content: Contents of file.
"""
with open(file_path, "w") as f:
f.write(content)
materializer_utils
Util functions for models and materializers.
load_model_from_metadata(model_uri)
Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created by the save_model_metadata function. The information in the Yaml file is used to load the model into memory in the inference environment.
model_uri: the URI of the model checkpoint/files to load. datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_uri |
str |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
Any |
The ML model object loaded into memory. |
Source code in zenml/utils/materializer_utils.py
def load_model_from_metadata(model_uri: str) -> Any:
"""Load a zenml model artifact from a json file.
This function is used to load information from a Yaml file that was created
by the save_model_metadata function. The information in the Yaml file is
used to load the model into memory in the inference environment.
model_uri: the URI of the model checkpoint/files to load.
datatype: the model type. This is the path to the model class.
materializer: the materializer class. This is the path to the materializer class.
Args:
model_uri: the artifact to extract the metadata from.
Returns:
The ML model object loaded into memory.
"""
with fileio.open(
os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
) as f:
metadata = read_yaml(f.name)
model_artifact = Artifact()
model_artifact.uri = model_uri
model_artifact.properties[METADATA_DATATYPE].string_value = metadata[
METADATA_DATATYPE
]
model_artifact.properties[METADATA_MATERIALIZER].string_value = metadata[
METADATA_MATERIALIZER
]
materializer_class = source_utils.load_source_path_class(
model_artifact.properties[METADATA_MATERIALIZER].string_value
)
model_class = source_utils.load_source_path_class(
model_artifact.properties[METADATA_DATATYPE].string_value
)
materializer_object: BaseMaterializer = materializer_class(model_artifact)
model = materializer_object.handle_input(model_class)
try:
import torch.nn as nn
if issubclass(model_class, nn.Module): # type: ignore
model.eval()
except ImportError:
pass
logger.debug(f"Model loaded successfully :\n{model}")
return model
model_from_model_artifact(model_artifact)
Load model to memory from a model artifact.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_artifact |
ModelArtifact |
The model artifact to load. |
required |
Returns:
Type | Description |
---|---|
Any |
The ML model object loaded into memory. |
Source code in zenml/utils/materializer_utils.py
def model_from_model_artifact(model_artifact: ModelArtifact) -> Any:
"""Load model to memory from a model artifact.
Args:
model_artifact: The model artifact to load.
Returns:
The ML model object loaded into memory.
"""
materializer_class = source_utils.load_source_path_class(
model_artifact.materializer
)
model_class = source_utils.load_source_path_class(model_artifact.datatype)
materializer_object: BaseMaterializer = materializer_class(model_artifact)
model = materializer_object.handle_input(model_class)
logger.debug(f"Model loaded successfully :\n{model}")
return model
save_model_metadata(model_artifact)
Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model artifact such as the model type and materializer. The extracted information will be the key to loading the model into memory in the inference environment.
datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_artifact |
ArtifactModel |
the artifact to extract the metadata from. |
required |
Returns:
Type | Description |
---|---|
str |
The path to the temporary file where the model metadata is saved |
Source code in zenml/utils/materializer_utils.py
def save_model_metadata(model_artifact: ArtifactModel) -> str:
"""Save a zenml model artifact metadata to a YAML file.
This function is used to extract and save information from a zenml model artifact
such as the model type and materializer. The extracted information will be
the key to loading the model into memory in the inference environment.
datatype: the model type. This is the path to the model class.
materializer: the materializer class. This is the path to the materializer class.
Args:
model_artifact: the artifact to extract the metadata from.
Returns:
The path to the temporary file where the model metadata is saved
"""
metadata = dict()
metadata[METADATA_DATATYPE] = model_artifact.data_type
metadata[METADATA_MATERIALIZER] = model_artifact.materializer
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
write_yaml(f.name, metadata)
return f.name
networking_utils
Utility functions for networking.
find_available_port()
Finds a local random unoccupied TCP port.
Returns:
Type | Description |
---|---|
int |
A random unoccupied TCP port. |
Source code in zenml/utils/networking_utils.py
def find_available_port() -> int:
"""Finds a local random unoccupied TCP port.
Returns:
A random unoccupied TCP port.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
_, port = s.getsockname()
return cast(int, port)
port_available(port, address='127.0.0.1')
Checks if a local port is available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
port |
int |
TCP port number |
required |
address |
str |
IP address on the local machine |
'127.0.0.1' |
Returns:
Type | Description |
---|---|
bool |
True if the port is available, otherwise False |
Source code in zenml/utils/networking_utils.py
def port_available(port: int, address: str = "127.0.0.1") -> bool:
"""Checks if a local port is available.
Args:
port: TCP port number
address: IP address on the local machine
Returns:
True if the port is available, otherwise False
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if sys.platform != "win32":
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
else:
# The SO_REUSEPORT socket option is not supported on Windows.
# This if clause exists just for mypy to not complain about
# missing code paths.
pass
s.bind((address, port))
except socket.error as e:
logger.debug("Port %d unavailable on %s: %s", port, address, e)
return False
return True
port_is_open(hostname, port)
Check if a TCP port is open on a remote host.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hostname |
str |
hostname of the remote machine |
required |
port |
int |
TCP port number |
required |
Returns:
Type | Description |
---|---|
bool |
True if the port is open, False otherwise |
Source code in zenml/utils/networking_utils.py
def port_is_open(hostname: str, port: int) -> bool:
"""Check if a TCP port is open on a remote host.
Args:
hostname: hostname of the remote machine
port: TCP port number
Returns:
True if the port is open, False otherwise
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex((hostname, port))
return result == 0
except socket.error as e:
logger.debug(
f"Error checking TCP port {port} on host {hostname}: {str(e)}"
)
return False
replace_internal_hostname_with_localhost(hostname)
Replaces an internal Docker or K3D hostname with localhost.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(host.docker.internal
) and K3D (host.k3d.internal
) that can be used to
access host services from within the containers.
Use this method to replace one of these special hostnames with localhost if used outside a container or in a container where special hostnames are not available.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hostname |
str |
The hostname to replace. |
required |
Returns:
Type | Description |
---|---|
str |
The original or replaced hostname. |
Source code in zenml/utils/networking_utils.py
def replace_internal_hostname_with_localhost(hostname: str) -> str:
"""Replaces an internal Docker or K3D hostname with localhost.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
access host services from within the containers.
Use this method to replace one of these special hostnames with localhost
if used outside a container or in a container where special hostnames are
not available.
Args:
hostname: The hostname to replace.
Returns:
The original or replaced hostname.
"""
if hostname not in ("host.docker.internal", "host.k3d.internal"):
return hostname
if Environment.in_container():
# Try to resolve one of the special hostnames to see if it is available
# inside the container and use that if it is.
for internal_hostname in (
"host.docker.internal",
"host.k3d.internal",
):
try:
socket.gethostbyname(internal_hostname)
if internal_hostname != hostname:
logger.debug(
f"Replacing internal hostname {hostname} with "
f"{internal_hostname}"
)
return internal_hostname
except socket.gaierror:
continue
logger.debug(f"Replacing internal hostname {hostname} with localhost.")
return "127.0.0.1"
replace_localhost_with_internal_hostname(url)
Replaces the localhost with an internal Docker or K3D hostname in a given URL.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(host.docker.internal
) and K3D (host.k3d.internal
) that can be used to
access host services from within the containers.
Use this method to attempt to replace localhost
in a URL with one of these
special hostnames, if they are available inside a container.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
The URL to update. |
required |
Returns:
Type | Description |
---|---|
str |
The updated URL. |
Source code in zenml/utils/networking_utils.py
def replace_localhost_with_internal_hostname(url: str) -> str:
"""Replaces the localhost with an internal Docker or K3D hostname in a given URL.
Localhost URLs that are directly accessible on the host machine are not
accessible from within a Docker or K3D container running on that same
machine, but there are special hostnames featured by both Docker
(`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
access host services from within the containers.
Use this method to attempt to replace `localhost` in a URL with one of these
special hostnames, if they are available inside a container.
Args:
url: The URL to update.
Returns:
The updated URL.
"""
if not Environment.in_container():
return url
parsed_url = urlparse(url)
if parsed_url.hostname in ("localhost", "127.0.0.1"):
for internal_hostname in (
"host.docker.internal",
"host.k3d.internal",
):
try:
socket.gethostbyname(internal_hostname)
parsed_url = parsed_url._replace(
netloc=parsed_url.netloc.replace(
parsed_url.hostname,
internal_hostname,
)
)
logger.debug(
f"Replacing localhost with {internal_hostname} in URL: "
f"{url}"
)
return parsed_url.geturl()
except socket.gaierror:
continue
return url
scan_for_available_port(start=8000, stop=65535)
Scan the local network for an available port in the given range.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start |
int |
the beginning of the port range value to scan |
8000 |
stop |
int |
the (inclusive) end of the port range value to scan |
65535 |
Returns:
Type | Description |
---|---|
Optional[int] |
The first available port in the given range, or None if no available port is found. |
Source code in zenml/utils/networking_utils.py
def scan_for_available_port(
start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1]
) -> Optional[int]:
"""Scan the local network for an available port in the given range.
Args:
start: the beginning of the port range value to scan
stop: the (inclusive) end of the port range value to scan
Returns:
The first available port in the given range, or None if no available
port is found.
"""
for port in range(start, stop + 1):
if port_available(port):
return port
logger.debug(
"No free TCP ports found in the range %d - %d",
start,
stop,
)
return None
pipeline_docker_image_builder
Implementation of Docker image builds to run ZenML pipelines.
PipelineDockerImageBuilder
Builds Docker images to run a ZenML pipeline.
Usage:
class MyStackComponent(StackComponent, PipelineDockerImageBuilder):
def method_that_requires_docker_image(self):
image_identifier = self.build_and_push_docker_image(...)
# use the image ID
Source code in zenml/utils/pipeline_docker_image_builder.py
class PipelineDockerImageBuilder:
"""Builds Docker images to run a ZenML pipeline.
**Usage**:
```python
class MyStackComponent(StackComponent, PipelineDockerImageBuilder):
def method_that_requires_docker_image(self):
image_identifier = self.build_and_push_docker_image(...)
# use the image ID
```
"""
def build_and_push_docker_image(
self,
deployment: "PipelineDeployment",
stack: "Stack",
entrypoint: Optional[str] = None,
) -> str:
"""Builds and pushes a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely
reference the pushed image in order to pull or run it.
Args:
deployment: The pipeline deployment for which the image should be
built.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
Returns:
The Docker repository digest of the pushed image.
Raises:
RuntimeError: If the stack doesn't contain a container registry.
"""
container_registry = stack.container_registry
if not container_registry:
raise RuntimeError(
"Unable to build and push Docker image because stack "
f"`{stack.name}` has no container registry."
)
target_image_name = self.get_target_image_name(
deployment=deployment, container_registry=container_registry
)
self.build_docker_image(
target_image_name=target_image_name,
deployment=deployment,
stack=stack,
entrypoint=entrypoint,
)
repo_digest = container_registry.push_image(target_image_name)
return repo_digest
@staticmethod
def get_target_image_name(
deployment: "PipelineDeployment",
container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
"""Returns the target image name.
If a container registry is given, the image name will include the
registry URI
Args:
deployment: The pipeline deployment for which the target image name
should be returned.
container_registry: Optional container registry to which this
image will be pushed.
Returns:
The docker image name.
"""
pipeline_name = deployment.pipeline.name
docker_settings = (
deployment.pipeline.docker_settings or DockerSettings()
)
target_image_name = (
f"{docker_settings.target_repository}:{pipeline_name}"
)
if container_registry:
target_image_name = (
f"{container_registry.config.uri}/{target_image_name}"
)
return target_image_name
def build_docker_image(
self,
target_image_name: str,
deployment: "PipelineDeployment",
stack: "Stack",
entrypoint: Optional[str] = None,
) -> None:
"""Builds a Docker image to run a pipeline.
Args:
target_image_name: The name of the image to build.
deployment: The pipeline deployment for which the image should be
built.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
Raises:
ValueError: If no Dockerfile and/or custom parent image is
specified and the Docker configuration doesn't require an
image build.
"""
pipeline_name = deployment.pipeline.name
docker_settings = (
deployment.pipeline.docker_settings or DockerSettings()
)
logger.info(
"Building Docker image(s) for pipeline `%s`.", pipeline_name
)
requires_zenml_build = any(
[
docker_settings.requirements,
docker_settings.required_integrations,
docker_settings.replicate_local_python_environment,
docker_settings.install_stack_requirements,
docker_settings.apt_packages,
docker_settings.environment,
docker_settings.copy_files,
docker_settings.copy_global_config,
entrypoint,
]
)
# Fallback to the value defined on the stack component if the
# pipeline configuration doesn't have a configured value
parent_image = (
docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
)
if docker_settings.dockerfile:
if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
logger.warning(
"You've specified both a Dockerfile and a custom parent "
"image, ignoring the parent image."
)
if requires_zenml_build:
# We will build an additional image on top of this one later
# to include user files and/or install requirements. The image
# we build now will be used as the parent for the next build.
user_image_name = f"zenml-intermediate-build:{pipeline_name}"
parent_image = user_image_name
else:
# The image we'll build from the custom Dockerfile will be
# used directly, so we tag it with the requested target name.
user_image_name = target_image_name
docker_utils.build_image(
image_name=user_image_name,
dockerfile=docker_settings.dockerfile,
build_context_root=docker_settings.build_context_root,
**docker_settings.build_options,
)
elif not requires_zenml_build:
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
raise ValueError(
"Unable to run a ZenML pipeline with the given Docker "
"settings: No Dockerfile or custom parent image "
"specified and no files will be copied or requirements "
"installed."
)
else:
# The parent image will be used directly to run the pipeline and
# needs to be tagged so it gets pushed later
docker_utils.tag_image(parent_image, target=target_image_name)
if requires_zenml_build:
requirement_files = self._gather_requirements_files(
docker_settings=docker_settings, stack=stack
)
requirements_file_names = [f[0] for f in requirement_files]
apt_packages = docker_settings.apt_packages
if docker_settings.install_stack_requirements:
apt_packages += stack.apt_packages
if apt_packages:
logger.info(
"Including apt packages: %s",
", ".join(f"`{p}`" for p in apt_packages),
)
dockerfile = self._generate_zenml_pipeline_dockerfile(
parent_image=parent_image,
docker_settings=docker_settings,
requirements_files=requirements_file_names,
apt_packages=apt_packages,
entrypoint=entrypoint,
)
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
# The default parent image is static and doesn't require a pull
# each time
pull_parent_image = False
else:
# If the image is local, we don't need to pull it. Otherwise
# we play it safe and always pull in case the user pushed a new
# image for the given name and tag
pull_parent_image = not docker_utils.is_local_image(
parent_image
)
extra_files = requirement_files.copy()
extra_files.append(
(DOCKER_IMAGE_DEPLOYMENT_CONFIG_FILE, deployment.yaml())
)
# Leave the build context empty if we don't want to copy any files
requires_build_context = (
docker_settings.copy_files or docker_settings.copy_global_config
)
build_context_root = (
source_utils.get_source_root_path()
if requires_build_context
else None
)
maybe_include_global_config = (
_include_global_config(build_context_root=build_context_root) # type: ignore[arg-type]
if docker_settings.copy_global_config
else contextlib.nullcontext()
)
with maybe_include_global_config:
docker_utils.build_image(
image_name=target_image_name,
dockerfile=dockerfile,
build_context_root=build_context_root,
dockerignore=docker_settings.dockerignore,
extra_files=extra_files,
pull=pull_parent_image,
)
@staticmethod
def _gather_requirements_files(
docker_settings: DockerSettings, stack: "Stack"
) -> List[Tuple[str, str]]:
"""Gathers and/or generates pip requirements files.
Args:
docker_settings: Docker settings that specifies which
requirements to install.
stack: The stack on which the pipeline will run.
Raises:
RuntimeError: If the command to export the local python packages
failed.
Returns:
List of tuples (filename, file_content) of all requirements files.
The files will be in the following order:
- Packages installed in the local Python environment
- User-defined requirements
- Requirements defined by user-defined and/or stack integrations
"""
requirements_files = []
logger.info("Gathering requirements for Docker build:")
# Generate requirements file for the local environment if configured
if docker_settings.replicate_local_python_environment:
if isinstance(
docker_settings.replicate_local_python_environment,
PythonEnvironmentExportMethod,
):
command = (
docker_settings.replicate_local_python_environment.command
)
else:
command = " ".join(
docker_settings.replicate_local_python_environment
)
try:
local_requirements = subprocess.check_output(
command, shell=True
).decode()
except subprocess.CalledProcessError as e:
raise RuntimeError(
"Unable to export local python packages."
) from e
requirements_files.append(
(".zenml_local_requirements", local_requirements)
)
logger.info("\t- Including python packages from local environment")
# Generate/Read requirements file for user-defined requirements
if isinstance(docker_settings.requirements, str):
user_requirements = io_utils.read_file_contents_as_string(
docker_settings.requirements
)
logger.info(
"\t- Including user-defined requirements from file `%s`",
os.path.abspath(docker_settings.requirements),
)
elif isinstance(docker_settings.requirements, List):
user_requirements = "\n".join(docker_settings.requirements)
logger.info(
"\t- Including user-defined requirements: %s",
", ".join(f"`{r}`" for r in docker_settings.requirements),
)
else:
user_requirements = None
if user_requirements:
requirements_files.append(
(".zenml_user_requirements", user_requirements)
)
# Generate requirements file for all required integrations
integration_requirements = set(
itertools.chain.from_iterable(
integration_registry.select_integration_requirements(
integration
)
for integration in docker_settings.required_integrations
)
)
if docker_settings.install_stack_requirements:
integration_requirements.update(stack.requirements())
if integration_requirements:
integration_requirements_list = sorted(integration_requirements)
integration_requirements_file = "\n".join(
integration_requirements_list
)
requirements_files.append(
(
".zenml_integration_requirements",
integration_requirements_file,
)
)
logger.info(
"\t- Including integration requirements: %s",
", ".join(f"`{r}`" for r in integration_requirements_list),
)
return requirements_files
@staticmethod
def _generate_zenml_pipeline_dockerfile(
parent_image: str,
docker_settings: DockerSettings,
requirements_files: Sequence[str] = (),
apt_packages: Sequence[str] = (),
entrypoint: Optional[str] = None,
) -> List[str]:
"""Generates a Dockerfile.
Args:
parent_image: The image to use as parent for the Dockerfile.
docker_settings: Docker settings for this image build.
requirements_files: Paths of requirements files to install.
apt_packages: APT packages to install.
entrypoint: The default entrypoint command that gets executed when
running a container of an image created by this Dockerfile.
Returns:
Lines of the generated Dockerfile.
"""
lines = [f"FROM {parent_image}", f"WORKDIR {DOCKER_IMAGE_WORKDIR}"]
if docker_settings.copy_global_config:
lines.append(
f"ENV {ENV_ZENML_CONFIG_PATH}={DOCKER_IMAGE_ZENML_CONFIG_PATH}"
)
for key, value in docker_settings.environment.items():
lines.append(f"ENV {key.upper()}={value}")
if apt_packages:
apt_packages = " ".join(f"'{p}'" for p in apt_packages)
lines.append(
"RUN apt-get update && apt-get install -y "
f"--no-install-recommends {apt_packages}"
)
for file in requirements_files:
lines.append(f"COPY {file} .")
lines.append(f"RUN pip install --no-cache-dir -r {file}")
if docker_settings.copy_files:
lines.append("COPY . .")
elif docker_settings.copy_global_config:
lines.append(f"COPY {DOCKER_IMAGE_ZENML_CONFIG_DIR} .")
lines.append("RUN chmod -R a+rw .")
if docker_settings.user:
lines.append(f"USER {docker_settings.user}")
lines.append(f"RUN chown -R {docker_settings.user} .")
if entrypoint:
lines.append(f"ENTRYPOINT {entrypoint}")
return lines
build_and_push_docker_image(self, deployment, stack, entrypoint=None)
Builds and pushes a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely reference the pushed image in order to pull or run it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment for which the image should be built. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
entrypoint |
Optional[str] |
Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image. |
None |
Returns:
Type | Description |
---|---|
str |
The Docker repository digest of the pushed image. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the stack doesn't contain a container registry. |
Source code in zenml/utils/pipeline_docker_image_builder.py
def build_and_push_docker_image(
self,
deployment: "PipelineDeployment",
stack: "Stack",
entrypoint: Optional[str] = None,
) -> str:
"""Builds and pushes a Docker image to run a pipeline.
Use the image name returned by this method whenever you need to uniquely
reference the pushed image in order to pull or run it.
Args:
deployment: The pipeline deployment for which the image should be
built.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
Returns:
The Docker repository digest of the pushed image.
Raises:
RuntimeError: If the stack doesn't contain a container registry.
"""
container_registry = stack.container_registry
if not container_registry:
raise RuntimeError(
"Unable to build and push Docker image because stack "
f"`{stack.name}` has no container registry."
)
target_image_name = self.get_target_image_name(
deployment=deployment, container_registry=container_registry
)
self.build_docker_image(
target_image_name=target_image_name,
deployment=deployment,
stack=stack,
entrypoint=entrypoint,
)
repo_digest = container_registry.push_image(target_image_name)
return repo_digest
build_docker_image(self, target_image_name, deployment, stack, entrypoint=None)
Builds a Docker image to run a pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_image_name |
str |
The name of the image to build. |
required |
deployment |
PipelineDeployment |
The pipeline deployment for which the image should be built. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
entrypoint |
Optional[str] |
Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image. |
None |
Exceptions:
Type | Description |
---|---|
ValueError |
If no Dockerfile and/or custom parent image is specified and the Docker configuration doesn't require an image build. |
Source code in zenml/utils/pipeline_docker_image_builder.py
def build_docker_image(
self,
target_image_name: str,
deployment: "PipelineDeployment",
stack: "Stack",
entrypoint: Optional[str] = None,
) -> None:
"""Builds a Docker image to run a pipeline.
Args:
target_image_name: The name of the image to build.
deployment: The pipeline deployment for which the image should be
built.
stack: The stack on which the pipeline will be deployed.
entrypoint: Entrypoint to use for the final image. If left empty,
no entrypoint will be included in the image.
Raises:
ValueError: If no Dockerfile and/or custom parent image is
specified and the Docker configuration doesn't require an
image build.
"""
pipeline_name = deployment.pipeline.name
docker_settings = (
deployment.pipeline.docker_settings or DockerSettings()
)
logger.info(
"Building Docker image(s) for pipeline `%s`.", pipeline_name
)
requires_zenml_build = any(
[
docker_settings.requirements,
docker_settings.required_integrations,
docker_settings.replicate_local_python_environment,
docker_settings.install_stack_requirements,
docker_settings.apt_packages,
docker_settings.environment,
docker_settings.copy_files,
docker_settings.copy_global_config,
entrypoint,
]
)
# Fallback to the value defined on the stack component if the
# pipeline configuration doesn't have a configured value
parent_image = (
docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
)
if docker_settings.dockerfile:
if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
logger.warning(
"You've specified both a Dockerfile and a custom parent "
"image, ignoring the parent image."
)
if requires_zenml_build:
# We will build an additional image on top of this one later
# to include user files and/or install requirements. The image
# we build now will be used as the parent for the next build.
user_image_name = f"zenml-intermediate-build:{pipeline_name}"
parent_image = user_image_name
else:
# The image we'll build from the custom Dockerfile will be
# used directly, so we tag it with the requested target name.
user_image_name = target_image_name
docker_utils.build_image(
image_name=user_image_name,
dockerfile=docker_settings.dockerfile,
build_context_root=docker_settings.build_context_root,
**docker_settings.build_options,
)
elif not requires_zenml_build:
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
raise ValueError(
"Unable to run a ZenML pipeline with the given Docker "
"settings: No Dockerfile or custom parent image "
"specified and no files will be copied or requirements "
"installed."
)
else:
# The parent image will be used directly to run the pipeline and
# needs to be tagged so it gets pushed later
docker_utils.tag_image(parent_image, target=target_image_name)
if requires_zenml_build:
requirement_files = self._gather_requirements_files(
docker_settings=docker_settings, stack=stack
)
requirements_file_names = [f[0] for f in requirement_files]
apt_packages = docker_settings.apt_packages
if docker_settings.install_stack_requirements:
apt_packages += stack.apt_packages
if apt_packages:
logger.info(
"Including apt packages: %s",
", ".join(f"`{p}`" for p in apt_packages),
)
dockerfile = self._generate_zenml_pipeline_dockerfile(
parent_image=parent_image,
docker_settings=docker_settings,
requirements_files=requirements_file_names,
apt_packages=apt_packages,
entrypoint=entrypoint,
)
if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
# The default parent image is static and doesn't require a pull
# each time
pull_parent_image = False
else:
# If the image is local, we don't need to pull it. Otherwise
# we play it safe and always pull in case the user pushed a new
# image for the given name and tag
pull_parent_image = not docker_utils.is_local_image(
parent_image
)
extra_files = requirement_files.copy()
extra_files.append(
(DOCKER_IMAGE_DEPLOYMENT_CONFIG_FILE, deployment.yaml())
)
# Leave the build context empty if we don't want to copy any files
requires_build_context = (
docker_settings.copy_files or docker_settings.copy_global_config
)
build_context_root = (
source_utils.get_source_root_path()
if requires_build_context
else None
)
maybe_include_global_config = (
_include_global_config(build_context_root=build_context_root) # type: ignore[arg-type]
if docker_settings.copy_global_config
else contextlib.nullcontext()
)
with maybe_include_global_config:
docker_utils.build_image(
image_name=target_image_name,
dockerfile=dockerfile,
build_context_root=build_context_root,
dockerignore=docker_settings.dockerignore,
extra_files=extra_files,
pull=pull_parent_image,
)
get_target_image_name(deployment, container_registry=None)
staticmethod
Returns the target image name.
If a container registry is given, the image name will include the registry URI
Parameters:
Name | Type | Description | Default |
---|---|---|---|
deployment |
PipelineDeployment |
The pipeline deployment for which the target image name should be returned. |
required |
container_registry |
Optional[BaseContainerRegistry] |
Optional container registry to which this image will be pushed. |
None |
Returns:
Type | Description |
---|---|
str |
The docker image name. |
Source code in zenml/utils/pipeline_docker_image_builder.py
@staticmethod
def get_target_image_name(
deployment: "PipelineDeployment",
container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
"""Returns the target image name.
If a container registry is given, the image name will include the
registry URI
Args:
deployment: The pipeline deployment for which the target image name
should be returned.
container_registry: Optional container registry to which this
image will be pushed.
Returns:
The docker image name.
"""
pipeline_name = deployment.pipeline.name
docker_settings = (
deployment.pipeline.docker_settings or DockerSettings()
)
target_image_name = (
f"{docker_settings.target_repository}:{pipeline_name}"
)
if container_registry:
target_image_name = (
f"{container_registry.config.uri}/{target_image_name}"
)
return target_image_name
proto_utils
Utility functions for interacting with TFX contexts.
add_mlmd_contexts(pipeline_node, step, deployment, stack)
Adds context to each pipeline node of a pb2_pipeline.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
PipelineNode |
The pipeline node to which the contexts should be added. |
required |
step |
Step |
The corresponding step for the pipeline node. |
required |
deployment |
PipelineDeployment |
The pipeline deployment to store in the contexts. |
required |
stack |
Stack |
The stack the pipeline will run on. |
required |
Source code in zenml/utils/proto_utils.py
def add_mlmd_contexts(
pipeline_node: pipeline_pb2.PipelineNode,
step: "Step",
deployment: "PipelineDeployment",
stack: "Stack",
) -> None:
"""Adds context to each pipeline node of a pb2_pipeline.
Args:
pipeline_node: The pipeline node to which the contexts should be
added.
step: The corresponding step for the pipeline node.
deployment: The pipeline deployment to store in the contexts.
stack: The stack the pipeline will run on.
"""
from zenml.client import Client
client = Client()
model_ids = json.dumps(
{
"user_id": client.active_user.id,
"project_id": client.active_project.id,
"pipeline_id": deployment.pipeline_id,
"stack_id": deployment.stack_id,
},
sort_keys=True,
default=pydantic_encoder,
)
stack_json = json.dumps(stack.dict(), sort_keys=True)
pipeline_config = deployment.pipeline.json(sort_keys=True)
step_config = step.json(sort_keys=True)
context_properties = {
MLMD_CONTEXT_STACK_PROPERTY_NAME: stack_json,
MLMD_CONTEXT_PIPELINE_CONFIG_PROPERTY_NAME: pipeline_config,
MLMD_CONTEXT_STEP_CONFIG_PROPERTY_NAME: step_config,
MLMD_CONTEXT_MODEL_IDS_PROPERTY_NAME: model_ids,
MLMD_CONTEXT_NUM_STEPS_PROPERTY_NAME: str(len(deployment.steps)),
MLMD_CONTEXT_NUM_OUTPUTS_PROPERTY_NAME: str(len(step.config.outputs)),
}
properties_json = json.dumps(context_properties, sort_keys=True)
context_name = hashlib.md5(properties_json.encode()).hexdigest()
add_pipeline_node_context(
pipeline_node,
type_=ZENML_MLMD_CONTEXT_TYPE,
name=context_name,
properties=context_properties,
)
add_pipeline_node_context(pipeline_node, type_, name, properties)
Adds a new context to a TFX protobuf pipeline node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
PipelineNode |
A tfx protobuf pipeline node |
required |
type_ |
str |
The type name for the context to be added |
required |
name |
str |
Unique key for the context |
required |
properties |
Dict[str, str] |
dictionary of strings as properties of the context |
required |
Source code in zenml/utils/proto_utils.py
def add_pipeline_node_context(
pipeline_node: pipeline_pb2.PipelineNode,
type_: str,
name: str,
properties: Dict[str, str],
) -> None:
"""Adds a new context to a TFX protobuf pipeline node.
Args:
pipeline_node: A tfx protobuf pipeline node
type_: The type name for the context to be added
name: Unique key for the context
properties: dictionary of strings as properties of the context
"""
context: pipeline_pb2.ContextSpec = pipeline_node.contexts.contexts.add()
context.type.name = type_
context.name.field_value.string_value = name
for key, value in properties.items():
c_property = context.properties[key]
c_property.field_value.string_value = value
get_pipeline_config(pipeline_node)
Fetches the pipeline configuration from a PipelineNode context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
PipelineNode |
Pipeline node info for a step. |
required |
Returns:
Type | Description |
---|---|
PipelineConfiguration |
The pipeline config. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no pipeline config was found. |
Source code in zenml/utils/proto_utils.py
def get_pipeline_config(
pipeline_node: pipeline_pb2.PipelineNode,
) -> "PipelineConfiguration":
"""Fetches the pipeline configuration from a PipelineNode context.
Args:
pipeline_node: Pipeline node info for a step.
Returns:
The pipeline config.
Raises:
RuntimeError: If no pipeline config was found.
"""
for context in pipeline_node.contexts.contexts:
if context.type.name == ZENML_MLMD_CONTEXT_TYPE:
config_json = context.properties[
MLMD_CONTEXT_PIPELINE_CONFIG_PROPERTY_NAME
].field_value.string_value
return PipelineConfiguration.parse_raw(config_json)
raise RuntimeError("Unable to find pipeline config.")
get_step(pipeline_node)
Fetches the step from a PipelineNode context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_node |
PipelineNode |
Pipeline node info for a step. |
required |
Returns:
Type | Description |
---|---|
Step |
The step. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If no step was found. |
Source code in zenml/utils/proto_utils.py
def get_step(
pipeline_node: pipeline_pb2.PipelineNode,
) -> "Step":
"""Fetches the step from a PipelineNode context.
Args:
pipeline_node: Pipeline node info for a step.
Returns:
The step.
Raises:
RuntimeError: If no step was found.
"""
for context in pipeline_node.contexts.contexts:
if context.type.name == ZENML_MLMD_CONTEXT_TYPE:
config_json = context.properties[
MLMD_CONTEXT_STEP_CONFIG_PROPERTY_NAME
].field_value.string_value
return Step.parse_raw(config_json)
raise RuntimeError("Unable to find step.")
pydantic_utils
Utilities for pydantic models.
TemplateGenerator
Class to generate templates for pydantic models or classes.
Source code in zenml/utils/pydantic_utils.py
class TemplateGenerator:
"""Class to generate templates for pydantic models or classes."""
def __init__(
self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
"""Initializes the template generator.
Args:
instance_or_class: The pydantic model or model class for which to
generate a template.
"""
self.instance_or_class = instance_or_class
def run(self) -> Dict[str, Any]:
"""Generates the template.
Returns:
The template dictionary.
"""
if isinstance(self.instance_or_class, BaseModel):
template = self._generate_template_for_model(self.instance_or_class)
else:
template = self._generate_template_for_model_class(
self.instance_or_class
)
# Convert to json in an intermediate step so we can leverage Pydantic's
# encoder to support types like UUID and datetime
json_string = json.dumps(template, default=pydantic_encoder)
return cast(Dict[str, Any], json.loads(json_string))
def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]:
"""Generates a template for a pydantic model.
Args:
model: The model for which to generate the template.
Returns:
The model template.
"""
template = self._generate_template_for_model_class(model.__class__)
for name in model.__fields_set__:
value = getattr(model, name)
template[name] = self._generate_template_for_value(value)
return template
def _generate_template_for_model_class(
self,
model_class: Type[BaseModel],
) -> Dict[str, Any]:
"""Generates a template for a pydantic model class.
Args:
model_class: The model class for which to generate the template.
Returns:
The model class template.
"""
template: Dict[str, Any] = {}
for name, field in model_class.__fields__.items():
if self._is_model_class(field.outer_type_):
template[name] = self._generate_template_for_model_class(
field.outer_type_
)
elif field.outer_type_ is Optional and self._is_model_class(
field.type_
):
template[name] = self._generate_template_for_model_class(
field.type_
)
else:
template[name] = field._type_display()
return template
def _generate_template_for_value(self, value: Any) -> Any:
"""Generates a template for an arbitrary value.
Args:
value: The value for which to generate the template.
Returns:
The value template.
"""
if isinstance(value, Dict):
return {
k: self._generate_template_for_value(v)
for k, v in value.items()
}
elif sequence_like(value):
return [self._generate_template_for_value(v) for v in value]
elif isinstance(value, BaseModel):
return self._generate_template_for_model(value)
else:
return value
@staticmethod
def _is_model_class(value: Any) -> bool:
"""Checks if the given value is a pydantic model class.
Args:
value: The value to check.
Returns:
If the value is a pydantic model class.
"""
return isinstance(value, type) and issubclass(value, BaseModel)
__init__(self, instance_or_class)
special
Initializes the template generator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
instance_or_class |
Union[pydantic.main.BaseModel, Type[pydantic.main.BaseModel]] |
The pydantic model or model class for which to generate a template. |
required |
Source code in zenml/utils/pydantic_utils.py
def __init__(
self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
"""Initializes the template generator.
Args:
instance_or_class: The pydantic model or model class for which to
generate a template.
"""
self.instance_or_class = instance_or_class
run(self)
Generates the template.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
The template dictionary. |
Source code in zenml/utils/pydantic_utils.py
def run(self) -> Dict[str, Any]:
"""Generates the template.
Returns:
The template dictionary.
"""
if isinstance(self.instance_or_class, BaseModel):
template = self._generate_template_for_model(self.instance_or_class)
else:
template = self._generate_template_for_model_class(
self.instance_or_class
)
# Convert to json in an intermediate step so we can leverage Pydantic's
# encoder to support types like UUID and datetime
json_string = json.dumps(template, default=pydantic_encoder)
return cast(Dict[str, Any], json.loads(json_string))
update_model(original, update, recursive=True, exclude_none=True)
Updates a pydantic model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
original |
~M |
The model to update. |
required |
update |
Union[BaseModel, Dict[str, Any]] |
The update values. |
required |
recursive |
bool |
If |
True |
exclude_none |
bool |
If |
True |
Returns:
Type | Description |
---|---|
~M |
The updated model. |
Source code in zenml/utils/pydantic_utils.py
def update_model(
original: M,
update: Union["BaseModel", Dict[str, Any]],
recursive: bool = True,
exclude_none: bool = True,
) -> M:
"""Updates a pydantic model.
Args:
original: The model to update.
update: The update values.
recursive: If `True`, dictionary values will be updated recursively.
exclude_none: If `True`, `None` values in the update dictionary
will be removed.
Returns:
The updated model.
"""
if isinstance(update, Dict):
if exclude_none:
update_dict = dict_utils.remove_none_values(
update, recursive=recursive
)
else:
update_dict = update
else:
update_dict = update.dict(exclude_unset=True)
original_dict = original.dict(exclude_unset=True)
if recursive:
values = dict_utils.recursive_update(original_dict, update_dict)
else:
values = {**original_dict, **update_dict}
return original.__class__(**values)
secret_utils
Utility functions for secrets and secret references.
SecretReference (tuple)
Class representing a secret reference.
Attributes:
Name | Type | Description |
---|---|---|
name |
str |
The secret name. |
key |
str |
The secret key. |
Source code in zenml/utils/secret_utils.py
class SecretReference(NamedTuple):
"""Class representing a secret reference.
Attributes:
name: The secret name.
key: The secret key.
"""
name: str
key: str
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/utils/secret_utils.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, name, key)
special
staticmethod
Create new instance of SecretReference(name, key)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/utils/secret_utils.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
ClearTextField(*args, **kwargs)
Marks a pydantic field to prevent secret references.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments which will be forwarded
to |
() |
**kwargs |
Any |
Keyword arguments which will be forwarded to
|
{} |
Returns:
Type | Description |
---|---|
Any |
Pydantic field info. |
Source code in zenml/utils/secret_utils.py
def ClearTextField(*args: Any, **kwargs: Any) -> Any:
"""Marks a pydantic field to prevent secret references.
Args:
*args: Positional arguments which will be forwarded
to `pydantic.Field(...)`.
**kwargs: Keyword arguments which will be forwarded to
`pydantic.Field(...)`.
Returns:
Pydantic field info.
"""
kwargs[PYDANTIC_CLEAR_TEXT_FIELD_MARKER] = True
return Field(*args, **kwargs)
SecretField(*args, **kwargs)
Marks a pydantic field as something containing sensitive information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Positional arguments which will be forwarded
to |
() |
**kwargs |
Any |
Keyword arguments which will be forwarded to
|
{} |
Returns:
Type | Description |
---|---|
Any |
Pydantic field info. |
Source code in zenml/utils/secret_utils.py
def SecretField(*args: Any, **kwargs: Any) -> Any:
"""Marks a pydantic field as something containing sensitive information.
Args:
*args: Positional arguments which will be forwarded
to `pydantic.Field(...)`.
**kwargs: Keyword arguments which will be forwarded to
`pydantic.Field(...)`.
Returns:
Pydantic field info.
"""
kwargs[PYDANTIC_SENSITIVE_FIELD_MARKER] = True
return Field(*args, **kwargs)
is_clear_text_field(field)
Returns whether a pydantic field prevents secret references or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
field |
ModelField |
The field to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_clear_text_field(field: "ModelField") -> bool:
"""Returns whether a pydantic field prevents secret references or not.
Args:
field: The field to check.
Returns:
`True` if the field prevents secret references, `False` otherwise.
"""
return field.field_info.extra.get(PYDANTIC_CLEAR_TEXT_FIELD_MARKER, False)
is_secret_field(field)
Returns whether a pydantic field contains sensitive information or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
field |
ModelField |
The field to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_secret_field(field: "ModelField") -> bool:
"""Returns whether a pydantic field contains sensitive information or not.
Args:
field: The field to check.
Returns:
`True` if the field contains sensitive information, `False` otherwise.
"""
return field.field_info.extra.get(PYDANTIC_SENSITIVE_FIELD_MARKER, False)
is_secret_reference(value)
Checks whether any value is a secret reference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
The value to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/secret_utils.py
def is_secret_reference(value: Any) -> bool:
"""Checks whether any value is a secret reference.
Args:
value: The value to check.
Returns:
`True` if the value is a secret reference, `False` otherwise.
"""
if not isinstance(value, str):
return False
return bool(_secret_reference_expression.fullmatch(value))
parse_secret_reference(reference)
Parses a secret reference.
This function assumes the input string is a valid secret reference and does not perform any additional checks. If you pass an invalid secret reference here, this will most likely crash.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference |
str |
The string representing a valid secret reference. |
required |
Returns:
Type | Description |
---|---|
SecretReference |
The parsed secret reference. |
Source code in zenml/utils/secret_utils.py
def parse_secret_reference(reference: str) -> SecretReference:
"""Parses a secret reference.
This function assumes the input string is a valid secret reference and
**does not** perform any additional checks. If you pass an invalid secret
reference here, this will most likely crash.
Args:
reference: The string representing a **valid** secret reference.
Returns:
The parsed secret reference.
"""
reference = reference[2:]
reference = reference[:-2]
secret_name, secret_key = reference.split(".", 1)
return SecretReference(name=secret_name, key=secret_key)
settings_utils
Utility functions for ZenML settings.
get_flavor_setting_key(flavor)
Gets the setting key for a flavor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
flavor |
Flavor |
The flavor for which to get the key. |
required |
Returns:
Type | Description |
---|---|
str |
The setting key for the flavor. |
Source code in zenml/utils/settings_utils.py
def get_flavor_setting_key(flavor: "Flavor") -> str:
"""Gets the setting key for a flavor.
Args:
flavor: The flavor for which to get the key.
Returns:
The setting key for the flavor.
"""
return f"{flavor.type}.{flavor.name}"
get_general_settings()
Returns all general settings.
Returns:
Type | Description |
---|---|
Dict[str, Type[BaseSettings]] |
Dictionary mapping general settings keys to their type. |
Source code in zenml/utils/settings_utils.py
def get_general_settings() -> Dict[str, Type["BaseSettings"]]:
"""Returns all general settings.
Returns:
Dictionary mapping general settings keys to their type.
"""
from zenml.config import DockerSettings, ResourceSettings
return {
DOCKER_SETTINGS_KEY: DockerSettings,
RESOURCE_SETTINGS_KEY: ResourceSettings,
}
get_stack_component_for_settings_key(key, stack)
Gets the stack component of a stack for a given settings key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The settings key for which to get the component. |
required |
stack |
Stack |
The stack from which to get the component. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the key is invalid or the stack does not contain a component of the correct flavor. |
Returns:
Type | Description |
---|---|
StackComponent |
The stack component. |
Source code in zenml/utils/settings_utils.py
def get_stack_component_for_settings_key(
key: str, stack: "Stack"
) -> "StackComponent":
"""Gets the stack component of a stack for a given settings key.
Args:
key: The settings key for which to get the component.
stack: The stack from which to get the component.
Raises:
ValueError: If the key is invalid or the stack does not contain a
component of the correct flavor.
Returns:
The stack component.
"""
if not is_stack_component_setting_key(key):
raise ValueError(
f"Settings key {key} does not refer to a stack component."
)
component_type, flavor = key.split(".", 1)
stack_component = stack.components.get(StackComponentType(component_type))
if not stack_component or stack_component.flavor != flavor:
raise ValueError(
f"Component of type {component_type} in stack {stack} is not "
f"of the flavor {flavor} specified by the settings key {key}."
)
return stack_component
get_stack_component_setting_key(stack_component)
Gets the setting key for a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stack_component |
StackComponent |
The stack component for which to get the key. |
required |
Returns:
Type | Description |
---|---|
str |
The setting key for the stack component. |
Source code in zenml/utils/settings_utils.py
def get_stack_component_setting_key(stack_component: "StackComponent") -> str:
"""Gets the setting key for a stack component.
Args:
stack_component: The stack component for which to get the key.
Returns:
The setting key for the stack component.
"""
return f"{stack_component.type}.{stack_component.flavor}"
is_general_setting_key(key)
Checks whether the key refers to a general setting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key refers to a general setting. |
Source code in zenml/utils/settings_utils.py
def is_general_setting_key(key: str) -> bool:
"""Checks whether the key refers to a general setting.
Args:
key: The key to check.
Returns:
If the key refers to a general setting.
"""
return key in get_general_settings()
is_stack_component_setting_key(key)
Checks whether a settings key refers to a stack component.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key refers to a stack component. |
Source code in zenml/utils/settings_utils.py
def is_stack_component_setting_key(key: str) -> bool:
"""Checks whether a settings key refers to a stack component.
Args:
key: The key to check.
Returns:
If the key refers to a stack component.
"""
return bool(STACK_COMPONENT_REGEX.fullmatch(key))
is_valid_setting_key(key)
Checks whether a settings key is valid.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
str |
The key to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the key is valid. |
Source code in zenml/utils/settings_utils.py
def is_valid_setting_key(key: str) -> bool:
"""Checks whether a settings key is valid.
Args:
key: The key to check.
Returns:
If the key is valid.
"""
return is_general_setting_key(key) or is_stack_component_setting_key(key)
validate_setting_keys(setting_keys)
Validates settings keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
setting_keys |
Sequence[str] |
The keys to validate. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If any key is invalid. |
Source code in zenml/utils/settings_utils.py
def validate_setting_keys(setting_keys: Sequence[str]) -> None:
"""Validates settings keys.
Args:
setting_keys: The keys to validate.
Raises:
ValueError: If any key is invalid.
"""
for key in setting_keys:
if not is_valid_setting_key(key):
raise ValueError(
f"Invalid setting key `{key}`. Setting keys can either refer "
"to general settings (available keys: "
f"{set(get_general_settings())}) or stack component specific "
"settings. Stack component specific keys are of the format "
"`<STACK_COMPONENT_TYPE>.<STACK_COMPONENT_FLAVOR>`."
)
singleton
Utility class to turn classes into singleton classes.
SingletonMetaClass (type)
Singleton metaclass.
Use this metaclass to make any class into a singleton class:
class OneRing(metaclass=SingletonMetaClass):
def __init__(self, owner):
self._owner = owner
@property
def owner(self):
return self._owner
the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner) # Sauron
OneRing._clear() # ring destroyed
Source code in zenml/utils/singleton.py
class SingletonMetaClass(type):
"""Singleton metaclass.
Use this metaclass to make any class into a singleton class:
```python
class OneRing(metaclass=SingletonMetaClass):
def __init__(self, owner):
self._owner = owner
@property
def owner(self):
return self._owner
the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner) # Sauron
OneRing._clear() # ring destroyed
```
"""
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize a singleton class.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
cls.__singleton_instance: Optional["SingletonMetaClass"] = None
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
"""Create or return the singleton instance.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The singleton instance.
"""
if not cls.__singleton_instance:
cls.__singleton_instance = cast(
"SingletonMetaClass", super().__call__(*args, **kwargs)
)
return cls.__singleton_instance
def _clear(cls) -> None:
"""Clear the singleton instance."""
cls.__singleton_instance = None
__call__(cls, *args, **kwargs)
special
Create or return the singleton instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
SingletonMetaClass |
The singleton instance. |
Source code in zenml/utils/singleton.py
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
"""Create or return the singleton instance.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The singleton instance.
"""
if not cls.__singleton_instance:
cls.__singleton_instance = cast(
"SingletonMetaClass", super().__call__(*args, **kwargs)
)
return cls.__singleton_instance
__init__(cls, *args, **kwargs)
special
Initialize a singleton class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Additional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Source code in zenml/utils/singleton.py
def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize a singleton class.
Args:
*args: Additional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__(*args, **kwargs)
cls.__singleton_instance: Optional["SingletonMetaClass"] = None
source_utils
Utility functions for source code.
These utils are predicated on the following definitions:
- class_source: This is a python-import type path to a class, e.g. some.mod.class
- module_source: This is a python-import type path to a module, e.g. some.mod
- file_path, relative_path, absolute_path: These are file system paths.
- source: This is a class_source or module_source. If it is a class_source, it can also be optionally pinned.
- pin: Whatever comes after the
@
symbol from a source, usually the git sha or the version of zenml as a string.
create_zenml_pin()
Creates a ZenML pin for source pinning from release version.
Returns:
Type | Description |
---|---|
str |
ZenML pin. |
Source code in zenml/utils/source_utils.py
def create_zenml_pin() -> str:
"""Creates a ZenML pin for source pinning from release version.
Returns:
ZenML pin.
"""
return f"{constants.APP_NAME}_{__version__}"
get_hashed_source(value)
Returns a hash of the objects source code.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
object to get source from. |
required |
Returns:
Type | Description |
---|---|
str |
Hash of source code. |
Exceptions:
Type | Description |
---|---|
TypeError |
If unable to compute the hash. |
Source code in zenml/utils/source_utils.py
def get_hashed_source(value: Any) -> str:
"""Returns a hash of the objects source code.
Args:
value: object to get source from.
Returns:
Hash of source code.
Raises:
TypeError: If unable to compute the hash.
"""
try:
source_code = get_source(value)
except TypeError:
raise TypeError(
f"Unable to compute the hash of source code of object: {value}."
)
return hashlib.sha256(source_code.encode("utf-8")).hexdigest()
get_main_module_source()
Gets the source of the main module.
Returns:
Type | Description |
---|---|
str |
The main module source. |
Source code in zenml/utils/source_utils.py
def get_main_module_source() -> str:
"""Gets the source of the main module.
Returns:
The main module source.
"""
main_module = sys.modules["__main__"]
return get_module_source_from_module(main_module)
get_module_source_from_module(module)
Gets the source of the supplied module.
E.g.:
-
a
/home/myrepo/src/run.py
module running as the main module returnsrun
if no repository root is specified. -
a
/home/myrepo/src/run.py
module running as the main module returnssrc.run
if the repository root is configured in/home/myrepo
-
a
/home/myrepo/src/pipeline.py
module not running as the main module returnssrc.pipeline
if the repository root is configured in/home/myrepo
-
a
/home/myrepo/src/pipeline.py
module not running as the main module returnspipeline
if no repository root is specified and the main module is also in/home/myrepo/src
. -
a
/home/step.py
module not running as the main module returnsstep
if the CWD is /home and the repository root or the main module are in a different path (e.g./home/myrepo/src
).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
module |
the module to get the source of. |
required |
Returns:
Type | Description |
---|---|
str |
The source of the main module. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the module is not loaded from a file |
Source code in zenml/utils/source_utils.py
def get_module_source_from_module(module: ModuleType) -> str:
"""Gets the source of the supplied module.
E.g.:
* a `/home/myrepo/src/run.py` module running as the main module returns
`run` if no repository root is specified.
* a `/home/myrepo/src/run.py` module running as the main module returns
`src.run` if the repository root is configured in `/home/myrepo`
* a `/home/myrepo/src/pipeline.py` module not running as the main module
returns `src.pipeline` if the repository root is configured in
`/home/myrepo`
* a `/home/myrepo/src/pipeline.py` module not running as the main module
returns `pipeline` if no repository root is specified and the main
module is also in `/home/myrepo/src`.
* a `/home/step.py` module not running as the main module
returns `step` if the CWD is /home and the repository root or the main
module are in a different path (e.g. `/home/myrepo/src`).
Args:
module: the module to get the source of.
Returns:
The source of the main module.
Raises:
RuntimeError: if the module is not loaded from a file
"""
if not hasattr(module, "__file__") or not module.__file__:
if module.__name__ == "__main__":
raise RuntimeError(
f"{module} module was not loaded from a file. Cannot "
"determine the module root path."
)
return module.__name__
module_path = os.path.abspath(module.__file__)
root_path = get_source_root_path()
if not module_path.startswith(root_path):
logger.warning(
"User module %s is not in the source root %s. Using current "
"directory %s instead to resolve module source.",
module,
root_path,
os.getcwd(),
)
root_path = os.getcwd()
root_path = os.path.abspath(root_path)
# Remove root_path from module_path to get relative path left over
module_path = os.path.relpath(module_path, root_path)
if module_path.startswith(os.pardir):
raise RuntimeError(
f"Unable to resolve source for module {module}. The module file "
f"'{module_path}' does not seem to be inside the source root "
f"'{root_path}'."
)
# Remove the file extension and replace the os specific path separators
# with `.` to get the module source
module_path, file_extension = os.path.splitext(module_path)
if file_extension != ".py":
raise RuntimeError(
f"Unable to resolve source for module {module}. The module file "
f"'{module_path}' does not seem to be a python file."
)
module_source = module_path.replace(os.path.sep, ".")
logger.debug(
f"Resolved module source for module {module} to: `{module_source}`"
)
return module_source
get_source(value)
Returns the source code of an object.
If executing within a IPython kernel environment, then this monkey-patches
inspect
module temporarily with a workaround to get source from the cell.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
object to get source from. |
required |
Returns:
Type | Description |
---|---|
str |
Source code of object. |
Source code in zenml/utils/source_utils.py
def get_source(value: Any) -> str:
"""Returns the source code of an object.
If executing within a IPython kernel environment, then this monkey-patches
`inspect` module temporarily with a workaround to get source from the cell.
Args:
value: object to get source from.
Returns:
Source code of object.
"""
if Environment.in_notebook():
# Monkey patch inspect.getfile temporarily to make getsource work.
# Source: https://stackoverflow.com/questions/51566497/
def _new_getfile(
object: Any,
_old_getfile: Callable[
[
Union[
ModuleType,
Type[Any],
MethodType,
FunctionType,
TracebackType,
FrameType,
CodeType,
Callable[..., Any],
]
],
str,
] = inspect.getfile,
) -> Any:
if not inspect.isclass(object):
return _old_getfile(object)
# Lookup by parent module (as in current inspect)
if hasattr(object, "__module__"):
object_ = sys.modules.get(object.__module__)
if hasattr(object_, "__file__"):
return object_.__file__ # type: ignore[union-attr]
# If parent module is __main__, lookup by methods
for name, member in inspect.getmembers(object):
if (
inspect.isfunction(member)
and object.__qualname__ + "." + member.__name__
== member.__qualname__
):
return inspect.getfile(member)
else:
raise TypeError(f"Source for {object!r} not found.")
# Monkey patch, compute source, then revert monkey patch.
_old_getfile = inspect.getfile
inspect.getfile = _new_getfile
try:
src = inspect.getsource(value)
finally:
inspect.getfile = _old_getfile
else:
# Use standard inspect if running outside a notebook
src = inspect.getsource(value)
return src
get_source_root_path()
Gets repository root path or the source root path of the current process.
E.g.:
-
if the process was started by running a
run.py
file underfull/path/to/my/run.py
, and the repository root is configured atfull/path
, the source root path isfull/path
. -
same case as above, but when there is no repository root configured, the source root path is
full/path/to/my
.
Returns:
Type | Description |
---|---|
str |
The source root path of the current process. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the main module was not started or determined. |
Source code in zenml/utils/source_utils.py
def get_source_root_path() -> str:
"""Gets repository root path or the source root path of the current process.
E.g.:
* if the process was started by running a `run.py` file under
`full/path/to/my/run.py`, and the repository root is configured at
`full/path`, the source root path is `full/path`.
* same case as above, but when there is no repository root configured,
the source root path is `full/path/to/my`.
Returns:
The source root path of the current process.
Raises:
RuntimeError: if the main module was not started or determined.
"""
from zenml.client import Client
repo_root = Client.find_repository()
if repo_root:
logger.debug("Using repository root as source root: %s", repo_root)
return str(repo_root.resolve())
main_module = sys.modules.get("__main__")
if main_module is None:
raise RuntimeError(
"Could not determine the main module used to run the current "
"process."
)
if not hasattr(main_module, "__file__") or not main_module.__file__:
raise RuntimeError(
"Main module was not started from a file. Cannot "
"determine the module root path."
)
path = pathlib.Path(main_module.__file__).resolve().parent
logger.debug("Using main module location as source root: %s", path)
return str(path)
import_class_by_path(class_path)
Imports a class based on a given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_path |
str |
str, class_source e.g. this.module.Class |
required |
Returns:
Type | Description |
---|---|
Type[Any] |
the given class |
Source code in zenml/utils/source_utils.py
def import_class_by_path(class_path: str) -> Type[Any]:
"""Imports a class based on a given path.
Args:
class_path: str, class_source e.g. this.module.Class
Returns:
the given class
"""
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name) # type: ignore[no-any-return]
import_python_file(file_path, zen_root)
Imports a python file in relationship to the zen root.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to python file that should be imported. |
required |
zen_root |
str |
Path to current zenml root |
required |
Returns:
Type | Description |
---|---|
imported module |
Module |
Source code in zenml/utils/source_utils.py
def import_python_file(file_path: str, zen_root: str) -> types.ModuleType:
"""Imports a python file in relationship to the zen root.
Args:
file_path: Path to python file that should be imported.
zen_root: Path to current zenml root
Returns:
imported module: Module
"""
file_path = os.path.abspath(file_path)
module_path = os.path.relpath(file_path, zen_root)
module_name = os.path.splitext(module_path)[0].replace(os.path.sep, ".")
if module_name in sys.modules:
del sys.modules[module_name]
# Add directory of python file to PYTHONPATH so we can import it
with prepend_python_path([zen_root]):
module = importlib.import_module(module_name)
return module
else:
# Add directory of python file to PYTHONPATH so we can import it
with prepend_python_path([zen_root]):
module = importlib.import_module(module_name)
return module
is_inside_repository(file_path)
Returns whether a file is inside a zenml repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
A file path. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_inside_repository(file_path: str) -> bool:
"""Returns whether a file is inside a zenml repository.
Args:
file_path: A file path.
Returns:
`True` if the file is inside a zenml repository, else `False`.
"""
from zenml.client import Client
repo_path = Client.find_repository()
if not repo_path:
return False
repo_path = repo_path.resolve()
absolute_file_path = pathlib.Path(file_path).resolve()
return repo_path in absolute_file_path.parents
is_standard_pin(pin)
Returns True
if pin is valid ZenML pin, else False.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pin |
str |
potential ZenML pin like 'zenml_0.1.1' |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_standard_pin(pin: str) -> bool:
"""Returns `True` if pin is valid ZenML pin, else False.
Args:
pin: potential ZenML pin like 'zenml_0.1.1'
Returns:
`True` if pin is valid ZenML pin, else False.
"""
if pin.startswith(f"{constants.APP_NAME}_"):
return True
return False
is_standard_source(source)
Returns True
if source is a standard ZenML source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
class_source e.g. this.module.Class[@pin]. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_standard_source(source: str) -> bool:
"""Returns `True` if source is a standard ZenML source.
Args:
source: class_source e.g. this.module.Class[@pin].
Returns:
`True` if source is a standard ZenML source, else `False`.
"""
if source.split(".")[0] == "zenml":
return True
return False
is_third_party_module(file_path)
Returns whether a file belongs to a third party package.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
A file path. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/utils/source_utils.py
def is_third_party_module(file_path: str) -> bool:
"""Returns whether a file belongs to a third party package.
Args:
file_path: A file path.
Returns:
`True` if the file belongs to a third party package, else `False`.
"""
absolute_file_path = pathlib.Path(file_path).resolve()
for path in site.getsitepackages() + [
site.getusersitepackages(),
get_python_lib(standard_lib=True),
]:
if pathlib.Path(path).resolve() in absolute_file_path.parents:
return True
return (
pathlib.Path(get_source_root_path()) not in absolute_file_path.parents
)
load_and_validate_class(source, expected_class)
Loads a source class and validates its type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source string. |
required |
expected_class |
Type[Any] |
The class that the source should resolve to. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If the source does not resolve to the expected type. |
Returns:
Type | Description |
---|---|
Type[Any] |
The resolved source class. |
Source code in zenml/utils/source_utils.py
def load_and_validate_class(
source: str, expected_class: Type[Any]
) -> Type[Any]:
"""Loads a source class and validates its type.
Args:
source: The source string.
expected_class: The class that the source should resolve to.
Raises:
TypeError: If the source does not resolve to the expected type.
Returns:
The resolved source class.
"""
class_ = load_source_path_class(source)
if isinstance(class_, type) and issubclass(class_, expected_class):
return class_
else:
raise TypeError(
f"Error while loading `{source}`. Expected class "
f"{expected_class.__name__}, got {class_} instead."
)
load_source_path_class(source, import_path=None)
Loads a Python class from the source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
class_source e.g. this.module.Class[@sha] |
required |
import_path |
Optional[str] |
optional path to add to python path |
None |
Returns:
Type | Description |
---|---|
Type[Any] |
the given class |
Source code in zenml/utils/source_utils.py
def load_source_path_class(
source: str, import_path: Optional[str] = None
) -> Type[Any]:
"""Loads a Python class from the source.
Args:
source: class_source e.g. this.module.Class[@sha]
import_path: optional path to add to python path
Returns:
the given class
"""
from zenml.client import Client
repo_root = Client.find_repository()
if not import_path and repo_root:
import_path = str(repo_root)
if "@" in source:
source = source.split("@")[0]
if import_path is not None:
with prepend_python_path([import_path]):
logger.debug(
f"Loading class {source} with import path {import_path}"
)
return import_class_by_path(source)
return import_class_by_path(source)
prepend_python_path(paths)
Simple context manager to help import module within the repo.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
paths |
List[str] |
paths to prepend to sys.path |
required |
Yields:
Type | Description |
---|---|
Iterator[NoneType] |
None |
Source code in zenml/utils/source_utils.py
@contextmanager
def prepend_python_path(paths: List[str]) -> Iterator[None]:
"""Simple context manager to help import module within the repo.
Args:
paths: paths to prepend to sys.path
Yields:
None
"""
try:
# Entering the with statement
for path in paths:
sys.path.insert(0, path)
yield
finally:
# Exiting the with statement
for path in paths:
sys.path.remove(path)
resolve_class(class_, replace_main_module=True)
Resolves a class into a serializable source string.
For classes that are not built-in nor imported from a Python package, the
get_source_root_path
function is used to determine the root path
relative to which the class source is resolved.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
class_ |
Type[Any] |
A Python Class reference. |
required |
replace_main_module |
bool |
If |
True |
Returns:
Type | Description |
---|---|
str |
source_path e.g. this.module.Class. |
Source code in zenml/utils/source_utils.py
def resolve_class(class_: Type[Any], replace_main_module: bool = True) -> str:
"""Resolves a class into a serializable source string.
For classes that are not built-in nor imported from a Python package, the
`get_source_root_path` function is used to determine the root path
relative to which the class source is resolved.
Args:
class_: A Python Class reference.
replace_main_module: If `True`, classes in the main module will have
the __main__ module source replaced with the source relative to
the ZenML source root.
Returns:
source_path e.g. this.module.Class.
"""
initial_source = class_.__module__ + "." + class_.__name__
if is_standard_source(initial_source):
return resolve_standard_source(initial_source)
try:
file_path = inspect.getfile(class_)
except TypeError:
# builtin file
return initial_source
if initial_source.startswith("__main__"):
if not replace_main_module:
return initial_source
# Resolve the __main__ module to something relative to the ZenML source
# root
return f"{get_main_module_source()}.{class_.__name__}"
if is_third_party_module(file_path):
return initial_source
# Regular user file -> get the full module path relative to the
# source root.
module_source = get_module_source_from_module(
sys.modules[class_.__module__]
)
source = module_source + "." + class_.__name__
logger.debug(f"Resolved class {class_} to `{source}`.")
return source
resolve_standard_source(source)
Creates a ZenML pin for source pinning from release version.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
class_source e.g. this.module.Class. |
required |
Returns:
Type | Description |
---|---|
str |
ZenML pin. |
Exceptions:
Type | Description |
---|---|
AssertionError |
If source is already pinned. |
Source code in zenml/utils/source_utils.py
def resolve_standard_source(source: str) -> str:
"""Creates a ZenML pin for source pinning from release version.
Args:
source: class_source e.g. this.module.Class.
Returns:
ZenML pin.
Raises:
AssertionError: If source is already pinned.
"""
if "@" in source:
raise AssertionError(f"source {source} is already pinned.")
pin = create_zenml_pin()
return f"{source}@{pin}"
validate_config_source(source, component_type)
Validates a StackComponentConfig class from a given source.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
source path of the implementation |
required |
component_type |
StackComponentType |
the type of the stack component |
required |
Returns:
Type | Description |
---|---|
Type[StackComponentConfig] |
The validated config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If ZenML cannot import the config class. |
TypeError |
If the config class is not a subclass of the |
Source code in zenml/utils/source_utils.py
def validate_config_source(
source: str, component_type: StackComponentType
) -> Type["StackComponentConfig"]:
"""Validates a StackComponentConfig class from a given source.
Args:
source: source path of the implementation
component_type: the type of the stack component
Returns:
The validated config.
Raises:
ValueError: If ZenML cannot import the config class.
TypeError: If the config class is not a subclass of the `config_class`.
"""
from zenml.stack.stack_component import StackComponentConfig
try:
config_class = load_source_path_class(source)
except (ValueError, AttributeError, ImportError) as e:
raise ValueError(
f"ZenML can not import the config class '{source}': {e}"
)
if not issubclass(config_class, StackComponentConfig):
raise TypeError(
f"The source path '{source}' does not point to a subclass of "
f"the ZenML config_class."
)
return config_class # noqa
validate_flavor_source(source, component_type)
Import a StackComponent class from a given source and validate its type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
source path of the implementation |
required |
component_type |
StackComponentType |
the type of the stack component |
required |
Returns:
Type | Description |
---|---|
Type[Flavor] |
the imported class |
Exceptions:
Type | Description |
---|---|
ValueError |
If ZenML cannot find the given module path |
TypeError |
If the given module path does not point to a subclass of a StackComponent which has the right component type. |
Source code in zenml/utils/source_utils.py
def validate_flavor_source(
source: str, component_type: StackComponentType
) -> Type["Flavor"]:
"""Import a StackComponent class from a given source and validate its type.
Args:
source: source path of the implementation
component_type: the type of the stack component
Returns:
the imported class
Raises:
ValueError: If ZenML cannot find the given module path
TypeError: If the given module path does not point to a subclass of a
StackComponent which has the right component type.
"""
from zenml.stack.flavor import Flavor
from zenml.stack.stack_component import StackComponent, StackComponentConfig
try:
flavor_class = load_source_path_class(source)
except (ValueError, AttributeError, ImportError) as e:
raise ValueError(
f"ZenML can not import the flavor class '{source}': {e}"
)
if not issubclass(flavor_class, Flavor):
raise TypeError(
f"The source '{source}' does not point to a subclass of the ZenML"
f"Flavor."
)
flavor = flavor_class()
try:
impl_class = flavor.implementation_class
except (ModuleNotFoundError, ImportError, NotImplementedError):
raise ValueError(
f"The implementation class defined within the "
f"'{flavor_class.__name__}' can not be imported."
)
if not issubclass(impl_class, StackComponent):
raise TypeError(
f"The implementation class '{impl_class.__name__}' of a flavor "
f"needs to be a subclass of the ZenML StackComponent."
)
if flavor.type != component_type: # noqa
raise TypeError(
f"The source points to a {impl_class.type}, not a " # noqa
f"{component_type}."
)
try:
conf_class = flavor.config_class
except (ModuleNotFoundError, ImportError, NotImplementedError):
raise ValueError(
f"The config class defined within the "
f"'{flavor_class.__name__}' can not be imported."
)
if not issubclass(conf_class, StackComponentConfig):
raise TypeError(
f"The config class '{conf_class.__name__}' of a flavor "
f"needs to be a subclass of the ZenML StackComponentConfig."
)
return flavor_class # noqa
validate_source_class(source, expected_class)
Validates that a source resolves to a certain type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
source |
str |
The source to validate. |
required |
expected_class |
Type[Any] |
The class that the source should resolve to. |
required |
Returns:
Type | Description |
---|---|
bool |
If the source resolves to the expected class. |
Source code in zenml/utils/source_utils.py
def validate_source_class(source: str, expected_class: Type[Any]) -> bool:
"""Validates that a source resolves to a certain type.
Args:
source: The source to validate.
expected_class: The class that the source should resolve to.
Returns:
If the source resolves to the expected class.
"""
try:
value = load_source_path_class(source)
except Exception:
return False
is_class = isinstance(value, type)
if is_class and issubclass(value, expected_class):
return True
else:
return False
string_utils
Utils for strings.
b64_decode(input_)
Returns a decoded string of the base 64 encoded input string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ |
str |
Base64 encoded string. |
required |
Returns:
Type | Description |
---|---|
str |
Decoded string. |
Source code in zenml/utils/string_utils.py
def b64_decode(input_: str) -> str:
"""Returns a decoded string of the base 64 encoded input string.
Args:
input_: Base64 encoded string.
Returns:
Decoded string.
"""
encoded_bytes = input_.encode()
decoded_bytes = base64.b64decode(encoded_bytes)
return decoded_bytes.decode()
b64_encode(input_)
Returns a base 64 encoded string of the input string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_ |
str |
The input to encode. |
required |
Returns:
Type | Description |
---|---|
str |
Base64 encoded string. |
Source code in zenml/utils/string_utils.py
def b64_encode(input_: str) -> str:
"""Returns a base 64 encoded string of the input string.
Args:
input_: The input to encode.
Returns:
Base64 encoded string.
"""
input_bytes = input_.encode()
encoded_bytes = base64.b64encode(input_bytes)
return encoded_bytes.decode()
get_human_readable_filesize(bytes_)
Convert a file size in bytes into a human-readable string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bytes_ |
int |
The number of bytes to convert. |
required |
Returns:
Type | Description |
---|---|
str |
A human-readable string. |
Source code in zenml/utils/string_utils.py
def get_human_readable_filesize(bytes_: int) -> str:
"""Convert a file size in bytes into a human-readable string.
Args:
bytes_: The number of bytes to convert.
Returns:
A human-readable string.
"""
size = abs(float(bytes_))
for unit in ["B", "KiB", "MiB", "GiB"]:
if size < 1024.0 or unit == "GiB":
break
size /= 1024.0
return f"{size:.2f} {unit}"
get_human_readable_time(seconds)
Convert seconds into a human-readable string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seconds |
float |
The number of seconds to convert. |
required |
Returns:
Type | Description |
---|---|
str |
A human-readable string. |
Source code in zenml/utils/string_utils.py
def get_human_readable_time(seconds: float) -> str:
"""Convert seconds into a human-readable string.
Args:
seconds: The number of seconds to convert.
Returns:
A human-readable string.
"""
prefix = "-" if seconds < 0 else ""
seconds = abs(seconds)
int_seconds = int(seconds)
days, int_seconds = divmod(int_seconds, 86400)
hours, int_seconds = divmod(int_seconds, 3600)
minutes, int_seconds = divmod(int_seconds, 60)
if days > 0:
time_string = f"{days}d{hours}h{minutes}m{int_seconds}s"
elif hours > 0:
time_string = f"{hours}h{minutes}m{int_seconds}s"
elif minutes > 0:
time_string = f"{minutes}m{int_seconds}s"
else:
time_string = f"{seconds:.3f}s"
return prefix + time_string
random_str(length)
Generate a random human readable string of given length.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
length |
int |
Length of string |
required |
Returns:
Type | Description |
---|---|
str |
Random human-readable string. |
Source code in zenml/utils/string_utils.py
def random_str(length: int) -> str:
"""Generate a random human readable string of given length.
Args:
length: Length of string
Returns:
Random human-readable string.
"""
random.seed()
return "".join(random.choices(string.ascii_letters, k=length))
typed_model
Utility classes for adding type information to Pydantic models.
BaseTypedModel (BaseModel)
pydantic-model
Typed Pydantic model base class.
Use this class as a base class instead of BaseModel to automatically
add a type
literal attribute to the model that stores the name of the
class.
This can be useful when serializing models to JSON and then de-serializing them as part of a submodel union field, e.g.:
class BluePill(BaseTypedModel):
...
class RedPill(BaseTypedModel):
...
class TheMatrix(BaseTypedModel):
choice: Union[BluePill, RedPill] = Field(..., discriminator='type')
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)
It can also facilitate de-serializing objects when their type isn't known:
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
Source code in zenml/utils/typed_model.py
class BaseTypedModel(BaseModel, metaclass=BaseTypedModelMeta):
"""Typed Pydantic model base class.
Use this class as a base class instead of BaseModel to automatically
add a `type` literal attribute to the model that stores the name of the
class.
This can be useful when serializing models to JSON and then de-serializing
them as part of a submodel union field, e.g.:
```python
class BluePill(BaseTypedModel):
...
class RedPill(BaseTypedModel):
...
class TheMatrix(BaseTypedModel):
choice: Union[BluePill, RedPill] = Field(..., discriminator='type')
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)
```
It can also facilitate de-serializing objects when their type isn't known:
```python
matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
```
"""
@classmethod
def from_dict(
cls,
model_dict: Dict[str, Any],
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON-able dict representation.
Args:
model_dict: the model attributes serialized as JSON-able dict.
Returns:
A BaseTypedModel created from the serialized representation.
Raises:
RuntimeError: if the model_dict contains an invalid type.
"""
model_type = model_dict.get("type")
if not model_type:
raise RuntimeError(
"`type` information is missing from the serialized model dict."
)
cls = load_source_path_class(model_type)
if not issubclass(cls, BaseTypedModel):
raise RuntimeError(
f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
)
return cls.parse_obj(model_dict)
@classmethod
def from_json(
cls,
json_str: str,
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON representation.
Args:
json_str: the model attributes serialized as JSON.
Returns:
A BaseTypedModel created from the serialized representation.
"""
model_dict = json.loads(json_str)
return cls.from_dict(model_dict)
from_dict(model_dict)
classmethod
Instantiate a Pydantic model from a serialized JSON-able dict representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_dict |
Dict[str, Any] |
the model attributes serialized as JSON-able dict. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModel |
A BaseTypedModel created from the serialized representation. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the model_dict contains an invalid type. |
Source code in zenml/utils/typed_model.py
@classmethod
def from_dict(
cls,
model_dict: Dict[str, Any],
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON-able dict representation.
Args:
model_dict: the model attributes serialized as JSON-able dict.
Returns:
A BaseTypedModel created from the serialized representation.
Raises:
RuntimeError: if the model_dict contains an invalid type.
"""
model_type = model_dict.get("type")
if not model_type:
raise RuntimeError(
"`type` information is missing from the serialized model dict."
)
cls = load_source_path_class(model_type)
if not issubclass(cls, BaseTypedModel):
raise RuntimeError(
f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
)
return cls.parse_obj(model_dict)
from_json(json_str)
classmethod
Instantiate a Pydantic model from a serialized JSON representation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
json_str |
str |
the model attributes serialized as JSON. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModel |
A BaseTypedModel created from the serialized representation. |
Source code in zenml/utils/typed_model.py
@classmethod
def from_json(
cls,
json_str: str,
) -> "BaseTypedModel":
"""Instantiate a Pydantic model from a serialized JSON representation.
Args:
json_str: the model attributes serialized as JSON.
Returns:
A BaseTypedModel created from the serialized representation.
"""
model_dict = json.loads(json_str)
return cls.from_dict(model_dict)
BaseTypedModelMeta (ModelMetaclass)
Metaclass responsible for adding type information to Pydantic models.
Source code in zenml/utils/typed_model.py
class BaseTypedModelMeta(ModelMetaclass):
"""Metaclass responsible for adding type information to Pydantic models."""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
"""Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class
identifier.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The class dictionary.
Returns:
A Pydantic BaseModel class that includes a hidden attribute that
reflects the full class identifier.
Raises:
TypeError: If the class is not a Pydantic BaseModel class.
"""
if "type" in dct:
raise TypeError(
"`type` is a reserved attribute name for BaseTypedModel "
"subclasses"
)
type_name = f"{dct['__module__']}.{dct['__qualname__']}"
type_ann = Literal[type_name] # type: ignore [misc,valid-type]
type = Field(type_name)
dct.setdefault("__annotations__", dict())["type"] = type_ann
dct["type"] = type
cls = cast(
Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class identifier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the class. |
required |
bases |
Tuple[Type[Any], ...] |
The base classes of the class. |
required |
dct |
Dict[str, Any] |
The class dictionary. |
required |
Returns:
Type | Description |
---|---|
BaseTypedModelMeta |
A Pydantic BaseModel class that includes a hidden attribute that reflects the full class identifier. |
Exceptions:
Type | Description |
---|---|
TypeError |
If the class is not a Pydantic BaseModel class. |
Source code in zenml/utils/typed_model.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
"""Creates a Pydantic BaseModel class.
This includes a hidden attribute that reflects the full class
identifier.
Args:
name: The name of the class.
bases: The base classes of the class.
dct: The class dictionary.
Returns:
A Pydantic BaseModel class that includes a hidden attribute that
reflects the full class identifier.
Raises:
TypeError: If the class is not a Pydantic BaseModel class.
"""
if "type" in dct:
raise TypeError(
"`type` is a reserved attribute name for BaseTypedModel "
"subclasses"
)
type_name = f"{dct['__module__']}.{dct['__qualname__']}"
type_ann = Literal[type_name] # type: ignore [misc,valid-type]
type = Field(type_name)
dct.setdefault("__annotations__", dict())["type"] = type_ann
dct["type"] = type
cls = cast(
Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
)
return cls
uuid_utils
Utility functions for handling UUIDs.
generate_uuid_from_string(value)
Deterministically generates a UUID from a string seed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
str |
The string from which to generate the UUID. |
required |
Returns:
Type | Description |
---|---|
UUID |
The generated UUID. |
Source code in zenml/utils/uuid_utils.py
def generate_uuid_from_string(value: str) -> UUID:
"""Deterministically generates a UUID from a string seed.
Args:
value: The string from which to generate the UUID.
Returns:
The generated UUID.
"""
hash_ = hashlib.md5()
hash_.update(value.encode("utf-8"))
return UUID(hex=hash_.hexdigest(), version=4)
is_valid_uuid(value, version=4)
Checks if a string is a valid UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any |
String to check. |
required |
version |
int |
Version of UUID to check for. |
4 |
Returns:
Type | Description |
---|---|
bool |
True if string is a valid UUID, False otherwise. |
Source code in zenml/utils/uuid_utils.py
def is_valid_uuid(value: Any, version: int = 4) -> bool:
"""Checks if a string is a valid UUID.
Args:
value: String to check.
version: Version of UUID to check for.
Returns:
True if string is a valid UUID, False otherwise.
"""
if isinstance(value, UUID):
return True
if isinstance(value, str):
try:
UUID(value, version=version)
return True
except ValueError:
return False
return False
parse_name_or_uuid(name_or_id)
Convert a "name or id" string value to a string or UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name_or_id |
str |
Name or id to convert. |
required |
Returns:
Type | Description |
---|---|
Union[str, uuid.UUID] |
A UUID if name_or_id is a UUID, string otherwise. |
Source code in zenml/utils/uuid_utils.py
def parse_name_or_uuid(name_or_id: str) -> Union[str, UUID]:
"""Convert a "name or id" string value to a string or UUID.
Args:
name_or_id: Name or id to convert.
Returns:
A UUID if name_or_id is a UUID, string otherwise.
"""
if name_or_id:
try:
return UUID(name_or_id)
except ValueError:
return name_or_id
else:
return name_or_id
parse_optional_name_or_uuid(name_or_id)
Convert an optional "name or id" string value to an optional string or UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name_or_id |
Optional[str] |
Name or id to convert. |
required |
Returns:
Type | Description |
---|---|
Union[str, uuid.UUID] |
A UUID if name_or_id is a UUID, string otherwise. |
Source code in zenml/utils/uuid_utils.py
def parse_optional_name_or_uuid(
name_or_id: Optional[str],
) -> Optional[Union[str, UUID]]:
"""Convert an optional "name or id" string value to an optional string or UUID.
Args:
name_or_id: Name or id to convert.
Returns:
A UUID if name_or_id is a UUID, string otherwise.
"""
if name_or_id is None:
return None
return parse_name_or_uuid(name_or_id)
yaml_utils
Utility functions to help with YAML files and data.
UUIDEncoder (JSONEncoder)
JSON encoder for UUID objects.
Source code in zenml/utils/yaml_utils.py
class UUIDEncoder(json.JSONEncoder):
"""JSON encoder for UUID objects."""
def default(self, obj: Any) -> Any:
"""Default UUID encoder for JSON.
Args:
obj: Object to encode.
Returns:
Encoded object.
"""
if isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return obj.hex
return json.JSONEncoder.default(self, obj)
default(self, obj)
Default UUID encoder for JSON.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any |
Object to encode. |
required |
Returns:
Type | Description |
---|---|
Any |
Encoded object. |
Source code in zenml/utils/yaml_utils.py
def default(self, obj: Any) -> Any:
"""Default UUID encoder for JSON.
Args:
obj: Object to encode.
Returns:
Encoded object.
"""
if isinstance(obj, UUID):
# if the obj is uuid, we simply return the value of uuid
return obj.hex
return json.JSONEncoder.default(self, obj)
append_yaml(file_path, contents)
Append contents to a YAML file at file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
contents |
Dict[Any, Any] |
Contents of YAML file as dict. |
required |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def append_yaml(file_path: str, contents: Dict[Any, Any]) -> None:
"""Append contents to a YAML file at file_path.
Args:
file_path: Path to YAML file.
contents: Contents of YAML file as dict.
Raises:
FileNotFoundError: if directory does not exist.
"""
file_contents = read_yaml(file_path) or {}
file_contents.update(contents)
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(file_path, yaml.dump(file_contents))
comment_out_yaml(yaml_string)
Comments out a yaml string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yaml_string |
str |
The yaml string to comment out. |
required |
Returns:
Type | Description |
---|---|
str |
The commented out yaml string. |
Source code in zenml/utils/yaml_utils.py
def comment_out_yaml(yaml_string: str) -> str:
"""Comments out a yaml string.
Args:
yaml_string: The yaml string to comment out.
Returns:
The commented out yaml string.
"""
lines = yaml_string.splitlines(keepends=True)
lines = ["# " + line for line in lines]
return "".join(lines)
is_yaml(file_path)
Returns True if file_path is YAML, else False.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
Returns:
Type | Description |
---|---|
bool |
True if is yaml, else False. |
Source code in zenml/utils/yaml_utils.py
def is_yaml(file_path: str) -> bool:
"""Returns True if file_path is YAML, else False.
Args:
file_path: Path to YAML file.
Returns:
True if is yaml, else False.
"""
if file_path.endswith("yaml") or file_path.endswith("yml"):
return True
return False
read_json(file_path)
Read JSON on file path and returns contents as dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to JSON file. |
required |
Returns:
Type | Description |
---|---|
Any |
Contents of the file in a dict. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if file does not exist. |
Source code in zenml/utils/yaml_utils.py
def read_json(file_path: str) -> Any:
"""Read JSON on file path and returns contents as dict.
Args:
file_path: Path to JSON file.
Returns:
Contents of the file in a dict.
Raises:
FileNotFoundError: if file does not exist.
"""
if fileio.exists(file_path):
contents = io_utils.read_file_contents_as_string(file_path)
return json.loads(contents)
else:
raise FileNotFoundError(f"{file_path} does not exist.")
read_yaml(file_path)
Read YAML on file path and returns contents as dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
Returns:
Type | Description |
---|---|
Any |
Contents of the file in a dict. |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if file does not exist. |
Source code in zenml/utils/yaml_utils.py
def read_yaml(file_path: str) -> Any:
"""Read YAML on file path and returns contents as dict.
Args:
file_path: Path to YAML file.
Returns:
Contents of the file in a dict.
Raises:
FileNotFoundError: if file does not exist.
"""
if fileio.exists(file_path):
contents = io_utils.read_file_contents_as_string(file_path)
# TODO: [LOW] consider adding a default empty dict to be returned
# instead of None
return yaml.safe_load(contents)
else:
raise FileNotFoundError(f"{file_path} does not exist.")
write_json(file_path, contents, encoder=None)
Write contents as JSON format to file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to JSON file. |
required |
contents |
Dict[str, Any] |
Contents of JSON file as dict. |
required |
encoder |
Optional[Type[json.encoder.JSONEncoder]] |
Custom JSON encoder to use when saving json. |
None |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def write_json(
file_path: str,
contents: Dict[str, Any],
encoder: Optional[Type[json.JSONEncoder]] = None,
) -> None:
"""Write contents as JSON format to file_path.
Args:
file_path: Path to JSON file.
contents: Contents of JSON file as dict.
encoder: Custom JSON encoder to use when saving json.
Raises:
FileNotFoundError: if directory does not exist.
"""
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
# Check if it is a local path, if it doesn't exist, raise Exception.
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(
file_path,
json.dumps(
contents,
cls=encoder,
),
)
write_yaml(file_path, contents, sort_keys=True)
Write contents as YAML format to file_path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path |
str |
Path to YAML file. |
required |
contents |
Union[Dict[Any, Any], List[Any]] |
Contents of YAML file as dict or list. |
required |
sort_keys |
bool |
If |
True |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
if directory does not exist. |
Source code in zenml/utils/yaml_utils.py
def write_yaml(
file_path: str,
contents: Union[Dict[Any, Any], List[Any]],
sort_keys: bool = True,
) -> None:
"""Write contents as YAML format to file_path.
Args:
file_path: Path to YAML file.
contents: Contents of YAML file as dict or list.
sort_keys: If `True`, keys are sorted alphabetically. If `False`,
the order in which the keys were inserted into the dict will
be preserved.
Raises:
FileNotFoundError: if directory does not exist.
"""
if not io_utils.is_remote(file_path):
dir_ = str(Path(file_path).parent)
if not fileio.isdir(dir_):
raise FileNotFoundError(f"Directory {dir_} does not exist.")
io_utils.write_file_contents_as_string(
file_path, yaml.dump(contents, sort_keys=sort_keys)
)