Skip to content

Utils

zenml.utils special

Initialization of the utils module.

The utils module contains utility functions handling analytics, reading and writing YAML data as well as other general purpose functions.

analytics_utils

Analytics code for ZenML.

AnalyticsContext

Context manager for analytics.

Source code in zenml/utils/analytics_utils.py
class AnalyticsContext:
    """Context manager for analytics."""

    def __init__(self) -> None:
        """Context manager for analytics.

        Use this as a context manager to ensure that analytics are initialized
        properly, only tracked when configured to do so and that any errors
        are handled gracefully.
        """
        import analytics

        from zenml.config.global_config import GlobalConfiguration

        try:
            gc = GlobalConfiguration()

            self.analytics_opt_in = gc.analytics_opt_in
            self.user_id = str(gc.user_id)

            # That means user opted out of analytics
            if not gc.analytics_opt_in:
                return

            if analytics.write_key is None:
                analytics.write_key = get_segment_key()

            assert (
                analytics.write_key is not None
            ), "Analytics key not set but trying to make telemetry call."

            # Set this to 1 to avoid backoff loop
            analytics.max_retries = 1
        except Exception as e:
            self.analytics_opt_in = False
            logger.debug(f"Analytics initialization failed: {e}")

    def __enter__(self) -> "AnalyticsContext":
        """Enter context manager.

        Returns:
            Self.
        """
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> bool:
        """Exit context manager.

        Args:
            exc_type: Exception type.
            exc_val: Exception value.
            exc_tb: Exception traceback.

        Returns:
            True if exception was handled, False otherwise.
        """
        if exc_val is not None:
            logger.debug("Sending telemetry data failed: {exc_val}")

        # We should never fail main thread
        return True

    def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
        """Identify the user.

        Args:
            traits: Traits of the user.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        logger.debug(
            f"Attempting to attach metadata to: User: {self.user_id}, "
            f"Metadata: {traits}"
        )

        if not self.analytics_opt_in:
            return False

        analytics.identify(self.user_id, traits)

        logger.debug(f"User data sent: User: {self.user_id},{traits}")

        return True

    def group(
        self,
        group: Union[str, AnalyticsGroup],
        group_id: str,
        traits: Optional[Dict[str, Any]] = None,
    ) -> bool:
        """Group the user.

        Args:
            group: Group to which the user belongs.
            group_id: Group ID.
            traits: Traits of the group.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        if isinstance(group, AnalyticsGroup):
            group = group.value

        if traits is None:
            traits = {}

        traits.update(
            {
                "group_id": group_id,
            }
        )

        logger.debug(
            f"Attempting to attach metadata to: User: {self.user_id}, "
            f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
        )

        if not self.analytics_opt_in:
            return False
        analytics.group(self.user_id, group_id, traits=traits)

        logger.debug(
            f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
            f"{group_id}, Metadata: {traits}"
        )
        return True

    def track(
        self,
        event: Union[str, AnalyticsEvent],
        properties: Optional[Dict[str, Any]] = None,
    ) -> bool:
        """Track an event.

        Args:
            event: Event to track.
            properties: Event properties.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        from zenml.config.global_config import GlobalConfiguration

        if isinstance(event, AnalyticsEvent):
            event = event.value

        if properties is None:
            properties = {}

        logger.debug(
            f"Attempting analytics: User: {self.user_id}, "
            f"Event: {event},"
            f"Metadata: {properties}"
        )

        if not self.analytics_opt_in and event not in {
            AnalyticsEvent.OPT_OUT_ANALYTICS,
            AnalyticsEvent.OPT_IN_ANALYTICS,
        }:
            return False

        # add basics
        properties.update(Environment.get_system_info())
        properties.update(
            {
                "environment": get_environment(),
                "python_version": Environment.python_version(),
                "version": __version__,
            }
        )

        gc = GlobalConfiguration()
        # avoid initializing the store in the analytics, to not create an
        # infinite loop
        if gc._zen_store is not None:
            zen_store = gc.zen_store
            if (
                zen_store.type == StoreType.REST
                and "server_id" not in properties
            ):
                user = zen_store.active_user
                server_info = zen_store.get_store_info()
                properties.update(
                    {
                        "user_id": str(user.id),
                        "server_id": str(server_info.id),
                        "server_deployment": str(server_info.deployment_type),
                        "database_type": str(server_info.database_type),
                    }
                )

        analytics.track(self.user_id, event, properties)

        logger.debug(
            f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
            f"{properties}"
        )

        return True
__enter__(self) special

Enter context manager.

Returns:

Type Description
AnalyticsContext

Self.

Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "AnalyticsContext":
    """Enter context manager.

    Returns:
        Self.
    """
    return self
__exit__(self, exc_type, exc_val, exc_tb) special

Exit context manager.

Parameters:

Name Type Description Default
exc_type Optional[Type[BaseException]]

Exception type.

required
exc_val Optional[BaseException]

Exception value.

required
exc_tb Optional[traceback]

Exception traceback.

required

Returns:

Type Description
bool

True if exception was handled, False otherwise.

Source code in zenml/utils/analytics_utils.py
def __exit__(
    self,
    exc_type: Optional[Type[BaseException]],
    exc_val: Optional[BaseException],
    exc_tb: Optional[TracebackType],
) -> bool:
    """Exit context manager.

    Args:
        exc_type: Exception type.
        exc_val: Exception value.
        exc_tb: Exception traceback.

    Returns:
        True if exception was handled, False otherwise.
    """
    if exc_val is not None:
        logger.debug("Sending telemetry data failed: {exc_val}")

    # We should never fail main thread
    return True
__init__(self) special

Context manager for analytics.

Use this as a context manager to ensure that analytics are initialized properly, only tracked when configured to do so and that any errors are handled gracefully.

Source code in zenml/utils/analytics_utils.py
def __init__(self) -> None:
    """Context manager for analytics.

    Use this as a context manager to ensure that analytics are initialized
    properly, only tracked when configured to do so and that any errors
    are handled gracefully.
    """
    import analytics

    from zenml.config.global_config import GlobalConfiguration

    try:
        gc = GlobalConfiguration()

        self.analytics_opt_in = gc.analytics_opt_in
        self.user_id = str(gc.user_id)

        # That means user opted out of analytics
        if not gc.analytics_opt_in:
            return

        if analytics.write_key is None:
            analytics.write_key = get_segment_key()

        assert (
            analytics.write_key is not None
        ), "Analytics key not set but trying to make telemetry call."

        # Set this to 1 to avoid backoff loop
        analytics.max_retries = 1
    except Exception as e:
        self.analytics_opt_in = False
        logger.debug(f"Analytics initialization failed: {e}")
group(self, group, group_id, traits=None)

Group the user.

Parameters:

Name Type Description Default
group Union[str, zenml.utils.analytics_utils.AnalyticsGroup]

Group to which the user belongs.

required
group_id str

Group ID.

required
traits Optional[Dict[str, Any]]

Traits of the group.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def group(
    self,
    group: Union[str, AnalyticsGroup],
    group_id: str,
    traits: Optional[Dict[str, Any]] = None,
) -> bool:
    """Group the user.

    Args:
        group: Group to which the user belongs.
        group_id: Group ID.
        traits: Traits of the group.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    if isinstance(group, AnalyticsGroup):
        group = group.value

    if traits is None:
        traits = {}

    traits.update(
        {
            "group_id": group_id,
        }
    )

    logger.debug(
        f"Attempting to attach metadata to: User: {self.user_id}, "
        f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
    )

    if not self.analytics_opt_in:
        return False
    analytics.group(self.user_id, group_id, traits=traits)

    logger.debug(
        f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
        f"{group_id}, Metadata: {traits}"
    )
    return True
identify(self, traits=None)

Identify the user.

Parameters:

Name Type Description Default
traits Optional[Dict[str, Any]]

Traits of the user.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
    """Identify the user.

    Args:
        traits: Traits of the user.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    logger.debug(
        f"Attempting to attach metadata to: User: {self.user_id}, "
        f"Metadata: {traits}"
    )

    if not self.analytics_opt_in:
        return False

    analytics.identify(self.user_id, traits)

    logger.debug(f"User data sent: User: {self.user_id},{traits}")

    return True
track(self, event, properties=None)

Track an event.

Parameters:

Name Type Description Default
event Union[str, zenml.utils.analytics_utils.AnalyticsEvent]

Event to track.

required
properties Optional[Dict[str, Any]]

Event properties.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def track(
    self,
    event: Union[str, AnalyticsEvent],
    properties: Optional[Dict[str, Any]] = None,
) -> bool:
    """Track an event.

    Args:
        event: Event to track.
        properties: Event properties.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    from zenml.config.global_config import GlobalConfiguration

    if isinstance(event, AnalyticsEvent):
        event = event.value

    if properties is None:
        properties = {}

    logger.debug(
        f"Attempting analytics: User: {self.user_id}, "
        f"Event: {event},"
        f"Metadata: {properties}"
    )

    if not self.analytics_opt_in and event not in {
        AnalyticsEvent.OPT_OUT_ANALYTICS,
        AnalyticsEvent.OPT_IN_ANALYTICS,
    }:
        return False

    # add basics
    properties.update(Environment.get_system_info())
    properties.update(
        {
            "environment": get_environment(),
            "python_version": Environment.python_version(),
            "version": __version__,
        }
    )

    gc = GlobalConfiguration()
    # avoid initializing the store in the analytics, to not create an
    # infinite loop
    if gc._zen_store is not None:
        zen_store = gc.zen_store
        if (
            zen_store.type == StoreType.REST
            and "server_id" not in properties
        ):
            user = zen_store.active_user
            server_info = zen_store.get_store_info()
            properties.update(
                {
                    "user_id": str(user.id),
                    "server_id": str(server_info.id),
                    "server_deployment": str(server_info.deployment_type),
                    "database_type": str(server_info.database_type),
                }
            )

    analytics.track(self.user_id, event, properties)

    logger.debug(
        f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
        f"{properties}"
    )

    return True

AnalyticsEvent (str, Enum)

Enum of events to track in segment.

Source code in zenml/utils/analytics_utils.py
class AnalyticsEvent(str, Enum):
    """Enum of events to track in segment."""

    # Pipelines
    RUN_PIPELINE = "Pipeline run"
    GET_PIPELINES = "Pipelines fetched"
    GET_PIPELINE = "Pipeline fetched"
    CREATE_PIPELINE = "Pipeline created"
    UPDATE_PIPELINE = "Pipeline updated"
    DELETE_PIPELINE = "Pipeline deleted"

    # Repo
    INITIALIZE_REPO = "ZenML initialized"
    CONNECT_REPOSITORY = "Repository connected"
    UPDATE_REPOSITORY = "Repository updated"
    DELETE_REPOSITORY = "Repository deleted"

    # Zen store
    INITIALIZED_STORE = "Store initialized"

    # Components
    REGISTERED_STACK_COMPONENT = "Stack component registered"
    UPDATED_STACK_COMPONENT = "Stack component updated"
    COPIED_STACK_COMPONENT = "Stack component copied"
    DELETED_STACK_COMPONENT = "Stack component copied"

    # Stack
    REGISTERED_STACK = "Stack registered"
    REGISTERED_DEFAULT_STACK = "Default stack registered"
    SET_STACK = "Stack set"
    UPDATED_STACK = "Stack updated"
    COPIED_STACK = "Stack copied"
    IMPORT_STACK = "Stack imported"
    EXPORT_STACK = "Stack exported"
    DELETED_STACK = "Stack deleted"

    # Model Deployment
    MODEL_DEPLOYED = "Model deployed"

    # Analytics opt in and out
    OPT_IN_ANALYTICS = "Analytics opt-in"
    OPT_OUT_ANALYTICS = "Analytics opt-out"
    OPT_IN_OUT_EMAIL = "Response for Email prompt"

    # Examples
    RUN_ZENML_GO = "ZenML go"
    RUN_EXAMPLE = "Example run"
    PULL_EXAMPLE = "Example pull"

    # Integrations
    INSTALL_INTEGRATION = "Integration installed"

    # Users
    CREATED_USER = "User created"
    CREATED_DEFAULT_USER = "Default user created"
    UPDATED_USER = "User updated"
    DELETED_USER = "User deleted"

    # Teams
    CREATED_TEAM = "Team created"
    UPDATED_TEAM = "Team updated"
    DELETED_TEAM = "Team deleted"

    # Projects
    CREATED_PROJECT = "Project created"
    CREATED_DEFAULT_PROJECT = "Default project created"
    UPDATED_PROJECT = "Project updated"
    DELETED_PROJECT = "Project deleted"
    SET_PROJECT = "Project set"

    # Role
    CREATED_ROLE = "Role created"
    CREATED_DEFAULT_ROLES = "Default roles created"
    UPDATED_ROLE = "Role updated"
    DELETED_ROLE = "Role deleted"

    # Flavor
    CREATED_FLAVOR = "Flavor created"
    UPDATED_FLAVOR = "Flavor updated"
    DELETED_FLAVOR = "Flavor deleted"

    # Test event
    EVENT_TEST = "Test event"

    # Stack recipes
    PULL_STACK_RECIPE = "Stack recipes pulled"
    RUN_STACK_RECIPE = "Stack recipe created"
    DESTROY_STACK_RECIPE = "Stack recipe destroyed"

    # ZenML server events
    ZENML_SERVER_STARTED = "ZenML server started"
    ZENML_SERVER_STOPPED = "ZenML server stopped"
    ZENML_SERVER_CONNECTED = "ZenML server connected"
    ZENML_SERVER_DEPLOYED = "ZenML server deployed"
    ZENML_SERVER_DESTROYED = "ZenML server destroyed"

AnalyticsGroup (str, Enum)

Enum of event groups to track in segment.

Source code in zenml/utils/analytics_utils.py
class AnalyticsGroup(str, Enum):
    """Enum of event groups to track in segment."""

    ZENML_SERVER_GROUP = "ZenML server group"

AnalyticsTrackedModelMixin (BaseModel) pydantic-model

Mixin for models that are tracked through analytics events.

Classes that have information tracked in analytics events can inherit from this mixin and implement the abstract methods. The @track decorator will detect function arguments and return values that inherit from this class and will include the ANALYTICS_FIELDS attributes as tracking metadata.

Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackedModelMixin(BaseModel):
    """Mixin for models that are tracked through analytics events.

    Classes that have information tracked in analytics events can inherit
    from this mixin and implement the abstract methods. The `@track` decorator
    will detect function arguments and return values that inherit from this
    class and will include the `ANALYTICS_FIELDS` attributes as
    tracking metadata.
    """

    ANALYTICS_FIELDS: ClassVar[List[str]]

    def get_analytics_metadata(self) -> Dict[str, Any]:
        """Get the analytics metadata for the model.

        Returns:
            Dict of analytics metadata.
        """
        metadata = {}
        for field_name in self.ANALYTICS_FIELDS:
            metadata[field_name] = getattr(self, field_name, None)
        return metadata

    def track_event(
        self,
        event: Union[str, AnalyticsEvent],
        tracker: Optional[AnalyticsTrackerMixin] = None,
    ) -> None:
        """Track an event for the model.

        Args:
            event: Event to track.
            tracker: Optional tracker to use for analytics.
        """
        metadata = self.get_analytics_metadata()
        if tracker:
            tracker.track_event(event, metadata)
        else:
            track_event(event, metadata)
get_analytics_metadata(self)

Get the analytics metadata for the model.

Returns:

Type Description
Dict[str, Any]

Dict of analytics metadata.

Source code in zenml/utils/analytics_utils.py
def get_analytics_metadata(self) -> Dict[str, Any]:
    """Get the analytics metadata for the model.

    Returns:
        Dict of analytics metadata.
    """
    metadata = {}
    for field_name in self.ANALYTICS_FIELDS:
        metadata[field_name] = getattr(self, field_name, None)
    return metadata
track_event(self, event, tracker=None)

Track an event for the model.

Parameters:

Name Type Description Default
event Union[str, zenml.utils.analytics_utils.AnalyticsEvent]

Event to track.

required
tracker Optional[zenml.utils.analytics_utils.AnalyticsTrackerMixin]

Optional tracker to use for analytics.

None
Source code in zenml/utils/analytics_utils.py
def track_event(
    self,
    event: Union[str, AnalyticsEvent],
    tracker: Optional[AnalyticsTrackerMixin] = None,
) -> None:
    """Track an event for the model.

    Args:
        event: Event to track.
        tracker: Optional tracker to use for analytics.
    """
    metadata = self.get_analytics_metadata()
    if tracker:
        tracker.track_event(event, metadata)
    else:
        track_event(event, metadata)

AnalyticsTrackerMixin (ABC)

Abstract base class for analytics trackers.

Use this as a mixin for classes that have methods decorated with @track to add global control over how analytics are tracked. The decorator will detect that the class has this mixin and will call the class track_event method.

Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackerMixin(ABC):
    """Abstract base class for analytics trackers.

    Use this as a mixin for classes that have methods decorated with
    `@track` to add global control over how analytics are tracked. The decorator
    will detect that the class has this mixin and will call the class
    `track_event` method.
    """

    @abstractmethod
    def track_event(
        self,
        event: Union[str, AnalyticsEvent],
        metadata: Optional[Dict[str, Any]],
    ) -> None:
        """Track an event.

        Args:
            event: Event to track.
            metadata: Metadata to track.
        """
track_event(self, event, metadata)

Track an event.

Parameters:

Name Type Description Default
event Union[str, zenml.utils.analytics_utils.AnalyticsEvent]

Event to track.

required
metadata Optional[Dict[str, Any]]

Metadata to track.

required
Source code in zenml/utils/analytics_utils.py
@abstractmethod
def track_event(
    self,
    event: Union[str, AnalyticsEvent],
    metadata: Optional[Dict[str, Any]],
) -> None:
    """Track an event.

    Args:
        event: Event to track.
        metadata: Metadata to track.
    """

get_segment_key()

Get key for authorizing to Segment backend.

Returns:

Type Description
str

Segment key as a string.

Source code in zenml/utils/analytics_utils.py
def get_segment_key() -> str:
    """Get key for authorizing to Segment backend.

    Returns:
        Segment key as a string.
    """
    if IS_DEBUG_ENV:
        return SEGMENT_KEY_DEV
    else:
        return SEGMENT_KEY_PROD

identify_group(group, group_id, group_metadata=None)

Attach metadata to a segment group.

Parameters:

Name Type Description Default
group Union[str, zenml.utils.analytics_utils.AnalyticsGroup]

Group to track.

required
group_id str

ID of the group.

required
group_metadata Optional[Dict[str, Any]]

Metadata to attach to the group.

None

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def identify_group(
    group: Union[str, AnalyticsGroup],
    group_id: str,
    group_metadata: Optional[Dict[str, Any]] = None,
) -> bool:
    """Attach metadata to a segment group.

    Args:
        group: Group to track.
        group_id: ID of the group.
        group_metadata: Metadata to attach to the group.

    Returns:
        True if event is sent successfully, False is not.
    """
    with AnalyticsContext() as analytics:
        return analytics.group(group, group_id, traits=group_metadata)

    return False

identify_user(user_metadata=None)

Attach metadata to user directly.

Parameters:

Name Type Description Default
user_metadata Optional[Dict[str, Any]]

Dict of metadata to attach to the user.

None

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def identify_user(user_metadata: Optional[Dict[str, Any]] = None) -> bool:
    """Attach metadata to user directly.

    Args:
        user_metadata: Dict of metadata to attach to the user.

    Returns:
        True if event is sent successfully, False is not.
    """
    with AnalyticsContext() as analytics:

        if user_metadata is None:
            return False

        return analytics.identify(traits=user_metadata)

    return False

parametrized(dec)

This is a meta-decorator, that is, a decorator for decorators.

As a decorator is a function, it actually works as a regular decorator with arguments.

Parameters:

Name Type Description Default
dec Callable[..., Callable[..., Any]]

Decorator to be applied to the function.

required

Returns:

Type Description
Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]

Decorator that applies the given decorator to the function.

Source code in zenml/utils/analytics_utils.py
def parametrized(
    dec: Callable[..., Callable[..., Any]]
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]:
    """This is a meta-decorator, that is, a decorator for decorators.

    As a decorator is a function, it actually works as a regular decorator
    with arguments.

    Args:
        dec: Decorator to be applied to the function.

    Returns:
        Decorator that applies the given decorator to the function.
    """

    def layer(
        *args: Any, **kwargs: Any
    ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
        """Internal layer.

        Args:
            *args: Arguments to be passed to the decorator.
            **kwargs: Keyword arguments to be passed to the decorator.

        Returns:
            Decorator that applies the given decorator to the function.
        """

        def repl(f: Callable[..., Any]) -> Callable[..., Any]:
            """Internal REPL.

            Args:
                f: Function to be decorated.

            Returns:
                Decorated function.
            """
            return dec(f, *args, **kwargs)

        return repl

    return layer

track(*args, **kwargs)

Internal layer.

Parameters:

Name Type Description Default
*args Any

Arguments to be passed to the decorator.

()
**kwargs Any

Keyword arguments to be passed to the decorator.

{}

Returns:

Type Description
Callable[[Callable[..., Any]], Callable[..., Any]]

Decorator that applies the given decorator to the function.

Source code in zenml/utils/analytics_utils.py
def layer(
    *args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Internal layer.

    Args:
        *args: Arguments to be passed to the decorator.
        **kwargs: Keyword arguments to be passed to the decorator.

    Returns:
        Decorator that applies the given decorator to the function.
    """

    def repl(f: Callable[..., Any]) -> Callable[..., Any]:
        """Internal REPL.

        Args:
            f: Function to be decorated.

        Returns:
            Decorated function.
        """
        return dec(f, *args, **kwargs)

    return repl

track_event(event, metadata=None)

Track segment event if user opted-in.

Parameters:

Name Type Description Default
event Union[str, zenml.utils.analytics_utils.AnalyticsEvent]

Name of event to track in segment.

required
metadata Optional[Dict[str, Any]]

Dict of metadata to track.

None

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def track_event(
    event: Union[str, AnalyticsEvent],
    metadata: Optional[Dict[str, Any]] = None,
) -> bool:
    """Track segment event if user opted-in.

    Args:
        event: Name of event to track in segment.
        metadata: Dict of metadata to track.

    Returns:
        True if event is sent successfully, False is not.
    """
    with AnalyticsContext() as analytics:
        return analytics.track(event, metadata)

    return False

daemon

Utility functions to start/stop daemon processes.

This is only implemented for UNIX systems and therefore doesn't work on Windows. Based on https://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/

check_if_daemon_is_running(pid_file)

Checks whether a daemon process indicated by the PID file is running.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to check.

required

Returns:

Type Description
bool

True if the daemon process is running, otherwise False.

Source code in zenml/utils/daemon.py
def check_if_daemon_is_running(pid_file: str) -> bool:
    """Checks whether a daemon process indicated by the PID file is running.

    Args:
        pid_file: Path to file containing the PID of the daemon
            process to check.

    Returns:
        True if the daemon process is running, otherwise False.
    """
    return get_daemon_pid_if_running(pid_file) is not None

daemonize(pid_file, log_file=None, working_directory='/')

Decorator that executes the decorated function as a daemon process.

Use this decorator to easily transform any function into a daemon process.

Examples:

import time
from zenml.utils.daemonizer import daemonize


@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
    print(f"I'm a daemon! I will sleep for {period} seconds.")
    time.sleep(period)
    print("Done sleeping, flying away.")

sleeping_daemon(period=30)

print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel

Parameters:

Name Type Description Default
pid_file str

an optional file where the PID of the daemon process will be stored.

required
log_file Optional[str]

file where stdout and stderr are redirected for the daemon process. If not supplied, the daemon will be silenced (i.e. have its stdout/stderr redirected to /dev/null).

None
working_directory str

working directory for the daemon process, defaults to the root directory.

'/'

Returns:

Type Description
Callable[[~F], ~F]

Decorated function that, when called, will detach from the current process and continue executing in the background, as a daemon process.

Source code in zenml/utils/daemon.py
def daemonize(
    pid_file: str,
    log_file: Optional[str] = None,
    working_directory: str = "/",
) -> Callable[[F], F]:
    """Decorator that executes the decorated function as a daemon process.

    Use this decorator to easily transform any function into a daemon
    process.

    Example:

    ```python
    import time
    from zenml.utils.daemonizer import daemonize


    @daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
    def sleeping_daemon(period: int) -> None:
        print(f"I'm a daemon! I will sleep for {period} seconds.")
        time.sleep(period)
        print("Done sleeping, flying away.")

    sleeping_daemon(period=30)

    print("I'm the daemon's parent!.")
    time.sleep(10) # just to prove that the daemon is running in parallel
    ```

    Args:
        pid_file: an optional file where the PID of the daemon process will
            be stored.
        log_file: file where stdout and stderr are redirected for the daemon
            process. If not supplied, the daemon will be silenced (i.e. have
            its stdout/stderr redirected to /dev/null).
        working_directory: working directory for the daemon process,
            defaults to the root directory.

    Returns:
        Decorated function that, when called, will detach from the current
        process and continue executing in the background, as a daemon
        process.
    """

    def inner_decorator(_func: F) -> F:
        def daemon(*args: Any, **kwargs: Any) -> None:
            """Standard daemonization of a process.

            Args:
                *args: Arguments to be passed to the decorated function.
                **kwargs: Keyword arguments to be passed to the decorated
                    function.
            """
            # flake8: noqa: C901
            if sys.platform == "win32":
                logger.error(
                    "Daemon functionality is currently not supported on Windows."
                )
            else:
                run_as_daemon(
                    _func,
                    log_file=log_file,
                    pid_file=pid_file,
                    working_directory=working_directory,
                    *args,
                    **kwargs,
                )

        return cast(F, daemon)

    return inner_decorator

get_daemon_pid_if_running(pid_file)

Read and return the PID value from a PID file.

It does this if the daemon process tracked by the PID file is running.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to check.

required

Returns:

Type Description
Optional[int]

The PID of the daemon process if it is running, otherwise None.

Source code in zenml/utils/daemon.py
def get_daemon_pid_if_running(pid_file: str) -> Optional[int]:
    """Read and return the PID value from a PID file.

    It does this if the daemon process tracked by the PID file is running.

    Args:
        pid_file: Path to file containing the PID of the daemon
            process to check.

    Returns:
        The PID of the daemon process if it is running, otherwise None.
    """
    try:
        with open(pid_file, "r") as f:
            pid = int(f.read().strip())
    except (IOError, FileNotFoundError):
        logger.debug(
            f"Daemon PID file '{pid_file}' does not exist or cannot be read."
        )
        return None

    if not pid or not psutil.pid_exists(pid):
        logger.debug(f"Daemon with PID '{pid}' is no longer running.")
        return None

    logger.debug(f"Daemon with PID '{pid}' is running.")
    return pid

run_as_daemon(daemon_function, *args, *, pid_file, log_file=None, working_directory='/', **kwargs)

Runs a function as a daemon process.

Parameters:

Name Type Description Default
daemon_function ~F

The function to run as a daemon.

required
pid_file str

Path to file in which to store the PID of the daemon process.

required
log_file Optional[str]

Optional file to which the daemons stdout/stderr will be redirected to.

None
working_directory str

Working directory for the daemon process, defaults to the root directory.

'/'
args Any

Positional arguments to pass to the daemon function.

()
kwargs Any

Keyword arguments to pass to the daemon function.

{}

Exceptions:

Type Description
FileExistsError

If the PID file already exists.

Source code in zenml/utils/daemon.py
def run_as_daemon(
    daemon_function: F,
    *args: Any,
    pid_file: str,
    log_file: Optional[str] = None,
    working_directory: str = "/",
    **kwargs: Any,
) -> None:
    """Runs a function as a daemon process.

    Args:
        daemon_function: The function to run as a daemon.
        pid_file: Path to file in which to store the PID of the daemon
            process.
        log_file: Optional file to which the daemons stdout/stderr will be
            redirected to.
        working_directory: Working directory for the daemon process,
            defaults to the root directory.
        args: Positional arguments to pass to the daemon function.
        kwargs: Keyword arguments to pass to the daemon function.

    Raises:
        FileExistsError: If the PID file already exists.
    """
    # convert to absolute path as we will change working directory later
    if pid_file:
        pid_file = os.path.abspath(pid_file)
    if log_file:
        log_file = os.path.abspath(log_file)

    # create parent directory if necessary
    dir_name = os.path.dirname(pid_file)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

    # check if PID file exists
    if pid_file and os.path.exists(pid_file):
        pid = get_daemon_pid_if_running(pid_file)
        if pid:
            raise FileExistsError(
                f"The PID file '{pid_file}' already exists and a daemon "
                f"process with the same PID '{pid}' is already running."
                f"Please remove the PID file or kill the daemon process "
                f"before starting a new daemon."
            )
        logger.warning(
            f"Removing left over PID file '{pid_file}' from a previous "
            f"daemon process that didn't shut down correctly."
        )
        os.remove(pid_file)

    # first fork
    try:
        pid = os.fork()
        if pid > 0:
            # this is the process that called `run_as_daemon` so we
            # simply return so it can keep running
            return
    except OSError as e:
        logger.error("Unable to fork (error code: %d)", e.errno)
        sys.exit(1)

    # decouple from parent environment
    os.chdir(working_directory)
    os.setsid()
    os.umask(0o22)

    # second fork
    try:
        pid = os.fork()
        if pid > 0:
            # this is the parent of the future daemon process, kill it
            # so the daemon gets adopted by the init process
            sys.exit(0)
    except OSError as e:
        sys.stderr.write(f"Unable to fork (error code: {e.errno})")
        sys.exit(1)

    # redirect standard file descriptors to devnull (or the given logfile)
    devnull = "/dev/null"
    if hasattr(os, "devnull"):
        devnull = os.devnull

    devnull_fd = os.open(devnull, os.O_RDWR)
    log_fd = (
        os.open(log_file, os.O_CREAT | os.O_RDWR | os.O_APPEND)
        if log_file
        else None
    )
    out_fd = log_fd or devnull_fd

    os.dup2(devnull_fd, sys.stdin.fileno())
    os.dup2(out_fd, sys.stdout.fileno())
    os.dup2(out_fd, sys.stderr.fileno())

    if pid_file:
        # write the PID file
        with open(pid_file, "w+") as f:
            f.write(f"{os.getpid()}\n")

    # register actions in case this process exits/gets killed
    def cleanup() -> None:
        """Daemon cleanup."""
        sys.stderr.write("Cleanup: terminating children processes...\n")
        terminate_children()
        if pid_file and os.path.exists(pid_file):
            sys.stderr.write(f"Cleanup: removing PID file {pid_file}...\n")
            os.remove(pid_file)
        sys.stderr.flush()

    def sighndl(signum: int, frame: Optional[types.FrameType]) -> None:
        """Daemon signal handler.

        Args:
            signum: Signal number.
            frame: Frame object.
        """
        sys.stderr.write(f"Handling signal {signum}...\n")
        cleanup()

    signal.signal(signal.SIGTERM, sighndl)
    signal.signal(signal.SIGINT, sighndl)
    atexit.register(cleanup)

    # finally run the actual daemon code
    daemon_function(*args, **kwargs)
    sys.exit(0)

stop_daemon(pid_file)

Stops a daemon process.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to kill.

required
Source code in zenml/utils/daemon.py
def stop_daemon(pid_file: str) -> None:
    """Stops a daemon process.

    Args:
        pid_file: Path to file containing the PID of the daemon process to
            kill.
    """
    try:
        with open(pid_file, "r") as f:
            pid = int(f.read().strip())
    except (IOError, FileNotFoundError):
        logger.warning("Daemon PID file '%s' does not exist.", pid_file)
        return

    if psutil.pid_exists(pid):
        process = psutil.Process(pid)
        process.terminate()
    else:
        logger.warning("PID from '%s' does not exist.", pid_file)

terminate_children()

Terminate all processes that are children of the currently running process.

Source code in zenml/utils/daemon.py
def terminate_children() -> None:
    """Terminate all processes that are children of the currently running process."""
    pid = os.getpid()
    try:
        parent = psutil.Process(pid)
    except psutil.Error:
        # could not find parent process id
        return
    children = parent.children(recursive=False)

    for p in children:
        sys.stderr.write(f"Terminating child process with PID {p.pid}...\n")
        p.terminate()
    _, alive = psutil.wait_procs(
        children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
    )
    for p in alive:
        sys.stderr.write(f"Killing child process with PID {p.pid}...\n")
        p.kill()
    _, alive = psutil.wait_procs(
        children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
    )

dashboard_utils

Utility class to help with interacting with the dashboard.

get_run_url(run_name, pipeline_id=None)

Computes a dashboard url to directly view the run.

Parameters:

Name Type Description Default
run_name str

Name of the pipeline run.

required
pipeline_id Optional[uuid.UUID]

Optional pipeline_id, to be sent when available.

None

Returns:

Type Description
Optional[str]

A direct url link to the pipeline run details page. If run does not exist, returns None.

Source code in zenml/utils/dashboard_utils.py
def get_run_url(
    run_name: str, pipeline_id: Optional[UUID] = None
) -> Optional[str]:
    """Computes a dashboard url to directly view the run.

    Args:
        run_name: Name of the pipeline run.
        pipeline_id: Optional pipeline_id, to be sent when available.

    Returns:
        A direct url link to the pipeline run details page. If run does not exist,
        returns None.
    """
    # Connected to ZenML Server
    client = Client()

    if client.zen_store.type != StoreType.REST:
        return ""

    url = client.zen_store.url
    runs = client.zen_store.list_runs(run_name=run_name)

    if pipeline_id:
        url += f"/pipelines/{str(pipeline_id)}/runs"
    elif runs:
        url += "/runs"
    else:
        url += "/pipelines/all-runs"

    if runs:
        url += f"/{runs[0].id}/dag"

    return url

print_run_url(run_name, pipeline_id=None)

Logs a dashboard url to directly view the run.

Parameters:

Name Type Description Default
run_name str

Name of the pipeline run.

required
pipeline_id Optional[uuid.UUID]

Optional pipeline_id, to be sent when available.

None
Source code in zenml/utils/dashboard_utils.py
def print_run_url(run_name: str, pipeline_id: Optional[UUID] = None) -> None:
    """Logs a dashboard url to directly view the run.

    Args:
        run_name: Name of the pipeline run.
        pipeline_id: Optional pipeline_id, to be sent when available.
    """
    client = Client()

    if client.zen_store.type == StoreType.REST:
        url = get_run_url(
            run_name,
            pipeline_id,
        )
        if url:
            logger.info(f"Dashboard URL: {url}")
    elif client.zen_store.type == StoreType.SQL:
        # Connected to SQL Store Type, we're local
        logger.info(
            "Pipeline visualization can be seen in the ZenML Dashboard. "
            "Run `zenml up` to see your pipeline!"
        )

deprecation_utils

Deprecation utilities.

deprecate_pydantic_attributes(*attributes)

Utility function for deprecating and migrating pydantic attributes.

Usage: To use this, you can specify it on any pydantic BaseModel subclass like this (all the deprecated attributes need to be non-required):

from pydantic import BaseModel
from typing import Optional

class MyModel(BaseModel):
    deprecated: Optional[int] = None

    old_name: Optional[str] = None
    new_name: str

    _deprecation_validator = deprecate_pydantic_attributes(
        "deprecated", ("old_name", "new_name")
    )

Parameters:

Name Type Description Default
*attributes Union[str, Tuple[str, str]]

List of attributes to deprecate. This is either the name of the attribute to deprecate, or a tuple containing the name of the deprecated attribute and it's replacement.

()

Returns:

Type Description
AnyClassMethod

Pydantic validator class method to be used on BaseModel subclasses to deprecate or migrate attributes.

Source code in zenml/utils/deprecation_utils.py
def deprecate_pydantic_attributes(
    *attributes: Union[str, Tuple[str, str]]
) -> "AnyClassMethod":
    """Utility function for deprecating and migrating pydantic attributes.

    **Usage**:
    To use this, you can specify it on any pydantic BaseModel subclass like
    this (all the deprecated attributes need to be non-required):

    ```python
    from pydantic import BaseModel
    from typing import Optional

    class MyModel(BaseModel):
        deprecated: Optional[int] = None

        old_name: Optional[str] = None
        new_name: str

        _deprecation_validator = deprecate_pydantic_attributes(
            "deprecated", ("old_name", "new_name")
        )
    ```

    Args:
        *attributes: List of attributes to deprecate. This is either the name
            of the attribute to deprecate, or a tuple containing the name of
            the deprecated attribute and it's replacement.

    Returns:
        Pydantic validator class method to be used on BaseModel subclasses
        to deprecate or migrate attributes.
    """

    @root_validator(pre=True, allow_reuse=True)
    def _deprecation_validator(
        cls: Type[BaseModel], values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic validator function for deprecating pydantic attributes.

        Args:
            cls: The class on which the attributes are defined.
            values: All values passed at model initialization.

        Raises:
            AssertionError: If either the deprecated or replacement attribute
                don't exist.
            TypeError: If the deprecated attribute is a required attribute.
            ValueError: If the deprecated attribute and replacement attribute
                contain different values.

        Returns:
            Input values with potentially migrated values.
        """
        previous_deprecation_warnings: Set[str] = getattr(
            cls, PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE, set()
        )

        def _warn(message: str, attribute: str) -> None:
            """Logs and raises a warning for a deprecated attribute.

            Args:
                message: The warning message.
                attribute: The name of the attribute.
            """
            if attribute not in previous_deprecation_warnings:
                logger.warning(message)
                previous_deprecation_warnings.add(attribute)

            warnings.warn(
                message,
                DeprecationWarning,
            )

        for attribute in attributes:
            if isinstance(attribute, str):
                deprecated_attribute = attribute
                replacement_attribute = None
            else:
                deprecated_attribute, replacement_attribute = attribute

                assert (
                    replacement_attribute in cls.__fields__
                ), f"Unable to find attribute {replacement_attribute}."

            assert (
                deprecated_attribute in cls.__fields__
            ), f"Unable to find attribute {deprecated_attribute}."

            if cls.__fields__[deprecated_attribute].required:
                raise TypeError(
                    f"Unable to deprecate attribute '{deprecated_attribute}' "
                    f"of class {cls.__name__}. In order to deprecate an "
                    "attribute, it needs to be a non-required attribute. "
                    "To do so, mark the attribute with an `Optional[...] type "
                    "annotation."
                )

            if values.get(deprecated_attribute, None) is None:
                continue

            if replacement_attribute is None:
                _warn(
                    message=f"The attribute `{deprecated_attribute}` of class "
                    f"`{cls.__name__}` will be deprecated soon.",
                    attribute=deprecated_attribute,
                )
                continue

            _warn(
                message=f"The attribute `{deprecated_attribute}` of class "
                f"`{cls.__name__}` will be deprecated soon. Use the "
                f"attribute `{replacement_attribute}` instead.",
                attribute=deprecated_attribute,
            )

            if values.get(replacement_attribute, None) is None:
                logger.debug(
                    "Migrating value of deprecated attribute %s to "
                    "replacement attribute %s.",
                    deprecated_attribute,
                    replacement_attribute,
                )
                values[replacement_attribute] = values.pop(deprecated_attribute)
            elif values[deprecated_attribute] != values[replacement_attribute]:
                raise ValueError(
                    "Got different values for deprecated attribute "
                    f"{deprecated_attribute} and replacement "
                    f"attribute {replacement_attribute}."
                )
            else:
                # Both values are identical, no need to do anything
                pass

        setattr(
            cls,
            PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE,
            previous_deprecation_warnings,
        )

        return values

    return _deprecation_validator

dict_utils

Util functions for dictionaries.

recursive_update(original, update)

Recursively updates a dictionary.

Parameters:

Name Type Description Default
original Dict[str, Any]

The dictionary to update.

required
update Dict[str, Any]

The dictionary containing the updated values.

required

Exceptions:

Type Description
TypeError

If the value types of original and update don't match.

Returns:

Type Description
Dict[str, Any]

The updated dictionary.

Source code in zenml/utils/dict_utils.py
def recursive_update(
    original: Dict[str, Any], update: Dict[str, Any]
) -> Dict[str, Any]:
    """Recursively updates a dictionary.

    Args:
        original: The dictionary to update.
        update: The dictionary containing the updated values.

    Raises:
        TypeError: If the value types of original and update don't match.

    Returns:
        The updated dictionary.
    """
    for key, value in update.items():
        if isinstance(value, Dict):
            original_value = original.get(key, None) or {}
            if not isinstance(original_value, Dict):
                raise TypeError(
                    f"Type of dictionary values for key {key} does not match "
                    "in original and update dict (original="
                    f"{type(original_value)}, update={type(value)})."
                )

            original[key] = recursive_update(original_value, value)
        else:
            original[key] = value
    return original

remove_none_values(dict_, recursive=False)

Removes all key-value pairs with None value.

Parameters:

Name Type Description Default
dict_ Dict[str, Any]

The dict from which the key-value pairs should be removed.

required
recursive bool

If True, will recursively remove None values in all child dicts.

False

Returns:

Type Description
Dict[str, Any]

The updated dictionary.

Source code in zenml/utils/dict_utils.py
def remove_none_values(
    dict_: Dict[str, Any], recursive: bool = False
) -> Dict[str, Any]:
    """Removes all key-value pairs with `None` value.

    Args:
        dict_: The dict from which the key-value pairs should be removed.
        recursive: If `True`, will recursively remove `None` values in all
            child dicts.

    Returns:
        The updated dictionary.
    """

    def _maybe_recurse(value: Any) -> Any:
        """Calls `remove_none_values` recursively if required.

        Args:
            value: A dictionary value.

        Returns:
            The updated dictionary value.
        """
        if recursive and isinstance(value, Dict):
            return remove_none_values(value, recursive=True)
        else:
            return value

    return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None}

docker_utils

Utility functions relating to Docker.

build_image(image_name, dockerfile, build_context_root=None, dockerignore=None, extra_files=(), **custom_build_options)

Builds a docker image.

Parameters:

Name Type Description Default
image_name str

The name to use for the built docker image.

required
dockerfile Union[str, List[str]]

Path to a dockerfile or a list of strings representing the Dockerfile lines/commands.

required
build_context_root Optional[str]

Optional path to a directory that will be sent to the Docker daemon as build context. If left empty, the Docker build context will be empty.

None
dockerignore Optional[str]

Optional path to a dockerignore file. If no value is given, the .dockerignore in the root of the build context will be used if it exists. Otherwise, all files inside build_context_root are included in the build context.

None
extra_files Sequence[Tuple[str, str]]

Additional files to include in the build context. The files should be passed as a tuple (filepath_inside_build_context, file_content) and will overwrite existing files in the build context if they share the same path.

()
**custom_build_options Any

Additional options that will be passed unmodified to the Docker build call when building the image. You can use this to for example specify build args or a target stage. See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build for a full list of available options.

{}
Source code in zenml/utils/docker_utils.py
def build_image(
    image_name: str,
    dockerfile: Union[str, List[str]],
    build_context_root: Optional[str] = None,
    dockerignore: Optional[str] = None,
    extra_files: Sequence[Tuple[str, str]] = (),
    **custom_build_options: Any,
) -> None:
    """Builds a docker image.

    Args:
        image_name: The name to use for the built docker image.
        dockerfile: Path to a dockerfile or a list of strings representing the
            Dockerfile lines/commands.
        build_context_root: Optional path to a directory that will be sent to
            the Docker daemon as build context. If left empty, the Docker build
            context will be empty.
        dockerignore: Optional path to a dockerignore file. If no value is
            given, the .dockerignore in the root of the build context will be
            used if it exists. Otherwise, all files inside `build_context_root`
            are included in the build context.
        extra_files: Additional files to include in the build context. The
            files should be passed as a tuple
            (filepath_inside_build_context, file_content) and will overwrite
            existing files in the build context if they share the same path.
        **custom_build_options: Additional options that will be passed
            unmodified to the Docker build call when building the image. You
            can use this to for example specify build args or a target stage.
            See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build
            for a full list of available options.
    """
    if isinstance(dockerfile, str):
        dockerfile_contents = io_utils.read_file_contents_as_string(dockerfile)
        logger.info("Using Dockerfile `%s`.", os.path.abspath(dockerfile))
    else:
        dockerfile_contents = "\n".join(dockerfile)

    build_context = _create_custom_build_context(
        dockerfile_contents=dockerfile_contents,
        build_context_root=build_context_root,
        dockerignore=dockerignore,
        extra_files=extra_files,
    )

    build_options = {
        "rm": False,  # don't remove intermediate containers to improve caching
        "pull": True,  # always pull parent images
        **custom_build_options,
    }

    logger.info("Building Docker image `%s`.", image_name)
    logger.debug("Docker build options: %s", build_options)

    logger.info("Building the image might take a while...")

    docker_client = DockerClient.from_env()
    # We use the client api directly here, so we can stream the logs
    output_stream = docker_client.images.client.api.build(
        fileobj=build_context,
        custom_context=True,
        tag=image_name,
        **build_options,
    )
    _process_stream(output_stream)

    logger.info("Finished building Docker image `%s`.", image_name)

check_docker()

Checks if Docker is installed and running.

Returns:

Type Description
bool

True if Docker is installed, False otherwise.

Source code in zenml/utils/docker_utils.py
def check_docker() -> bool:
    """Checks if Docker is installed and running.

    Returns:
        `True` if Docker is installed, `False` otherwise.
    """
    # Try to ping Docker, to see if it's running
    try:
        docker_client = DockerClient.from_env()
        docker_client.ping()
        return True
    except Exception:
        logger.debug("Docker is not running.", exc_info=True)

    return False

get_image_digest(image_name)

Gets the digest of an image.

Parameters:

Name Type Description Default
image_name str

Name of the image to get the digest for.

required

Returns:

Type Description
Optional[str]

Returns the repo digest for the given image if there exists exactly one. If there are zero or multiple repo digests, returns None.

Source code in zenml/utils/docker_utils.py
def get_image_digest(image_name: str) -> Optional[str]:
    """Gets the digest of an image.

    Args:
        image_name: Name of the image to get the digest for.

    Returns:
        Returns the repo digest for the given image if there exists exactly one.
        If there are zero or multiple repo digests, returns `None`.
    """
    docker_client = DockerClient.from_env()
    image = docker_client.images.get(image_name)
    repo_digests = image.attrs["RepoDigests"]
    if len(repo_digests) == 1:
        return cast(str, repo_digests[0])
    else:
        logger.debug(
            "Found zero or more repo digests for docker image '%s': %s",
            image_name,
            repo_digests,
        )
        return None

is_local_image(image_name)

Returns whether an image was pulled from a registry or not.

Parameters:

Name Type Description Default
image_name str

Name of the image to check.

required

Returns:

Type Description
bool

True if the image was pulled from a registry, False otherwise.

Source code in zenml/utils/docker_utils.py
def is_local_image(image_name: str) -> bool:
    """Returns whether an image was pulled from a registry or not.

    Args:
        image_name: Name of the image to check.

    Returns:
        `True` if the image was pulled from a registry, `False` otherwise.
    """
    docker_client = DockerClient.from_env()
    images = docker_client.images.list(name=image_name)
    if images:
        # An image with this name is available locally -> now check whether it
        # was pulled from a repo or built locally (in which case the repo
        # digest is empty)
        return get_image_digest(image_name) is None
    else:
        # no image with this name found locally
        return False

push_image(image_name)

Pushes an image to a container registry.

Parameters:

Name Type Description Default
image_name str

The full name (including a tag) of the image to push.

required

Returns:

Type Description
str

The Docker repository digest of the pushed image.

Exceptions:

Type Description
RuntimeError

If fetching the repository digest of the image failed.

Source code in zenml/utils/docker_utils.py
def push_image(image_name: str) -> str:
    """Pushes an image to a container registry.

    Args:
        image_name: The full name (including a tag) of the image to push.

    Returns:
        The Docker repository digest of the pushed image.

    Raises:
        RuntimeError: If fetching the repository digest of the image failed.
    """
    logger.info("Pushing Docker image `%s`.", image_name)
    docker_client = DockerClient.from_env()
    output_stream = docker_client.images.push(image_name, stream=True)
    aux_info = _process_stream(output_stream)
    logger.info("Finished pushing Docker image.")

    image_name_without_tag, _ = image_name.rsplit(":", maxsplit=1)
    for info in reversed(aux_info):
        try:
            repo_digest = info["Digest"]
            return f"{image_name_without_tag}@{repo_digest}"
        except KeyError:
            pass
    else:
        raise RuntimeError(
            f"Unable to find repo digest after pushing image {image_name}."
        )

tag_image(image_name, target)

Tags an image.

Parameters:

Name Type Description Default
image_name str

The name of the image to tag.

required
target str

The full target name including a tag.

required
Source code in zenml/utils/docker_utils.py
def tag_image(image_name: str, target: str) -> None:
    """Tags an image.

    Args:
        image_name: The name of the image to tag.
        target: The full target name including a tag.
    """
    docker_client = DockerClient.from_env()
    image = docker_client.images.get(image_name)
    image.tag(target)

enum_utils

Util functions for enums.

StrEnum (str, Enum)

Base enum type for string enum values.

Source code in zenml/utils/enum_utils.py
class StrEnum(str, Enum):
    """Base enum type for string enum values."""

    def __str__(self) -> str:
        """Returns the enum string value.

        Returns:
            The enum string value.
        """
        return self.value  # type: ignore

    @classmethod
    def names(cls) -> List[str]:
        """Get all enum names as a list of strings.

        Returns:
            A list of all enum names.
        """
        return [c.name for c in cls]

    @classmethod
    def values(cls) -> List[str]:
        """Get all enum values as a list of strings.

        Returns:
            A list of all enum values.
        """
        return [c.value for c in cls]

filesync_model

Filesync utils for ZenML.

FileSyncModel (BaseModel) pydantic-model

Pydantic model synchronized with a configuration file.

Use this class as a base Pydantic model that is automatically synchronized with a configuration file on disk.

This class overrides the setattr and getattr magic methods to ensure that the FileSyncModel instance acts as an in-memory cache of the information stored in the associated configuration file.

Source code in zenml/utils/filesync_model.py
class FileSyncModel(BaseModel):
    """Pydantic model synchronized with a configuration file.

    Use this class as a base Pydantic model that is automatically synchronized
    with a configuration file on disk.

    This class overrides the __setattr__ and __getattr__ magic methods to
    ensure that the FileSyncModel instance acts as an in-memory cache of the
    information stored in the associated configuration file.
    """

    _config_file: str
    _config_file_timestamp: Optional[float]

    def __init__(self, config_file: str, **kwargs: Any) -> None:
        """Create a FileSyncModel instance synchronized with a configuration file on disk.

        Args:
            config_file: configuration file path. If the file exists, the model
                will be initialized with the values from the file.
            **kwargs: additional keyword arguments to pass to the Pydantic model
                constructor. If supplied, these values will override those
                loaded from the configuration file.
        """
        config_dict = {}
        if fileio.exists(config_file):
            config_dict = yaml_utils.read_yaml(config_file)

        self._config_file = config_file
        self._config_file_timestamp = None

        config_dict.update(kwargs)
        super(FileSyncModel, self).__init__(**config_dict)

        # write the configuration file to disk, to reflect new attributes
        # and schema changes
        self.write_config()

    def __setattr__(self, key: str, value: Any) -> None:
        """Sets an attribute on the model and persists it in the configuration file.

        Args:
            key: attribute name.
            value: attribute value.
        """
        super(FileSyncModel, self).__setattr__(key, value)
        if key.startswith("_"):
            return
        self.write_config()

    def __getattribute__(self, key: str) -> Any:
        """Gets an attribute value for a specific key.

        Args:
            key: attribute name.

        Returns:
            attribute value.
        """
        if not key.startswith("_") and key in self.__dict__:
            self.load_config()
        return super(FileSyncModel, self).__getattribute__(key)

    def write_config(self) -> None:
        """Writes the model to the configuration file."""
        config_dict = json.loads(self.json())
        yaml_utils.write_yaml(self._config_file, config_dict)
        self._config_file_timestamp = os.path.getmtime(self._config_file)

    def load_config(self) -> None:
        """Loads the model from the configuration file on disk."""
        if not fileio.exists(self._config_file):
            return

        # don't reload the configuration if the file hasn't
        # been updated since the last load
        file_timestamp = os.path.getmtime(self._config_file)
        if file_timestamp == self._config_file_timestamp:
            return

        if self._config_file_timestamp is not None:
            logger.info(f"Reloading configuration file {self._config_file}")

        # refresh the model from the configuration file values
        config_dict = yaml_utils.read_yaml(self._config_file)
        for key, value in config_dict.items():
            super(FileSyncModel, self).__setattr__(key, value)

        self._config_file_timestamp = file_timestamp

    class Config:
        """Pydantic configuration class."""

        # all attributes with leading underscore are private and therefore
        # are mutable and not included in serialization
        underscore_attrs_are_private = True
Config

Pydantic configuration class.

Source code in zenml/utils/filesync_model.py
class Config:
    """Pydantic configuration class."""

    # all attributes with leading underscore are private and therefore
    # are mutable and not included in serialization
    underscore_attrs_are_private = True
__getattribute__(self, key) special

Gets an attribute value for a specific key.

Parameters:

Name Type Description Default
key str

attribute name.

required

Returns:

Type Description
Any

attribute value.

Source code in zenml/utils/filesync_model.py
def __getattribute__(self, key: str) -> Any:
    """Gets an attribute value for a specific key.

    Args:
        key: attribute name.

    Returns:
        attribute value.
    """
    if not key.startswith("_") and key in self.__dict__:
        self.load_config()
    return super(FileSyncModel, self).__getattribute__(key)
__init__(self, config_file, **kwargs) special

Create a FileSyncModel instance synchronized with a configuration file on disk.

Parameters:

Name Type Description Default
config_file str

configuration file path. If the file exists, the model will be initialized with the values from the file.

required
**kwargs Any

additional keyword arguments to pass to the Pydantic model constructor. If supplied, these values will override those loaded from the configuration file.

{}
Source code in zenml/utils/filesync_model.py
def __init__(self, config_file: str, **kwargs: Any) -> None:
    """Create a FileSyncModel instance synchronized with a configuration file on disk.

    Args:
        config_file: configuration file path. If the file exists, the model
            will be initialized with the values from the file.
        **kwargs: additional keyword arguments to pass to the Pydantic model
            constructor. If supplied, these values will override those
            loaded from the configuration file.
    """
    config_dict = {}
    if fileio.exists(config_file):
        config_dict = yaml_utils.read_yaml(config_file)

    self._config_file = config_file
    self._config_file_timestamp = None

    config_dict.update(kwargs)
    super(FileSyncModel, self).__init__(**config_dict)

    # write the configuration file to disk, to reflect new attributes
    # and schema changes
    self.write_config()
__setattr__(self, key, value) special

Sets an attribute on the model and persists it in the configuration file.

Parameters:

Name Type Description Default
key str

attribute name.

required
value Any

attribute value.

required
Source code in zenml/utils/filesync_model.py
def __setattr__(self, key: str, value: Any) -> None:
    """Sets an attribute on the model and persists it in the configuration file.

    Args:
        key: attribute name.
        value: attribute value.
    """
    super(FileSyncModel, self).__setattr__(key, value)
    if key.startswith("_"):
        return
    self.write_config()
load_config(self)

Loads the model from the configuration file on disk.

Source code in zenml/utils/filesync_model.py
def load_config(self) -> None:
    """Loads the model from the configuration file on disk."""
    if not fileio.exists(self._config_file):
        return

    # don't reload the configuration if the file hasn't
    # been updated since the last load
    file_timestamp = os.path.getmtime(self._config_file)
    if file_timestamp == self._config_file_timestamp:
        return

    if self._config_file_timestamp is not None:
        logger.info(f"Reloading configuration file {self._config_file}")

    # refresh the model from the configuration file values
    config_dict = yaml_utils.read_yaml(self._config_file)
    for key, value in config_dict.items():
        super(FileSyncModel, self).__setattr__(key, value)

    self._config_file_timestamp = file_timestamp
write_config(self)

Writes the model to the configuration file.

Source code in zenml/utils/filesync_model.py
def write_config(self) -> None:
    """Writes the model to the configuration file."""
    config_dict = json.loads(self.json())
    yaml_utils.write_yaml(self._config_file, config_dict)
    self._config_file_timestamp = os.path.getmtime(self._config_file)

io_utils

Various utility functions for the io module.

convert_to_str(path)

Converts a PathType to a str using UTF-8.

Parameters:

Name Type Description Default
path PathType

Path to convert.

required

Returns:

Type Description
str

Converted path.

Source code in zenml/utils/io_utils.py
def convert_to_str(path: "PathType") -> str:
    """Converts a PathType to a str using UTF-8.

    Args:
        path: Path to convert.

    Returns:
        Converted path.
    """
    if isinstance(path, str):
        return path
    else:
        return path.decode("utf-8")

copy_dir(source_dir, destination_dir, overwrite=False)

Copies dir from source to destination.

Parameters:

Name Type Description Default
source_dir str

Path to copy from.

required
destination_dir str

Path to copy to.

required
overwrite bool

Boolean. If false, function throws an error before overwrite.

False
Source code in zenml/utils/io_utils.py
def copy_dir(
    source_dir: str, destination_dir: str, overwrite: bool = False
) -> None:
    """Copies dir from source to destination.

    Args:
        source_dir: Path to copy from.
        destination_dir: Path to copy to.
        overwrite: Boolean. If false, function throws an error before overwrite.
    """
    for source_file in listdir(source_dir):
        source_path = os.path.join(source_dir, convert_to_str(source_file))
        destination_path = os.path.join(
            destination_dir, convert_to_str(source_file)
        )
        if isdir(source_path):
            if source_path == destination_dir:
                # if the destination is a subdirectory of the source, we skip
                # copying it to avoid an infinite loop.
                return
            copy_dir(source_path, destination_path, overwrite)
        else:
            create_dir_recursive_if_not_exists(
                os.path.dirname(destination_path)
            )
            copy(str(source_path), str(destination_path), overwrite)

create_dir_if_not_exists(dir_path)

Creates directory if it does not exist.

Parameters:

Name Type Description Default
dir_path str

Local path in filesystem.

required
Source code in zenml/utils/io_utils.py
def create_dir_if_not_exists(dir_path: str) -> None:
    """Creates directory if it does not exist.

    Args:
        dir_path: Local path in filesystem.
    """
    if not isdir(dir_path):
        mkdir(dir_path)

create_dir_recursive_if_not_exists(dir_path)

Creates directory recursively if it does not exist.

Parameters:

Name Type Description Default
dir_path str

Local path in filesystem.

required
Source code in zenml/utils/io_utils.py
def create_dir_recursive_if_not_exists(dir_path: str) -> None:
    """Creates directory recursively if it does not exist.

    Args:
        dir_path: Local path in filesystem.
    """
    if not isdir(dir_path):
        makedirs(dir_path)

create_file_if_not_exists(file_path, file_contents='{}')

Creates file if it does not exist.

Parameters:

Name Type Description Default
file_path str

Local path in filesystem.

required
file_contents str

Contents of file.

'{}'
Source code in zenml/utils/io_utils.py
def create_file_if_not_exists(
    file_path: str, file_contents: str = "{}"
) -> None:
    """Creates file if it does not exist.

    Args:
        file_path: Local path in filesystem.
        file_contents: Contents of file.
    """
    full_path = Path(file_path)
    if not exists(file_path):
        create_dir_recursive_if_not_exists(str(full_path.parent))
        with open(str(full_path), "w") as f:
            f.write(file_contents)

find_files(dir_path, pattern)

Find files in a directory that match pattern.

Parameters:

Name Type Description Default
dir_path PathType

Path to directory.

required
pattern str

pattern like *.png.

required

Yields:

Type Description
Iterable[str]

All matching filenames if found.

Source code in zenml/utils/io_utils.py
def find_files(dir_path: "PathType", pattern: str) -> Iterable[str]:
    """Find files in a directory that match pattern.

    Args:
        dir_path: Path to directory.
        pattern: pattern like *.png.

    Yields:
        All matching filenames if found.
    """
    for root, dirs, files in walk(dir_path):
        for basename in files:
            if fnmatch.fnmatch(convert_to_str(basename), pattern):
                filename = os.path.join(
                    convert_to_str(root), convert_to_str(basename)
                )
                yield filename

get_global_config_directory()

Gets the global config directory for ZenML.

Returns:

Type Description
str

The global config directory for ZenML.

Source code in zenml/utils/io_utils.py
def get_global_config_directory() -> str:
    """Gets the global config directory for ZenML.

    Returns:
        The global config directory for ZenML.
    """
    env_var_path = os.getenv(ENV_ZENML_CONFIG_PATH)
    if env_var_path:
        return str(Path(env_var_path).resolve())
    return click.get_app_dir(APP_NAME)

get_grandparent(dir_path)

Get grandparent of dir.

Parameters:

Name Type Description Default
dir_path str

Path to directory.

required

Returns:

Type Description
str

The input path's parent's parent.

Source code in zenml/utils/io_utils.py
def get_grandparent(dir_path: str) -> str:
    """Get grandparent of dir.

    Args:
        dir_path: Path to directory.

    Returns:
        The input path's parent's parent.
    """
    return Path(dir_path).parent.parent.stem

get_parent(dir_path)

Get parent of dir.

Parameters:

Name Type Description Default
dir_path str

Path to directory.

required

Returns:

Type Description
str

Parent (stem) of the dir as a string.

Source code in zenml/utils/io_utils.py
def get_parent(dir_path: str) -> str:
    """Get parent of dir.

    Args:
        dir_path: Path to directory.

    Returns:
        Parent (stem) of the dir as a string.
    """
    return Path(dir_path).parent.stem

is_remote(path)

Returns True if path exists remotely.

Parameters:

Name Type Description Default
path str

Any path as a string.

required

Returns:

Type Description
bool

True if remote path, else False.

Source code in zenml/utils/io_utils.py
def is_remote(path: str) -> bool:
    """Returns True if path exists remotely.

    Args:
        path: Any path as a string.

    Returns:
        True if remote path, else False.
    """
    return any(path.startswith(prefix) for prefix in REMOTE_FS_PREFIX)

is_root(path)

Returns true if path has no parent in local filesystem.

Parameters:

Name Type Description Default
path str

Local path in filesystem.

required

Returns:

Type Description
bool

True if root, else False.

Source code in zenml/utils/io_utils.py
def is_root(path: str) -> bool:
    """Returns true if path has no parent in local filesystem.

    Args:
        path: Local path in filesystem.

    Returns:
        True if root, else False.
    """
    return Path(path).parent == Path(path)

read_file_contents_as_string(file_path)

Reads contents of file.

Parameters:

Name Type Description Default
file_path str

Path to file.

required

Returns:

Type Description
str

Contents of file.

Exceptions:

Type Description
FileNotFoundError

If file does not exist.

Source code in zenml/utils/io_utils.py
def read_file_contents_as_string(file_path: str) -> str:
    """Reads contents of file.

    Args:
        file_path: Path to file.

    Returns:
        Contents of file.

    Raises:
        FileNotFoundError: If file does not exist.
    """
    if not exists(file_path):
        raise FileNotFoundError(f"{file_path} does not exist!")
    with open(file_path) as f:
        return f.read()  # type: ignore[no-any-return]

resolve_relative_path(path)

Takes relative path and resolves it absolutely.

Parameters:

Name Type Description Default
path str

Local path in filesystem.

required

Returns:

Type Description
str

Resolved path.

Source code in zenml/utils/io_utils.py
def resolve_relative_path(path: str) -> str:
    """Takes relative path and resolves it absolutely.

    Args:
        path: Local path in filesystem.

    Returns:
        Resolved path.
    """
    if is_remote(path):
        return path
    return str(Path(path).resolve())

write_file_contents_as_string(file_path, content)

Writes contents of file.

Parameters:

Name Type Description Default
file_path str

Path to file.

required
content str

Contents of file.

required
Source code in zenml/utils/io_utils.py
def write_file_contents_as_string(file_path: str, content: str) -> None:
    """Writes contents of file.

    Args:
        file_path: Path to file.
        content: Contents of file.
    """
    with open(file_path, "w") as f:
        f.write(content)

materializer_utils

Util functions for models and materializers.

load_model_from_metadata(model_uri)

Load a zenml model artifact from a json file.

This function is used to load information from a Yaml file that was created by the save_model_metadata function. The information in the Yaml file is used to load the model into memory in the inference environment.

model_uri: the URI of the model checkpoint/files to load. datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.

Parameters:

Name Type Description Default
model_uri str

the artifact to extract the metadata from.

required

Returns:

Type Description
Any

The ML model object loaded into memory.

Source code in zenml/utils/materializer_utils.py
def load_model_from_metadata(model_uri: str) -> Any:
    """Load a zenml model artifact from a json file.

    This function is used to load information from a Yaml file that was created
    by the save_model_metadata function. The information in the Yaml file is
    used to load the model into memory in the inference environment.

    model_uri: the URI of the model checkpoint/files to load.
    datatype: the model type. This is the path to the model class.
    materializer: the materializer class. This is the path to the materializer class.

    Args:
        model_uri: the artifact to extract the metadata from.

    Returns:
        The ML model object loaded into memory.
    """
    with fileio.open(
        os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
    ) as f:
        metadata = read_yaml(f.name)
    model_artifact = Artifact()
    model_artifact.uri = model_uri
    model_artifact.properties[METADATA_DATATYPE].string_value = metadata[
        METADATA_DATATYPE
    ]
    model_artifact.properties[METADATA_MATERIALIZER].string_value = metadata[
        METADATA_MATERIALIZER
    ]
    materializer_class = source_utils.load_source_path_class(
        model_artifact.properties[METADATA_MATERIALIZER].string_value
    )
    model_class = source_utils.load_source_path_class(
        model_artifact.properties[METADATA_DATATYPE].string_value
    )
    materializer_object: BaseMaterializer = materializer_class(model_artifact)
    model = materializer_object.handle_input(model_class)
    try:
        import torch.nn as nn

        if issubclass(model_class, nn.Module):  # type: ignore
            model.eval()
    except ImportError:
        pass
    logger.debug(f"Model loaded successfully :\n{model}")
    return model

model_from_model_artifact(model_artifact)

Load model to memory from a model artifact.

Parameters:

Name Type Description Default
model_artifact ModelArtifact

The model artifact to load.

required

Returns:

Type Description
Any

The ML model object loaded into memory.

Source code in zenml/utils/materializer_utils.py
def model_from_model_artifact(model_artifact: ModelArtifact) -> Any:
    """Load model to memory from a model artifact.

    Args:
        model_artifact: The model artifact to load.

    Returns:
        The ML model object loaded into memory.
    """
    materializer_class = source_utils.load_source_path_class(
        model_artifact.materializer
    )
    model_class = source_utils.load_source_path_class(model_artifact.datatype)
    materializer_object: BaseMaterializer = materializer_class(model_artifact)
    model = materializer_object.handle_input(model_class)
    logger.debug(f"Model loaded successfully :\n{model}")
    return model

save_model_metadata(model_artifact)

Save a zenml model artifact metadata to a YAML file.

This function is used to extract and save information from a zenml model artifact such as the model type and materializer. The extracted information will be the key to loading the model into memory in the inference environment.

datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.

Parameters:

Name Type Description Default
model_artifact ArtifactModel

the artifact to extract the metadata from.

required

Returns:

Type Description
str

The path to the temporary file where the model metadata is saved

Source code in zenml/utils/materializer_utils.py
def save_model_metadata(model_artifact: ArtifactModel) -> str:
    """Save a zenml model artifact metadata to a YAML file.

    This function is used to extract and save information from a zenml model artifact
    such as the model type and materializer. The extracted information will be
    the key to loading the model into memory in the inference environment.

    datatype: the model type. This is the path to the model class.
    materializer: the materializer class. This is the path to the materializer class.

    Args:
        model_artifact: the artifact to extract the metadata from.

    Returns:
        The path to the temporary file where the model metadata is saved
    """
    metadata = dict()
    metadata[METADATA_DATATYPE] = model_artifact.data_type
    metadata[METADATA_MATERIALIZER] = model_artifact.materializer

    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".yaml", delete=False
    ) as f:
        write_yaml(f.name, metadata)
    return f.name

networking_utils

Utility functions for networking.

find_available_port()

Finds a local random unoccupied TCP port.

Returns:

Type Description
int

A random unoccupied TCP port.

Source code in zenml/utils/networking_utils.py
def find_available_port() -> int:
    """Finds a local random unoccupied TCP port.

    Returns:
        A random unoccupied TCP port.
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        _, port = s.getsockname()

    return cast(int, port)

port_available(port, address='127.0.0.1')

Checks if a local port is available.

Parameters:

Name Type Description Default
port int

TCP port number

required
address str

IP address on the local machine

'127.0.0.1'

Returns:

Type Description
bool

True if the port is available, otherwise False

Source code in zenml/utils/networking_utils.py
def port_available(port: int, address: str = "127.0.0.1") -> bool:
    """Checks if a local port is available.

    Args:
        port: TCP port number
        address: IP address on the local machine

    Returns:
        True if the port is available, otherwise False
    """
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if sys.platform != "win32":
                s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            else:
                # The SO_REUSEPORT socket option is not supported on Windows.
                # This if clause exists just for mypy to not complain about
                # missing code paths.
                pass
            s.bind((address, port))
    except socket.error as e:
        logger.debug("Port %d unavailable on %s: %s", port, address, e)
        return False

    return True

port_is_open(hostname, port)

Check if a TCP port is open on a remote host.

Parameters:

Name Type Description Default
hostname str

hostname of the remote machine

required
port int

TCP port number

required

Returns:

Type Description
bool

True if the port is open, False otherwise

Source code in zenml/utils/networking_utils.py
def port_is_open(hostname: str, port: int) -> bool:
    """Check if a TCP port is open on a remote host.

    Args:
        hostname: hostname of the remote machine
        port: TCP port number

    Returns:
        True if the port is open, False otherwise
    """
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
            result = sock.connect_ex((hostname, port))
            return result == 0
    except socket.error as e:
        logger.debug(
            f"Error checking TCP port {port} on host {hostname}: {str(e)}"
        )
        return False

replace_internal_hostname_with_localhost(hostname)

Replaces an internal Docker or K3D hostname with localhost.

Localhost URLs that are directly accessible on the host machine are not accessible from within a Docker or K3D container running on that same machine, but there are special hostnames featured by both Docker (host.docker.internal) and K3D (host.k3d.internal) that can be used to access host services from within the containers.

Use this method to replace one of these special hostnames with localhost if used outside a container or in a container where special hostnames are not available.

Parameters:

Name Type Description Default
hostname str

The hostname to replace.

required

Returns:

Type Description
str

The original or replaced hostname.

Source code in zenml/utils/networking_utils.py
def replace_internal_hostname_with_localhost(hostname: str) -> str:
    """Replaces an internal Docker or K3D hostname with localhost.

    Localhost URLs that are directly accessible on the host machine are not
    accessible from within a Docker or K3D container running on that same
    machine, but there are special hostnames featured by both Docker
    (`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
    access host services from within the containers.

    Use this method to replace one of these special hostnames with localhost
    if used outside a container or in a container where special hostnames are
    not available.

    Args:
        hostname: The hostname to replace.

    Returns:
        The original or replaced hostname.
    """
    if hostname not in ("host.docker.internal", "host.k3d.internal"):
        return hostname

    if Environment.in_container():

        # Try to resolve one of the special hostnames to see if it is available
        # inside the container and use that if it is.
        for internal_hostname in (
            "host.docker.internal",
            "host.k3d.internal",
        ):
            try:
                socket.gethostbyname(internal_hostname)
                if internal_hostname != hostname:
                    logger.debug(
                        f"Replacing internal hostname {hostname} with "
                        f"{internal_hostname}"
                    )
                return internal_hostname
            except socket.gaierror:
                continue

    logger.debug(f"Replacing internal hostname {hostname} with localhost.")

    return "127.0.0.1"

replace_localhost_with_internal_hostname(url)

Replaces the localhost with an internal Docker or K3D hostname in a given URL.

Localhost URLs that are directly accessible on the host machine are not accessible from within a Docker or K3D container running on that same machine, but there are special hostnames featured by both Docker (host.docker.internal) and K3D (host.k3d.internal) that can be used to access host services from within the containers.

Use this method to attempt to replace localhost in a URL with one of these special hostnames, if they are available inside a container.

Parameters:

Name Type Description Default
url str

The URL to update.

required

Returns:

Type Description
str

The updated URL.

Source code in zenml/utils/networking_utils.py
def replace_localhost_with_internal_hostname(url: str) -> str:
    """Replaces the localhost with an internal Docker or K3D hostname in a given URL.

    Localhost URLs that are directly accessible on the host machine are not
    accessible from within a Docker or K3D container running on that same
    machine, but there are special hostnames featured by both Docker
    (`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
    access host services from within the containers.

    Use this method to attempt to replace `localhost` in a URL with one of these
    special hostnames, if they are available inside a container.

    Args:
        url: The URL to update.

    Returns:
        The updated URL.
    """
    if not Environment.in_container():
        return url

    parsed_url = urlparse(url)
    if parsed_url.hostname in ("localhost", "127.0.0.1"):

        for internal_hostname in (
            "host.docker.internal",
            "host.k3d.internal",
        ):
            try:
                socket.gethostbyname(internal_hostname)
                parsed_url = parsed_url._replace(
                    netloc=parsed_url.netloc.replace(
                        parsed_url.hostname,
                        internal_hostname,
                    )
                )
                logger.debug(
                    f"Replacing localhost with {internal_hostname} in URL: "
                    f"{url}"
                )
                return parsed_url.geturl()

            except socket.gaierror:
                continue

    return url

scan_for_available_port(start=8000, stop=65535)

Scan the local network for an available port in the given range.

Parameters:

Name Type Description Default
start int

the beginning of the port range value to scan

8000
stop int

the (inclusive) end of the port range value to scan

65535

Returns:

Type Description
Optional[int]

The first available port in the given range, or None if no available port is found.

Source code in zenml/utils/networking_utils.py
def scan_for_available_port(
    start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1]
) -> Optional[int]:
    """Scan the local network for an available port in the given range.

    Args:
        start: the beginning of the port range value to scan
        stop: the (inclusive) end of the port range value to scan

    Returns:
        The first available port in the given range, or None if no available
        port is found.
    """
    for port in range(start, stop + 1):
        if port_available(port):
            return port
    logger.debug(
        "No free TCP ports found in the range %d - %d",
        start,
        stop,
    )
    return None

pipeline_docker_image_builder

Implementation of Docker image builds to run ZenML pipelines.

PipelineDockerImageBuilder

Builds Docker images to run a ZenML pipeline.

Usage:

class MyStackComponent(StackComponent, PipelineDockerImageBuilder):
    def method_that_requires_docker_image(self):
        image_identifier = self.build_and_push_docker_image(...)
        # use the image ID
Source code in zenml/utils/pipeline_docker_image_builder.py
class PipelineDockerImageBuilder:
    """Builds Docker images to run a ZenML pipeline.

    **Usage**:
    ```python
    class MyStackComponent(StackComponent, PipelineDockerImageBuilder):
        def method_that_requires_docker_image(self):
            image_identifier = self.build_and_push_docker_image(...)
            # use the image ID
    ```
    """

    def build_and_push_docker_image(
        self,
        deployment: "PipelineDeployment",
        stack: "Stack",
        entrypoint: Optional[str] = None,
    ) -> str:
        """Builds and pushes a Docker image to run a pipeline.

        Use the image name returned by this method whenever you need to uniquely
        reference the pushed image in order to pull or run it.

        Args:
            deployment: The pipeline deployment for which the image should be
                built.
            stack: The stack on which the pipeline will be deployed.
            entrypoint: Entrypoint to use for the final image. If left empty,
                no entrypoint will be included in the image.

        Returns:
            The Docker repository digest of the pushed image.

        Raises:
            RuntimeError: If the stack doesn't contain a container registry.
        """
        container_registry = stack.container_registry
        if not container_registry:
            raise RuntimeError(
                "Unable to build and push Docker image because stack "
                f"`{stack.name}` has no container registry."
            )

        target_image_name = self.get_target_image_name(
            deployment=deployment, container_registry=container_registry
        )

        self.build_docker_image(
            target_image_name=target_image_name,
            deployment=deployment,
            stack=stack,
            entrypoint=entrypoint,
        )

        repo_digest = container_registry.push_image(target_image_name)
        return repo_digest

    @staticmethod
    def get_target_image_name(
        deployment: "PipelineDeployment",
        container_registry: Optional["BaseContainerRegistry"] = None,
    ) -> str:
        """Returns the target image name.

        If a container registry is given, the image name will include the
        registry URI

        Args:
            deployment: The pipeline deployment for which the target image name
                should be returned.
            container_registry: Optional container registry to which this
                image will be pushed.

        Returns:
            The docker image name.
        """
        pipeline_name = deployment.pipeline.name
        docker_settings = (
            deployment.pipeline.docker_settings or DockerSettings()
        )

        target_image_name = (
            f"{docker_settings.target_repository}:{pipeline_name}"
        )
        if container_registry:
            target_image_name = (
                f"{container_registry.config.uri}/{target_image_name}"
            )

        return target_image_name

    def build_docker_image(
        self,
        target_image_name: str,
        deployment: "PipelineDeployment",
        stack: "Stack",
        entrypoint: Optional[str] = None,
    ) -> None:
        """Builds a Docker image to run a pipeline.

        Args:
            target_image_name: The name of the image to build.
            deployment: The pipeline deployment for which the image should be
                built.
            stack: The stack on which the pipeline will be deployed.
            entrypoint: Entrypoint to use for the final image. If left empty,
                no entrypoint will be included in the image.

        Raises:
            ValueError: If no Dockerfile and/or custom parent image is
                specified and the Docker configuration doesn't require an
                image build.
        """
        pipeline_name = deployment.pipeline.name
        docker_settings = (
            deployment.pipeline.docker_settings or DockerSettings()
        )

        logger.info(
            "Building Docker image(s) for pipeline `%s`.", pipeline_name
        )
        requires_zenml_build = any(
            [
                docker_settings.requirements,
                docker_settings.required_integrations,
                docker_settings.replicate_local_python_environment,
                docker_settings.install_stack_requirements,
                docker_settings.apt_packages,
                docker_settings.environment,
                docker_settings.copy_files,
                docker_settings.copy_global_config,
                entrypoint,
            ]
        )

        # Fallback to the value defined on the stack component if the
        # pipeline configuration doesn't have a configured value
        parent_image = (
            docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
        )

        if docker_settings.dockerfile:
            if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
                logger.warning(
                    "You've specified both a Dockerfile and a custom parent "
                    "image, ignoring the parent image."
                )

            if requires_zenml_build:
                # We will build an additional image on top of this one later
                # to include user files and/or install requirements. The image
                # we build now will be used as the parent for the next build.
                user_image_name = f"zenml-intermediate-build:{pipeline_name}"
                parent_image = user_image_name
            else:
                # The image we'll build from the custom Dockerfile will be
                # used directly, so we tag it with the requested target name.
                user_image_name = target_image_name

            docker_utils.build_image(
                image_name=user_image_name,
                dockerfile=docker_settings.dockerfile,
                build_context_root=docker_settings.build_context_root,
                **docker_settings.build_options,
            )
        elif not requires_zenml_build:
            if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
                raise ValueError(
                    "Unable to run a ZenML pipeline with the given Docker "
                    "settings: No Dockerfile or custom parent image "
                    "specified and no files will be copied or requirements "
                    "installed."
                )
            else:
                # The parent image will be used directly to run the pipeline and
                # needs to be tagged so it gets pushed later
                docker_utils.tag_image(parent_image, target=target_image_name)

        if requires_zenml_build:
            requirement_files = self._gather_requirements_files(
                docker_settings=docker_settings, stack=stack
            )
            requirements_file_names = [f[0] for f in requirement_files]

            apt_packages = docker_settings.apt_packages
            if docker_settings.install_stack_requirements:
                apt_packages += stack.apt_packages

            if apt_packages:
                logger.info(
                    "Including apt packages: %s",
                    ", ".join(f"`{p}`" for p in apt_packages),
                )

            dockerfile = self._generate_zenml_pipeline_dockerfile(
                parent_image=parent_image,
                docker_settings=docker_settings,
                requirements_files=requirements_file_names,
                apt_packages=apt_packages,
                entrypoint=entrypoint,
            )

            if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
                # The default parent image is static and doesn't require a pull
                # each time
                pull_parent_image = False
            else:
                # If the image is local, we don't need to pull it. Otherwise
                # we play it safe and always pull in case the user pushed a new
                # image for the given name and tag
                pull_parent_image = not docker_utils.is_local_image(
                    parent_image
                )

            extra_files = requirement_files.copy()
            extra_files.append(
                (DOCKER_IMAGE_DEPLOYMENT_CONFIG_FILE, deployment.yaml())
            )

            # Leave the build context empty if we don't want to copy any files
            requires_build_context = (
                docker_settings.copy_files or docker_settings.copy_global_config
            )
            build_context_root = (
                source_utils.get_source_root_path()
                if requires_build_context
                else None
            )
            maybe_include_global_config = (
                _include_global_config(build_context_root=build_context_root)  # type: ignore[arg-type]
                if docker_settings.copy_global_config
                else contextlib.nullcontext()
            )
            with maybe_include_global_config:
                docker_utils.build_image(
                    image_name=target_image_name,
                    dockerfile=dockerfile,
                    build_context_root=build_context_root,
                    dockerignore=docker_settings.dockerignore,
                    extra_files=extra_files,
                    pull=pull_parent_image,
                )

    @staticmethod
    def _gather_requirements_files(
        docker_settings: DockerSettings, stack: "Stack"
    ) -> List[Tuple[str, str]]:
        """Gathers and/or generates pip requirements files.

        Args:
            docker_settings: Docker settings that specifies which
                requirements to install.
            stack: The stack on which the pipeline will run.

        Raises:
            RuntimeError: If the command to export the local python packages
                failed.

        Returns:
            List of tuples (filename, file_content) of all requirements files.
            The files will be in the following order:
            - Packages installed in the local Python environment
            - User-defined requirements
            - Requirements defined by user-defined and/or stack integrations
        """
        requirements_files = []
        logger.info("Gathering requirements for Docker build:")

        # Generate requirements file for the local environment if configured
        if docker_settings.replicate_local_python_environment:
            if isinstance(
                docker_settings.replicate_local_python_environment,
                PythonEnvironmentExportMethod,
            ):
                command = (
                    docker_settings.replicate_local_python_environment.command
                )
            else:
                command = " ".join(
                    docker_settings.replicate_local_python_environment
                )

            try:
                local_requirements = subprocess.check_output(
                    command, shell=True
                ).decode()
            except subprocess.CalledProcessError as e:
                raise RuntimeError(
                    "Unable to export local python packages."
                ) from e

            requirements_files.append(
                (".zenml_local_requirements", local_requirements)
            )
            logger.info("\t- Including python packages from local environment")

        # Generate/Read requirements file for user-defined requirements
        if isinstance(docker_settings.requirements, str):
            user_requirements = io_utils.read_file_contents_as_string(
                docker_settings.requirements
            )
            logger.info(
                "\t- Including user-defined requirements from file `%s`",
                os.path.abspath(docker_settings.requirements),
            )
        elif isinstance(docker_settings.requirements, List):
            user_requirements = "\n".join(docker_settings.requirements)
            logger.info(
                "\t- Including user-defined requirements: %s",
                ", ".join(f"`{r}`" for r in docker_settings.requirements),
            )
        else:
            user_requirements = None

        if user_requirements:
            requirements_files.append(
                (".zenml_user_requirements", user_requirements)
            )

        # Generate requirements file for all required integrations
        integration_requirements = set(
            itertools.chain.from_iterable(
                integration_registry.select_integration_requirements(
                    integration
                )
                for integration in docker_settings.required_integrations
            )
        )

        if docker_settings.install_stack_requirements:
            integration_requirements.update(stack.requirements())

        if integration_requirements:
            integration_requirements_list = sorted(integration_requirements)
            integration_requirements_file = "\n".join(
                integration_requirements_list
            )
            requirements_files.append(
                (
                    ".zenml_integration_requirements",
                    integration_requirements_file,
                )
            )
            logger.info(
                "\t- Including integration requirements: %s",
                ", ".join(f"`{r}`" for r in integration_requirements_list),
            )

        return requirements_files

    @staticmethod
    def _generate_zenml_pipeline_dockerfile(
        parent_image: str,
        docker_settings: DockerSettings,
        requirements_files: Sequence[str] = (),
        apt_packages: Sequence[str] = (),
        entrypoint: Optional[str] = None,
    ) -> List[str]:
        """Generates a Dockerfile.

        Args:
            parent_image: The image to use as parent for the Dockerfile.
            docker_settings: Docker settings for this image build.
            requirements_files: Paths of requirements files to install.
            apt_packages: APT packages to install.
            entrypoint: The default entrypoint command that gets executed when
                running a container of an image created by this Dockerfile.

        Returns:
            Lines of the generated Dockerfile.
        """
        lines = [f"FROM {parent_image}", f"WORKDIR {DOCKER_IMAGE_WORKDIR}"]

        if docker_settings.copy_global_config:
            lines.append(
                f"ENV {ENV_ZENML_CONFIG_PATH}={DOCKER_IMAGE_ZENML_CONFIG_PATH}"
            )

        for key, value in docker_settings.environment.items():
            lines.append(f"ENV {key.upper()}={value}")

        if apt_packages:
            apt_packages = " ".join(f"'{p}'" for p in apt_packages)

            lines.append(
                "RUN apt-get update && apt-get install -y "
                f"--no-install-recommends {apt_packages}"
            )

        for file in requirements_files:
            lines.append(f"COPY {file} .")
            lines.append(f"RUN pip install --no-cache-dir -r {file}")

        if docker_settings.copy_files:
            lines.append("COPY . .")
        elif docker_settings.copy_global_config:
            lines.append(f"COPY {DOCKER_IMAGE_ZENML_CONFIG_DIR} .")

        lines.append("RUN chmod -R a+rw .")

        if docker_settings.user:
            lines.append(f"USER {docker_settings.user}")
            lines.append(f"RUN chown -R {docker_settings.user} .")

        if entrypoint:
            lines.append(f"ENTRYPOINT {entrypoint}")

        return lines
build_and_push_docker_image(self, deployment, stack, entrypoint=None)

Builds and pushes a Docker image to run a pipeline.

Use the image name returned by this method whenever you need to uniquely reference the pushed image in order to pull or run it.

Parameters:

Name Type Description Default
deployment PipelineDeployment

The pipeline deployment for which the image should be built.

required
stack Stack

The stack on which the pipeline will be deployed.

required
entrypoint Optional[str]

Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image.

None

Returns:

Type Description
str

The Docker repository digest of the pushed image.

Exceptions:

Type Description
RuntimeError

If the stack doesn't contain a container registry.

Source code in zenml/utils/pipeline_docker_image_builder.py
def build_and_push_docker_image(
    self,
    deployment: "PipelineDeployment",
    stack: "Stack",
    entrypoint: Optional[str] = None,
) -> str:
    """Builds and pushes a Docker image to run a pipeline.

    Use the image name returned by this method whenever you need to uniquely
    reference the pushed image in order to pull or run it.

    Args:
        deployment: The pipeline deployment for which the image should be
            built.
        stack: The stack on which the pipeline will be deployed.
        entrypoint: Entrypoint to use for the final image. If left empty,
            no entrypoint will be included in the image.

    Returns:
        The Docker repository digest of the pushed image.

    Raises:
        RuntimeError: If the stack doesn't contain a container registry.
    """
    container_registry = stack.container_registry
    if not container_registry:
        raise RuntimeError(
            "Unable to build and push Docker image because stack "
            f"`{stack.name}` has no container registry."
        )

    target_image_name = self.get_target_image_name(
        deployment=deployment, container_registry=container_registry
    )

    self.build_docker_image(
        target_image_name=target_image_name,
        deployment=deployment,
        stack=stack,
        entrypoint=entrypoint,
    )

    repo_digest = container_registry.push_image(target_image_name)
    return repo_digest
build_docker_image(self, target_image_name, deployment, stack, entrypoint=None)

Builds a Docker image to run a pipeline.

Parameters:

Name Type Description Default
target_image_name str

The name of the image to build.

required
deployment PipelineDeployment

The pipeline deployment for which the image should be built.

required
stack Stack

The stack on which the pipeline will be deployed.

required
entrypoint Optional[str]

Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image.

None

Exceptions:

Type Description
ValueError

If no Dockerfile and/or custom parent image is specified and the Docker configuration doesn't require an image build.

Source code in zenml/utils/pipeline_docker_image_builder.py
def build_docker_image(
    self,
    target_image_name: str,
    deployment: "PipelineDeployment",
    stack: "Stack",
    entrypoint: Optional[str] = None,
) -> None:
    """Builds a Docker image to run a pipeline.

    Args:
        target_image_name: The name of the image to build.
        deployment: The pipeline deployment for which the image should be
            built.
        stack: The stack on which the pipeline will be deployed.
        entrypoint: Entrypoint to use for the final image. If left empty,
            no entrypoint will be included in the image.

    Raises:
        ValueError: If no Dockerfile and/or custom parent image is
            specified and the Docker configuration doesn't require an
            image build.
    """
    pipeline_name = deployment.pipeline.name
    docker_settings = (
        deployment.pipeline.docker_settings or DockerSettings()
    )

    logger.info(
        "Building Docker image(s) for pipeline `%s`.", pipeline_name
    )
    requires_zenml_build = any(
        [
            docker_settings.requirements,
            docker_settings.required_integrations,
            docker_settings.replicate_local_python_environment,
            docker_settings.install_stack_requirements,
            docker_settings.apt_packages,
            docker_settings.environment,
            docker_settings.copy_files,
            docker_settings.copy_global_config,
            entrypoint,
        ]
    )

    # Fallback to the value defined on the stack component if the
    # pipeline configuration doesn't have a configured value
    parent_image = (
        docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
    )

    if docker_settings.dockerfile:
        if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
            logger.warning(
                "You've specified both a Dockerfile and a custom parent "
                "image, ignoring the parent image."
            )

        if requires_zenml_build:
            # We will build an additional image on top of this one later
            # to include user files and/or install requirements. The image
            # we build now will be used as the parent for the next build.
            user_image_name = f"zenml-intermediate-build:{pipeline_name}"
            parent_image = user_image_name
        else:
            # The image we'll build from the custom Dockerfile will be
            # used directly, so we tag it with the requested target name.
            user_image_name = target_image_name

        docker_utils.build_image(
            image_name=user_image_name,
            dockerfile=docker_settings.dockerfile,
            build_context_root=docker_settings.build_context_root,
            **docker_settings.build_options,
        )
    elif not requires_zenml_build:
        if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
            raise ValueError(
                "Unable to run a ZenML pipeline with the given Docker "
                "settings: No Dockerfile or custom parent image "
                "specified and no files will be copied or requirements "
                "installed."
            )
        else:
            # The parent image will be used directly to run the pipeline and
            # needs to be tagged so it gets pushed later
            docker_utils.tag_image(parent_image, target=target_image_name)

    if requires_zenml_build:
        requirement_files = self._gather_requirements_files(
            docker_settings=docker_settings, stack=stack
        )
        requirements_file_names = [f[0] for f in requirement_files]

        apt_packages = docker_settings.apt_packages
        if docker_settings.install_stack_requirements:
            apt_packages += stack.apt_packages

        if apt_packages:
            logger.info(
                "Including apt packages: %s",
                ", ".join(f"`{p}`" for p in apt_packages),
            )

        dockerfile = self._generate_zenml_pipeline_dockerfile(
            parent_image=parent_image,
            docker_settings=docker_settings,
            requirements_files=requirements_file_names,
            apt_packages=apt_packages,
            entrypoint=entrypoint,
        )

        if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
            # The default parent image is static and doesn't require a pull
            # each time
            pull_parent_image = False
        else:
            # If the image is local, we don't need to pull it. Otherwise
            # we play it safe and always pull in case the user pushed a new
            # image for the given name and tag
            pull_parent_image = not docker_utils.is_local_image(
                parent_image
            )

        extra_files = requirement_files.copy()
        extra_files.append(
            (DOCKER_IMAGE_DEPLOYMENT_CONFIG_FILE, deployment.yaml())
        )

        # Leave the build context empty if we don't want to copy any files
        requires_build_context = (
            docker_settings.copy_files or docker_settings.copy_global_config
        )
        build_context_root = (
            source_utils.get_source_root_path()
            if requires_build_context
            else None
        )
        maybe_include_global_config = (
            _include_global_config(build_context_root=build_context_root)  # type: ignore[arg-type]
            if docker_settings.copy_global_config
            else contextlib.nullcontext()
        )
        with maybe_include_global_config:
            docker_utils.build_image(
                image_name=target_image_name,
                dockerfile=dockerfile,
                build_context_root=build_context_root,
                dockerignore=docker_settings.dockerignore,
                extra_files=extra_files,
                pull=pull_parent_image,
            )
get_target_image_name(deployment, container_registry=None) staticmethod

Returns the target image name.

If a container registry is given, the image name will include the registry URI

Parameters:

Name Type Description Default
deployment PipelineDeployment

The pipeline deployment for which the target image name should be returned.

required
container_registry Optional[BaseContainerRegistry]

Optional container registry to which this image will be pushed.

None

Returns:

Type Description
str

The docker image name.

Source code in zenml/utils/pipeline_docker_image_builder.py
@staticmethod
def get_target_image_name(
    deployment: "PipelineDeployment",
    container_registry: Optional["BaseContainerRegistry"] = None,
) -> str:
    """Returns the target image name.

    If a container registry is given, the image name will include the
    registry URI

    Args:
        deployment: The pipeline deployment for which the target image name
            should be returned.
        container_registry: Optional container registry to which this
            image will be pushed.

    Returns:
        The docker image name.
    """
    pipeline_name = deployment.pipeline.name
    docker_settings = (
        deployment.pipeline.docker_settings or DockerSettings()
    )

    target_image_name = (
        f"{docker_settings.target_repository}:{pipeline_name}"
    )
    if container_registry:
        target_image_name = (
            f"{container_registry.config.uri}/{target_image_name}"
        )

    return target_image_name

proto_utils

Utility functions for interacting with TFX contexts.

add_mlmd_contexts(pipeline_node, step, deployment, stack)

Adds context to each pipeline node of a pb2_pipeline.

Parameters:

Name Type Description Default
pipeline_node PipelineNode

The pipeline node to which the contexts should be added.

required
step Step

The corresponding step for the pipeline node.

required
deployment PipelineDeployment

The pipeline deployment to store in the contexts.

required
stack Stack

The stack the pipeline will run on.

required
Source code in zenml/utils/proto_utils.py
def add_mlmd_contexts(
    pipeline_node: pipeline_pb2.PipelineNode,
    step: "Step",
    deployment: "PipelineDeployment",
    stack: "Stack",
) -> None:
    """Adds context to each pipeline node of a pb2_pipeline.

    Args:
        pipeline_node: The pipeline node to which the contexts should be
            added.
        step: The corresponding step for the pipeline node.
        deployment: The pipeline deployment to store in the contexts.
        stack: The stack the pipeline will run on.
    """
    from zenml.client import Client

    client = Client()

    model_ids = json.dumps(
        {
            "user_id": client.active_user.id,
            "project_id": client.active_project.id,
            "pipeline_id": deployment.pipeline_id,
            "stack_id": deployment.stack_id,
        },
        sort_keys=True,
        default=pydantic_encoder,
    )

    stack_json = json.dumps(stack.dict(), sort_keys=True)
    pipeline_config = deployment.pipeline.json(sort_keys=True)
    step_config = step.json(sort_keys=True)

    context_properties = {
        MLMD_CONTEXT_STACK_PROPERTY_NAME: stack_json,
        MLMD_CONTEXT_PIPELINE_CONFIG_PROPERTY_NAME: pipeline_config,
        MLMD_CONTEXT_STEP_CONFIG_PROPERTY_NAME: step_config,
        MLMD_CONTEXT_MODEL_IDS_PROPERTY_NAME: model_ids,
        MLMD_CONTEXT_NUM_STEPS_PROPERTY_NAME: str(len(deployment.steps)),
        MLMD_CONTEXT_NUM_OUTPUTS_PROPERTY_NAME: str(len(step.config.outputs)),
    }

    properties_json = json.dumps(context_properties, sort_keys=True)
    context_name = hashlib.md5(properties_json.encode()).hexdigest()

    add_pipeline_node_context(
        pipeline_node,
        type_=ZENML_MLMD_CONTEXT_TYPE,
        name=context_name,
        properties=context_properties,
    )

add_pipeline_node_context(pipeline_node, type_, name, properties)

Adds a new context to a TFX protobuf pipeline node.

Parameters:

Name Type Description Default
pipeline_node PipelineNode

A tfx protobuf pipeline node

required
type_ str

The type name for the context to be added

required
name str

Unique key for the context

required
properties Dict[str, str]

dictionary of strings as properties of the context

required
Source code in zenml/utils/proto_utils.py
def add_pipeline_node_context(
    pipeline_node: pipeline_pb2.PipelineNode,
    type_: str,
    name: str,
    properties: Dict[str, str],
) -> None:
    """Adds a new context to a TFX protobuf pipeline node.

    Args:
        pipeline_node: A tfx protobuf pipeline node
        type_: The type name for the context to be added
        name: Unique key for the context
        properties: dictionary of strings as properties of the context
    """
    context: pipeline_pb2.ContextSpec = pipeline_node.contexts.contexts.add()
    context.type.name = type_
    context.name.field_value.string_value = name
    for key, value in properties.items():
        c_property = context.properties[key]
        c_property.field_value.string_value = value

get_pipeline_config(pipeline_node)

Fetches the pipeline configuration from a PipelineNode context.

Parameters:

Name Type Description Default
pipeline_node PipelineNode

Pipeline node info for a step.

required

Returns:

Type Description
PipelineConfiguration

The pipeline config.

Exceptions:

Type Description
RuntimeError

If no pipeline config was found.

Source code in zenml/utils/proto_utils.py
def get_pipeline_config(
    pipeline_node: pipeline_pb2.PipelineNode,
) -> "PipelineConfiguration":
    """Fetches the pipeline configuration from a PipelineNode context.

    Args:
        pipeline_node: Pipeline node info for a step.

    Returns:
        The pipeline config.

    Raises:
        RuntimeError: If no pipeline config was found.
    """
    for context in pipeline_node.contexts.contexts:
        if context.type.name == ZENML_MLMD_CONTEXT_TYPE:
            config_json = context.properties[
                MLMD_CONTEXT_PIPELINE_CONFIG_PROPERTY_NAME
            ].field_value.string_value

            return PipelineConfiguration.parse_raw(config_json)

    raise RuntimeError("Unable to find pipeline config.")

get_step(pipeline_node)

Fetches the step from a PipelineNode context.

Parameters:

Name Type Description Default
pipeline_node PipelineNode

Pipeline node info for a step.

required

Returns:

Type Description
Step

The step.

Exceptions:

Type Description
RuntimeError

If no step was found.

Source code in zenml/utils/proto_utils.py
def get_step(
    pipeline_node: pipeline_pb2.PipelineNode,
) -> "Step":
    """Fetches the step from a PipelineNode context.

    Args:
        pipeline_node: Pipeline node info for a step.

    Returns:
        The step.

    Raises:
        RuntimeError: If no step was found.
    """
    for context in pipeline_node.contexts.contexts:
        if context.type.name == ZENML_MLMD_CONTEXT_TYPE:
            config_json = context.properties[
                MLMD_CONTEXT_STEP_CONFIG_PROPERTY_NAME
            ].field_value.string_value

            return Step.parse_raw(config_json)

    raise RuntimeError("Unable to find step.")

pydantic_utils

Utilities for pydantic models.

TemplateGenerator

Class to generate templates for pydantic models or classes.

Source code in zenml/utils/pydantic_utils.py
class TemplateGenerator:
    """Class to generate templates for pydantic models or classes."""

    def __init__(
        self, instance_or_class: Union[BaseModel, Type[BaseModel]]
    ) -> None:
        """Initializes the template generator.

        Args:
            instance_or_class: The pydantic model or model class for which to
                generate a template.
        """
        self.instance_or_class = instance_or_class

    def run(self) -> Dict[str, Any]:
        """Generates the template.

        Returns:
            The template dictionary.
        """
        if isinstance(self.instance_or_class, BaseModel):
            template = self._generate_template_for_model(self.instance_or_class)
        else:
            template = self._generate_template_for_model_class(
                self.instance_or_class
            )

        # Convert to json in an intermediate step so we can leverage Pydantic's
        # encoder to support types like UUID and datetime
        json_string = json.dumps(template, default=pydantic_encoder)
        return cast(Dict[str, Any], json.loads(json_string))

    def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]:
        """Generates a template for a pydantic model.

        Args:
            model: The model for which to generate the template.

        Returns:
            The model template.
        """
        template = self._generate_template_for_model_class(model.__class__)

        for name in model.__fields_set__:
            value = getattr(model, name)
            template[name] = self._generate_template_for_value(value)

        return template

    def _generate_template_for_model_class(
        self,
        model_class: Type[BaseModel],
    ) -> Dict[str, Any]:
        """Generates a template for a pydantic model class.

        Args:
            model_class: The model class for which to generate the template.

        Returns:
            The model class template.
        """
        template: Dict[str, Any] = {}

        for name, field in model_class.__fields__.items():
            if self._is_model_class(field.outer_type_):
                template[name] = self._generate_template_for_model_class(
                    field.outer_type_
                )
            elif field.outer_type_ is Optional and self._is_model_class(
                field.type_
            ):
                template[name] = self._generate_template_for_model_class(
                    field.type_
                )
            else:
                template[name] = field._type_display()

        return template

    def _generate_template_for_value(self, value: Any) -> Any:
        """Generates a template for an arbitrary value.

        Args:
            value: The value for which to generate the template.

        Returns:
            The value template.
        """
        if isinstance(value, Dict):
            return {
                k: self._generate_template_for_value(v)
                for k, v in value.items()
            }
        elif sequence_like(value):
            return [self._generate_template_for_value(v) for v in value]
        elif isinstance(value, BaseModel):
            return self._generate_template_for_model(value)
        else:
            return value

    @staticmethod
    def _is_model_class(value: Any) -> bool:
        """Checks if the given value is a pydantic model class.

        Args:
            value: The value to check.

        Returns:
            If the value is a pydantic model class.
        """
        return isinstance(value, type) and issubclass(value, BaseModel)
__init__(self, instance_or_class) special

Initializes the template generator.

Parameters:

Name Type Description Default
instance_or_class Union[pydantic.main.BaseModel, Type[pydantic.main.BaseModel]]

The pydantic model or model class for which to generate a template.

required
Source code in zenml/utils/pydantic_utils.py
def __init__(
    self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
    """Initializes the template generator.

    Args:
        instance_or_class: The pydantic model or model class for which to
            generate a template.
    """
    self.instance_or_class = instance_or_class
run(self)

Generates the template.

Returns:

Type Description
Dict[str, Any]

The template dictionary.

Source code in zenml/utils/pydantic_utils.py
def run(self) -> Dict[str, Any]:
    """Generates the template.

    Returns:
        The template dictionary.
    """
    if isinstance(self.instance_or_class, BaseModel):
        template = self._generate_template_for_model(self.instance_or_class)
    else:
        template = self._generate_template_for_model_class(
            self.instance_or_class
        )

    # Convert to json in an intermediate step so we can leverage Pydantic's
    # encoder to support types like UUID and datetime
    json_string = json.dumps(template, default=pydantic_encoder)
    return cast(Dict[str, Any], json.loads(json_string))

update_model(original, update, recursive=True, exclude_none=True)

Updates a pydantic model.

Parameters:

Name Type Description Default
original ~M

The model to update.

required
update Union[BaseModel, Dict[str, Any]]

The update values.

required
recursive bool

If True, dictionary values will be updated recursively.

True
exclude_none bool

If True, None values in the update dictionary will be removed.

True

Returns:

Type Description
~M

The updated model.

Source code in zenml/utils/pydantic_utils.py
def update_model(
    original: M,
    update: Union["BaseModel", Dict[str, Any]],
    recursive: bool = True,
    exclude_none: bool = True,
) -> M:
    """Updates a pydantic model.

    Args:
        original: The model to update.
        update: The update values.
        recursive: If `True`, dictionary values will be updated recursively.
        exclude_none: If `True`, `None` values in the update dictionary
            will be removed.

    Returns:
        The updated model.
    """
    if isinstance(update, Dict):
        if exclude_none:
            update_dict = dict_utils.remove_none_values(
                update, recursive=recursive
            )
        else:
            update_dict = update
    else:
        update_dict = update.dict(exclude_unset=True)

    original_dict = original.dict(exclude_unset=True)
    if recursive:
        values = dict_utils.recursive_update(original_dict, update_dict)
    else:
        values = {**original_dict, **update_dict}

    return original.__class__(**values)

secret_utils

Utility functions for secrets and secret references.

SecretReference (tuple)

Class representing a secret reference.

Attributes:

Name Type Description
name str

The secret name.

key str

The secret key.

Source code in zenml/utils/secret_utils.py
class SecretReference(NamedTuple):
    """Class representing a secret reference.

    Attributes:
        name: The secret name.
        key: The secret key.
    """

    name: str
    key: str
__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in zenml/utils/secret_utils.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)
__new__(_cls, name, key) special staticmethod

Create new instance of SecretReference(name, key)

__repr__(self) special

Return a nicely formatted representation string

Source code in zenml/utils/secret_utils.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

ClearTextField(*args, **kwargs)

Marks a pydantic field to prevent secret references.

Parameters:

Name Type Description Default
*args Any

Positional arguments which will be forwarded to pydantic.Field(...).

()
**kwargs Any

Keyword arguments which will be forwarded to pydantic.Field(...).

{}

Returns:

Type Description
Any

Pydantic field info.

Source code in zenml/utils/secret_utils.py
def ClearTextField(*args: Any, **kwargs: Any) -> Any:
    """Marks a pydantic field to prevent secret references.

    Args:
        *args: Positional arguments which will be forwarded
            to `pydantic.Field(...)`.
        **kwargs: Keyword arguments which will be forwarded to
            `pydantic.Field(...)`.

    Returns:
        Pydantic field info.
    """
    kwargs[PYDANTIC_CLEAR_TEXT_FIELD_MARKER] = True
    return Field(*args, **kwargs)

SecretField(*args, **kwargs)

Marks a pydantic field as something containing sensitive information.

Parameters:

Name Type Description Default
*args Any

Positional arguments which will be forwarded to pydantic.Field(...).

()
**kwargs Any

Keyword arguments which will be forwarded to pydantic.Field(...).

{}

Returns:

Type Description
Any

Pydantic field info.

Source code in zenml/utils/secret_utils.py
def SecretField(*args: Any, **kwargs: Any) -> Any:
    """Marks a pydantic field as something containing sensitive information.

    Args:
        *args: Positional arguments which will be forwarded
            to `pydantic.Field(...)`.
        **kwargs: Keyword arguments which will be forwarded to
            `pydantic.Field(...)`.

    Returns:
        Pydantic field info.
    """
    kwargs[PYDANTIC_SENSITIVE_FIELD_MARKER] = True
    return Field(*args, **kwargs)

is_clear_text_field(field)

Returns whether a pydantic field prevents secret references or not.

Parameters:

Name Type Description Default
field ModelField

The field to check.

required

Returns:

Type Description
bool

True if the field prevents secret references, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_clear_text_field(field: "ModelField") -> bool:
    """Returns whether a pydantic field prevents secret references or not.

    Args:
        field: The field to check.

    Returns:
        `True` if the field prevents secret references, `False` otherwise.
    """
    return field.field_info.extra.get(PYDANTIC_CLEAR_TEXT_FIELD_MARKER, False)

is_secret_field(field)

Returns whether a pydantic field contains sensitive information or not.

Parameters:

Name Type Description Default
field ModelField

The field to check.

required

Returns:

Type Description
bool

True if the field contains sensitive information, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_secret_field(field: "ModelField") -> bool:
    """Returns whether a pydantic field contains sensitive information or not.

    Args:
        field: The field to check.

    Returns:
        `True` if the field contains sensitive information, `False` otherwise.
    """
    return field.field_info.extra.get(PYDANTIC_SENSITIVE_FIELD_MARKER, False)

is_secret_reference(value)

Checks whether any value is a secret reference.

Parameters:

Name Type Description Default
value Any

The value to check.

required

Returns:

Type Description
bool

True if the value is a secret reference, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_secret_reference(value: Any) -> bool:
    """Checks whether any value is a secret reference.

    Args:
        value: The value to check.

    Returns:
        `True` if the value is a secret reference, `False` otherwise.
    """
    if not isinstance(value, str):
        return False

    return bool(_secret_reference_expression.fullmatch(value))

parse_secret_reference(reference)

Parses a secret reference.

This function assumes the input string is a valid secret reference and does not perform any additional checks. If you pass an invalid secret reference here, this will most likely crash.

Parameters:

Name Type Description Default
reference str

The string representing a valid secret reference.

required

Returns:

Type Description
SecretReference

The parsed secret reference.

Source code in zenml/utils/secret_utils.py
def parse_secret_reference(reference: str) -> SecretReference:
    """Parses a secret reference.

    This function assumes the input string is a valid secret reference and
    **does not** perform any additional checks. If you pass an invalid secret
    reference here, this will most likely crash.

    Args:
        reference: The string representing a **valid** secret reference.

    Returns:
        The parsed secret reference.
    """
    reference = reference[2:]
    reference = reference[:-2]

    secret_name, secret_key = reference.split(".", 1)
    return SecretReference(name=secret_name, key=secret_key)

settings_utils

Utility functions for ZenML settings.

get_flavor_setting_key(flavor)

Gets the setting key for a flavor.

Parameters:

Name Type Description Default
flavor Flavor

The flavor for which to get the key.

required

Returns:

Type Description
str

The setting key for the flavor.

Source code in zenml/utils/settings_utils.py
def get_flavor_setting_key(flavor: "Flavor") -> str:
    """Gets the setting key for a flavor.

    Args:
        flavor: The flavor for which to get the key.

    Returns:
        The setting key for the flavor.
    """
    return f"{flavor.type}.{flavor.name}"

get_general_settings()

Returns all general settings.

Returns:

Type Description
Dict[str, Type[BaseSettings]]

Dictionary mapping general settings keys to their type.

Source code in zenml/utils/settings_utils.py
def get_general_settings() -> Dict[str, Type["BaseSettings"]]:
    """Returns all general settings.

    Returns:
        Dictionary mapping general settings keys to their type.
    """
    from zenml.config import DockerSettings, ResourceSettings

    return {
        DOCKER_SETTINGS_KEY: DockerSettings,
        RESOURCE_SETTINGS_KEY: ResourceSettings,
    }

get_stack_component_for_settings_key(key, stack)

Gets the stack component of a stack for a given settings key.

Parameters:

Name Type Description Default
key str

The settings key for which to get the component.

required
stack Stack

The stack from which to get the component.

required

Exceptions:

Type Description
ValueError

If the key is invalid or the stack does not contain a component of the correct flavor.

Returns:

Type Description
StackComponent

The stack component.

Source code in zenml/utils/settings_utils.py
def get_stack_component_for_settings_key(
    key: str, stack: "Stack"
) -> "StackComponent":
    """Gets the stack component of a stack for a given settings key.

    Args:
        key: The settings key for which to get the component.
        stack: The stack from which to get the component.

    Raises:
        ValueError: If the key is invalid or the stack does not contain a
            component of the correct flavor.

    Returns:
        The stack component.
    """
    if not is_stack_component_setting_key(key):
        raise ValueError(
            f"Settings key {key} does not refer to a stack component."
        )

    component_type, flavor = key.split(".", 1)
    stack_component = stack.components.get(StackComponentType(component_type))
    if not stack_component or stack_component.flavor != flavor:
        raise ValueError(
            f"Component of type {component_type} in stack {stack} is not "
            f"of the flavor {flavor} specified by the settings key {key}."
        )
    return stack_component

get_stack_component_setting_key(stack_component)

Gets the setting key for a stack component.

Parameters:

Name Type Description Default
stack_component StackComponent

The stack component for which to get the key.

required

Returns:

Type Description
str

The setting key for the stack component.

Source code in zenml/utils/settings_utils.py
def get_stack_component_setting_key(stack_component: "StackComponent") -> str:
    """Gets the setting key for a stack component.

    Args:
        stack_component: The stack component for which to get the key.

    Returns:
        The setting key for the stack component.
    """
    return f"{stack_component.type}.{stack_component.flavor}"

is_general_setting_key(key)

Checks whether the key refers to a general setting.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key refers to a general setting.

Source code in zenml/utils/settings_utils.py
def is_general_setting_key(key: str) -> bool:
    """Checks whether the key refers to a general setting.

    Args:
        key: The key to check.

    Returns:
        If the key refers to a general setting.
    """
    return key in get_general_settings()

is_stack_component_setting_key(key)

Checks whether a settings key refers to a stack component.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key refers to a stack component.

Source code in zenml/utils/settings_utils.py
def is_stack_component_setting_key(key: str) -> bool:
    """Checks whether a settings key refers to a stack component.

    Args:
        key: The key to check.

    Returns:
        If the key refers to a stack component.
    """
    return bool(STACK_COMPONENT_REGEX.fullmatch(key))

is_valid_setting_key(key)

Checks whether a settings key is valid.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key is valid.

Source code in zenml/utils/settings_utils.py
def is_valid_setting_key(key: str) -> bool:
    """Checks whether a settings key is valid.

    Args:
        key: The key to check.

    Returns:
        If the key is valid.
    """
    return is_general_setting_key(key) or is_stack_component_setting_key(key)

validate_setting_keys(setting_keys)

Validates settings keys.

Parameters:

Name Type Description Default
setting_keys Sequence[str]

The keys to validate.

required

Exceptions:

Type Description
ValueError

If any key is invalid.

Source code in zenml/utils/settings_utils.py
def validate_setting_keys(setting_keys: Sequence[str]) -> None:
    """Validates settings keys.

    Args:
        setting_keys: The keys to validate.

    Raises:
        ValueError: If any key is invalid.
    """
    for key in setting_keys:
        if not is_valid_setting_key(key):
            raise ValueError(
                f"Invalid setting key `{key}`. Setting keys can either refer "
                "to general settings (available keys: "
                f"{set(get_general_settings())}) or stack component specific "
                "settings. Stack component specific keys are of the format "
                "`<STACK_COMPONENT_TYPE>.<STACK_COMPONENT_FLAVOR>`."
            )

singleton

Utility class to turn classes into singleton classes.

SingletonMetaClass (type)

Singleton metaclass.

Use this metaclass to make any class into a singleton class:

class OneRing(metaclass=SingletonMetaClass):
    def __init__(self, owner):
        self._owner = owner

    @property
    def owner(self):
        return self._owner

the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner)  # Sauron
OneRing._clear() # ring destroyed
Source code in zenml/utils/singleton.py
class SingletonMetaClass(type):
    """Singleton metaclass.

    Use this metaclass to make any class into a singleton class:

    ```python
    class OneRing(metaclass=SingletonMetaClass):
        def __init__(self, owner):
            self._owner = owner

        @property
        def owner(self):
            return self._owner

    the_one_ring = OneRing('Sauron')
    the_lost_ring = OneRing('Frodo')
    print(the_lost_ring.owner)  # Sauron
    OneRing._clear() # ring destroyed
    ```
    """

    def __init__(cls, *args: Any, **kwargs: Any) -> None:
        """Initialize a singleton class.

        Args:
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(*args, **kwargs)
        cls.__singleton_instance: Optional["SingletonMetaClass"] = None

    def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
        """Create or return the singleton instance.

        Args:
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The singleton instance.
        """
        if not cls.__singleton_instance:
            cls.__singleton_instance = cast(
                "SingletonMetaClass", super().__call__(*args, **kwargs)
            )

        return cls.__singleton_instance

    def _clear(cls) -> None:
        """Clear the singleton instance."""
        cls.__singleton_instance = None
__call__(cls, *args, **kwargs) special

Create or return the singleton instance.

Parameters:

Name Type Description Default
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
SingletonMetaClass

The singleton instance.

Source code in zenml/utils/singleton.py
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
    """Create or return the singleton instance.

    Args:
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The singleton instance.
    """
    if not cls.__singleton_instance:
        cls.__singleton_instance = cast(
            "SingletonMetaClass", super().__call__(*args, **kwargs)
        )

    return cls.__singleton_instance
__init__(cls, *args, **kwargs) special

Initialize a singleton class.

Parameters:

Name Type Description Default
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/utils/singleton.py
def __init__(cls, *args: Any, **kwargs: Any) -> None:
    """Initialize a singleton class.

    Args:
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(*args, **kwargs)
    cls.__singleton_instance: Optional["SingletonMetaClass"] = None

source_utils

Utility functions for source code.

These utils are predicated on the following definitions:

  • class_source: This is a python-import type path to a class, e.g. some.mod.class
  • module_source: This is a python-import type path to a module, e.g. some.mod
  • file_path, relative_path, absolute_path: These are file system paths.
  • source: This is a class_source or module_source. If it is a class_source, it can also be optionally pinned.
  • pin: Whatever comes after the @ symbol from a source, usually the git sha or the version of zenml as a string.

create_zenml_pin()

Creates a ZenML pin for source pinning from release version.

Returns:

Type Description
str

ZenML pin.

Source code in zenml/utils/source_utils.py
def create_zenml_pin() -> str:
    """Creates a ZenML pin for source pinning from release version.

    Returns:
        ZenML pin.
    """
    return f"{constants.APP_NAME}_{__version__}"

get_hashed_source(value)

Returns a hash of the objects source code.

Parameters:

Name Type Description Default
value Any

object to get source from.

required

Returns:

Type Description
str

Hash of source code.

Exceptions:

Type Description
TypeError

If unable to compute the hash.

Source code in zenml/utils/source_utils.py
def get_hashed_source(value: Any) -> str:
    """Returns a hash of the objects source code.

    Args:
        value: object to get source from.

    Returns:
        Hash of source code.

    Raises:
        TypeError: If unable to compute the hash.
    """
    try:
        source_code = get_source(value)
    except TypeError:
        raise TypeError(
            f"Unable to compute the hash of source code of object: {value}."
        )
    return hashlib.sha256(source_code.encode("utf-8")).hexdigest()

get_main_module_source()

Gets the source of the main module.

Returns:

Type Description
str

The main module source.

Source code in zenml/utils/source_utils.py
def get_main_module_source() -> str:
    """Gets the source of the main module.

    Returns:
        The main module source.
    """
    main_module = sys.modules["__main__"]
    return get_module_source_from_module(main_module)

get_module_source_from_module(module)

Gets the source of the supplied module.

E.g.:

  • a /home/myrepo/src/run.py module running as the main module returns run if no repository root is specified.

  • a /home/myrepo/src/run.py module running as the main module returns src.run if the repository root is configured in /home/myrepo

  • a /home/myrepo/src/pipeline.py module not running as the main module returns src.pipeline if the repository root is configured in /home/myrepo

  • a /home/myrepo/src/pipeline.py module not running as the main module returns pipeline if no repository root is specified and the main module is also in /home/myrepo/src.

  • a /home/step.py module not running as the main module returns step if the CWD is /home and the repository root or the main module are in a different path (e.g. /home/myrepo/src).

Parameters:

Name Type Description Default
module module

the module to get the source of.

required

Returns:

Type Description
str

The source of the main module.

Exceptions:

Type Description
RuntimeError

if the module is not loaded from a file

Source code in zenml/utils/source_utils.py
def get_module_source_from_module(module: ModuleType) -> str:
    """Gets the source of the supplied module.

    E.g.:

      * a `/home/myrepo/src/run.py` module running as the main module returns
      `run` if no repository root is specified.

      * a `/home/myrepo/src/run.py` module running as the main module returns
      `src.run` if the repository root is configured in `/home/myrepo`

      * a `/home/myrepo/src/pipeline.py` module not running as the main module
      returns `src.pipeline` if the repository root is configured in
      `/home/myrepo`

      * a `/home/myrepo/src/pipeline.py` module not running as the main module
      returns `pipeline` if no repository root is specified and the main
      module is also in `/home/myrepo/src`.

      * a `/home/step.py` module not running as the main module
      returns `step` if the CWD is /home and the repository root or the main
      module are in a different path (e.g. `/home/myrepo/src`).

    Args:
        module: the module to get the source of.

    Returns:
        The source of the main module.

    Raises:
        RuntimeError: if the module is not loaded from a file
    """
    if not hasattr(module, "__file__") or not module.__file__:
        if module.__name__ == "__main__":
            raise RuntimeError(
                f"{module} module was not loaded from a file. Cannot "
                "determine the module root path."
            )
        return module.__name__
    module_path = os.path.abspath(module.__file__)

    root_path = get_source_root_path()

    if not module_path.startswith(root_path):
        logger.warning(
            "User module %s is not in the source root %s. Using current "
            "directory %s instead to resolve module source.",
            module,
            root_path,
            os.getcwd(),
        )
        root_path = os.getcwd()

    root_path = os.path.abspath(root_path)

    # Remove root_path from module_path to get relative path left over
    module_path = os.path.relpath(module_path, root_path)

    if module_path.startswith(os.pardir):
        raise RuntimeError(
            f"Unable to resolve source for module {module}. The module file "
            f"'{module_path}' does not seem to be inside the source root "
            f"'{root_path}'."
        )

    # Remove the file extension and replace the os specific path separators
    # with `.` to get the module source
    module_path, file_extension = os.path.splitext(module_path)
    if file_extension != ".py":
        raise RuntimeError(
            f"Unable to resolve source for module {module}. The module file "
            f"'{module_path}' does not seem to be a python file."
        )

    module_source = module_path.replace(os.path.sep, ".")

    logger.debug(
        f"Resolved module source for module {module} to: `{module_source}`"
    )

    return module_source

get_source(value)

Returns the source code of an object.

If executing within a IPython kernel environment, then this monkey-patches inspect module temporarily with a workaround to get source from the cell.

Parameters:

Name Type Description Default
value Any

object to get source from.

required

Returns:

Type Description
str

Source code of object.

Source code in zenml/utils/source_utils.py
def get_source(value: Any) -> str:
    """Returns the source code of an object.

    If executing within a IPython kernel environment, then this monkey-patches
    `inspect` module temporarily with a workaround to get source from the cell.

    Args:
        value: object to get source from.

    Returns:
        Source code of object.
    """
    if Environment.in_notebook():
        # Monkey patch inspect.getfile temporarily to make getsource work.
        # Source: https://stackoverflow.com/questions/51566497/
        def _new_getfile(
            object: Any,
            _old_getfile: Callable[
                [
                    Union[
                        ModuleType,
                        Type[Any],
                        MethodType,
                        FunctionType,
                        TracebackType,
                        FrameType,
                        CodeType,
                        Callable[..., Any],
                    ]
                ],
                str,
            ] = inspect.getfile,
        ) -> Any:
            if not inspect.isclass(object):
                return _old_getfile(object)

            # Lookup by parent module (as in current inspect)
            if hasattr(object, "__module__"):
                object_ = sys.modules.get(object.__module__)
                if hasattr(object_, "__file__"):
                    return object_.__file__  # type: ignore[union-attr]

            # If parent module is __main__, lookup by methods
            for name, member in inspect.getmembers(object):
                if (
                    inspect.isfunction(member)
                    and object.__qualname__ + "." + member.__name__
                    == member.__qualname__
                ):
                    return inspect.getfile(member)
            else:
                raise TypeError(f"Source for {object!r} not found.")

        # Monkey patch, compute source, then revert monkey patch.
        _old_getfile = inspect.getfile
        inspect.getfile = _new_getfile
        try:
            src = inspect.getsource(value)
        finally:
            inspect.getfile = _old_getfile
    else:
        # Use standard inspect if running outside a notebook
        src = inspect.getsource(value)
    return src

get_source_root_path()

Gets repository root path or the source root path of the current process.

E.g.:

  • if the process was started by running a run.py file under full/path/to/my/run.py, and the repository root is configured at full/path, the source root path is full/path.

  • same case as above, but when there is no repository root configured, the source root path is full/path/to/my.

Returns:

Type Description
str

The source root path of the current process.

Exceptions:

Type Description
RuntimeError

if the main module was not started or determined.

Source code in zenml/utils/source_utils.py
def get_source_root_path() -> str:
    """Gets repository root path or the source root path of the current process.

    E.g.:

      * if the process was started by running a `run.py` file under
      `full/path/to/my/run.py`, and the repository root is configured at
      `full/path`, the source root path is `full/path`.

      * same case as above, but when there is no repository root configured,
      the source root path is `full/path/to/my`.

    Returns:
        The source root path of the current process.

    Raises:
        RuntimeError: if the main module was not started or determined.
    """
    from zenml.client import Client

    repo_root = Client.find_repository()
    if repo_root:
        logger.debug("Using repository root as source root: %s", repo_root)
        return str(repo_root.resolve())

    main_module = sys.modules.get("__main__")
    if main_module is None:
        raise RuntimeError(
            "Could not determine the main module used to run the current "
            "process."
        )

    if not hasattr(main_module, "__file__") or not main_module.__file__:
        raise RuntimeError(
            "Main module was not started from a file. Cannot "
            "determine the module root path."
        )
    path = pathlib.Path(main_module.__file__).resolve().parent

    logger.debug("Using main module location as source root: %s", path)
    return str(path)

import_class_by_path(class_path)

Imports a class based on a given path.

Parameters:

Name Type Description Default
class_path str

str, class_source e.g. this.module.Class

required

Returns:

Type Description
Type[Any]

the given class

Source code in zenml/utils/source_utils.py
def import_class_by_path(class_path: str) -> Type[Any]:
    """Imports a class based on a given path.

    Args:
        class_path: str, class_source e.g. this.module.Class

    Returns:
        the given class
    """
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    return getattr(module, class_name)  # type: ignore[no-any-return]

import_python_file(file_path, zen_root)

Imports a python file in relationship to the zen root.

Parameters:

Name Type Description Default
file_path str

Path to python file that should be imported.

required
zen_root str

Path to current zenml root

required

Returns:

Type Description
imported module

Module

Source code in zenml/utils/source_utils.py
def import_python_file(file_path: str, zen_root: str) -> types.ModuleType:
    """Imports a python file in relationship to the zen root.

    Args:
        file_path: Path to python file that should be imported.
        zen_root: Path to current zenml root

    Returns:
        imported module: Module
    """
    file_path = os.path.abspath(file_path)
    module_path = os.path.relpath(file_path, zen_root)
    module_name = os.path.splitext(module_path)[0].replace(os.path.sep, ".")

    if module_name in sys.modules:
        del sys.modules[module_name]
        # Add directory of python file to PYTHONPATH so we can import it
        with prepend_python_path([zen_root]):
            module = importlib.import_module(module_name)
        return module
    else:
        # Add directory of python file to PYTHONPATH so we can import it
        with prepend_python_path([zen_root]):
            module = importlib.import_module(module_name)
        return module

is_inside_repository(file_path)

Returns whether a file is inside a zenml repository.

Parameters:

Name Type Description Default
file_path str

A file path.

required

Returns:

Type Description
bool

True if the file is inside a zenml repository, else False.

Source code in zenml/utils/source_utils.py
def is_inside_repository(file_path: str) -> bool:
    """Returns whether a file is inside a zenml repository.

    Args:
        file_path: A file path.

    Returns:
        `True` if the file is inside a zenml repository, else `False`.
    """
    from zenml.client import Client

    repo_path = Client.find_repository()
    if not repo_path:
        return False

    repo_path = repo_path.resolve()
    absolute_file_path = pathlib.Path(file_path).resolve()
    return repo_path in absolute_file_path.parents

is_standard_pin(pin)

Returns True if pin is valid ZenML pin, else False.

Parameters:

Name Type Description Default
pin str

potential ZenML pin like 'zenml_0.1.1'

required

Returns:

Type Description
bool

True if pin is valid ZenML pin, else False.

Source code in zenml/utils/source_utils.py
def is_standard_pin(pin: str) -> bool:
    """Returns `True` if pin is valid ZenML pin, else False.

    Args:
        pin: potential ZenML pin like 'zenml_0.1.1'

    Returns:
        `True` if pin is valid ZenML pin, else False.
    """
    if pin.startswith(f"{constants.APP_NAME}_"):
        return True
    return False

is_standard_source(source)

Returns True if source is a standard ZenML source.

Parameters:

Name Type Description Default
source str

class_source e.g. this.module.Class[@pin].

required

Returns:

Type Description
bool

True if source is a standard ZenML source, else False.

Source code in zenml/utils/source_utils.py
def is_standard_source(source: str) -> bool:
    """Returns `True` if source is a standard ZenML source.

    Args:
        source: class_source e.g. this.module.Class[@pin].

    Returns:
        `True` if source is a standard ZenML source, else `False`.
    """
    if source.split(".")[0] == "zenml":
        return True
    return False

is_third_party_module(file_path)

Returns whether a file belongs to a third party package.

Parameters:

Name Type Description Default
file_path str

A file path.

required

Returns:

Type Description
bool

True if the file belongs to a third party package, else False.

Source code in zenml/utils/source_utils.py
def is_third_party_module(file_path: str) -> bool:
    """Returns whether a file belongs to a third party package.

    Args:
        file_path: A file path.

    Returns:
        `True` if the file belongs to a third party package, else `False`.
    """
    absolute_file_path = pathlib.Path(file_path).resolve()

    for path in site.getsitepackages() + [
        site.getusersitepackages(),
        get_python_lib(standard_lib=True),
    ]:
        if pathlib.Path(path).resolve() in absolute_file_path.parents:
            return True

    return (
        pathlib.Path(get_source_root_path()) not in absolute_file_path.parents
    )

load_and_validate_class(source, expected_class)

Loads a source class and validates its type.

Parameters:

Name Type Description Default
source str

The source string.

required
expected_class Type[Any]

The class that the source should resolve to.

required

Exceptions:

Type Description
TypeError

If the source does not resolve to the expected type.

Returns:

Type Description
Type[Any]

The resolved source class.

Source code in zenml/utils/source_utils.py
def load_and_validate_class(
    source: str, expected_class: Type[Any]
) -> Type[Any]:
    """Loads a source class and validates its type.

    Args:
        source: The source string.
        expected_class: The class that the source should resolve to.

    Raises:
        TypeError: If the source does not resolve to the expected type.

    Returns:
        The resolved source class.
    """
    class_ = load_source_path_class(source)

    if isinstance(class_, type) and issubclass(class_, expected_class):
        return class_
    else:
        raise TypeError(
            f"Error while loading `{source}`. Expected class "
            f"{expected_class.__name__}, got {class_} instead."
        )

load_source_path_class(source, import_path=None)

Loads a Python class from the source.

Parameters:

Name Type Description Default
source str

class_source e.g. this.module.Class[@sha]

required
import_path Optional[str]

optional path to add to python path

None

Returns:

Type Description
Type[Any]

the given class

Source code in zenml/utils/source_utils.py
def load_source_path_class(
    source: str, import_path: Optional[str] = None
) -> Type[Any]:
    """Loads a Python class from the source.

    Args:
        source: class_source e.g. this.module.Class[@sha]
        import_path: optional path to add to python path

    Returns:
        the given class
    """
    from zenml.client import Client

    repo_root = Client.find_repository()
    if not import_path and repo_root:
        import_path = str(repo_root)

    if "@" in source:
        source = source.split("@")[0]

    if import_path is not None:
        with prepend_python_path([import_path]):
            logger.debug(
                f"Loading class {source} with import path {import_path}"
            )
            return import_class_by_path(source)
    return import_class_by_path(source)

prepend_python_path(paths)

Simple context manager to help import module within the repo.

Parameters:

Name Type Description Default
paths List[str]

paths to prepend to sys.path

required

Yields:

Type Description
Iterator[NoneType]

None

Source code in zenml/utils/source_utils.py
@contextmanager
def prepend_python_path(paths: List[str]) -> Iterator[None]:
    """Simple context manager to help import module within the repo.

    Args:
        paths: paths to prepend to sys.path

    Yields:
        None
    """
    try:
        # Entering the with statement
        for path in paths:
            sys.path.insert(0, path)
        yield
    finally:
        # Exiting the with statement
        for path in paths:
            sys.path.remove(path)

resolve_class(class_, replace_main_module=True)

Resolves a class into a serializable source string.

For classes that are not built-in nor imported from a Python package, the get_source_root_path function is used to determine the root path relative to which the class source is resolved.

Parameters:

Name Type Description Default
class_ Type[Any]

A Python Class reference.

required
replace_main_module bool

If True, classes in the main module will have the main module source replaced with the source relative to the ZenML source root.

True

Returns:

Type Description
str

source_path e.g. this.module.Class.

Source code in zenml/utils/source_utils.py
def resolve_class(class_: Type[Any], replace_main_module: bool = True) -> str:
    """Resolves a class into a serializable source string.

    For classes that are not built-in nor imported from a Python package, the
    `get_source_root_path` function is used to determine the root path
    relative to which the class source is resolved.

    Args:
        class_: A Python Class reference.
        replace_main_module: If `True`, classes in the main module will have
            the __main__ module source replaced with the source relative to
            the ZenML source root.

    Returns:
        source_path e.g. this.module.Class.
    """
    initial_source = class_.__module__ + "." + class_.__name__
    if is_standard_source(initial_source):
        return resolve_standard_source(initial_source)

    try:
        file_path = inspect.getfile(class_)
    except TypeError:
        # builtin file
        return initial_source

    if initial_source.startswith("__main__"):
        if not replace_main_module:
            return initial_source

        # Resolve the __main__ module to something relative to the ZenML source
        # root
        return f"{get_main_module_source()}.{class_.__name__}"

    if is_third_party_module(file_path):
        return initial_source

    # Regular user file -> get the full module path relative to the
    # source root.
    module_source = get_module_source_from_module(
        sys.modules[class_.__module__]
    )

    source = module_source + "." + class_.__name__
    logger.debug(f"Resolved class {class_} to `{source}`.")
    return source

resolve_standard_source(source)

Creates a ZenML pin for source pinning from release version.

Parameters:

Name Type Description Default
source str

class_source e.g. this.module.Class.

required

Returns:

Type Description
str

ZenML pin.

Exceptions:

Type Description
AssertionError

If source is already pinned.

Source code in zenml/utils/source_utils.py
def resolve_standard_source(source: str) -> str:
    """Creates a ZenML pin for source pinning from release version.

    Args:
        source: class_source e.g. this.module.Class.

    Returns:
        ZenML pin.

    Raises:
        AssertionError: If source is already pinned.
    """
    if "@" in source:
        raise AssertionError(f"source {source} is already pinned.")
    pin = create_zenml_pin()
    return f"{source}@{pin}"

validate_config_source(source, component_type)

Validates a StackComponentConfig class from a given source.

Parameters:

Name Type Description Default
source str

source path of the implementation

required
component_type StackComponentType

the type of the stack component

required

Returns:

Type Description
Type[StackComponentConfig]

The validated config.

Exceptions:

Type Description
ValueError

If ZenML cannot import the config class.

TypeError

If the config class is not a subclass of the config_class.

Source code in zenml/utils/source_utils.py
def validate_config_source(
    source: str, component_type: StackComponentType
) -> Type["StackComponentConfig"]:
    """Validates a StackComponentConfig class from a given source.

    Args:
        source: source path of the implementation
        component_type: the type of the stack component

    Returns:
        The validated config.

    Raises:
        ValueError: If ZenML cannot import the config class.
        TypeError: If the config class is not a subclass of the `config_class`.
    """
    from zenml.stack.stack_component import StackComponentConfig

    try:
        config_class = load_source_path_class(source)
    except (ValueError, AttributeError, ImportError) as e:
        raise ValueError(
            f"ZenML can not import the config class '{source}': {e}"
        )

    if not issubclass(config_class, StackComponentConfig):
        raise TypeError(
            f"The source path '{source}' does not point to a subclass of "
            f"the ZenML config_class."
        )

    return config_class  # noqa

validate_flavor_source(source, component_type)

Import a StackComponent class from a given source and validate its type.

Parameters:

Name Type Description Default
source str

source path of the implementation

required
component_type StackComponentType

the type of the stack component

required

Returns:

Type Description
Type[Flavor]

the imported class

Exceptions:

Type Description
ValueError

If ZenML cannot find the given module path

TypeError

If the given module path does not point to a subclass of a StackComponent which has the right component type.

Source code in zenml/utils/source_utils.py
def validate_flavor_source(
    source: str, component_type: StackComponentType
) -> Type["Flavor"]:
    """Import a StackComponent class from a given source and validate its type.

    Args:
        source: source path of the implementation
        component_type: the type of the stack component

    Returns:
        the imported class

    Raises:
        ValueError: If ZenML cannot find the given module path
        TypeError: If the given module path does not point to a subclass of a
            StackComponent which has the right component type.
    """
    from zenml.stack.flavor import Flavor
    from zenml.stack.stack_component import StackComponent, StackComponentConfig

    try:
        flavor_class = load_source_path_class(source)
    except (ValueError, AttributeError, ImportError) as e:
        raise ValueError(
            f"ZenML can not import the flavor class '{source}': {e}"
        )

    if not issubclass(flavor_class, Flavor):
        raise TypeError(
            f"The source '{source}' does not point to a subclass of the ZenML"
            f"Flavor."
        )

    flavor = flavor_class()
    try:
        impl_class = flavor.implementation_class
    except (ModuleNotFoundError, ImportError, NotImplementedError):
        raise ValueError(
            f"The implementation class defined within the "
            f"'{flavor_class.__name__}' can not be imported."
        )

    if not issubclass(impl_class, StackComponent):
        raise TypeError(
            f"The implementation class '{impl_class.__name__}' of a flavor "
            f"needs to be a subclass of the ZenML StackComponent."
        )

    if flavor.type != component_type:  # noqa
        raise TypeError(
            f"The source points to a {impl_class.type}, not a "  # noqa
            f"{component_type}."
        )

    try:
        conf_class = flavor.config_class
    except (ModuleNotFoundError, ImportError, NotImplementedError):
        raise ValueError(
            f"The config class defined within the "
            f"'{flavor_class.__name__}' can not be imported."
        )

    if not issubclass(conf_class, StackComponentConfig):
        raise TypeError(
            f"The config class '{conf_class.__name__}' of a flavor "
            f"needs to be a subclass of the ZenML StackComponentConfig."
        )

    return flavor_class  # noqa

validate_source_class(source, expected_class)

Validates that a source resolves to a certain type.

Parameters:

Name Type Description Default
source str

The source to validate.

required
expected_class Type[Any]

The class that the source should resolve to.

required

Returns:

Type Description
bool

If the source resolves to the expected class.

Source code in zenml/utils/source_utils.py
def validate_source_class(source: str, expected_class: Type[Any]) -> bool:
    """Validates that a source resolves to a certain type.

    Args:
        source: The source to validate.
        expected_class: The class that the source should resolve to.

    Returns:
        If the source resolves to the expected class.
    """
    try:
        value = load_source_path_class(source)
    except Exception:
        return False

    is_class = isinstance(value, type)
    if is_class and issubclass(value, expected_class):
        return True
    else:
        return False

string_utils

Utils for strings.

b64_decode(input_)

Returns a decoded string of the base 64 encoded input string.

Parameters:

Name Type Description Default
input_ str

Base64 encoded string.

required

Returns:

Type Description
str

Decoded string.

Source code in zenml/utils/string_utils.py
def b64_decode(input_: str) -> str:
    """Returns a decoded string of the base 64 encoded input string.

    Args:
        input_: Base64 encoded string.

    Returns:
        Decoded string.
    """
    encoded_bytes = input_.encode()
    decoded_bytes = base64.b64decode(encoded_bytes)
    return decoded_bytes.decode()

b64_encode(input_)

Returns a base 64 encoded string of the input string.

Parameters:

Name Type Description Default
input_ str

The input to encode.

required

Returns:

Type Description
str

Base64 encoded string.

Source code in zenml/utils/string_utils.py
def b64_encode(input_: str) -> str:
    """Returns a base 64 encoded string of the input string.

    Args:
        input_: The input to encode.

    Returns:
        Base64 encoded string.
    """
    input_bytes = input_.encode()
    encoded_bytes = base64.b64encode(input_bytes)
    return encoded_bytes.decode()

get_human_readable_filesize(bytes_)

Convert a file size in bytes into a human-readable string.

Parameters:

Name Type Description Default
bytes_ int

The number of bytes to convert.

required

Returns:

Type Description
str

A human-readable string.

Source code in zenml/utils/string_utils.py
def get_human_readable_filesize(bytes_: int) -> str:
    """Convert a file size in bytes into a human-readable string.

    Args:
        bytes_: The number of bytes to convert.

    Returns:
        A human-readable string.
    """
    size = abs(float(bytes_))
    for unit in ["B", "KiB", "MiB", "GiB"]:
        if size < 1024.0 or unit == "GiB":
            break
        size /= 1024.0

    return f"{size:.2f} {unit}"

get_human_readable_time(seconds)

Convert seconds into a human-readable string.

Parameters:

Name Type Description Default
seconds float

The number of seconds to convert.

required

Returns:

Type Description
str

A human-readable string.

Source code in zenml/utils/string_utils.py
def get_human_readable_time(seconds: float) -> str:
    """Convert seconds into a human-readable string.

    Args:
        seconds: The number of seconds to convert.

    Returns:
        A human-readable string.
    """
    prefix = "-" if seconds < 0 else ""
    seconds = abs(seconds)
    int_seconds = int(seconds)
    days, int_seconds = divmod(int_seconds, 86400)
    hours, int_seconds = divmod(int_seconds, 3600)
    minutes, int_seconds = divmod(int_seconds, 60)
    if days > 0:
        time_string = f"{days}d{hours}h{minutes}m{int_seconds}s"
    elif hours > 0:
        time_string = f"{hours}h{minutes}m{int_seconds}s"
    elif minutes > 0:
        time_string = f"{minutes}m{int_seconds}s"
    else:
        time_string = f"{seconds:.3f}s"

    return prefix + time_string

random_str(length)

Generate a random human readable string of given length.

Parameters:

Name Type Description Default
length int

Length of string

required

Returns:

Type Description
str

Random human-readable string.

Source code in zenml/utils/string_utils.py
def random_str(length: int) -> str:
    """Generate a random human readable string of given length.

    Args:
        length: Length of string

    Returns:
        Random human-readable string.
    """
    random.seed()
    return "".join(random.choices(string.ascii_letters, k=length))

typed_model

Utility classes for adding type information to Pydantic models.

BaseTypedModel (BaseModel) pydantic-model

Typed Pydantic model base class.

Use this class as a base class instead of BaseModel to automatically add a type literal attribute to the model that stores the name of the class.

This can be useful when serializing models to JSON and then de-serializing them as part of a submodel union field, e.g.:


class BluePill(BaseTypedModel):
    ...

class RedPill(BaseTypedModel):
    ...

class TheMatrix(BaseTypedModel):
    choice: Union[BluePill, RedPill] = Field(..., discriminator='type')

matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)

It can also facilitate de-serializing objects when their type isn't known:

matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
Source code in zenml/utils/typed_model.py
class BaseTypedModel(BaseModel, metaclass=BaseTypedModelMeta):
    """Typed Pydantic model base class.

    Use this class as a base class instead of BaseModel to automatically
    add a `type` literal attribute to the model that stores the name of the
    class.

    This can be useful when serializing models to JSON and then de-serializing
    them as part of a submodel union field, e.g.:

    ```python

    class BluePill(BaseTypedModel):
        ...

    class RedPill(BaseTypedModel):
        ...

    class TheMatrix(BaseTypedModel):
        choice: Union[BluePill, RedPill] = Field(..., discriminator='type')

    matrix = TheMatrix(choice=RedPill())
    d = matrix.dict()
    new_matrix = TheMatrix.parse_obj(d)
    assert isinstance(new_matrix.choice, RedPill)
    ```

    It can also facilitate de-serializing objects when their type isn't known:

    ```python
    matrix = TheMatrix(choice=RedPill())
    d = matrix.dict()
    new_matrix = BaseTypedModel.from_dict(d)
    assert isinstance(new_matrix.choice, RedPill)
    ```
    """

    @classmethod
    def from_dict(
        cls,
        model_dict: Dict[str, Any],
    ) -> "BaseTypedModel":
        """Instantiate a Pydantic model from a serialized JSON-able dict representation.

        Args:
            model_dict: the model attributes serialized as JSON-able dict.

        Returns:
            A BaseTypedModel created from the serialized representation.

        Raises:
            RuntimeError: if the model_dict contains an invalid type.
        """
        model_type = model_dict.get("type")
        if not model_type:
            raise RuntimeError(
                "`type` information is missing from the serialized model dict."
            )
        cls = load_source_path_class(model_type)
        if not issubclass(cls, BaseTypedModel):
            raise RuntimeError(
                f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
            )

        return cls.parse_obj(model_dict)

    @classmethod
    def from_json(
        cls,
        json_str: str,
    ) -> "BaseTypedModel":
        """Instantiate a Pydantic model from a serialized JSON representation.

        Args:
            json_str: the model attributes serialized as JSON.

        Returns:
            A BaseTypedModel created from the serialized representation.
        """
        model_dict = json.loads(json_str)
        return cls.from_dict(model_dict)
from_dict(model_dict) classmethod

Instantiate a Pydantic model from a serialized JSON-able dict representation.

Parameters:

Name Type Description Default
model_dict Dict[str, Any]

the model attributes serialized as JSON-able dict.

required

Returns:

Type Description
BaseTypedModel

A BaseTypedModel created from the serialized representation.

Exceptions:

Type Description
RuntimeError

if the model_dict contains an invalid type.

Source code in zenml/utils/typed_model.py
@classmethod
def from_dict(
    cls,
    model_dict: Dict[str, Any],
) -> "BaseTypedModel":
    """Instantiate a Pydantic model from a serialized JSON-able dict representation.

    Args:
        model_dict: the model attributes serialized as JSON-able dict.

    Returns:
        A BaseTypedModel created from the serialized representation.

    Raises:
        RuntimeError: if the model_dict contains an invalid type.
    """
    model_type = model_dict.get("type")
    if not model_type:
        raise RuntimeError(
            "`type` information is missing from the serialized model dict."
        )
    cls = load_source_path_class(model_type)
    if not issubclass(cls, BaseTypedModel):
        raise RuntimeError(
            f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
        )

    return cls.parse_obj(model_dict)
from_json(json_str) classmethod

Instantiate a Pydantic model from a serialized JSON representation.

Parameters:

Name Type Description Default
json_str str

the model attributes serialized as JSON.

required

Returns:

Type Description
BaseTypedModel

A BaseTypedModel created from the serialized representation.

Source code in zenml/utils/typed_model.py
@classmethod
def from_json(
    cls,
    json_str: str,
) -> "BaseTypedModel":
    """Instantiate a Pydantic model from a serialized JSON representation.

    Args:
        json_str: the model attributes serialized as JSON.

    Returns:
        A BaseTypedModel created from the serialized representation.
    """
    model_dict = json.loads(json_str)
    return cls.from_dict(model_dict)

BaseTypedModelMeta (ModelMetaclass)

Metaclass responsible for adding type information to Pydantic models.

Source code in zenml/utils/typed_model.py
class BaseTypedModelMeta(ModelMetaclass):
    """Metaclass responsible for adding type information to Pydantic models."""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "BaseTypedModelMeta":
        """Creates a Pydantic BaseModel class.

        This includes a hidden attribute that reflects the full class
        identifier.

        Args:
            name: The name of the class.
            bases: The base classes of the class.
            dct: The class dictionary.

        Returns:
            A Pydantic BaseModel class that includes a hidden attribute that
            reflects the full class identifier.

        Raises:
            TypeError: If the class is not a Pydantic BaseModel class.
        """
        if "type" in dct:
            raise TypeError(
                "`type` is a reserved attribute name for BaseTypedModel "
                "subclasses"
            )
        type_name = f"{dct['__module__']}.{dct['__qualname__']}"
        type_ann = Literal[type_name]  # type: ignore [misc,valid-type]
        type = Field(type_name)
        dct.setdefault("__annotations__", dict())["type"] = type_ann
        dct["type"] = type
        cls = cast(
            Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
        )
        return cls
__new__(mcs, name, bases, dct) special staticmethod

Creates a Pydantic BaseModel class.

This includes a hidden attribute that reflects the full class identifier.

Parameters:

Name Type Description Default
name str

The name of the class.

required
bases Tuple[Type[Any], ...]

The base classes of the class.

required
dct Dict[str, Any]

The class dictionary.

required

Returns:

Type Description
BaseTypedModelMeta

A Pydantic BaseModel class that includes a hidden attribute that reflects the full class identifier.

Exceptions:

Type Description
TypeError

If the class is not a Pydantic BaseModel class.

Source code in zenml/utils/typed_model.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
    """Creates a Pydantic BaseModel class.

    This includes a hidden attribute that reflects the full class
    identifier.

    Args:
        name: The name of the class.
        bases: The base classes of the class.
        dct: The class dictionary.

    Returns:
        A Pydantic BaseModel class that includes a hidden attribute that
        reflects the full class identifier.

    Raises:
        TypeError: If the class is not a Pydantic BaseModel class.
    """
    if "type" in dct:
        raise TypeError(
            "`type` is a reserved attribute name for BaseTypedModel "
            "subclasses"
        )
    type_name = f"{dct['__module__']}.{dct['__qualname__']}"
    type_ann = Literal[type_name]  # type: ignore [misc,valid-type]
    type = Field(type_name)
    dct.setdefault("__annotations__", dict())["type"] = type_ann
    dct["type"] = type
    cls = cast(
        Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
    )
    return cls

uuid_utils

Utility functions for handling UUIDs.

generate_uuid_from_string(value)

Deterministically generates a UUID from a string seed.

Parameters:

Name Type Description Default
value str

The string from which to generate the UUID.

required

Returns:

Type Description
UUID

The generated UUID.

Source code in zenml/utils/uuid_utils.py
def generate_uuid_from_string(value: str) -> UUID:
    """Deterministically generates a UUID from a string seed.

    Args:
        value: The string from which to generate the UUID.

    Returns:
        The generated UUID.
    """
    hash_ = hashlib.md5()
    hash_.update(value.encode("utf-8"))
    return UUID(hex=hash_.hexdigest(), version=4)

is_valid_uuid(value, version=4)

Checks if a string is a valid UUID.

Parameters:

Name Type Description Default
value Any

String to check.

required
version int

Version of UUID to check for.

4

Returns:

Type Description
bool

True if string is a valid UUID, False otherwise.

Source code in zenml/utils/uuid_utils.py
def is_valid_uuid(value: Any, version: int = 4) -> bool:
    """Checks if a string is a valid UUID.

    Args:
        value: String to check.
        version: Version of UUID to check for.

    Returns:
        True if string is a valid UUID, False otherwise.
    """
    if isinstance(value, UUID):
        return True
    if isinstance(value, str):
        try:
            UUID(value, version=version)
            return True
        except ValueError:
            return False
    return False

parse_name_or_uuid(name_or_id)

Convert a "name or id" string value to a string or UUID.

Parameters:

Name Type Description Default
name_or_id str

Name or id to convert.

required

Returns:

Type Description
Union[str, uuid.UUID]

A UUID if name_or_id is a UUID, string otherwise.

Source code in zenml/utils/uuid_utils.py
def parse_name_or_uuid(name_or_id: str) -> Union[str, UUID]:
    """Convert a "name or id" string value to a string or UUID.

    Args:
        name_or_id: Name or id to convert.

    Returns:
        A UUID if name_or_id is a UUID, string otherwise.
    """
    if name_or_id:
        try:
            return UUID(name_or_id)
        except ValueError:
            return name_or_id
    else:
        return name_or_id

parse_optional_name_or_uuid(name_or_id)

Convert an optional "name or id" string value to an optional string or UUID.

Parameters:

Name Type Description Default
name_or_id Optional[str]

Name or id to convert.

required

Returns:

Type Description
Union[str, uuid.UUID]

A UUID if name_or_id is a UUID, string otherwise.

Source code in zenml/utils/uuid_utils.py
def parse_optional_name_or_uuid(
    name_or_id: Optional[str],
) -> Optional[Union[str, UUID]]:
    """Convert an optional "name or id" string value to an optional string or UUID.

    Args:
        name_or_id: Name or id to convert.

    Returns:
        A UUID if name_or_id is a UUID, string otherwise.
    """
    if name_or_id is None:
        return None
    return parse_name_or_uuid(name_or_id)

yaml_utils

Utility functions to help with YAML files and data.

UUIDEncoder (JSONEncoder)

JSON encoder for UUID objects.

Source code in zenml/utils/yaml_utils.py
class UUIDEncoder(json.JSONEncoder):
    """JSON encoder for UUID objects."""

    def default(self, obj: Any) -> Any:
        """Default UUID encoder for JSON.

        Args:
            obj: Object to encode.

        Returns:
            Encoded object.
        """
        if isinstance(obj, UUID):
            # if the obj is uuid, we simply return the value of uuid
            return obj.hex
        return json.JSONEncoder.default(self, obj)
default(self, obj)

Default UUID encoder for JSON.

Parameters:

Name Type Description Default
obj Any

Object to encode.

required

Returns:

Type Description
Any

Encoded object.

Source code in zenml/utils/yaml_utils.py
def default(self, obj: Any) -> Any:
    """Default UUID encoder for JSON.

    Args:
        obj: Object to encode.

    Returns:
        Encoded object.
    """
    if isinstance(obj, UUID):
        # if the obj is uuid, we simply return the value of uuid
        return obj.hex
    return json.JSONEncoder.default(self, obj)

append_yaml(file_path, contents)

Append contents to a YAML file at file_path.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required
contents Dict[Any, Any]

Contents of YAML file as dict.

required

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def append_yaml(file_path: str, contents: Dict[Any, Any]) -> None:
    """Append contents to a YAML file at file_path.

    Args:
        file_path: Path to YAML file.
        contents: Contents of YAML file as dict.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    file_contents = read_yaml(file_path) or {}
    file_contents.update(contents)
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(file_path, yaml.dump(file_contents))

comment_out_yaml(yaml_string)

Comments out a yaml string.

Parameters:

Name Type Description Default
yaml_string str

The yaml string to comment out.

required

Returns:

Type Description
str

The commented out yaml string.

Source code in zenml/utils/yaml_utils.py
def comment_out_yaml(yaml_string: str) -> str:
    """Comments out a yaml string.

    Args:
        yaml_string: The yaml string to comment out.

    Returns:
        The commented out yaml string.
    """
    lines = yaml_string.splitlines(keepends=True)
    lines = ["# " + line for line in lines]
    return "".join(lines)

is_yaml(file_path)

Returns True if file_path is YAML, else False.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required

Returns:

Type Description
bool

True if is yaml, else False.

Source code in zenml/utils/yaml_utils.py
def is_yaml(file_path: str) -> bool:
    """Returns True if file_path is YAML, else False.

    Args:
        file_path: Path to YAML file.

    Returns:
        True if is yaml, else False.
    """
    if file_path.endswith("yaml") or file_path.endswith("yml"):
        return True
    return False

read_json(file_path)

Read JSON on file path and returns contents as dict.

Parameters:

Name Type Description Default
file_path str

Path to JSON file.

required

Returns:

Type Description
Any

Contents of the file in a dict.

Exceptions:

Type Description
FileNotFoundError

if file does not exist.

Source code in zenml/utils/yaml_utils.py
def read_json(file_path: str) -> Any:
    """Read JSON on file path and returns contents as dict.

    Args:
        file_path: Path to JSON file.

    Returns:
        Contents of the file in a dict.

    Raises:
        FileNotFoundError: if file does not exist.
    """
    if fileio.exists(file_path):
        contents = io_utils.read_file_contents_as_string(file_path)
        return json.loads(contents)
    else:
        raise FileNotFoundError(f"{file_path} does not exist.")

read_yaml(file_path)

Read YAML on file path and returns contents as dict.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required

Returns:

Type Description
Any

Contents of the file in a dict.

Exceptions:

Type Description
FileNotFoundError

if file does not exist.

Source code in zenml/utils/yaml_utils.py
def read_yaml(file_path: str) -> Any:
    """Read YAML on file path and returns contents as dict.

    Args:
        file_path: Path to YAML file.

    Returns:
        Contents of the file in a dict.

    Raises:
        FileNotFoundError: if file does not exist.
    """
    if fileio.exists(file_path):
        contents = io_utils.read_file_contents_as_string(file_path)
        # TODO: [LOW] consider adding a default empty dict to be returned
        #   instead of None
        return yaml.safe_load(contents)
    else:
        raise FileNotFoundError(f"{file_path} does not exist.")

write_json(file_path, contents, encoder=None)

Write contents as JSON format to file_path.

Parameters:

Name Type Description Default
file_path str

Path to JSON file.

required
contents Dict[str, Any]

Contents of JSON file as dict.

required
encoder Optional[Type[json.encoder.JSONEncoder]]

Custom JSON encoder to use when saving json.

None

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def write_json(
    file_path: str,
    contents: Dict[str, Any],
    encoder: Optional[Type[json.JSONEncoder]] = None,
) -> None:
    """Write contents as JSON format to file_path.

    Args:
        file_path: Path to JSON file.
        contents: Contents of JSON file as dict.
        encoder: Custom JSON encoder to use when saving json.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            # Check if it is a local path, if it doesn't exist, raise Exception.
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(
        file_path,
        json.dumps(
            contents,
            cls=encoder,
        ),
    )

write_yaml(file_path, contents, sort_keys=True)

Write contents as YAML format to file_path.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required
contents Union[Dict[Any, Any], List[Any]]

Contents of YAML file as dict or list.

required
sort_keys bool

If True, keys are sorted alphabetically. If False, the order in which the keys were inserted into the dict will be preserved.

True

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def write_yaml(
    file_path: str,
    contents: Union[Dict[Any, Any], List[Any]],
    sort_keys: bool = True,
) -> None:
    """Write contents as YAML format to file_path.

    Args:
        file_path: Path to YAML file.
        contents: Contents of YAML file as dict or list.
        sort_keys: If `True`, keys are sorted alphabetically. If `False`,
            the order in which the keys were inserted into the dict will
            be preserved.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(
        file_path, yaml.dump(contents, sort_keys=sort_keys)
    )