Skip to content

Utils

zenml.utils special

Initialization of the utils module.

The utils module contains utility functions handling analytics, reading and writing YAML data as well as other general purpose functions.

analytics_utils

Analytics code for ZenML.

AnalyticsContext

Context manager for analytics.

Source code in zenml/utils/analytics_utils.py
class AnalyticsContext:
    """Context manager for analytics."""

    def __init__(self) -> None:
        """Context manager for analytics.

        Use this as a context manager to ensure that analytics are initialized
        properly, only tracked when configured to do so and that any errors
        are handled gracefully.
        """
        import analytics

        from zenml.config.global_config import GlobalConfiguration

        try:
            gc = GlobalConfiguration()

            self.analytics_opt_in = gc.analytics_opt_in
            self.user_id = str(gc.user_id)

            # That means user opted out of analytics
            if not gc.analytics_opt_in:
                return

            if analytics.write_key is None:
                analytics.write_key = get_segment_key()

            assert (
                analytics.write_key is not None
            ), "Analytics key not set but trying to make telemetry call."

            # Set this to 1 to avoid backoff loop
            analytics.max_retries = 1
        except Exception as e:
            self.analytics_opt_in = False
            logger.debug(f"Analytics initialization failed: {e}")

    def __enter__(self) -> "AnalyticsContext":
        """Enter context manager.

        Returns:
            Self.
        """
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> bool:
        """Exit context manager.

        Args:
            exc_type: Exception type.
            exc_val: Exception value.
            exc_tb: Exception traceback.

        Returns:
            True, we should never fail main thread.
        """
        if exc_val is not None:
            logger.debug(f"Sending telemetry data failed: {exc_val}")

        return True

    def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
        """Identify the user.

        Args:
            traits: Traits of the user.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        logger.debug(
            f"Attempting to attach metadata to: User: {self.user_id}, "
            f"Metadata: {traits}"
        )

        if not self.analytics_opt_in:
            return False

        analytics.identify(self.user_id, traits)

        logger.debug(f"User data sent: User: {self.user_id},{traits}")

        return True

    def group(
        self,
        group: Union[str, AnalyticsGroup],
        group_id: str,
        traits: Optional[Dict[str, Any]] = None,
    ) -> bool:
        """Group the user.

        Args:
            group: Group to which the user belongs.
            group_id: Group ID.
            traits: Traits of the group.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        if isinstance(group, AnalyticsGroup):
            group = group.value

        if traits is None:
            traits = {}

        traits.update(
            {
                "group_id": group_id,
            }
        )

        logger.debug(
            f"Attempting to attach metadata to: User: {self.user_id}, "
            f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
        )

        if not self.analytics_opt_in:
            return False
        analytics.group(self.user_id, group_id, traits=traits)

        logger.debug(
            f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
            f"{group_id}, Metadata: {traits}"
        )
        return True

    def track(
        self,
        event: Union[str, AnalyticsEvent],
        properties: Optional[Dict[str, Any]] = None,
    ) -> bool:
        """Track an event.

        Args:
            event: Event to track.
            properties: Event properties.

        Returns:
            True if tracking information was sent, False otherwise.
        """
        import analytics

        from zenml.config.global_config import GlobalConfiguration

        if isinstance(event, AnalyticsEvent):
            event = event.value

        if properties is None:
            properties = {}

        logger.debug(
            f"Attempting analytics: User: {self.user_id}, "
            f"Event: {event},"
            f"Metadata: {properties}"
        )

        if not self.analytics_opt_in and event not in {
            AnalyticsEvent.OPT_OUT_ANALYTICS,
            AnalyticsEvent.OPT_IN_ANALYTICS,
        }:
            return False

        # add basics
        properties.update(Environment.get_system_info())
        properties.update(
            {
                "environment": get_environment(),
                "python_version": Environment.python_version(),
                "version": __version__,
            }
        )

        gc = GlobalConfiguration()
        # avoid initializing the store in the analytics, to not create an
        # infinite loop
        if gc._zen_store is not None:
            zen_store = gc.zen_store
            user = zen_store.get_user()

            if "client_id" not in properties:
                properties["client_id"] = self.user_id
            if "user_id" not in properties:
                properties["user_id"] = str(user.id)

            if (
                zen_store.type == StoreType.REST
                and "server_id" not in properties
            ):
                server_info = zen_store.get_store_info()
                properties.update(
                    {
                        "user_id": str(user.id),
                        "server_id": str(server_info.id),
                        "server_deployment": str(server_info.deployment_type),
                        "database_type": str(server_info.database_type),
                        "secrets_store_type": str(
                            server_info.secrets_store_type
                        ),
                    }
                )

        for k, v in properties.items():
            if isinstance(v, UUID):
                properties[k] = str(v)

        analytics.track(self.user_id, event, properties)

        logger.debug(
            f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
            f"{properties}"
        )

        return True
__enter__(self) special

Enter context manager.

Returns:

Type Description
AnalyticsContext

Self.

Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "AnalyticsContext":
    """Enter context manager.

    Returns:
        Self.
    """
    return self
__exit__(self, exc_type, exc_val, exc_tb) special

Exit context manager.

Parameters:

Name Type Description Default
exc_type Optional[Type[BaseException]]

Exception type.

required
exc_val Optional[BaseException]

Exception value.

required
exc_tb Optional[traceback]

Exception traceback.

required

Returns:

Type Description
bool

True, we should never fail main thread.

Source code in zenml/utils/analytics_utils.py
def __exit__(
    self,
    exc_type: Optional[Type[BaseException]],
    exc_val: Optional[BaseException],
    exc_tb: Optional[TracebackType],
) -> bool:
    """Exit context manager.

    Args:
        exc_type: Exception type.
        exc_val: Exception value.
        exc_tb: Exception traceback.

    Returns:
        True, we should never fail main thread.
    """
    if exc_val is not None:
        logger.debug(f"Sending telemetry data failed: {exc_val}")

    return True
__init__(self) special

Context manager for analytics.

Use this as a context manager to ensure that analytics are initialized properly, only tracked when configured to do so and that any errors are handled gracefully.

Source code in zenml/utils/analytics_utils.py
def __init__(self) -> None:
    """Context manager for analytics.

    Use this as a context manager to ensure that analytics are initialized
    properly, only tracked when configured to do so and that any errors
    are handled gracefully.
    """
    import analytics

    from zenml.config.global_config import GlobalConfiguration

    try:
        gc = GlobalConfiguration()

        self.analytics_opt_in = gc.analytics_opt_in
        self.user_id = str(gc.user_id)

        # That means user opted out of analytics
        if not gc.analytics_opt_in:
            return

        if analytics.write_key is None:
            analytics.write_key = get_segment_key()

        assert (
            analytics.write_key is not None
        ), "Analytics key not set but trying to make telemetry call."

        # Set this to 1 to avoid backoff loop
        analytics.max_retries = 1
    except Exception as e:
        self.analytics_opt_in = False
        logger.debug(f"Analytics initialization failed: {e}")
group(self, group, group_id, traits=None)

Group the user.

Parameters:

Name Type Description Default
group Union[str, zenml.utils.analytics_utils.AnalyticsGroup]

Group to which the user belongs.

required
group_id str

Group ID.

required
traits Optional[Dict[str, Any]]

Traits of the group.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def group(
    self,
    group: Union[str, AnalyticsGroup],
    group_id: str,
    traits: Optional[Dict[str, Any]] = None,
) -> bool:
    """Group the user.

    Args:
        group: Group to which the user belongs.
        group_id: Group ID.
        traits: Traits of the group.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    if isinstance(group, AnalyticsGroup):
        group = group.value

    if traits is None:
        traits = {}

    traits.update(
        {
            "group_id": group_id,
        }
    )

    logger.debug(
        f"Attempting to attach metadata to: User: {self.user_id}, "
        f"Group: {group}, Group ID: {group_id}, Metadata: {traits}"
    )

    if not self.analytics_opt_in:
        return False
    analytics.group(self.user_id, group_id, traits=traits)

    logger.debug(
        f"Group data sent: User: {self.user_id}, Group: {group}, Group ID: "
        f"{group_id}, Metadata: {traits}"
    )
    return True
identify(self, traits=None)

Identify the user.

Parameters:

Name Type Description Default
traits Optional[Dict[str, Any]]

Traits of the user.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool:
    """Identify the user.

    Args:
        traits: Traits of the user.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    logger.debug(
        f"Attempting to attach metadata to: User: {self.user_id}, "
        f"Metadata: {traits}"
    )

    if not self.analytics_opt_in:
        return False

    analytics.identify(self.user_id, traits)

    logger.debug(f"User data sent: User: {self.user_id},{traits}")

    return True
track(self, event, properties=None)

Track an event.

Parameters:

Name Type Description Default
event Union[str, zenml.utils.analytics_utils.AnalyticsEvent]

Event to track.

required
properties Optional[Dict[str, Any]]

Event properties.

None

Returns:

Type Description
bool

True if tracking information was sent, False otherwise.

Source code in zenml/utils/analytics_utils.py
def track(
    self,
    event: Union[str, AnalyticsEvent],
    properties: Optional[Dict[str, Any]] = None,
) -> bool:
    """Track an event.

    Args:
        event: Event to track.
        properties: Event properties.

    Returns:
        True if tracking information was sent, False otherwise.
    """
    import analytics

    from zenml.config.global_config import GlobalConfiguration

    if isinstance(event, AnalyticsEvent):
        event = event.value

    if properties is None:
        properties = {}

    logger.debug(
        f"Attempting analytics: User: {self.user_id}, "
        f"Event: {event},"
        f"Metadata: {properties}"
    )

    if not self.analytics_opt_in and event not in {
        AnalyticsEvent.OPT_OUT_ANALYTICS,
        AnalyticsEvent.OPT_IN_ANALYTICS,
    }:
        return False

    # add basics
    properties.update(Environment.get_system_info())
    properties.update(
        {
            "environment": get_environment(),
            "python_version": Environment.python_version(),
            "version": __version__,
        }
    )

    gc = GlobalConfiguration()
    # avoid initializing the store in the analytics, to not create an
    # infinite loop
    if gc._zen_store is not None:
        zen_store = gc.zen_store
        user = zen_store.get_user()

        if "client_id" not in properties:
            properties["client_id"] = self.user_id
        if "user_id" not in properties:
            properties["user_id"] = str(user.id)

        if (
            zen_store.type == StoreType.REST
            and "server_id" not in properties
        ):
            server_info = zen_store.get_store_info()
            properties.update(
                {
                    "user_id": str(user.id),
                    "server_id": str(server_info.id),
                    "server_deployment": str(server_info.deployment_type),
                    "database_type": str(server_info.database_type),
                    "secrets_store_type": str(
                        server_info.secrets_store_type
                    ),
                }
            )

    for k, v in properties.items():
        if isinstance(v, UUID):
            properties[k] = str(v)

    analytics.track(self.user_id, event, properties)

    logger.debug(
        f"Analytics sent: User: {self.user_id}, Event: {event}, Metadata: "
        f"{properties}"
    )

    return True

AnalyticsEvent (str, Enum)

Enum of events to track in segment.

Source code in zenml/utils/analytics_utils.py
class AnalyticsEvent(str, Enum):
    """Enum of events to track in segment."""

    # Pipelines
    RUN_PIPELINE = "Pipeline run"
    GET_PIPELINES = "Pipelines fetched"
    GET_PIPELINE = "Pipeline fetched"
    CREATE_PIPELINE = "Pipeline created"
    UPDATE_PIPELINE = "Pipeline updated"
    DELETE_PIPELINE = "Pipeline deleted"
    BUILD_PIPELINE = "Pipeline built"

    # Repo
    INITIALIZE_REPO = "ZenML initialized"
    CONNECT_REPOSITORY = "Repository connected"
    UPDATE_REPOSITORY = "Repository updated"
    DELETE_REPOSITORY = "Repository deleted"

    # Template
    GENERATE_TEMPLATE = "Template generated"

    # Zen store
    INITIALIZED_STORE = "Store initialized"

    # Components
    REGISTERED_STACK_COMPONENT = "Stack component registered"
    UPDATED_STACK_COMPONENT = "Stack component updated"
    COPIED_STACK_COMPONENT = "Stack component copied"
    DELETED_STACK_COMPONENT = "Stack component copied"
    CONNECTED_STACK_COMPONENT = "Stack component connected"

    # Stack
    REGISTERED_STACK = "Stack registered"
    REGISTERED_DEFAULT_STACK = "Default stack registered"
    SET_STACK = "Stack set"
    UPDATED_STACK = "Stack updated"
    COPIED_STACK = "Stack copied"
    IMPORT_STACK = "Stack imported"
    EXPORT_STACK = "Stack exported"
    DELETED_STACK = "Stack deleted"

    # Model Deployment
    MODEL_DEPLOYED = "Model deployed"

    # Analytics opt in and out
    OPT_IN_ANALYTICS = "Analytics opt-in"
    OPT_OUT_ANALYTICS = "Analytics opt-out"
    OPT_IN_OUT_EMAIL = "Response for Email prompt"

    # Examples
    RUN_ZENML_GO = "ZenML go"
    RUN_EXAMPLE = "Example run"
    PULL_EXAMPLE = "Example pull"

    # Integrations
    INSTALL_INTEGRATION = "Integration installed"

    # Users
    CREATED_USER = "User created"
    CREATED_DEFAULT_USER = "Default user created"
    UPDATED_USER = "User updated"
    DELETED_USER = "User deleted"

    # Teams
    CREATED_TEAM = "Team created"
    UPDATED_TEAM = "Team updated"
    DELETED_TEAM = "Team deleted"

    # Workspaces
    CREATED_WORKSPACE = "Workspace created"
    CREATED_DEFAULT_WORKSPACE = "Default workspace created"
    UPDATED_WORKSPACE = "Workspace updated"
    DELETED_WORKSPACE = "Workspace deleted"
    SET_WORKSPACE = "Workspace set"

    # Role
    CREATED_ROLE = "Role created"
    CREATED_DEFAULT_ROLES = "Default roles created"
    UPDATED_ROLE = "Role updated"
    DELETED_ROLE = "Role deleted"

    # Flavor
    CREATED_FLAVOR = "Flavor created"
    UPDATED_FLAVOR = "Flavor updated"
    DELETED_FLAVOR = "Flavor deleted"

    # Secret
    CREATED_SECRET = "Secret created"
    UPDATED_SECRET = "Secret updated"
    DELETED_SECRET = "Secret deleted"

    # Service connector
    CREATED_SERVICE_CONNECTOR = "Service connector created"
    UPDATED_SERVICE_CONNECTOR = "Service connector updated"
    DELETED_SERVICE_CONNECTOR = "Service connector deleted"

    # Test event
    EVENT_TEST = "Test event"

    # Stack recipes
    PULL_STACK_RECIPE = "Stack recipes pulled"
    RUN_STACK_RECIPE = "Stack recipe ran"
    DESTROY_STACK_RECIPE = "Stack recipe destroyed"
    GET_STACK_RECIPE_OUTPUTS = "Stack recipe outputs fetched"

    # Stack component deploy
    DEPLOY_STACK_COMPONENT = "Stack component deployed"
    DESTROY_STACK_COMPONENT = "Stack component destroyed"

    # ZenML server events
    ZENML_SERVER_STARTED = "ZenML server started"
    ZENML_SERVER_STOPPED = "ZenML server stopped"
    ZENML_SERVER_CONNECTED = "ZenML server connected"
    ZENML_SERVER_DEPLOYED = "ZenML server deployed"
    ZENML_SERVER_DESTROYED = "ZenML server destroyed"

    # ZenML Hub events
    ZENML_HUB_PLUGIN_INSTALL = "ZenML Hub plugin installed"
    ZENML_HUB_PLUGIN_UNINSTALL = "ZenML Hub plugin uninstalled"
    ZENML_HUB_PLUGIN_CLONE = "ZenML Hub plugin pulled"
    ZENML_HUB_PLUGIN_SUBMIT = "ZenML Hub plugin pushed"

AnalyticsGroup (str, Enum)

Enum of event groups to track in segment.

Source code in zenml/utils/analytics_utils.py
class AnalyticsGroup(str, Enum):
    """Enum of event groups to track in segment."""

    ZENML_SERVER_GROUP = "ZenML server group"

AnalyticsTrackedModelMixin (BaseModel) pydantic-model

Mixin for models that are tracked through analytics events.

Classes that have information tracked in analytics events can inherit from this mixin and implement the abstract methods. The @track decorator will detect function arguments and return values that inherit from this class and will include the ANALYTICS_FIELDS attributes as tracking metadata.

Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackedModelMixin(BaseModel):
    """Mixin for models that are tracked through analytics events.

    Classes that have information tracked in analytics events can inherit
    from this mixin and implement the abstract methods. The `@track` decorator
    will detect function arguments and return values that inherit from this
    class and will include the `ANALYTICS_FIELDS` attributes as
    tracking metadata.
    """

    ANALYTICS_FIELDS: ClassVar[List[str]] = []

    def get_analytics_metadata(self) -> Dict[str, Any]:
        """Get the analytics metadata for the model.

        Returns:
            Dict of analytics metadata.
        """
        metadata = {}
        for field_name in self.ANALYTICS_FIELDS:
            metadata[field_name] = getattr(self, field_name, None)
        return metadata
get_analytics_metadata(self)

Get the analytics metadata for the model.

Returns:

Type Description
Dict[str, Any]

Dict of analytics metadata.

Source code in zenml/utils/analytics_utils.py
def get_analytics_metadata(self) -> Dict[str, Any]:
    """Get the analytics metadata for the model.

    Returns:
        Dict of analytics metadata.
    """
    metadata = {}
    for field_name in self.ANALYTICS_FIELDS:
        metadata[field_name] = getattr(self, field_name, None)
    return metadata

AnalyticsTrackerMixin (ABC)

Abstract base class for analytics trackers.

Use this as a mixin for classes that have methods decorated with @track to add global control over how analytics are tracked. The decorator will detect that the class has this mixin and will call the class track_event method.

Source code in zenml/utils/analytics_utils.py
class AnalyticsTrackerMixin(ABC):
    """Abstract base class for analytics trackers.

    Use this as a mixin for classes that have methods decorated with
    `@track` to add global control over how analytics are tracked. The decorator
    will detect that the class has this mixin and will call the class
    `track_event` method.
    """

    @abstractmethod
    def track_event(
        self,
        event: AnalyticsEvent,
        metadata: Optional[Dict[str, Any]],
    ) -> None:
        """Track an event.

        Args:
            event: Event to track.
            metadata: Metadata to track.
        """
track_event(self, event, metadata)

Track an event.

Parameters:

Name Type Description Default
event AnalyticsEvent

Event to track.

required
metadata Optional[Dict[str, Any]]

Metadata to track.

required
Source code in zenml/utils/analytics_utils.py
@abstractmethod
def track_event(
    self,
    event: AnalyticsEvent,
    metadata: Optional[Dict[str, Any]],
) -> None:
    """Track an event.

    Args:
        event: Event to track.
        metadata: Metadata to track.
    """

event_handler

Context handler to enable tracking the success status of an event.

Source code in zenml/utils/analytics_utils.py
class event_handler(object):
    """Context handler to enable tracking the success status of an event."""

    def __init__(
        self,
        event: AnalyticsEvent,
        metadata: Optional[Dict[str, Any]] = None,
        v1: Optional[bool] = True,
        v2: Optional[bool] = False,
    ):
        """Initialization of the context manager.

        Args:
            event: The type of the analytics event
            metadata: The metadata of the event.
            v1: Flag to determine whether analytics v1 is included.
            v2: Flag to determine whether analytics v2 is included.
        """
        self.event: AnalyticsEvent = event
        self.metadata: Dict[str, Any] = metadata or {}
        self.tracker: Optional[AnalyticsTrackerMixin] = None
        self.v1: Optional[bool] = v1
        self.v2: Optional[bool] = v2

    def __enter__(self) -> "event_handler":
        """Enter function of the event handler.

        Returns:
            the handler instance.
        """
        return self

    def __exit__(
        self,
        type_: Optional[Any],
        value: Optional[Any],
        traceback: Optional[Any],
    ) -> Any:
        """Exit function of the event handler.

        Checks whether there was a traceback and updates the metadata
        accordingly. Following the check, it calls the function to track the
        event.

        Args:
            type_: The class of the exception
            value: The instance of the exception
            traceback: The traceback of the exception

        """
        if traceback is not None:
            self.metadata.update({"event_success": False})
        else:
            self.metadata.update({"event_success": True})

        if type_ is not None:
            self.metadata.update({"event_error_type": type_.__name__})

        if self.v1:
            if self.tracker:
                self.tracker.track_event(self.event, self.metadata)
            else:
                track_event(self.event, self.metadata, v1=True, v2=False)

        if self.v2:
            track_event(self.event, self.metadata, v1=False, v2=True)
__enter__(self) special

Enter function of the event handler.

Returns:

Type Description
event_handler

the handler instance.

Source code in zenml/utils/analytics_utils.py
def __enter__(self) -> "event_handler":
    """Enter function of the event handler.

    Returns:
        the handler instance.
    """
    return self
__exit__(self, type_, value, traceback) special

Exit function of the event handler.

Checks whether there was a traceback and updates the metadata accordingly. Following the check, it calls the function to track the event.

Parameters:

Name Type Description Default
type_ Optional[Any]

The class of the exception

required
value Optional[Any]

The instance of the exception

required
traceback Optional[Any]

The traceback of the exception

required
Source code in zenml/utils/analytics_utils.py
def __exit__(
    self,
    type_: Optional[Any],
    value: Optional[Any],
    traceback: Optional[Any],
) -> Any:
    """Exit function of the event handler.

    Checks whether there was a traceback and updates the metadata
    accordingly. Following the check, it calls the function to track the
    event.

    Args:
        type_: The class of the exception
        value: The instance of the exception
        traceback: The traceback of the exception

    """
    if traceback is not None:
        self.metadata.update({"event_success": False})
    else:
        self.metadata.update({"event_success": True})

    if type_ is not None:
        self.metadata.update({"event_error_type": type_.__name__})

    if self.v1:
        if self.tracker:
            self.tracker.track_event(self.event, self.metadata)
        else:
            track_event(self.event, self.metadata, v1=True, v2=False)

    if self.v2:
        track_event(self.event, self.metadata, v1=False, v2=True)
__init__(self, event, metadata=None, v1=True, v2=False) special

Initialization of the context manager.

Parameters:

Name Type Description Default
event AnalyticsEvent

The type of the analytics event

required
metadata Optional[Dict[str, Any]]

The metadata of the event.

None
v1 Optional[bool]

Flag to determine whether analytics v1 is included.

True
v2 Optional[bool]

Flag to determine whether analytics v2 is included.

False
Source code in zenml/utils/analytics_utils.py
def __init__(
    self,
    event: AnalyticsEvent,
    metadata: Optional[Dict[str, Any]] = None,
    v1: Optional[bool] = True,
    v2: Optional[bool] = False,
):
    """Initialization of the context manager.

    Args:
        event: The type of the analytics event
        metadata: The metadata of the event.
        v1: Flag to determine whether analytics v1 is included.
        v2: Flag to determine whether analytics v2 is included.
    """
    self.event: AnalyticsEvent = event
    self.metadata: Dict[str, Any] = metadata or {}
    self.tracker: Optional[AnalyticsTrackerMixin] = None
    self.v1: Optional[bool] = v1
    self.v2: Optional[bool] = v2

get_segment_key()

Get key for authorizing to Segment backend.

Returns:

Type Description
str

Segment key as a string.

Source code in zenml/utils/analytics_utils.py
def get_segment_key() -> str:
    """Get key for authorizing to Segment backend.

    Returns:
        Segment key as a string.
    """
    if IS_DEBUG_ENV:
        return SEGMENT_KEY_DEV
    else:
        return SEGMENT_KEY_PROD

identify_group(group, group_id, group_metadata=None, v1=True, v2=False)

Attach metadata to a segment group.

Parameters:

Name Type Description Default
group Union[str, zenml.utils.analytics_utils.AnalyticsGroup]

Group to track.

required
group_id UUID

ID of the group.

required
group_metadata Optional[Dict[str, Any]]

Metadata to attach to the group.

None
v1 Optional[bool]

Flag to determine whether analytics v1 is included.

True
v2 Optional[bool]

Flag to determine whether analytics v2 is included.

False

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def identify_group(
    group: Union[str, AnalyticsGroup],
    group_id: UUID,
    group_metadata: Optional[Dict[str, Any]] = None,
    v1: Optional[bool] = True,
    v2: Optional[bool] = False,
) -> bool:
    """Attach metadata to a segment group.

    Args:
        group: Group to track.
        group_id: ID of the group.
        group_metadata: Metadata to attach to the group.
        v1: Flag to determine whether analytics v1 is included.
        v2: Flag to determine whether analytics v2 is included.

    Returns:
        True if event is sent successfully, False is not.
    """
    success = True

    if v1:
        with AnalyticsContext() as analytics:
            success_v1 = analytics.group(
                group=group, group_id=str(group_id), traits=group_metadata
            )
            success = success and success_v1

    if v2:
        with AnalyticsContextV2() as analytics:
            success_v2 = analytics.group(
                group_id=group_id, traits=group_metadata
            )
            success = success and success_v2

    return success

identify_user(user_metadata=None, v1=True, v2=False)

Attach metadata to user directly.

Parameters:

Name Type Description Default
user_metadata Optional[Dict[str, Any]]

Dict of metadata to attach to the user.

None
v1 Optional[bool]

Flag to determine whether analytics v1 is included.

True
v2 Optional[bool]

Flag to determine whether analytics v2 is included.

False

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def identify_user(
    user_metadata: Optional[Dict[str, Any]] = None,
    v1: Optional[bool] = True,
    v2: Optional[bool] = False,
) -> bool:
    """Attach metadata to user directly.

    Args:
        user_metadata: Dict of metadata to attach to the user.
        v1: Flag to determine whether analytics v1 is included.
        v2: Flag to determine whether analytics v2 is included.

    Returns:
        True if event is sent successfully, False is not.
    """
    success = True

    if user_metadata is None:
        return False

    if v1:
        with AnalyticsContext() as analytics:
            success_v1 = analytics.identify(traits=user_metadata)
            success = success and success_v1

    if v2:
        with AnalyticsContextV2() as analytics:
            success_v2 = analytics.identify(traits=user_metadata)
            success = success and success_v2

    return success

parametrized(dec)

A meta-decorator, that is, a decorator for decorators.

As a decorator is a function, it actually works as a regular decorator with arguments.

Parameters:

Name Type Description Default
dec Callable[..., Callable[..., Any]]

Decorator to be applied to the function.

required

Returns:

Type Description
Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]

Decorator that applies the given decorator to the function.

Source code in zenml/utils/analytics_utils.py
def parametrized(
    dec: Callable[..., Callable[..., Any]]
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., Any]]]:
    """A meta-decorator, that is, a decorator for decorators.

    As a decorator is a function, it actually works as a regular decorator
    with arguments.

    Args:
        dec: Decorator to be applied to the function.

    Returns:
        Decorator that applies the given decorator to the function.
    """

    def layer(
        *args: Any, **kwargs: Any
    ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
        """Internal layer.

        Args:
            *args: Arguments to be passed to the decorator.
            **kwargs: Keyword arguments to be passed to the decorator.

        Returns:
            Decorator that applies the given decorator to the function.
        """

        def repl(f: Callable[..., Any]) -> Callable[..., Any]:
            """Internal REPL.

            Args:
                f: Function to be decorated.

            Returns:
                Decorated function.
            """
            return dec(f, *args, **kwargs)

        return repl

    return layer

track(*args, **kwargs)

Internal layer.

Parameters:

Name Type Description Default
*args Any

Arguments to be passed to the decorator.

()
**kwargs Any

Keyword arguments to be passed to the decorator.

{}

Returns:

Type Description
Callable[[Callable[..., Any]], Callable[..., Any]]

Decorator that applies the given decorator to the function.

Source code in zenml/utils/analytics_utils.py
def layer(
    *args: Any, **kwargs: Any
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
    """Internal layer.

    Args:
        *args: Arguments to be passed to the decorator.
        **kwargs: Keyword arguments to be passed to the decorator.

    Returns:
        Decorator that applies the given decorator to the function.
    """

    def repl(f: Callable[..., Any]) -> Callable[..., Any]:
        """Internal REPL.

        Args:
            f: Function to be decorated.

        Returns:
            Decorated function.
        """
        return dec(f, *args, **kwargs)

    return repl

track_event(event, metadata=None, v1=True, v2=False)

Track segment event if user opted-in.

Parameters:

Name Type Description Default
event AnalyticsEvent

Name of event to track in segment.

required
metadata Optional[Dict[str, Any]]

Dict of metadata to track.

None
v1 Optional[bool]

Flag to determine whether analytics v1 is included.

True
v2 Optional[bool]

Flag to determine whether analytics v2 is included.

False

Returns:

Type Description
bool

True if event is sent successfully, False is not.

Source code in zenml/utils/analytics_utils.py
def track_event(
    event: AnalyticsEvent,
    metadata: Optional[Dict[str, Any]] = None,
    v1: Optional[bool] = True,
    v2: Optional[bool] = False,
) -> bool:
    """Track segment event if user opted-in.

    Args:
        event: Name of event to track in segment.
        metadata: Dict of metadata to track.
        v1: Flag to determine whether analytics v1 is included.
        v2: Flag to determine whether analytics v2 is included.

    Returns:
        True if event is sent successfully, False is not.
    """
    success = True

    if metadata is None:
        metadata = {}

    metadata.setdefault("event_success", True)

    if v1:
        with AnalyticsContext() as analytics:
            success_v1 = analytics.track(event=event, properties=metadata)
            success = success and success_v1

    if v2:
        with AnalyticsContextV2() as analytics:
            success_v2 = analytics.track(event=event, properties=metadata)
            success = success and success_v2

    return success

artifact_utils

Util functions for artifact handling.

load_artifact(artifact)

Load the given artifact into memory.

Parameters:

Name Type Description Default
artifact ArtifactResponseModel

The artifact to load.

required

Returns:

Type Description
Any

The artifact loaded into memory.

Source code in zenml/utils/artifact_utils.py
def load_artifact(artifact: "ArtifactResponseModel") -> Any:
    """Load the given artifact into memory.

    Args:
        artifact: The artifact to load.

    Returns:
        The artifact loaded into memory.
    """
    artifact_store_loaded = False
    if artifact.artifact_store_id:
        try:
            artifact_store_model = Client().get_stack_component(
                component_type=StackComponentType.ARTIFACT_STORE,
                name_id_or_prefix=artifact.artifact_store_id,
            )
            _ = StackComponent.from_model(artifact_store_model)
            artifact_store_loaded = True
        except KeyError:
            pass

    if not artifact_store_loaded:
        logger.warning(
            "Unable to restore artifact store while trying to load artifact "
            "`%s`. If this artifact is stored in a remote artifact store, "
            "this might lead to issues when trying to load the artifact.",
            artifact.id,
        )

    return _load_artifact(
        materializer=artifact.materializer,
        data_type=artifact.data_type,
        uri=artifact.uri,
    )

load_artifact_visualization(artifact, index=0, zen_store=None, encode_image=False)

Load a visualization of the given artifact.

Parameters:

Name Type Description Default
artifact ArtifactResponseModel

The artifact to visualize.

required
index int

The index of the visualization to load.

0
zen_store Optional[BaseZenStore]

The ZenStore to use for finding the artifact store. If not provided, the ZenStore of the client will be used.

None
encode_image bool

Whether to base64 encode image visualizations.

False

Returns:

Type Description
LoadedVisualizationModel

The loaded visualization.

Exceptions:

Type Description
DoesNotExistException

If the artifact does not have the requested visualization or if the visualization was not found in the artifact store.

Source code in zenml/utils/artifact_utils.py
def load_artifact_visualization(
    artifact: "ArtifactResponseModel",
    index: int = 0,
    zen_store: Optional["BaseZenStore"] = None,
    encode_image: bool = False,
) -> LoadedVisualizationModel:
    """Load a visualization of the given artifact.

    Args:
        artifact: The artifact to visualize.
        index: The index of the visualization to load.
        zen_store: The ZenStore to use for finding the artifact store. If not
            provided, the ZenStore of the client will be used.
        encode_image: Whether to base64 encode image visualizations.

    Returns:
        The loaded visualization.

    Raises:
        DoesNotExistException: If the artifact does not have the requested
            visualization or if the visualization was not found in the artifact
            store.
    """
    # Get the visualization to load
    if not artifact.visualizations:
        raise DoesNotExistException(
            f"Artifact '{artifact.id}' has no visualizations."
        )
    if index < 0 or index >= len(artifact.visualizations):
        raise DoesNotExistException(
            f"Artifact '{artifact.id}' only has {len(artifact.visualizations)} "
            f"visualizations, but index {index} was requested."
        )
    visualization = artifact.visualizations[index]

    # Load the visualization from the artifact's artifact store
    artifact_store = _load_artifact_store_of_artifact(
        artifact=artifact, zen_store=zen_store
    )
    mode = "rb" if visualization.type == VisualizationType.IMAGE else "r"
    value = _load_file_from_artifact_store(
        uri=visualization.uri,
        artifact_store=artifact_store,
        mode=mode,
    )

    # Encode image visualizations if requested
    if visualization.type == VisualizationType.IMAGE and encode_image:
        value = base64.b64encode(bytes(value))

    return LoadedVisualizationModel(type=visualization.type, value=value)

load_model_from_metadata(model_uri)

Load a zenml model artifact from a json file.

This function is used to load information from a Yaml file that was created by the save_model_metadata function. The information in the Yaml file is used to load the model into memory in the inference environment.

Parameters:

Name Type Description Default
model_uri str

the artifact to extract the metadata from.

required

Returns:

Type Description
Any

The ML model object loaded into memory.

Source code in zenml/utils/artifact_utils.py
def load_model_from_metadata(model_uri: str) -> Any:
    """Load a zenml model artifact from a json file.

    This function is used to load information from a Yaml file that was created
    by the save_model_metadata function. The information in the Yaml file is
    used to load the model into memory in the inference environment.

    Args:
        model_uri: the artifact to extract the metadata from.

    Returns:
        The ML model object loaded into memory.
    """
    # Load the model from its metadata
    with fileio.open(
        os.path.join(model_uri, MODEL_METADATA_YAML_FILE_NAME), "r"
    ) as f:
        metadata = read_yaml(f.name)
    data_type = metadata[METADATA_DATATYPE]
    materializer = metadata[METADATA_MATERIALIZER]
    model = _load_artifact(
        materializer=materializer, data_type=data_type, uri=model_uri
    )

    # Switch to eval mode if the model is a torch model
    try:
        import torch.nn as nn

        if isinstance(model, nn.Module):
            model.eval()
    except ImportError:
        pass

    return model

save_model_metadata(model_artifact)

Save a zenml model artifact metadata to a YAML file.

This function is used to extract and save information from a zenml model artifact such as the model type and materializer. The extracted information will be the key to loading the model into memory in the inference environment.

datatype: the model type. This is the path to the model class. materializer: the materializer class. This is the path to the materializer class.

Parameters:

Name Type Description Default
model_artifact ArtifactResponseModel

the artifact to extract the metadata from.

required

Returns:

Type Description
str

The path to the temporary file where the model metadata is saved

Source code in zenml/utils/artifact_utils.py
def save_model_metadata(model_artifact: "ArtifactResponseModel") -> str:
    """Save a zenml model artifact metadata to a YAML file.

    This function is used to extract and save information from a zenml model artifact
    such as the model type and materializer. The extracted information will be
    the key to loading the model into memory in the inference environment.

    datatype: the model type. This is the path to the model class.
    materializer: the materializer class. This is the path to the materializer class.

    Args:
        model_artifact: the artifact to extract the metadata from.

    Returns:
        The path to the temporary file where the model metadata is saved
    """
    metadata = dict()
    metadata[METADATA_DATATYPE] = model_artifact.data_type
    metadata[METADATA_MATERIALIZER] = model_artifact.materializer

    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".yaml", delete=False
    ) as f:
        write_yaml(f.name, metadata)
    return f.name

upload_artifact(name, data, materializer, artifact_store_id, extract_metadata, include_visualizations)

Upload and publish an artifact.

Parameters:

Name Type Description Default
name str

The name of the artifact.

required
data Any

The artifact data.

required
materializer BaseMaterializer

The materializer to store the artifact.

required
artifact_store_id UUID

ID of the artifact store in which the artifact should be stored.

required
extract_metadata bool

If artifact metadata should be extracted and returned.

required
include_visualizations bool

If artifact visualizations should be generated.

required

Returns:

Type Description
UUID

The ID of the published artifact.

Source code in zenml/utils/artifact_utils.py
def upload_artifact(
    name: str,
    data: Any,
    materializer: "BaseMaterializer",
    artifact_store_id: "UUID",
    extract_metadata: bool,
    include_visualizations: bool,
) -> "UUID":
    """Upload and publish an artifact.

    Args:
        name: The name of the artifact.
        data: The artifact data.
        materializer: The materializer to store the artifact.
        artifact_store_id: ID of the artifact store in which the artifact should
            be stored.
        extract_metadata: If artifact metadata should be extracted and returned.
        include_visualizations: If artifact visualizations should be generated.

    Returns:
        The ID of the published artifact.
    """
    data_type = type(data)
    materializer.validate_type_compatibility(data_type)
    materializer.save(data)

    visualizations: List[VisualizationModel] = []
    if include_visualizations:
        try:
            vis_data = materializer.save_visualizations(data)
            for vis_uri, vis_type in vis_data.items():
                vis_model = VisualizationModel(
                    type=vis_type,
                    uri=vis_uri,
                )
                visualizations.append(vis_model)
        except Exception as e:
            logger.warning(
                f"Failed to save visualization for output artifact '{name}': "
                f"{e}"
            )

    artifact_metadata = {}
    if extract_metadata:
        try:
            artifact_metadata = materializer.extract_full_metadata(data)
        except Exception as e:
            logger.warning(
                f"Failed to extract metadata for output artifact '{name}': {e}"
            )

    artifact = ArtifactRequestModel(
        name=name,
        type=materializer.ASSOCIATED_ARTIFACT_TYPE,
        uri=materializer.uri,
        materializer=source_utils.resolve(materializer.__class__),
        data_type=source_utils.resolve(data_type),
        user=Client().active_user.id,
        workspace=Client().active_workspace.id,
        artifact_store_id=artifact_store_id,
        visualizations=visualizations,
    )
    response = Client().zen_store.create_artifact(artifact=artifact)
    if artifact_metadata:
        Client().create_run_metadata(
            metadata=artifact_metadata, artifact_id=response.id
        )

    return response.id

code_repository_utils

Utilities for code repositories.

find_active_code_repository(path=None)

Find the active code repository for a given path.

Parameters:

Name Type Description Default
path Optional[str]

Path at which to look for the code repository. If not given, the source root will be used.

None

Returns:

Type Description
Optional[LocalRepositoryContext]

The local repository context active at that path or None.

Source code in zenml/utils/code_repository_utils.py
def find_active_code_repository(
    path: Optional[str] = None,
) -> Optional["LocalRepositoryContext"]:
    """Find the active code repository for a given path.

    Args:
        path: Path at which to look for the code repository. If not given, the
            source root will be used.

    Returns:
        The local repository context active at that path or None.
    """
    global _CODE_REPOSITORY_CACHE
    from zenml.client import Client
    from zenml.code_repositories import BaseCodeRepository

    path = path or source_utils.get_source_root()
    path = os.path.abspath(path)

    if path in _CODE_REPOSITORY_CACHE:
        return _CODE_REPOSITORY_CACHE[path]

    for model in depaginate(list_method=Client().list_code_repositories):
        try:
            repo = BaseCodeRepository.from_model(model)
        except Exception:
            logger.debug(
                "Failed to instantiate code repository class.", exc_info=True
            )
            continue

        local_context = repo.get_local_context(path)
        if local_context:
            _CODE_REPOSITORY_CACHE[path] = local_context
            return local_context

    return None

set_custom_local_repository(root, commit, repo)

Manually defines a local repository for a path.

To explain what this function does we need to take a dive into source resolving and what happens inside the Docker image entrypoint: * When trying to resolve an object to a source, we first determine whether the file is a user file or not. * If the file is a user file, we check if that user file is inside a clean code repository using the code_repository_utils.find_active_code_repository(...) function. If that is the case, the object will be resolved to a CodeRepositorySource which includes additional information about the current commit and the ID of the code repository. * The code_repository_utils.find_active_code_repository(...) uses the code repository implementation classes to check whether the code repository "exists" at that local path. For git repositories, this check might look as follows: The code repository first checks if there is a git repository at that path or in any parent directory. If there is, the remote URLs of this git repository will be checked to see if one matches the URL defined for the code repository. * When running a step inside a Docker image, ZenML potentially downloads files from a code repository. This usually does not download the entire repository (and in the case of git might not download a .git directory which defines a local git repository) but only specific files. If we now try to resolve any object while running in this container, it will not get resolved to a CodeRepositorySource as code_repository_utils.find_active_code_repository(...) won't find an active repository. As we downloaded these files, we however know that they belong to a certain code repository at a specific commit, and that's what we can define using this function.

Parameters:

Name Type Description Default
root str

The repository root.

required
commit str

The commit of the repository.

required
repo BaseCodeRepository

The code repository associated with the local repository.

required
Source code in zenml/utils/code_repository_utils.py
def set_custom_local_repository(
    root: str, commit: str, repo: "BaseCodeRepository"
) -> None:
    """Manually defines a local repository for a path.

    To explain what this function does we need to take a dive into source
    resolving and what happens inside the Docker image entrypoint:
    * When trying to resolve an object to a source, we first determine whether
    the file is a user file or not.
    * If the file is a user file, we check if that user file is inside a clean
    code repository using the
    `code_repository_utils.find_active_code_repository(...)` function. If that
    is the case, the object will be resolved to a `CodeRepositorySource` which
    includes additional information about the current commit and the ID of the
    code repository.
    * The `code_repository_utils.find_active_code_repository(...)` uses the
    code repository implementation classes to check whether the code repository
    "exists" at that local path. For git repositories, this check might look as
    follows: The code repository first checks if there is a git repository at
    that path or in any parent directory. If there is, the remote URLs of this
    git repository will be checked to see if one matches the URL defined for
    the code repository.
    * When running a step inside a Docker image, ZenML potentially downloads
    files from a code repository. This usually does not download the entire
    repository (and in the case of git might not download a .git directory which
    defines a local git repository) but only specific files. If we now try to
    resolve any object while running in this container, it will not get resolved
    to a `CodeRepositorySource` as
    `code_repository_utils.find_active_code_repository(...)` won't find an
    active repository. As we downloaded these files, we however know that they
    belong to a certain code repository at a specific commit, and that's what we
    can define using this function.

    Args:
        root: The repository root.
        commit: The commit of the repository.
        repo: The code repository associated with the local repository.
    """
    from zenml.utils.downloaded_repository_context import (
        _DownloadedRepositoryContext,
    )

    global _CODE_REPOSITORY_CACHE

    path = os.path.abspath(source_utils.get_source_root())
    _CODE_REPOSITORY_CACHE[path] = _DownloadedRepositoryContext(
        code_repository_id=repo.id, root=root, commit=commit
    )

daemon

Utility functions to start/stop daemon processes.

This is only implemented for UNIX systems and therefore doesn't work on Windows. Based on https://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/

check_if_daemon_is_running(pid_file)

Checks whether a daemon process indicated by the PID file is running.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to check.

required

Returns:

Type Description
bool

True if the daemon process is running, otherwise False.

Source code in zenml/utils/daemon.py
def check_if_daemon_is_running(pid_file: str) -> bool:
    """Checks whether a daemon process indicated by the PID file is running.

    Args:
        pid_file: Path to file containing the PID of the daemon
            process to check.

    Returns:
        True if the daemon process is running, otherwise False.
    """
    return get_daemon_pid_if_running(pid_file) is not None

daemonize(pid_file, log_file=None, working_directory='/')

Decorator that executes the decorated function as a daemon process.

Use this decorator to easily transform any function into a daemon process.

For example,

import time
from zenml.utils.daemon import daemonize


@daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
def sleeping_daemon(period: int) -> None:
    print(f"I'm a daemon! I will sleep for {period} seconds.")
    time.sleep(period)
    print("Done sleeping, flying away.")

sleeping_daemon(period=30)

print("I'm the daemon's parent!.")
time.sleep(10) # just to prove that the daemon is running in parallel

Parameters:

Name Type Description Default
pid_file str

a file where the PID of the daemon process will be stored.

required
log_file Optional[str]

file where stdout and stderr are redirected for the daemon process. If not supplied, the daemon will be silenced (i.e. have its stdout/stderr redirected to /dev/null).

None
working_directory str

working directory for the daemon process, defaults to the root directory.

'/'

Returns:

Type Description
Callable[[~F], ~F]

Decorated function that, when called, will detach from the current process and continue executing in the background, as a daemon process.

Source code in zenml/utils/daemon.py
def daemonize(
    pid_file: str,
    log_file: Optional[str] = None,
    working_directory: str = "/",
) -> Callable[[F], F]:
    """Decorator that executes the decorated function as a daemon process.

    Use this decorator to easily transform any function into a daemon
    process.

    For example,

    ```python
    import time
    from zenml.utils.daemon import daemonize


    @daemonize(log_file='/tmp/daemon.log', pid_file='/tmp/daemon.pid')
    def sleeping_daemon(period: int) -> None:
        print(f"I'm a daemon! I will sleep for {period} seconds.")
        time.sleep(period)
        print("Done sleeping, flying away.")

    sleeping_daemon(period=30)

    print("I'm the daemon's parent!.")
    time.sleep(10) # just to prove that the daemon is running in parallel
    ```

    Args:
        pid_file: a file where the PID of the daemon process will
            be stored.
        log_file: file where stdout and stderr are redirected for the daemon
            process. If not supplied, the daemon will be silenced (i.e. have
            its stdout/stderr redirected to /dev/null).
        working_directory: working directory for the daemon process,
            defaults to the root directory.

    Returns:
        Decorated function that, when called, will detach from the current
        process and continue executing in the background, as a daemon
        process.
    """

    def inner_decorator(_func: F) -> F:
        def daemon(*args: Any, **kwargs: Any) -> None:
            """Standard daemonization of a process.

            Args:
                *args: Arguments to be passed to the decorated function.
                **kwargs: Keyword arguments to be passed to the decorated
                    function.
            """
            if sys.platform == "win32":
                logger.error(
                    "Daemon functionality is currently not supported on Windows."
                )
            else:
                run_as_daemon(
                    _func,
                    log_file=log_file,
                    pid_file=pid_file,
                    working_directory=working_directory,
                    *args,
                    **kwargs,
                )

        return cast(F, daemon)

    return inner_decorator

get_daemon_pid_if_running(pid_file)

Read and return the PID value from a PID file.

It does this if the daemon process tracked by the PID file is running.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to check.

required

Returns:

Type Description
Optional[int]

The PID of the daemon process if it is running, otherwise None.

Source code in zenml/utils/daemon.py
def get_daemon_pid_if_running(pid_file: str) -> Optional[int]:
    """Read and return the PID value from a PID file.

    It does this if the daemon process tracked by the PID file is running.

    Args:
        pid_file: Path to file containing the PID of the daemon
            process to check.

    Returns:
        The PID of the daemon process if it is running, otherwise None.
    """
    try:
        with open(pid_file, "r") as f:
            pid = int(f.read().strip())
    except (IOError, FileNotFoundError):
        logger.debug(
            f"Daemon PID file '{pid_file}' does not exist or cannot be read."
        )
        return None

    if not pid or not psutil.pid_exists(pid):
        logger.debug(f"Daemon with PID '{pid}' is no longer running.")
        return None

    logger.debug(f"Daemon with PID '{pid}' is running.")
    return pid

run_as_daemon(daemon_function, *args, *, pid_file, log_file=None, working_directory='/', **kwargs)

Runs a function as a daemon process.

Parameters:

Name Type Description Default
daemon_function ~F

The function to run as a daemon.

required
pid_file str

Path to file in which to store the PID of the daemon process.

required
log_file Optional[str]

Optional file to which the daemons stdout/stderr will be redirected to.

None
working_directory str

Working directory for the daemon process, defaults to the root directory.

'/'
args Any

Positional arguments to pass to the daemon function.

()
kwargs Any

Keyword arguments to pass to the daemon function.

{}

Exceptions:

Type Description
FileExistsError

If the PID file already exists.

Source code in zenml/utils/daemon.py
def run_as_daemon(
    daemon_function: F,
    *args: Any,
    pid_file: str,
    log_file: Optional[str] = None,
    working_directory: str = "/",
    **kwargs: Any,
) -> None:
    """Runs a function as a daemon process.

    Args:
        daemon_function: The function to run as a daemon.
        pid_file: Path to file in which to store the PID of the daemon
            process.
        log_file: Optional file to which the daemons stdout/stderr will be
            redirected to.
        working_directory: Working directory for the daemon process,
            defaults to the root directory.
        args: Positional arguments to pass to the daemon function.
        kwargs: Keyword arguments to pass to the daemon function.

    Raises:
        FileExistsError: If the PID file already exists.
    """
    # convert to absolute path as we will change working directory later
    if pid_file:
        pid_file = os.path.abspath(pid_file)
    if log_file:
        log_file = os.path.abspath(log_file)

    # create parent directory if necessary
    dir_name = os.path.dirname(pid_file)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

    # check if PID file exists
    if pid_file and os.path.exists(pid_file):
        pid = get_daemon_pid_if_running(pid_file)
        if pid:
            raise FileExistsError(
                f"The PID file '{pid_file}' already exists and a daemon "
                f"process with the same PID '{pid}' is already running."
                f"Please remove the PID file or kill the daemon process "
                f"before starting a new daemon."
            )
        logger.warning(
            f"Removing left over PID file '{pid_file}' from a previous "
            f"daemon process that didn't shut down correctly."
        )
        os.remove(pid_file)

    # first fork
    try:
        pid = os.fork()
        if pid > 0:
            # this is the process that called `run_as_daemon` so we
            # wait for the child process to finish to avoid creating
            # zombie processes. Then we simply return so the current process
            # can continue what it was doing.
            os.wait()
            return
    except OSError as e:
        logger.error("Unable to fork (error code: %d)", e.errno)
        sys.exit(1)

    # decouple from parent environment
    os.chdir(working_directory)
    os.setsid()
    os.umask(0o22)

    # second fork
    try:
        pid = os.fork()
        if pid > 0:
            # this is the parent of the future daemon process, kill it
            # so the daemon gets adopted by the init process.
            # we use os._exit here to prevent the inherited code from
            # catching the SystemExit exception and doing something else.
            os._exit(0)
    except OSError as e:
        sys.stderr.write(f"Unable to fork (error code: {e.errno})")
        # we use os._exit here to prevent the inherited code from
        # catching the SystemExit exception and doing something else.
        os._exit(1)

    # redirect standard file descriptors to devnull (or the given logfile)
    devnull = "/dev/null"
    if hasattr(os, "devnull"):
        devnull = os.devnull

    devnull_fd = os.open(devnull, os.O_RDWR)
    log_fd = (
        os.open(log_file, os.O_CREAT | os.O_RDWR | os.O_APPEND)
        if log_file
        else None
    )
    out_fd = log_fd or devnull_fd

    try:
        os.dup2(devnull_fd, sys.stdin.fileno())
    except io.UnsupportedOperation:
        # stdin is not a file descriptor
        pass
    try:
        os.dup2(out_fd, sys.stdout.fileno())
    except io.UnsupportedOperation:
        # stdout is not a file descriptor
        pass
    try:
        os.dup2(out_fd, sys.stderr.fileno())
    except io.UnsupportedOperation:
        # stderr is not a file descriptor
        pass

    if pid_file:
        # write the PID file
        with open(pid_file, "w+") as f:
            f.write(f"{os.getpid()}\n")

    # register actions in case this process exits/gets killed
    def cleanup() -> None:
        """Daemon cleanup."""
        sys.stderr.write("Cleanup: terminating children processes...\n")
        terminate_children()
        if pid_file and os.path.exists(pid_file):
            sys.stderr.write(f"Cleanup: removing PID file {pid_file}...\n")
            os.remove(pid_file)
        sys.stderr.flush()

    def sighndl(signum: int, frame: Optional[types.FrameType]) -> None:
        """Daemon signal handler.

        Args:
            signum: Signal number.
            frame: Frame object.
        """
        sys.stderr.write(f"Handling signal {signum}...\n")
        cleanup()

    signal.signal(signal.SIGTERM, sighndl)
    signal.signal(signal.SIGINT, sighndl)
    atexit.register(cleanup)

    # finally run the actual daemon code
    daemon_function(*args, **kwargs)
    sys.exit(0)

stop_daemon(pid_file)

Stops a daemon process.

Parameters:

Name Type Description Default
pid_file str

Path to file containing the PID of the daemon process to kill.

required
Source code in zenml/utils/daemon.py
def stop_daemon(pid_file: str) -> None:
    """Stops a daemon process.

    Args:
        pid_file: Path to file containing the PID of the daemon process to
            kill.
    """
    try:
        with open(pid_file, "r") as f:
            pid = int(f.read().strip())
    except (IOError, FileNotFoundError):
        logger.warning("Daemon PID file '%s' does not exist.", pid_file)
        return

    if psutil.pid_exists(pid):
        process = psutil.Process(pid)
        process.terminate()
    else:
        logger.warning("PID from '%s' does not exist.", pid_file)

terminate_children()

Terminate all processes that are children of the currently running process.

Source code in zenml/utils/daemon.py
def terminate_children() -> None:
    """Terminate all processes that are children of the currently running process."""
    pid = os.getpid()
    try:
        parent = psutil.Process(pid)
    except psutil.Error:
        # could not find parent process id
        return
    children = parent.children(recursive=False)

    for p in children:
        sys.stderr.write(
            f"Terminating child process with PID {p.pid}...\n"
        )
        p.terminate()
    _, alive = psutil.wait_procs(
        children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
    )
    for p in alive:
        sys.stderr.write(f"Killing child process with PID {p.pid}...\n")
        p.kill()
    _, alive = psutil.wait_procs(
        children, timeout=CHILD_PROCESS_WAIT_TIMEOUT
    )

dashboard_utils

Utility class to help with interacting with the dashboard.

get_run_url(run_name, pipeline_id=None)

Computes a dashboard url to directly view the run.

Parameters:

Name Type Description Default
run_name str

Name of the pipeline run.

required
pipeline_id Optional[uuid.UUID]

Optional pipeline_id, to be sent when available.

None

Returns:

Type Description
Optional[str]

A direct url link to the pipeline run details page. If run does not exist, returns None.

Source code in zenml/utils/dashboard_utils.py
def get_run_url(
    run_name: str, pipeline_id: Optional[UUID] = None
) -> Optional[str]:
    """Computes a dashboard url to directly view the run.

    Args:
        run_name: Name of the pipeline run.
        pipeline_id: Optional pipeline_id, to be sent when available.

    Returns:
        A direct url link to the pipeline run details page. If run does not exist,
        returns None.
    """
    # Connected to ZenML Server
    client = Client()

    if client.zen_store.type != StoreType.REST:
        return ""

    url = client.zen_store.url
    runs = depaginate(partial(client.list_runs, name=run_name))

    if pipeline_id:
        url += f"/workspaces/{client.active_workspace.name}/pipelines/{str(pipeline_id)}/runs"
    elif runs:
        url += "/runs"
    else:
        url += "/pipelines/all-runs"

    if runs:
        url += f"/{runs[0].id}/dag"

    return url

print_run_url(run_name, pipeline_id=None)

Logs a dashboard url to directly view the run.

Parameters:

Name Type Description Default
run_name str

Name of the pipeline run.

required
pipeline_id Optional[uuid.UUID]

Optional pipeline_id, to be sent when available.

None
Source code in zenml/utils/dashboard_utils.py
def print_run_url(run_name: str, pipeline_id: Optional[UUID] = None) -> None:
    """Logs a dashboard url to directly view the run.

    Args:
        run_name: Name of the pipeline run.
        pipeline_id: Optional pipeline_id, to be sent when available.
    """
    client = Client()

    if client.zen_store.type == StoreType.REST:
        url = get_run_url(
            run_name,
            pipeline_id,
        )
        if url:
            logger.info(f"Dashboard URL: {url}")
    elif client.zen_store.type == StoreType.SQL:
        # Connected to SQL Store Type, we're local
        logger.info(
            "Pipeline visualization can be seen in the ZenML Dashboard. "
            "Run `zenml up` to see your pipeline!"
        )

show_dashboard(url)

Show the ZenML dashboard at the given URL.

In native environments, the dashboard is opened in the default browser. In notebook environments, the dashboard is embedded in an iframe.

Parameters:

Name Type Description Default
url str

URL of the ZenML dashboard.

required
Source code in zenml/utils/dashboard_utils.py
def show_dashboard(url: str) -> None:
    """Show the ZenML dashboard at the given URL.

    In native environments, the dashboard is opened in the default browser.
    In notebook environments, the dashboard is embedded in an iframe.

    Args:
        url: URL of the ZenML dashboard.
    """
    environment = get_environment()
    if environment in (EnvironmentType.NOTEBOOK, EnvironmentType.COLAB):
        from IPython.core.display import display
        from IPython.display import IFrame

        display(IFrame(src=url, width="100%", height=720))

    elif environment in (EnvironmentType.NATIVE, EnvironmentType.WSL):
        if handle_bool_env_var(ENV_AUTO_OPEN_DASHBOARD, default=True):
            try:
                import webbrowser

                if environment == EnvironmentType.WSL:
                    webbrowser.get("wslview %s").open(url)
                else:
                    webbrowser.open(url)
                logger.info(
                    "Automatically opening the dashboard in your "
                    "browser. To disable this, set the env variable "
                    "AUTO_OPEN_DASHBOARD=false."
                )
            except Exception as e:
                logger.error(e)
        else:
            logger.info(
                "To open the dashboard in a browser automatically, "
                "set the env variable AUTO_OPEN_DASHBOARD=true."
            )

    else:
        logger.info(f"The ZenML dashboard is available at {url}.")

deprecation_utils

Deprecation utilities.

deprecate_pydantic_attributes(*attributes)

Utility function for deprecating and migrating pydantic attributes.

Usage: To use this, you can specify it on any pydantic BaseModel subclass like this (all the deprecated attributes need to be non-required):

from pydantic import BaseModel
from typing import Optional

class MyModel(BaseModel):
    deprecated: Optional[int] = None

    old_name: Optional[str] = None
    new_name: str

    _deprecation_validator = deprecate_pydantic_attributes(
        "deprecated", ("old_name", "new_name")
    )

Parameters:

Name Type Description Default
*attributes Union[str, Tuple[str, str]]

List of attributes to deprecate. This is either the name of the attribute to deprecate, or a tuple containing the name of the deprecated attribute and it's replacement.

()

Returns:

Type Description
AnyClassMethod

Pydantic validator class method to be used on BaseModel subclasses to deprecate or migrate attributes.

Source code in zenml/utils/deprecation_utils.py
def deprecate_pydantic_attributes(
    *attributes: Union[str, Tuple[str, str]]
) -> "AnyClassMethod":
    """Utility function for deprecating and migrating pydantic attributes.

    **Usage**:
    To use this, you can specify it on any pydantic BaseModel subclass like
    this (all the deprecated attributes need to be non-required):

    ```python
    from pydantic import BaseModel
    from typing import Optional

    class MyModel(BaseModel):
        deprecated: Optional[int] = None

        old_name: Optional[str] = None
        new_name: str

        _deprecation_validator = deprecate_pydantic_attributes(
            "deprecated", ("old_name", "new_name")
        )
    ```

    Args:
        *attributes: List of attributes to deprecate. This is either the name
            of the attribute to deprecate, or a tuple containing the name of
            the deprecated attribute and it's replacement.

    Returns:
        Pydantic validator class method to be used on BaseModel subclasses
        to deprecate or migrate attributes.
    """

    @root_validator(pre=True, allow_reuse=True)
    def _deprecation_validator(
        cls: Type[BaseModel], values: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Pydantic validator function for deprecating pydantic attributes.

        Args:
            cls: The class on which the attributes are defined.
            values: All values passed at model initialization.

        Raises:
            AssertionError: If either the deprecated or replacement attribute
                don't exist.
            TypeError: If the deprecated attribute is a required attribute.
            ValueError: If the deprecated attribute and replacement attribute
                contain different values.

        Returns:
            Input values with potentially migrated values.
        """
        previous_deprecation_warnings: Set[str] = getattr(
            cls, PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE, set()
        )

        def _warn(message: str, attribute: str) -> None:
            """Logs and raises a warning for a deprecated attribute.

            Args:
                message: The warning message.
                attribute: The name of the attribute.
            """
            if attribute not in previous_deprecation_warnings:
                logger.warning(message)
                previous_deprecation_warnings.add(attribute)

            warnings.warn(
                message,
                DeprecationWarning,
            )

        for attribute in attributes:
            if isinstance(attribute, str):
                deprecated_attribute = attribute
                replacement_attribute = None
            else:
                deprecated_attribute, replacement_attribute = attribute

                assert (
                    replacement_attribute in cls.__fields__
                ), f"Unable to find attribute {replacement_attribute}."

            assert (
                deprecated_attribute in cls.__fields__
            ), f"Unable to find attribute {deprecated_attribute}."

            if cls.__fields__[deprecated_attribute].required:
                raise TypeError(
                    f"Unable to deprecate attribute '{deprecated_attribute}' "
                    f"of class {cls.__name__}. In order to deprecate an "
                    "attribute, it needs to be a non-required attribute. "
                    "To do so, mark the attribute with an `Optional[...] type "
                    "annotation."
                )

            if values.get(deprecated_attribute, None) is None:
                continue

            if replacement_attribute is None:
                _warn(
                    message=f"The attribute `{deprecated_attribute}` of class "
                    f"`{cls.__name__}` will be deprecated soon.",
                    attribute=deprecated_attribute,
                )
                continue

            _warn(
                message=f"The attribute `{deprecated_attribute}` of class "
                f"`{cls.__name__}` will be deprecated soon. Use the "
                f"attribute `{replacement_attribute}` instead.",
                attribute=deprecated_attribute,
            )

            if values.get(replacement_attribute, None) is None:
                logger.debug(
                    "Migrating value of deprecated attribute %s to "
                    "replacement attribute %s.",
                    deprecated_attribute,
                    replacement_attribute,
                )
                values[replacement_attribute] = values.pop(
                    deprecated_attribute
                )
            elif values[deprecated_attribute] != values[replacement_attribute]:
                raise ValueError(
                    "Got different values for deprecated attribute "
                    f"{deprecated_attribute} and replacement "
                    f"attribute {replacement_attribute}."
                )
            else:
                # Both values are identical, no need to do anything
                pass

        setattr(
            cls,
            PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE,
            previous_deprecation_warnings,
        )

        return values

    return _deprecation_validator

dict_utils

Util functions for dictionaries.

recursive_update(original, update)

Recursively updates a dictionary.

Parameters:

Name Type Description Default
original Dict[str, Any]

The dictionary to update.

required
update Dict[str, Any]

The dictionary containing the updated values.

required

Returns:

Type Description
Dict[str, Any]

The updated dictionary.

Source code in zenml/utils/dict_utils.py
def recursive_update(
    original: Dict[str, Any], update: Dict[str, Any]
) -> Dict[str, Any]:
    """Recursively updates a dictionary.

    Args:
        original: The dictionary to update.
        update: The dictionary containing the updated values.

    Returns:
        The updated dictionary.
    """
    for key, value in update.items():
        if isinstance(value, Dict):
            original_value = original.get(key, None) or {}
            if isinstance(original_value, Dict):
                original[key] = recursive_update(original_value, value)
            else:
                original[key] = value
        else:
            original[key] = value
    return original

remove_none_values(dict_, recursive=False)

Removes all key-value pairs with None value.

Parameters:

Name Type Description Default
dict_ Dict[str, Any]

The dict from which the key-value pairs should be removed.

required
recursive bool

If True, will recursively remove None values in all child dicts.

False

Returns:

Type Description
Dict[str, Any]

The updated dictionary.

Source code in zenml/utils/dict_utils.py
def remove_none_values(
    dict_: Dict[str, Any], recursive: bool = False
) -> Dict[str, Any]:
    """Removes all key-value pairs with `None` value.

    Args:
        dict_: The dict from which the key-value pairs should be removed.
        recursive: If `True`, will recursively remove `None` values in all
            child dicts.

    Returns:
        The updated dictionary.
    """

    def _maybe_recurse(value: Any) -> Any:
        """Calls `remove_none_values` recursively if required.

        Args:
            value: A dictionary value.

        Returns:
            The updated dictionary value.
        """
        if recursive and isinstance(value, Dict):
            return remove_none_values(value, recursive=True)
        else:
            return value

    return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None}

docker_utils

Utility functions relating to Docker.

build_image(image_name, dockerfile, build_context_root=None, dockerignore=None, extra_files=(), **custom_build_options)

Builds a docker image.

Parameters:

Name Type Description Default
image_name str

The name to use for the built docker image.

required
dockerfile Union[str, List[str]]

Path to a dockerfile or a list of strings representing the Dockerfile lines/commands.

required
build_context_root Optional[str]

Optional path to a directory that will be sent to the Docker daemon as build context. If left empty, the Docker build context will be empty.

None
dockerignore Optional[str]

Optional path to a dockerignore file. If no value is given, the .dockerignore in the root of the build context will be used if it exists. Otherwise, all files inside build_context_root are included in the build context.

None
extra_files Sequence[Tuple[str, str]]

Additional files to include in the build context. The files should be passed as a tuple (filepath_inside_build_context, file_content) and will overwrite existing files in the build context if they share the same path.

()
**custom_build_options Any

Additional options that will be passed unmodified to the Docker build call when building the image. You can use this to for example specify build args or a target stage. See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build for a full list of available options.

{}
Source code in zenml/utils/docker_utils.py
def build_image(
    image_name: str,
    dockerfile: Union[str, List[str]],
    build_context_root: Optional[str] = None,
    dockerignore: Optional[str] = None,
    extra_files: Sequence[Tuple[str, str]] = (),
    **custom_build_options: Any,
) -> None:
    """Builds a docker image.

    Args:
        image_name: The name to use for the built docker image.
        dockerfile: Path to a dockerfile or a list of strings representing the
            Dockerfile lines/commands.
        build_context_root: Optional path to a directory that will be sent to
            the Docker daemon as build context. If left empty, the Docker build
            context will be empty.
        dockerignore: Optional path to a dockerignore file. If no value is
            given, the .dockerignore in the root of the build context will be
            used if it exists. Otherwise, all files inside `build_context_root`
            are included in the build context.
        extra_files: Additional files to include in the build context. The
            files should be passed as a tuple
            (filepath_inside_build_context, file_content) and will overwrite
            existing files in the build context if they share the same path.
        **custom_build_options: Additional options that will be passed
            unmodified to the Docker build call when building the image. You
            can use this to for example specify build args or a target stage.
            See https://docker-py.readthedocs.io/en/stable/images.html#docker.models.images.ImageCollection.build
            for a full list of available options.
    """
    if isinstance(dockerfile, str):
        dockerfile_contents = io_utils.read_file_contents_as_string(dockerfile)
        logger.info("Using Dockerfile `%s`.", os.path.abspath(dockerfile))
    else:
        dockerfile_contents = "\n".join(dockerfile)

    build_context = _create_custom_build_context(
        dockerfile_contents=dockerfile_contents,
        build_context_root=build_context_root,
        dockerignore=dockerignore,
        extra_files=extra_files,
    )

    build_options = {
        "rm": False,  # don't remove intermediate containers to improve caching
        "pull": True,  # always pull parent images
        **custom_build_options,
    }

    logger.info("Building Docker image `%s`.", image_name)
    logger.debug("Docker build options: %s", build_options)

    logger.info("Building the image might take a while...")

    docker_client = DockerClient.from_env()
    # We use the client api directly here, so we can stream the logs
    output_stream = docker_client.images.client.api.build(
        fileobj=build_context,
        custom_context=True,
        tag=image_name,
        **build_options,
    )
    _process_stream(output_stream)

    logger.info("Finished building Docker image `%s`.", image_name)

check_docker()

Checks if Docker is installed and running.

Returns:

Type Description
bool

True if Docker is installed, False otherwise.

Source code in zenml/utils/docker_utils.py
def check_docker() -> bool:
    """Checks if Docker is installed and running.

    Returns:
        `True` if Docker is installed, `False` otherwise.
    """
    # Try to ping Docker, to see if it's running
    try:
        docker_client = DockerClient.from_env()
        docker_client.ping()
        return True
    except Exception:
        logger.debug("Docker is not running.", exc_info=True)

    return False

get_image_digest(image_name)

Gets the digest of an image.

Parameters:

Name Type Description Default
image_name str

Name of the image to get the digest for.

required

Returns:

Type Description
Optional[str]

Returns the repo digest for the given image if there exists exactly one. If there are zero or multiple repo digests, returns None.

Source code in zenml/utils/docker_utils.py
def get_image_digest(image_name: str) -> Optional[str]:
    """Gets the digest of an image.

    Args:
        image_name: Name of the image to get the digest for.

    Returns:
        Returns the repo digest for the given image if there exists exactly one.
        If there are zero or multiple repo digests, returns `None`.
    """
    docker_client = DockerClient.from_env()
    image = docker_client.images.get(image_name)
    repo_digests = image.attrs["RepoDigests"]
    if len(repo_digests) == 1:
        return cast(str, repo_digests[0])
    else:
        logger.debug(
            "Found zero or more repo digests for docker image '%s': %s",
            image_name,
            repo_digests,
        )
        return None

is_local_image(image_name)

Returns whether an image was pulled from a registry or not.

Parameters:

Name Type Description Default
image_name str

Name of the image to check.

required

Returns:

Type Description
bool

True if the image was pulled from a registry, False otherwise.

Source code in zenml/utils/docker_utils.py
def is_local_image(image_name: str) -> bool:
    """Returns whether an image was pulled from a registry or not.

    Args:
        image_name: Name of the image to check.

    Returns:
        `True` if the image was pulled from a registry, `False` otherwise.
    """
    docker_client = DockerClient.from_env()
    images = docker_client.images.list(name=image_name)
    if images:
        # An image with this name is available locally -> now check whether it
        # was pulled from a repo or built locally (in which case the repo
        # digest is empty)
        return get_image_digest(image_name) is None
    else:
        # no image with this name found locally
        return False

push_image(image_name, docker_client=None)

Pushes an image to a container registry.

Parameters:

Name Type Description Default
image_name str

The full name (including a tag) of the image to push.

required
docker_client Optional[docker.client.DockerClient]

Optional Docker client to use for pushing the image. If no client is given, a new client will be created using the default Docker environment.

None

Returns:

Type Description
str

The Docker repository digest of the pushed image.

Exceptions:

Type Description
RuntimeError

If fetching the repository digest of the image failed.

Source code in zenml/utils/docker_utils.py
def push_image(
    image_name: str, docker_client: Optional[DockerClient] = None
) -> str:
    """Pushes an image to a container registry.

    Args:
        image_name: The full name (including a tag) of the image to push.
        docker_client: Optional Docker client to use for pushing the image. If
            no client is given, a new client will be created using the default
            Docker environment.

    Returns:
        The Docker repository digest of the pushed image.

    Raises:
        RuntimeError: If fetching the repository digest of the image failed.
    """
    logger.info("Pushing Docker image `%s`.", image_name)
    docker_client = docker_client or DockerClient.from_env()
    output_stream = docker_client.images.push(image_name, stream=True)
    aux_info = _process_stream(output_stream)
    logger.info("Finished pushing Docker image.")

    image_name_without_tag, _ = image_name.rsplit(":", maxsplit=1)
    for info in reversed(aux_info):
        try:
            repo_digest = info["Digest"]
            return f"{image_name_without_tag}@{repo_digest}"
        except KeyError:
            pass
    else:
        raise RuntimeError(
            f"Unable to find repo digest after pushing image {image_name}."
        )

tag_image(image_name, target)

Tags an image.

Parameters:

Name Type Description Default
image_name str

The name of the image to tag.

required
target str

The full target name including a tag.

required
Source code in zenml/utils/docker_utils.py
def tag_image(image_name: str, target: str) -> None:
    """Tags an image.

    Args:
        image_name: The name of the image to tag.
        target: The full target name including a tag.
    """
    docker_client = DockerClient.from_env()
    image = docker_client.images.get(image_name)
    image.tag(target)

downloaded_repository_context

Downloaded code repository.

enum_utils

Util functions for enums.

StrEnum (str, Enum)

Base enum type for string enum values.

Source code in zenml/utils/enum_utils.py
class StrEnum(str, Enum):
    """Base enum type for string enum values."""

    def __str__(self) -> str:
        """Returns the enum string value.

        Returns:
            The enum string value.
        """
        return self.value  # type: ignore

    @classmethod
    def names(cls) -> List[str]:
        """Get all enum names as a list of strings.

        Returns:
            A list of all enum names.
        """
        return [c.name for c in cls]

    @classmethod
    def values(cls) -> List[str]:
        """Get all enum values as a list of strings.

        Returns:
            A list of all enum values.
        """
        return [c.value for c in cls]

filesync_model

Filesync utils for ZenML.

FileSyncModel (BaseModel) pydantic-model

Pydantic model synchronized with a configuration file.

Use this class as a base Pydantic model that is automatically synchronized with a configuration file on disk.

This class overrides the setattr and getattr magic methods to ensure that the FileSyncModel instance acts as an in-memory cache of the information stored in the associated configuration file.

Source code in zenml/utils/filesync_model.py
class FileSyncModel(BaseModel):
    """Pydantic model synchronized with a configuration file.

    Use this class as a base Pydantic model that is automatically synchronized
    with a configuration file on disk.

    This class overrides the __setattr__ and __getattr__ magic methods to
    ensure that the FileSyncModel instance acts as an in-memory cache of the
    information stored in the associated configuration file.
    """

    _config_file: str
    _config_file_timestamp: Optional[float] = None

    def __init__(self, config_file: str, **kwargs: Any) -> None:
        """Create a FileSyncModel instance synchronized with a configuration file on disk.

        Args:
            config_file: configuration file path. If the file exists, the model
                will be initialized with the values from the file.
            **kwargs: additional keyword arguments to pass to the Pydantic model
                constructor. If supplied, these values will override those
                loaded from the configuration file.
        """
        config_dict = {}
        if fileio.exists(config_file):
            config_dict = yaml_utils.read_yaml(config_file)

        self._config_file = config_file
        self._config_file_timestamp = None

        config_dict.update(kwargs)
        super(FileSyncModel, self).__init__(**config_dict)

        # write the configuration file to disk, to reflect new attributes
        # and schema changes
        self.write_config()

    def __setattr__(self, key: str, value: Any) -> None:
        """Sets an attribute on the model and persists it in the configuration file.

        Args:
            key: attribute name.
            value: attribute value.
        """
        super(FileSyncModel, self).__setattr__(key, value)
        if key.startswith("_"):
            return
        self.write_config()

    def __getattribute__(self, key: str) -> Any:
        """Gets an attribute value for a specific key.

        Args:
            key: attribute name.

        Returns:
            attribute value.
        """
        if not key.startswith("_") and key in self.__dict__:
            self.load_config()
        return super(FileSyncModel, self).__getattribute__(key)

    def write_config(self) -> None:
        """Writes the model to the configuration file."""
        config_dict = json.loads(self.json())
        yaml_utils.write_yaml(self._config_file, config_dict)
        self._config_file_timestamp = os.path.getmtime(self._config_file)

    def load_config(self) -> None:
        """Loads the model from the configuration file on disk."""
        if not fileio.exists(self._config_file):
            return

        # don't reload the configuration if the file hasn't
        # been updated since the last load
        file_timestamp = os.path.getmtime(self._config_file)
        if file_timestamp == self._config_file_timestamp:
            return

        if self._config_file_timestamp is not None:
            logger.info(f"Reloading configuration file {self._config_file}")

        # refresh the model from the configuration file values
        config_dict = yaml_utils.read_yaml(self._config_file)
        for key, value in config_dict.items():
            super(FileSyncModel, self).__setattr__(key, value)

        self._config_file_timestamp = file_timestamp

    class Config:
        """Pydantic configuration class."""

        # all attributes with leading underscore are private and therefore
        # are mutable and not included in serialization
        underscore_attrs_are_private = True
Config

Pydantic configuration class.

Source code in zenml/utils/filesync_model.py
class Config:
    """Pydantic configuration class."""

    # all attributes with leading underscore are private and therefore
    # are mutable and not included in serialization
    underscore_attrs_are_private = True
__getattribute__(self, key) special

Gets an attribute value for a specific key.

Parameters:

Name Type Description Default
key str

attribute name.

required

Returns:

Type Description
Any

attribute value.

Source code in zenml/utils/filesync_model.py
def __getattribute__(self, key: str) -> Any:
    """Gets an attribute value for a specific key.

    Args:
        key: attribute name.

    Returns:
        attribute value.
    """
    if not key.startswith("_") and key in self.__dict__:
        self.load_config()
    return super(FileSyncModel, self).__getattribute__(key)
__init__(self, config_file, **kwargs) special

Create a FileSyncModel instance synchronized with a configuration file on disk.

Parameters:

Name Type Description Default
config_file str

configuration file path. If the file exists, the model will be initialized with the values from the file.

required
**kwargs Any

additional keyword arguments to pass to the Pydantic model constructor. If supplied, these values will override those loaded from the configuration file.

{}
Source code in zenml/utils/filesync_model.py
def __init__(self, config_file: str, **kwargs: Any) -> None:
    """Create a FileSyncModel instance synchronized with a configuration file on disk.

    Args:
        config_file: configuration file path. If the file exists, the model
            will be initialized with the values from the file.
        **kwargs: additional keyword arguments to pass to the Pydantic model
            constructor. If supplied, these values will override those
            loaded from the configuration file.
    """
    config_dict = {}
    if fileio.exists(config_file):
        config_dict = yaml_utils.read_yaml(config_file)

    self._config_file = config_file
    self._config_file_timestamp = None

    config_dict.update(kwargs)
    super(FileSyncModel, self).__init__(**config_dict)

    # write the configuration file to disk, to reflect new attributes
    # and schema changes
    self.write_config()
__setattr__(self, key, value) special

Sets an attribute on the model and persists it in the configuration file.

Parameters:

Name Type Description Default
key str

attribute name.

required
value Any

attribute value.

required
Source code in zenml/utils/filesync_model.py
def __setattr__(self, key: str, value: Any) -> None:
    """Sets an attribute on the model and persists it in the configuration file.

    Args:
        key: attribute name.
        value: attribute value.
    """
    super(FileSyncModel, self).__setattr__(key, value)
    if key.startswith("_"):
        return
    self.write_config()
load_config(self)

Loads the model from the configuration file on disk.

Source code in zenml/utils/filesync_model.py
def load_config(self) -> None:
    """Loads the model from the configuration file on disk."""
    if not fileio.exists(self._config_file):
        return

    # don't reload the configuration if the file hasn't
    # been updated since the last load
    file_timestamp = os.path.getmtime(self._config_file)
    if file_timestamp == self._config_file_timestamp:
        return

    if self._config_file_timestamp is not None:
        logger.info(f"Reloading configuration file {self._config_file}")

    # refresh the model from the configuration file values
    config_dict = yaml_utils.read_yaml(self._config_file)
    for key, value in config_dict.items():
        super(FileSyncModel, self).__setattr__(key, value)

    self._config_file_timestamp = file_timestamp
write_config(self)

Writes the model to the configuration file.

Source code in zenml/utils/filesync_model.py
def write_config(self) -> None:
    """Writes the model to the configuration file."""
    config_dict = json.loads(self.json())
    yaml_utils.write_yaml(self._config_file, config_dict)
    self._config_file_timestamp = os.path.getmtime(self._config_file)

git_utils

Utility function to clone a Git repository.

clone_git_repository(url, to_path, branch=None, commit=None)

Clone a Git repository.

Parameters:

Name Type Description Default
url str

URL of the repository to clone.

required
to_path str

Path to clone the repository to.

required
branch Optional[str]

Branch to clone. Defaults to "main".

None
commit Optional[str]

Commit to checkout. If specified, the branch argument is ignored.

None

Returns:

Type Description
Repo

The cloned repository.

Exceptions:

Type Description
RuntimeError

If the repository could not be cloned.

Source code in zenml/utils/git_utils.py
def clone_git_repository(
    url: str,
    to_path: str,
    branch: Optional[str] = None,
    commit: Optional[str] = None,
) -> Repo:
    """Clone a Git repository.

    Args:
        url: URL of the repository to clone.
        to_path: Path to clone the repository to.
        branch: Branch to clone. Defaults to "main".
        commit: Commit to checkout. If specified, the branch argument is
            ignored.

    Returns:
        The cloned repository.

    Raises:
        RuntimeError: If the repository could not be cloned.
    """
    os.makedirs(os.path.basename(to_path), exist_ok=True)
    try:
        if commit:
            repo = Repo.clone_from(
                url=url,
                to_path=to_path,
                no_checkout=True,
            )
            repo.git.checkout(commit)
        else:
            repo = Repo.clone_from(
                url=url,
                to_path=to_path,
                branch=branch or "main",
            )
        return repo
    except GitCommandError as e:
        raise RuntimeError from e

io_utils

Various utility functions for the io module.

copy_dir(source_dir, destination_dir, overwrite=False)

Copies dir from source to destination.

Parameters:

Name Type Description Default
source_dir str

Path to copy from.

required
destination_dir str

Path to copy to.

required
overwrite bool

Boolean. If false, function throws an error before overwrite.

False
Source code in zenml/utils/io_utils.py
def copy_dir(
    source_dir: str, destination_dir: str, overwrite: bool = False
) -> None:
    """Copies dir from source to destination.

    Args:
        source_dir: Path to copy from.
        destination_dir: Path to copy to.
        overwrite: Boolean. If false, function throws an error before overwrite.
    """
    for source_file in listdir(source_dir):
        source_path = os.path.join(source_dir, convert_to_str(source_file))
        destination_path = os.path.join(
            destination_dir, convert_to_str(source_file)
        )
        if isdir(source_path):
            if source_path == destination_dir:
                # if the destination is a subdirectory of the source, we skip
                # copying it to avoid an infinite loop.
                continue
            copy_dir(source_path, destination_path, overwrite)
        else:
            create_dir_recursive_if_not_exists(
                os.path.dirname(destination_path)
            )
            copy(str(source_path), str(destination_path), overwrite)

create_dir_if_not_exists(dir_path)

Creates directory if it does not exist.

Parameters:

Name Type Description Default
dir_path str

Local path in filesystem.

required
Source code in zenml/utils/io_utils.py
def create_dir_if_not_exists(dir_path: str) -> None:
    """Creates directory if it does not exist.

    Args:
        dir_path: Local path in filesystem.
    """
    if not isdir(dir_path):
        mkdir(dir_path)

create_dir_recursive_if_not_exists(dir_path)

Creates directory recursively if it does not exist.

Parameters:

Name Type Description Default
dir_path str

Local path in filesystem.

required
Source code in zenml/utils/io_utils.py
def create_dir_recursive_if_not_exists(dir_path: str) -> None:
    """Creates directory recursively if it does not exist.

    Args:
        dir_path: Local path in filesystem.
    """
    if not isdir(dir_path):
        makedirs(dir_path)

create_file_if_not_exists(file_path, file_contents='{}')

Creates file if it does not exist.

Parameters:

Name Type Description Default
file_path str

Local path in filesystem.

required
file_contents str

Contents of file.

'{}'
Source code in zenml/utils/io_utils.py
def create_file_if_not_exists(
    file_path: str, file_contents: str = "{}"
) -> None:
    """Creates file if it does not exist.

    Args:
        file_path: Local path in filesystem.
        file_contents: Contents of file.
    """
    full_path = Path(file_path)
    if not exists(file_path):
        create_dir_recursive_if_not_exists(str(full_path.parent))
        with open(str(full_path), "w") as f:
            f.write(file_contents)

find_files(dir_path, pattern)

Find files in a directory that match pattern.

Parameters:

Name Type Description Default
dir_path PathType

The path to directory.

required
pattern str

pattern like *.png.

required

Yields:

Type Description
Iterable[str]

All matching filenames in the directory.

Source code in zenml/utils/io_utils.py
def find_files(dir_path: "PathType", pattern: str) -> Iterable[str]:
    """Find files in a directory that match pattern.

    Args:
        dir_path: The path to directory.
        pattern: pattern like *.png.

    Yields:
        All matching filenames in the directory.
    """
    for root, _, files in walk(dir_path):
        for basename in files:
            if fnmatch.fnmatch(convert_to_str(basename), pattern):
                filename = os.path.join(
                    convert_to_str(root), convert_to_str(basename)
                )
                yield filename

get_global_config_directory()

Gets the global config directory for ZenML.

Returns:

Type Description
str

The global config directory for ZenML.

Source code in zenml/utils/io_utils.py
def get_global_config_directory() -> str:
    """Gets the global config directory for ZenML.

    Returns:
        The global config directory for ZenML.
    """
    env_var_path = os.getenv(ENV_ZENML_CONFIG_PATH)
    if env_var_path:
        return str(Path(env_var_path).resolve())
    return click.get_app_dir(APP_NAME)

get_grandparent(dir_path)

Get grandparent of dir.

Parameters:

Name Type Description Default
dir_path str

The path to directory.

required

Returns:

Type Description
str

The input paths parents parent.

Exceptions:

Type Description
ValueError

If dir_path does not exist.

Source code in zenml/utils/io_utils.py
def get_grandparent(dir_path: str) -> str:
    """Get grandparent of dir.

    Args:
        dir_path: The path to directory.

    Returns:
        The input paths parents parent.

    Raises:
        ValueError: If dir_path does not exist.
    """
    if not os.path.exists(dir_path):
        raise ValueError(f"Path '{dir_path}' does not exist.")
    return Path(dir_path).parent.parent.stem

get_parent(dir_path)

Get parent of dir.

Parameters:

Name Type Description Default
dir_path str

The path to directory.

required

Returns:

Type Description
str

Parent (stem) of the dir as a string.

Exceptions:

Type Description
ValueError

If dir_path does not exist.

Source code in zenml/utils/io_utils.py
def get_parent(dir_path: str) -> str:
    """Get parent of dir.

    Args:
        dir_path: The path to directory.

    Returns:
        Parent (stem) of the dir as a string.

    Raises:
        ValueError: If dir_path does not exist.
    """
    if not os.path.exists(dir_path):
        raise ValueError(f"Path '{dir_path}' does not exist.")
    return Path(dir_path).parent.stem

is_remote(path)

Returns True if path exists remotely.

Parameters:

Name Type Description Default
path str

Any path as a string.

required

Returns:

Type Description
bool

True if remote path, else False.

Source code in zenml/utils/io_utils.py
def is_remote(path: str) -> bool:
    """Returns True if path exists remotely.

    Args:
        path: Any path as a string.

    Returns:
        True if remote path, else False.
    """
    return any(path.startswith(prefix) for prefix in REMOTE_FS_PREFIX)

is_root(path)

Returns true if path has no parent in local filesystem.

Parameters:

Name Type Description Default
path str

Local path in filesystem.

required

Returns:

Type Description
bool

True if root, else False.

Source code in zenml/utils/io_utils.py
def is_root(path: str) -> bool:
    """Returns true if path has no parent in local filesystem.

    Args:
        path: Local path in filesystem.

    Returns:
        True if root, else False.
    """
    return Path(path).parent == Path(path)

move(source, destination, overwrite=False)

Moves dir or file from source to destination. Can be used to rename.

Parameters:

Name Type Description Default
source str

Local path to copy from.

required
destination str

Local path to copy to.

required
overwrite bool

boolean, if false, then throws an error before overwrite.

False
Source code in zenml/utils/io_utils.py
def move(source: str, destination: str, overwrite: bool = False) -> None:
    """Moves dir or file from source to destination. Can be used to rename.

    Args:
        source: Local path to copy from.
        destination: Local path to copy to.
        overwrite: boolean, if false, then throws an error before overwrite.
    """
    rename(source, destination, overwrite)

read_file_contents_as_string(file_path)

Reads contents of file.

Parameters:

Name Type Description Default
file_path str

Path to file.

required

Returns:

Type Description
str

Contents of file.

Exceptions:

Type Description
FileNotFoundError

If file does not exist.

Source code in zenml/utils/io_utils.py
def read_file_contents_as_string(file_path: str) -> str:
    """Reads contents of file.

    Args:
        file_path: Path to file.

    Returns:
        Contents of file.

    Raises:
        FileNotFoundError: If file does not exist.
    """
    if not exists(file_path):
        raise FileNotFoundError(f"{file_path} does not exist!")
    with open(file_path) as f:
        return f.read()  # type: ignore[no-any-return]

resolve_relative_path(path)

Takes relative path and resolves it absolutely.

Parameters:

Name Type Description Default
path str

Local path in filesystem.

required

Returns:

Type Description
str

Resolved path.

Source code in zenml/utils/io_utils.py
def resolve_relative_path(path: str) -> str:
    """Takes relative path and resolves it absolutely.

    Args:
        path: Local path in filesystem.

    Returns:
        Resolved path.
    """
    if is_remote(path):
        return path
    return str(Path(path).resolve())

write_file_contents_as_string(file_path, content)

Writes contents of file.

Parameters:

Name Type Description Default
file_path str

Path to file.

required
content str

Contents of file.

required

Exceptions:

Type Description
ValueError

If content is not of type str.

Source code in zenml/utils/io_utils.py
def write_file_contents_as_string(file_path: str, content: str) -> None:
    """Writes contents of file.

    Args:
        file_path: Path to file.
        content: Contents of file.

    Raises:
        ValueError: If content is not of type str.
    """
    if not isinstance(content, str):
        raise ValueError(f"Content must be of type str, got {type(content)}")
    with open(file_path, "w") as f:
        f.write(content)

materializer_utils

Util functions for materializers.

select_materializer(data_type, materializer_classes)

Select a materializer for a given data type.

Parameters:

Name Type Description Default
data_type Type[Any]

The data type for which to select the materializer.

required
materializer_classes Sequence[Type[BaseMaterializer]]

Available materializer classes.

required

Exceptions:

Type Description
RuntimeError

If no materializer can handle the given data type.

Returns:

Type Description
Type[BaseMaterializer]

The first materializer that can handle the given data type.

Source code in zenml/utils/materializer_utils.py
def select_materializer(
    data_type: Type[Any],
    materializer_classes: Sequence[Type["BaseMaterializer"]],
) -> Type["BaseMaterializer"]:
    """Select a materializer for a given data type.

    Args:
        data_type: The data type for which to select the materializer.
        materializer_classes: Available materializer classes.

    Raises:
        RuntimeError: If no materializer can handle the given data type.

    Returns:
        The first materializer that can handle the given data type.
    """
    fallback: Optional[Type["BaseMaterializer"]] = None

    for class_ in data_type.__mro__:
        for materializer_class in materializer_classes:
            if class_ in materializer_class.ASSOCIATED_TYPES:
                return materializer_class
            elif not fallback and materializer_class.can_handle_type(class_):
                fallback = materializer_class

    if fallback:
        return fallback

    raise RuntimeError(f"No materializer found for type {data_type}.")

networking_utils

Utility functions for networking.

find_available_port()

Finds a local random unoccupied TCP port.

Returns:

Type Description
int

A random unoccupied TCP port.

Source code in zenml/utils/networking_utils.py
def find_available_port() -> int:
    """Finds a local random unoccupied TCP port.

    Returns:
        A random unoccupied TCP port.
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        _, port = s.getsockname()

    return cast(int, port)

get_or_create_ngrok_tunnel(ngrok_token, port)

Get or create an ngrok tunnel at the given port.

Parameters:

Name Type Description Default
ngrok_token str

The ngrok auth token.

required
port int

The port to tunnel.

required

Returns:

Type Description
str

The public URL of the ngrok tunnel.

Exceptions:

Type Description
ImportError

If the pyngrok package is not installed.

Source code in zenml/utils/networking_utils.py
def get_or_create_ngrok_tunnel(ngrok_token: str, port: int) -> str:
    """Get or create an ngrok tunnel at the given port.

    Args:
        ngrok_token: The ngrok auth token.
        port: The port to tunnel.

    Returns:
        The public URL of the ngrok tunnel.

    Raises:
        ImportError: If the `pyngrok` package is not installed.
    """
    try:
        from pyngrok import ngrok as ngrok_client
    except ImportError:
        raise ImportError(
            "The `pyngrok` package is required to create ngrok tunnels. "
            "Please install it by running `pip install pyngrok`."
        )

    # Check if ngrok is already tunneling the port
    tunnels = ngrok_client.get_tunnels()
    for tunnel in tunnels:
        if tunnel.config and isinstance(tunnel.config, dict):
            tunnel_protocol = tunnel.config.get("proto")
            tunnel_port = tunnel.config.get("addr")
            if tunnel_protocol == "http" and tunnel_port == port:
                return str(tunnel.public_url)

    # Create new tunnel
    ngrok_client.set_auth_token(ngrok_token)
    return str(ngrok_client.connect(port).public_url)

port_available(port, address='127.0.0.1')

Checks if a local port is available.

Parameters:

Name Type Description Default
port int

TCP port number

required
address str

IP address on the local machine

'127.0.0.1'

Returns:

Type Description
bool

True if the port is available, otherwise False

Source code in zenml/utils/networking_utils.py
def port_available(port: int, address: str = "127.0.0.1") -> bool:
    """Checks if a local port is available.

    Args:
        port: TCP port number
        address: IP address on the local machine

    Returns:
        True if the port is available, otherwise False
    """
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if hasattr(socket, "SO_REUSEPORT"):
                s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
            else:
                # The SO_REUSEPORT socket option is not supported on Windows.
                # This if clause exists just for mypy to not complain about
                # missing code paths.
                pass
            s.bind((address, port))
    except socket.error as e:
        logger.debug("Port %d unavailable on %s: %s", port, address, e)
        return False

    return True

port_is_open(hostname, port)

Check if a TCP port is open on a remote host.

Parameters:

Name Type Description Default
hostname str

hostname of the remote machine

required
port int

TCP port number

required

Returns:

Type Description
bool

True if the port is open, False otherwise

Source code in zenml/utils/networking_utils.py
def port_is_open(hostname: str, port: int) -> bool:
    """Check if a TCP port is open on a remote host.

    Args:
        hostname: hostname of the remote machine
        port: TCP port number

    Returns:
        True if the port is open, False otherwise
    """
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
            result = sock.connect_ex((hostname, port))
            return result == 0
    except socket.error as e:
        logger.debug(
            f"Error checking TCP port {port} on host {hostname}: {str(e)}"
        )
        return False

replace_internal_hostname_with_localhost(hostname)

Replaces an internal Docker or K3D hostname with localhost.

Localhost URLs that are directly accessible on the host machine are not accessible from within a Docker or K3D container running on that same machine, but there are special hostnames featured by both Docker (host.docker.internal) and K3D (host.k3d.internal) that can be used to access host services from within the containers.

Use this method to replace one of these special hostnames with localhost if used outside a container or in a container where special hostnames are not available.

Parameters:

Name Type Description Default
hostname str

The hostname to replace.

required

Returns:

Type Description
str

The original or replaced hostname.

Source code in zenml/utils/networking_utils.py
def replace_internal_hostname_with_localhost(hostname: str) -> str:
    """Replaces an internal Docker or K3D hostname with localhost.

    Localhost URLs that are directly accessible on the host machine are not
    accessible from within a Docker or K3D container running on that same
    machine, but there are special hostnames featured by both Docker
    (`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
    access host services from within the containers.

    Use this method to replace one of these special hostnames with localhost
    if used outside a container or in a container where special hostnames are
    not available.

    Args:
        hostname: The hostname to replace.

    Returns:
        The original or replaced hostname.
    """
    if hostname not in ("host.docker.internal", "host.k3d.internal"):
        return hostname

    if Environment.in_container():
        # Try to resolve one of the special hostnames to see if it is available
        # inside the container and use that if it is.
        for internal_hostname in (
            "host.docker.internal",
            "host.k3d.internal",
        ):
            try:
                socket.gethostbyname(internal_hostname)
                if internal_hostname != hostname:
                    logger.debug(
                        f"Replacing internal hostname {hostname} with "
                        f"{internal_hostname}"
                    )
                return internal_hostname
            except socket.gaierror:
                continue

    logger.debug(f"Replacing internal hostname {hostname} with localhost.")

    return "127.0.0.1"

replace_localhost_with_internal_hostname(url)

Replaces the localhost with an internal Docker or K3D hostname in a given URL.

Localhost URLs that are directly accessible on the host machine are not accessible from within a Docker or K3D container running on that same machine, but there are special hostnames featured by both Docker (host.docker.internal) and K3D (host.k3d.internal) that can be used to access host services from within the containers.

Use this method to attempt to replace localhost in a URL with one of these special hostnames, if they are available inside a container.

Parameters:

Name Type Description Default
url str

The URL to update.

required

Returns:

Type Description
str

The updated URL.

Source code in zenml/utils/networking_utils.py
def replace_localhost_with_internal_hostname(url: str) -> str:
    """Replaces the localhost with an internal Docker or K3D hostname in a given URL.

    Localhost URLs that are directly accessible on the host machine are not
    accessible from within a Docker or K3D container running on that same
    machine, but there are special hostnames featured by both Docker
    (`host.docker.internal`) and K3D (`host.k3d.internal`) that can be used to
    access host services from within the containers.

    Use this method to attempt to replace `localhost` in a URL with one of these
    special hostnames, if they are available inside a container.

    Args:
        url: The URL to update.

    Returns:
        The updated URL.
    """
    if not Environment.in_container():
        return url

    parsed_url = urlparse(url)
    if parsed_url.hostname in ("localhost", "127.0.0.1"):
        for internal_hostname in (
            "host.docker.internal",
            "host.k3d.internal",
        ):
            try:
                socket.gethostbyname(internal_hostname)
                parsed_url = parsed_url._replace(
                    netloc=parsed_url.netloc.replace(
                        parsed_url.hostname,
                        internal_hostname,
                    )
                )
                logger.debug(
                    f"Replacing localhost with {internal_hostname} in URL: "
                    f"{url}"
                )
                return parsed_url.geturl()

            except socket.gaierror:
                continue

    return url

scan_for_available_port(start=8000, stop=65535)

Scan the local network for an available port in the given range.

Parameters:

Name Type Description Default
start int

the beginning of the port range value to scan

8000
stop int

the (inclusive) end of the port range value to scan

65535

Returns:

Type Description
Optional[int]

The first available port in the given range, or None if no available port is found.

Source code in zenml/utils/networking_utils.py
def scan_for_available_port(
    start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1]
) -> Optional[int]:
    """Scan the local network for an available port in the given range.

    Args:
        start: the beginning of the port range value to scan
        stop: the (inclusive) end of the port range value to scan

    Returns:
        The first available port in the given range, or None if no available
        port is found.
    """
    for port in range(start, stop + 1):
        if port_available(port):
            return port
    logger.debug(
        "No free TCP ports found in the range %d - %d",
        start,
        stop,
    )
    return None

pagination_utils

Pagination utilities.

Page[AnyResponseModel] (Page) pydantic-model

Config

Pydantic configuration class.

Source code in zenml/utils/pagination_utils.py
class Config:
    """Pydantic configuration class."""

    # This is needed to allow the REST API server to unpack SecretStr
    # values correctly before sending them to the client.
    json_encoders = {
        SecretStr: lambda v: v.get_secret_value() if v else None
    }
__json_encoder__(obj) special staticmethod

partial(func, args, *keywords) - new function with partial application of the given arguments and keywords.

depaginate(list_method)

Depaginate the results from a client or store method that returns pages.

Parameters:

Name Type Description Default
list_method Callable[..., zenml.utils.pagination_utils.Page[AnyResponseModel]]

The list method to wrap around.

required

Returns:

Type Description
List[~AnyResponseModel]

A list of the corresponding Response Models.

Source code in zenml/utils/pagination_utils.py
def depaginate(
    list_method: Callable[..., Page[AnyResponseModel]],
) -> List[AnyResponseModel]:
    """Depaginate the results from a client or store method that returns pages.

    Args:
        list_method: The list method to wrap around.

    Returns:
        A list of the corresponding Response Models.
    """
    page = list_method()
    items = list(page.items)
    while page.index < page.total_pages:
        page = list_method(page=page.index + 1)
        items += list(page.items)

    return items

pipeline_docker_image_builder

Implementation of Docker image builds to run ZenML pipelines.

PipelineDockerImageBuilder

Builds Docker images to run a ZenML pipeline.

Source code in zenml/utils/pipeline_docker_image_builder.py
class PipelineDockerImageBuilder:
    """Builds Docker images to run a ZenML pipeline."""

    def build_docker_image(
        self,
        docker_settings: "DockerSettings",
        tag: str,
        stack: "Stack",
        include_files: bool,
        download_files: bool,
        entrypoint: Optional[str] = None,
        extra_files: Optional[Dict[str, str]] = None,
        code_repository: Optional["BaseCodeRepository"] = None,
    ) -> Tuple[str, Optional[str], Optional[str]]:
        """Builds (and optionally pushes) a Docker image to run a pipeline.

        Use the image name returned by this method whenever you need to uniquely
        reference the pushed image in order to pull or run it.

        Args:
            docker_settings: The settings for the image build.
            tag: The tag to use for the image.
            stack: The stack on which the pipeline will be deployed.
            include_files: Whether to include files in the build context.
            download_files: Whether to download files in the build context.
            entrypoint: Entrypoint to use for the final image. If left empty,
                no entrypoint will be included in the image.
            extra_files: Extra files to add to the build context. Keys are the
                path inside the build context, values are either the file
                content or a file path.
            code_repository: The code repository from which files will be
                downloaded.

        Returns:
            A tuple (image_digest, dockerfile, requirements):
            - The Docker image repo digest or local name, depending on whether
            the image was pushed or is just stored locally.
            - Dockerfile will contain the contents of the Dockerfile used to
            build the image.
            - Requirements is a string with a single pip requirement per line.

        Raises:
            RuntimeError: If the stack does not contain an image builder.
            ValueError: If no Dockerfile and/or custom parent image is
                specified and the Docker configuration doesn't require an
                image build.
        """
        requirements: Optional[str] = None
        dockerfile: Optional[str] = None

        if docker_settings.skip_build:
            assert (
                docker_settings.parent_image
            )  # checked via validator already

            # Should we tag this here and push it to the container registry of
            # the stack to make sure it's always accessible when running the
            # pipeline?
            return docker_settings.parent_image, dockerfile, requirements

        image_builder = stack.image_builder
        if not image_builder:
            raise RuntimeError(
                "Unable to build Docker images without an image builder in the "
                f"stack `{stack.name}`."
            )

        container_registry = stack.container_registry

        build_context_class = image_builder.build_context_class
        target_image_name = self._get_target_image_name(
            docker_settings=docker_settings,
            tag=tag,
            container_registry=container_registry,
        )

        requires_zenml_build = any(
            [
                docker_settings.requirements,
                docker_settings.required_integrations,
                docker_settings.required_hub_plugins,
                docker_settings.replicate_local_python_environment,
                docker_settings.install_stack_requirements,
                docker_settings.apt_packages,
                docker_settings.environment,
                include_files,
                download_files,
                entrypoint,
                extra_files,
            ]
        )

        # Fallback to the value defined on the stack component if the
        # pipeline configuration doesn't have a configured value
        parent_image = (
            docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
        )

        if docker_settings.dockerfile:
            if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
                logger.warning(
                    "You've specified both a Dockerfile and a custom parent "
                    "image, ignoring the parent image."
                )

            push = (
                not image_builder.is_building_locally
                or not requires_zenml_build
            )

            if requires_zenml_build:
                # We will build an additional image on top of this one later
                # to include user files and/or install requirements. The image
                # we build now will be used as the parent for the next build.
                user_image_name = (
                    f"{docker_settings.target_repository}:"
                    f"{tag}-intermediate-build"
                )
                if push and container_registry:
                    user_image_name = (
                        f"{container_registry.config.uri}/{user_image_name}"
                    )

                parent_image = user_image_name
            else:
                # The image we'll build from the custom Dockerfile will be
                # used directly, so we tag it with the requested target name.
                user_image_name = target_image_name

            build_context = build_context_class(
                root=docker_settings.build_context_root
            )
            build_context.add_file(
                source=docker_settings.dockerfile, destination="Dockerfile"
            )
            logger.info("Building Docker image `%s`.", user_image_name)
            image_name_or_digest = image_builder.build(
                image_name=user_image_name,
                build_context=build_context,
                docker_build_options=docker_settings.build_options,
                container_registry=container_registry if push else None,
            )

        elif not requires_zenml_build:
            if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
                raise ValueError(
                    "Unable to run a ZenML pipeline with the given Docker "
                    "settings: No Dockerfile or custom parent image "
                    "specified and no files will be copied or requirements "
                    "installed."
                )
            else:
                # The parent image will be used directly to run the pipeline and
                # needs to be tagged/pushed
                docker_utils.tag_image(parent_image, target=target_image_name)
                if container_registry:
                    image_name_or_digest = container_registry.push_image(
                        target_image_name
                    )
                else:
                    image_name_or_digest = target_image_name

        if requires_zenml_build:
            logger.info("Building Docker image `%s`.", target_image_name)
            # Leave the build context empty if we don't want to include any files
            build_context_root = (
                source_utils.get_source_root() if include_files else None
            )
            build_context = build_context_class(
                root=build_context_root,
                dockerignore_file=docker_settings.dockerignore,
            )

            requirements_files = self.gather_requirements_files(
                docker_settings=docker_settings,
                stack=stack,
                # Only pass code repo to include its dependencies if we actually
                # need to download code
                code_repository=code_repository if download_files else None,
            )

            self._add_requirements_files(
                requirements_files=requirements_files,
                build_context=build_context,
            )
            requirements = (
                "\n".join(
                    file_content for _, file_content, _ in requirements_files
                )
                or None
            )

            apt_packages = docker_settings.apt_packages
            if docker_settings.install_stack_requirements:
                apt_packages += stack.apt_packages

            if apt_packages:
                logger.info(
                    "Including apt packages: %s",
                    ", ".join(f"`{p}`" for p in apt_packages),
                )

            if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
                # The default parent image is static and doesn't require a pull
                # each time
                pull_parent_image = False
            elif docker_settings.dockerfile and not container_registry:
                # We built a custom parent image and there was no container
                # registry in the stack to push to, this is a local image
                pull_parent_image = False
            elif not image_builder.is_building_locally:
                # Remote image builders always need to pull the image
                pull_parent_image = True
            else:
                # If the image is local, we don't need to pull it. Otherwise
                # we play it safe and always pull in case the user pushed a new
                # image for the given name and tag
                pull_parent_image = not docker_utils.is_local_image(
                    parent_image
                )

            build_options = {"pull": pull_parent_image, "rm": False}

            dockerfile = self._generate_zenml_pipeline_dockerfile(
                parent_image=parent_image,
                docker_settings=docker_settings,
                download_files=download_files,
                requirements_files=requirements_files,
                apt_packages=apt_packages,
                entrypoint=entrypoint,
            )
            build_context.add_file(destination="Dockerfile", source=dockerfile)

            if extra_files:
                for destination, source in extra_files.items():
                    build_context.add_file(
                        destination=destination, source=source
                    )

            image_name_or_digest = image_builder.build(
                image_name=target_image_name,
                build_context=build_context,
                docker_build_options=build_options,
                container_registry=container_registry,
            )

        return image_name_or_digest, dockerfile, requirements

    @staticmethod
    def _get_target_image_name(
        docker_settings: "DockerSettings",
        tag: str,
        container_registry: Optional["BaseContainerRegistry"] = None,
    ) -> str:
        """Returns the target image name.

        If a container registry is given, the image name will include the
        registry URI

        Args:
            docker_settings: The settings for the image build.
            tag: The tag to use for the image.
            container_registry: Optional container registry to which this
                image will be pushed.

        Returns:
            The docker image name.
        """
        target_image_name = f"{docker_settings.target_repository}:{tag}"
        if container_registry:
            target_image_name = (
                f"{container_registry.config.uri}/{target_image_name}"
            )

        return target_image_name

    @classmethod
    def _add_requirements_files(
        cls,
        requirements_files: List[Tuple[str, str, List[str]]],
        build_context: "BuildContext",
    ) -> None:
        """Adds requirements files to the build context.

        Args:
            requirements_files: List of tuples
                (filename, file_content, pip_options).
            build_context: Build context to add the requirements files to.
        """
        for filename, file_content, _ in requirements_files:
            build_context.add_file(source=file_content, destination=filename)

    @staticmethod
    def gather_requirements_files(
        docker_settings: DockerSettings,
        stack: "Stack",
        code_repository: Optional["BaseCodeRepository"] = None,
        log: bool = True,
    ) -> List[Tuple[str, str, List[str]]]:
        """Gathers and/or generates pip requirements files.

        This method is called in `PipelineDockerImageBuilder.build_docker_image`
        but it is also called by other parts of the codebase, e.g. the
        `AzureMLStepOperator`, which needs to upload the requirements files to
        AzureML where the step image is then built.

        Args:
            docker_settings: Docker settings that specifies which
                requirements to install.
            stack: The stack on which the pipeline will run.
            code_repository: The code repository from which files will be
                downloaded.
            log: If True, will log the requirements.

        Raises:
            RuntimeError: If the command to export the local python packages
                failed.

        Returns:
            List of tuples (filename, file_content, pip_options) of all
            requirements files.
            The files will be in the following order:
            - Packages installed in the local Python environment
            - User-defined requirements
            - Requirements defined by user-defined and/or stack integrations
        """
        requirements_files: List[Tuple[str, str, List[str]]] = []

        # Generate requirements file for the local environment if configured
        if docker_settings.replicate_local_python_environment:
            if isinstance(
                docker_settings.replicate_local_python_environment,
                PythonEnvironmentExportMethod,
            ):
                command = (
                    docker_settings.replicate_local_python_environment.command
                )
            else:
                command = " ".join(
                    docker_settings.replicate_local_python_environment
                )

            try:
                local_requirements = subprocess.check_output(
                    command, shell=True
                ).decode()
            except subprocess.CalledProcessError as e:
                raise RuntimeError(
                    "Unable to export local python packages."
                ) from e

            requirements_files.append(
                (".zenml_local_requirements", local_requirements, [])
            )
            if log:
                logger.info(
                    "- Including python packages from local environment"
                )

        # Generate/Read requirements file for user-defined requirements
        if isinstance(docker_settings.requirements, str):
            user_requirements = io_utils.read_file_contents_as_string(
                docker_settings.requirements
            )
            if log:
                logger.info(
                    "- Including user-defined requirements from file `%s`",
                    os.path.abspath(docker_settings.requirements),
                )
        elif isinstance(docker_settings.requirements, List):
            user_requirements = "\n".join(docker_settings.requirements)
            if log:
                logger.info(
                    "- Including user-defined requirements: %s",
                    ", ".join(f"`{r}`" for r in docker_settings.requirements),
                )
        else:
            user_requirements = None

        if user_requirements:
            requirements_files.append(
                (".zenml_user_requirements", user_requirements, [])
            )

        # Generate requirements file for all required integrations
        integration_requirements = set(
            itertools.chain.from_iterable(
                integration_registry.select_integration_requirements(
                    integration_name=integration,
                    target_os=OperatingSystemType.LINUX,
                )
                for integration in docker_settings.required_integrations
            )
        )

        if docker_settings.install_stack_requirements:
            integration_requirements.update(stack.requirements())
            if code_repository:
                integration_requirements.update(code_repository.requirements)

        if integration_requirements:
            integration_requirements_list = sorted(integration_requirements)
            integration_requirements_file = "\n".join(
                integration_requirements_list
            )
            requirements_files.append(
                (
                    ".zenml_integration_requirements",
                    integration_requirements_file,
                    [],
                )
            )
            if log:
                logger.info(
                    "- Including integration requirements: %s",
                    ", ".join(f"`{r}`" for r in integration_requirements_list),
                )

        # Generate requirements files for all ZenML Hub plugins
        if docker_settings.required_hub_plugins:
            (
                hub_internal_requirements,
                hub_pypi_requirements,
            ) = PipelineDockerImageBuilder._get_hub_requirements(
                docker_settings.required_hub_plugins
            )

            # Plugin packages themselves
            for i, (index, packages) in enumerate(
                hub_internal_requirements.items()
            ):
                file_name = f".zenml_hub_internal_requirements_{i}"
                file_lines = [f"-i {index}", *packages]
                file_contents = "\n".join(file_lines)
                requirements_files.append(
                    (file_name, file_contents, ["--no-deps"])
                )
                if log:
                    logger.info(
                        "- Including internal hub packages from index `%s`: %s",
                        index,
                        ", ".join(f"`{r}`" for r in packages),
                    )

            # PyPI requirements of plugin packages
            if hub_pypi_requirements:
                file_name = ".zenml_hub_pypi_requirements"
                file_contents = "\n".join(hub_pypi_requirements)
                requirements_files.append((file_name, file_contents, []))
                if log:
                    logger.info(
                        "- Including hub requirements from PyPI: %s",
                        ", ".join(f"`{r}`" for r in hub_pypi_requirements),
                    )

        return requirements_files

    @staticmethod
    def _get_hub_requirements(
        required_hub_plugins: List[str],
    ) -> Tuple[Dict[str, List[str]], List[str]]:
        """Get package requirements for ZenML Hub plugins.

        Args:
            required_hub_plugins: List of hub plugin names in the format
                `(<author_username>/)<plugin_name>(==<version>)`.

        Returns:
            - A dict of the hub plugin packages themselves (which need to be
                installed from a custom index, mapping index URLs to lists of
                package names.
            - A list of all unique dependencies of the required hub plugins
                (which can be installed from PyPI).
        """
        from zenml._hub.client import HubClient
        from zenml._hub.utils import parse_plugin_name, plugin_display_name

        client = HubClient()

        internal_requirements: DefaultDict[str, List[str]] = defaultdict(list)
        pypi_requirements: List[str] = []

        for plugin_str in required_hub_plugins:
            author, name, version = parse_plugin_name(
                plugin_str, version_separator="=="
            )

            plugin = client.get_plugin(
                name=name,
                version=version,
                author=author,
            )

            if plugin and plugin.index_url and plugin.package_name:
                internal_requirements[plugin.index_url].append(
                    plugin.package_name
                )
                if plugin.requirements:
                    pypi_requirements.extend(plugin.requirements)
            else:
                display_name = plugin_display_name(name, version, author)
                logger.warning(
                    "Hub plugin `%s` does not exist or cannot be installed."
                    "Skipping installation of this plugin.",
                    display_name,
                )

        pypi_requirements = sorted(set(pypi_requirements))
        return dict(internal_requirements), pypi_requirements

    @staticmethod
    def _generate_zenml_pipeline_dockerfile(
        parent_image: str,
        docker_settings: DockerSettings,
        download_files: bool,
        requirements_files: Sequence[Tuple[str, str, List[str]]] = (),
        apt_packages: Sequence[str] = (),
        entrypoint: Optional[str] = None,
    ) -> str:
        """Generates a Dockerfile.

        Args:
            parent_image: The image to use as parent for the Dockerfile.
            docker_settings: Docker settings for this image build.
            download_files: Whether to download files in the build context.
            requirements_files: List of tuples that contain three items:
                - the name of a requirements file,
                - the content of that file,
                - options that should be passed to pip when installing the
                    requirements file.
            apt_packages: APT packages to install.
            entrypoint: The default entrypoint command that gets executed when
                running a container of an image created by this Dockerfile.

        Returns:
            The generated Dockerfile.
        """
        lines = [f"FROM {parent_image}", f"WORKDIR {DOCKER_IMAGE_WORKDIR}"]

        if apt_packages:
            apt_packages = " ".join(f"'{p}'" for p in apt_packages)

            lines.append(
                "RUN apt-get update && apt-get install -y "
                f"--no-install-recommends {apt_packages}"
            )

        for file, _, options in requirements_files:
            lines.append(f"COPY {file} .")

            option_string = " ".join(options)
            lines.append(
                f"RUN pip install --default-timeout=60 --no-cache-dir "
                f"{option_string} -r {file}"
            )

        lines.append(f"ENV {ENV_ZENML_ENABLE_REPO_INIT_WARNINGS}=False")
        if download_files:
            lines.append(f"ENV {ENV_ZENML_REQUIRES_CODE_DOWNLOAD}=True")

        lines.append(
            f"ENV {ENV_ZENML_CONFIG_PATH}={DOCKER_IMAGE_ZENML_CONFIG_PATH}"
        )

        for key, value in docker_settings.environment.items():
            lines.append(f"ENV {key.upper()}={value}")

        lines.append("COPY . .")
        lines.append("RUN chmod -R a+rw .")

        if docker_settings.user:
            lines.append(f"USER {docker_settings.user}")
            lines.append(f"RUN chown -R {docker_settings.user} .")

        if entrypoint:
            lines.append(f"ENTRYPOINT {entrypoint}")

        return "\n".join(lines)
build_docker_image(self, docker_settings, tag, stack, include_files, download_files, entrypoint=None, extra_files=None, code_repository=None)

Builds (and optionally pushes) a Docker image to run a pipeline.

Use the image name returned by this method whenever you need to uniquely reference the pushed image in order to pull or run it.

Parameters:

Name Type Description Default
docker_settings DockerSettings

The settings for the image build.

required
tag str

The tag to use for the image.

required
stack Stack

The stack on which the pipeline will be deployed.

required
include_files bool

Whether to include files in the build context.

required
download_files bool

Whether to download files in the build context.

required
entrypoint Optional[str]

Entrypoint to use for the final image. If left empty, no entrypoint will be included in the image.

None
extra_files Optional[Dict[str, str]]

Extra files to add to the build context. Keys are the path inside the build context, values are either the file content or a file path.

None
code_repository Optional[BaseCodeRepository]

The code repository from which files will be downloaded.

None

Returns:

Type Description
A tuple (image_digest, dockerfile, requirements)
  • The Docker image repo digest or local name, depending on whether the image was pushed or is just stored locally.
  • Dockerfile will contain the contents of the Dockerfile used to build the image.
  • Requirements is a string with a single pip requirement per line.

Exceptions:

Type Description
RuntimeError

If the stack does not contain an image builder.

ValueError

If no Dockerfile and/or custom parent image is specified and the Docker configuration doesn't require an image build.

Source code in zenml/utils/pipeline_docker_image_builder.py
def build_docker_image(
    self,
    docker_settings: "DockerSettings",
    tag: str,
    stack: "Stack",
    include_files: bool,
    download_files: bool,
    entrypoint: Optional[str] = None,
    extra_files: Optional[Dict[str, str]] = None,
    code_repository: Optional["BaseCodeRepository"] = None,
) -> Tuple[str, Optional[str], Optional[str]]:
    """Builds (and optionally pushes) a Docker image to run a pipeline.

    Use the image name returned by this method whenever you need to uniquely
    reference the pushed image in order to pull or run it.

    Args:
        docker_settings: The settings for the image build.
        tag: The tag to use for the image.
        stack: The stack on which the pipeline will be deployed.
        include_files: Whether to include files in the build context.
        download_files: Whether to download files in the build context.
        entrypoint: Entrypoint to use for the final image. If left empty,
            no entrypoint will be included in the image.
        extra_files: Extra files to add to the build context. Keys are the
            path inside the build context, values are either the file
            content or a file path.
        code_repository: The code repository from which files will be
            downloaded.

    Returns:
        A tuple (image_digest, dockerfile, requirements):
        - The Docker image repo digest or local name, depending on whether
        the image was pushed or is just stored locally.
        - Dockerfile will contain the contents of the Dockerfile used to
        build the image.
        - Requirements is a string with a single pip requirement per line.

    Raises:
        RuntimeError: If the stack does not contain an image builder.
        ValueError: If no Dockerfile and/or custom parent image is
            specified and the Docker configuration doesn't require an
            image build.
    """
    requirements: Optional[str] = None
    dockerfile: Optional[str] = None

    if docker_settings.skip_build:
        assert (
            docker_settings.parent_image
        )  # checked via validator already

        # Should we tag this here and push it to the container registry of
        # the stack to make sure it's always accessible when running the
        # pipeline?
        return docker_settings.parent_image, dockerfile, requirements

    image_builder = stack.image_builder
    if not image_builder:
        raise RuntimeError(
            "Unable to build Docker images without an image builder in the "
            f"stack `{stack.name}`."
        )

    container_registry = stack.container_registry

    build_context_class = image_builder.build_context_class
    target_image_name = self._get_target_image_name(
        docker_settings=docker_settings,
        tag=tag,
        container_registry=container_registry,
    )

    requires_zenml_build = any(
        [
            docker_settings.requirements,
            docker_settings.required_integrations,
            docker_settings.required_hub_plugins,
            docker_settings.replicate_local_python_environment,
            docker_settings.install_stack_requirements,
            docker_settings.apt_packages,
            docker_settings.environment,
            include_files,
            download_files,
            entrypoint,
            extra_files,
        ]
    )

    # Fallback to the value defined on the stack component if the
    # pipeline configuration doesn't have a configured value
    parent_image = (
        docker_settings.parent_image or DEFAULT_DOCKER_PARENT_IMAGE
    )

    if docker_settings.dockerfile:
        if parent_image != DEFAULT_DOCKER_PARENT_IMAGE:
            logger.warning(
                "You've specified both a Dockerfile and a custom parent "
                "image, ignoring the parent image."
            )

        push = (
            not image_builder.is_building_locally
            or not requires_zenml_build
        )

        if requires_zenml_build:
            # We will build an additional image on top of this one later
            # to include user files and/or install requirements. The image
            # we build now will be used as the parent for the next build.
            user_image_name = (
                f"{docker_settings.target_repository}:"
                f"{tag}-intermediate-build"
            )
            if push and container_registry:
                user_image_name = (
                    f"{container_registry.config.uri}/{user_image_name}"
                )

            parent_image = user_image_name
        else:
            # The image we'll build from the custom Dockerfile will be
            # used directly, so we tag it with the requested target name.
            user_image_name = target_image_name

        build_context = build_context_class(
            root=docker_settings.build_context_root
        )
        build_context.add_file(
            source=docker_settings.dockerfile, destination="Dockerfile"
        )
        logger.info("Building Docker image `%s`.", user_image_name)
        image_name_or_digest = image_builder.build(
            image_name=user_image_name,
            build_context=build_context,
            docker_build_options=docker_settings.build_options,
            container_registry=container_registry if push else None,
        )

    elif not requires_zenml_build:
        if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
            raise ValueError(
                "Unable to run a ZenML pipeline with the given Docker "
                "settings: No Dockerfile or custom parent image "
                "specified and no files will be copied or requirements "
                "installed."
            )
        else:
            # The parent image will be used directly to run the pipeline and
            # needs to be tagged/pushed
            docker_utils.tag_image(parent_image, target=target_image_name)
            if container_registry:
                image_name_or_digest = container_registry.push_image(
                    target_image_name
                )
            else:
                image_name_or_digest = target_image_name

    if requires_zenml_build:
        logger.info("Building Docker image `%s`.", target_image_name)
        # Leave the build context empty if we don't want to include any files
        build_context_root = (
            source_utils.get_source_root() if include_files else None
        )
        build_context = build_context_class(
            root=build_context_root,
            dockerignore_file=docker_settings.dockerignore,
        )

        requirements_files = self.gather_requirements_files(
            docker_settings=docker_settings,
            stack=stack,
            # Only pass code repo to include its dependencies if we actually
            # need to download code
            code_repository=code_repository if download_files else None,
        )

        self._add_requirements_files(
            requirements_files=requirements_files,
            build_context=build_context,
        )
        requirements = (
            "\n".join(
                file_content for _, file_content, _ in requirements_files
            )
            or None
        )

        apt_packages = docker_settings.apt_packages
        if docker_settings.install_stack_requirements:
            apt_packages += stack.apt_packages

        if apt_packages:
            logger.info(
                "Including apt packages: %s",
                ", ".join(f"`{p}`" for p in apt_packages),
            )

        if parent_image == DEFAULT_DOCKER_PARENT_IMAGE:
            # The default parent image is static and doesn't require a pull
            # each time
            pull_parent_image = False
        elif docker_settings.dockerfile and not container_registry:
            # We built a custom parent image and there was no container
            # registry in the stack to push to, this is a local image
            pull_parent_image = False
        elif not image_builder.is_building_locally:
            # Remote image builders always need to pull the image
            pull_parent_image = True
        else:
            # If the image is local, we don't need to pull it. Otherwise
            # we play it safe and always pull in case the user pushed a new
            # image for the given name and tag
            pull_parent_image = not docker_utils.is_local_image(
                parent_image
            )

        build_options = {"pull": pull_parent_image, "rm": False}

        dockerfile = self._generate_zenml_pipeline_dockerfile(
            parent_image=parent_image,
            docker_settings=docker_settings,
            download_files=download_files,
            requirements_files=requirements_files,
            apt_packages=apt_packages,
            entrypoint=entrypoint,
        )
        build_context.add_file(destination="Dockerfile", source=dockerfile)

        if extra_files:
            for destination, source in extra_files.items():
                build_context.add_file(
                    destination=destination, source=source
                )

        image_name_or_digest = image_builder.build(
            image_name=target_image_name,
            build_context=build_context,
            docker_build_options=build_options,
            container_registry=container_registry,
        )

    return image_name_or_digest, dockerfile, requirements
gather_requirements_files(docker_settings, stack, code_repository=None, log=True) staticmethod

Gathers and/or generates pip requirements files.

This method is called in PipelineDockerImageBuilder.build_docker_image but it is also called by other parts of the codebase, e.g. the AzureMLStepOperator, which needs to upload the requirements files to AzureML where the step image is then built.

Parameters:

Name Type Description Default
docker_settings DockerSettings

Docker settings that specifies which requirements to install.

required
stack Stack

The stack on which the pipeline will run.

required
code_repository Optional[BaseCodeRepository]

The code repository from which files will be downloaded.

None
log bool

If True, will log the requirements.

True

Exceptions:

Type Description
RuntimeError

If the command to export the local python packages failed.

Returns:

Type Description
List of tuples (filename, file_content, pip_options) of all requirements files. The files will be in the following order
  • Packages installed in the local Python environment
  • User-defined requirements
  • Requirements defined by user-defined and/or stack integrations
Source code in zenml/utils/pipeline_docker_image_builder.py
@staticmethod
def gather_requirements_files(
    docker_settings: DockerSettings,
    stack: "Stack",
    code_repository: Optional["BaseCodeRepository"] = None,
    log: bool = True,
) -> List[Tuple[str, str, List[str]]]:
    """Gathers and/or generates pip requirements files.

    This method is called in `PipelineDockerImageBuilder.build_docker_image`
    but it is also called by other parts of the codebase, e.g. the
    `AzureMLStepOperator`, which needs to upload the requirements files to
    AzureML where the step image is then built.

    Args:
        docker_settings: Docker settings that specifies which
            requirements to install.
        stack: The stack on which the pipeline will run.
        code_repository: The code repository from which files will be
            downloaded.
        log: If True, will log the requirements.

    Raises:
        RuntimeError: If the command to export the local python packages
            failed.

    Returns:
        List of tuples (filename, file_content, pip_options) of all
        requirements files.
        The files will be in the following order:
        - Packages installed in the local Python environment
        - User-defined requirements
        - Requirements defined by user-defined and/or stack integrations
    """
    requirements_files: List[Tuple[str, str, List[str]]] = []

    # Generate requirements file for the local environment if configured
    if docker_settings.replicate_local_python_environment:
        if isinstance(
            docker_settings.replicate_local_python_environment,
            PythonEnvironmentExportMethod,
        ):
            command = (
                docker_settings.replicate_local_python_environment.command
            )
        else:
            command = " ".join(
                docker_settings.replicate_local_python_environment
            )

        try:
            local_requirements = subprocess.check_output(
                command, shell=True
            ).decode()
        except subprocess.CalledProcessError as e:
            raise RuntimeError(
                "Unable to export local python packages."
            ) from e

        requirements_files.append(
            (".zenml_local_requirements", local_requirements, [])
        )
        if log:
            logger.info(
                "- Including python packages from local environment"
            )

    # Generate/Read requirements file for user-defined requirements
    if isinstance(docker_settings.requirements, str):
        user_requirements = io_utils.read_file_contents_as_string(
            docker_settings.requirements
        )
        if log:
            logger.info(
                "- Including user-defined requirements from file `%s`",
                os.path.abspath(docker_settings.requirements),
            )
    elif isinstance(docker_settings.requirements, List):
        user_requirements = "\n".join(docker_settings.requirements)
        if log:
            logger.info(
                "- Including user-defined requirements: %s",
                ", ".join(f"`{r}`" for r in docker_settings.requirements),
            )
    else:
        user_requirements = None

    if user_requirements:
        requirements_files.append(
            (".zenml_user_requirements", user_requirements, [])
        )

    # Generate requirements file for all required integrations
    integration_requirements = set(
        itertools.chain.from_iterable(
            integration_registry.select_integration_requirements(
                integration_name=integration,
                target_os=OperatingSystemType.LINUX,
            )
            for integration in docker_settings.required_integrations
        )
    )

    if docker_settings.install_stack_requirements:
        integration_requirements.update(stack.requirements())
        if code_repository:
            integration_requirements.update(code_repository.requirements)

    if integration_requirements:
        integration_requirements_list = sorted(integration_requirements)
        integration_requirements_file = "\n".join(
            integration_requirements_list
        )
        requirements_files.append(
            (
                ".zenml_integration_requirements",
                integration_requirements_file,
                [],
            )
        )
        if log:
            logger.info(
                "- Including integration requirements: %s",
                ", ".join(f"`{r}`" for r in integration_requirements_list),
            )

    # Generate requirements files for all ZenML Hub plugins
    if docker_settings.required_hub_plugins:
        (
            hub_internal_requirements,
            hub_pypi_requirements,
        ) = PipelineDockerImageBuilder._get_hub_requirements(
            docker_settings.required_hub_plugins
        )

        # Plugin packages themselves
        for i, (index, packages) in enumerate(
            hub_internal_requirements.items()
        ):
            file_name = f".zenml_hub_internal_requirements_{i}"
            file_lines = [f"-i {index}", *packages]
            file_contents = "\n".join(file_lines)
            requirements_files.append(
                (file_name, file_contents, ["--no-deps"])
            )
            if log:
                logger.info(
                    "- Including internal hub packages from index `%s`: %s",
                    index,
                    ", ".join(f"`{r}`" for r in packages),
                )

        # PyPI requirements of plugin packages
        if hub_pypi_requirements:
            file_name = ".zenml_hub_pypi_requirements"
            file_contents = "\n".join(hub_pypi_requirements)
            requirements_files.append((file_name, file_contents, []))
            if log:
                logger.info(
                    "- Including hub requirements from PyPI: %s",
                    ", ".join(f"`{r}`" for r in hub_pypi_requirements),
                )

    return requirements_files

proxy_utils

Proxy design pattern utils.

make_proxy_class(interface, attribute)

Proxy class decorator.

Use this decorator to transform the decorated class into a proxy that forwards all calls defined in the interface interface to the attribute class attribute that implements the same interface.

This class is useful in cases where you need to have a base class that acts as a proxy or facade for one or more other classes. Both the decorated class and the class attribute must inherit from the same ABC interface for this to work. Only regular methods are supported, not class methods or attributes.

Example: Let's say you have an interface called BodyBuilder, a base class called FatBob and another class called BigJim. BigJim implements the BodyBuilder interface, but FatBob does not. And let's say you want FatBob to look as if it implements the BodyBuilder interface, but in fact it just forwards all calls to BigJim. You could do this:

from abc import ABC, abstractmethod

class BodyBuilder(ABC):

    @abstractmethod
    def build_body(self):
        pass

class BigJim(BodyBuilder):

    def build_body(self):
        print("Looks fit!")

class FatBob(BodyBuilder)

    def __init__(self):
        self.big_jim = BigJim()

    def build_body(self):
        self.big_jim.build_body()

fat_bob = FatBob()
fat_bob.build_body()

But this leads to a lot of boilerplate code with bigger interfaces and makes everything harder to maintain. This is where the proxy class decorator comes in handy. Here's how to use it:

from zenml.utils.proxy_utils import make_proxy_class
from typing import Optional

@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
    big_jim: Optional[BodyBuilder] = None

    def __init__(self):
        self.big_jim = BigJim()

fat_bob = FatBob()
fat_bob.build_body()

This is the same as implementing FatBob to call BigJim explicitly, but it has the advantage that you don't need to write a lot of boilerplate code of modify the FatBob class every time you change something in the BodyBuilder interface.

This proxy decorator also allows to extend classes dynamically at runtime: if the attribute class attribute is set to None, the proxy class will assume that the interface is not implemented by the class and will raise a NotImplementedError:

@make_proxy_class(BodyBuilder, "big_jim")
class FatBob(BodyBuilder)
    big_jim: Optional[BodyBuilder] = None

    def __init__(self):
        self.big_jim = None

fat_bob = FatBob()

# Raises NotImplementedError, class not extended yet:
fat_bob.build_body()

fat_bob.big_jim = BigJim()
# Now it works:
fat_bob.build_body()

Parameters:

Name Type Description Default
interface Type[abc.ABC]

The interface to implement.

required
attribute str

The attribute of the base class to forward calls to.

required

Returns:

Type Description
Callable[[~C], ~C]

The proxy class.

Source code in zenml/utils/proxy_utils.py
def make_proxy_class(interface: Type[ABC], attribute: str) -> Callable[[C], C]:
    """Proxy class decorator.

    Use this decorator to transform the decorated class into a proxy that
    forwards all calls defined in the `interface` interface to the `attribute`
    class attribute that implements the same interface.

    This class is useful in cases where you need to have a base class that acts
    as a proxy or facade for one or more other classes. Both the decorated class
    and the class attribute must inherit from the same ABC interface for this to
    work. Only regular methods are supported, not class methods or attributes.

    Example: Let's say you have an interface called `BodyBuilder`, a base class
    called `FatBob` and another class called `BigJim`. `BigJim` implements the
    `BodyBuilder` interface, but `FatBob` does not. And let's say you want
    `FatBob` to look as if it implements the `BodyBuilder` interface, but in
    fact it just forwards all calls to `BigJim`. You could do this:

    ```python
    from abc import ABC, abstractmethod

    class BodyBuilder(ABC):

        @abstractmethod
        def build_body(self):
            pass

    class BigJim(BodyBuilder):

        def build_body(self):
            print("Looks fit!")

    class FatBob(BodyBuilder)

        def __init__(self):
            self.big_jim = BigJim()

        def build_body(self):
            self.big_jim.build_body()

    fat_bob = FatBob()
    fat_bob.build_body()
    ```

    But this leads to a lot of boilerplate code with bigger interfaces and
    makes everything harder to maintain. This is where the proxy class
    decorator comes in handy. Here's how to use it:

    ```python
    from zenml.utils.proxy_utils import make_proxy_class
    from typing import Optional

    @make_proxy_class(BodyBuilder, "big_jim")
    class FatBob(BodyBuilder)
        big_jim: Optional[BodyBuilder] = None

        def __init__(self):
            self.big_jim = BigJim()

    fat_bob = FatBob()
    fat_bob.build_body()
    ```

    This is the same as implementing FatBob to call BigJim explicitly, but it
    has the advantage that you don't need to write a lot of boilerplate code
    of modify the FatBob class every time you change something in the
    BodyBuilder interface.

    This proxy decorator also allows to extend classes dynamically at runtime:
    if the `attribute` class attribute is set to None, the proxy class
    will assume that the interface is not implemented by the class and will
    raise a NotImplementedError:

    ```python
    @make_proxy_class(BodyBuilder, "big_jim")
    class FatBob(BodyBuilder)
        big_jim: Optional[BodyBuilder] = None

        def __init__(self):
            self.big_jim = None

    fat_bob = FatBob()

    # Raises NotImplementedError, class not extended yet:
    fat_bob.build_body()

    fat_bob.big_jim = BigJim()
    # Now it works:
    fat_bob.build_body()
    ```

    Args:
        interface: The interface to implement.
        attribute: The attribute of the base class to forward calls to.

    Returns:
        The proxy class.
    """

    def make_proxy_method(cls: C, _method: F) -> F:
        """Proxy method decorator.

        Used to transform a method into a proxy that forwards all calls to the
        given class attribute.

        Args:
            cls: The class to use as the base.
            _method: The method to replace.

        Returns:
            The proxy method.
        """

        @wraps(_method)
        def proxy_method(*args: Any, **kw: Any) -> Any:
            """Proxy method.

            Args:
                *args: The arguments to pass to the method.
                **kw: The keyword arguments to pass to the method.

            Returns:
                The return value of the proxied method.

            Raises:
                TypeError: If the class does not have the attribute specified
                    in the decorator or if the attribute does not implement
                    the specified interface.
                NotImplementedError: If the attribute specified in the
                    decorator is None, i.e. the interface is not implemented.
            """
            self = args[0]
            if not hasattr(self, attribute):
                raise TypeError(
                    f"Class '{cls.__name__}' does not have a '{attribute}' "
                    f"as specified in the 'make_proxy_class' decorator."
                )
            proxied_obj = getattr(self, attribute)
            if proxied_obj is None:
                raise NotImplementedError(
                    f"This '{cls.__name__}' instance does not implement the "
                    f"'{interface.__name__}' interface."
                )
            if not isinstance(proxied_obj, interface):
                raise TypeError(
                    f"Interface '{interface.__name__}' must be implemented by "
                    f"the '{cls.__name__}' '{attribute}' attribute."
                )
            proxied_method = getattr(proxied_obj, _method.__name__)
            return proxied_method(*args[1:], **kw)

        return cast(F, proxy_method)

    def _inner_decorator(_cls: C) -> C:
        """Inner proxy class decorator.

        Args:
            _cls: The class to decorate.

        Returns:
            The decorated class.

        Raises:
            TypeError: If the decorated class does not implement the specified
                interface.
        """
        if not issubclass(_cls, interface):
            raise TypeError(
                f"Interface '{interface.__name__}' must be implemented by "
                f"the '{_cls.__name__}' class."
            )

        for method_name in interface.__abstractmethods__:
            original_method = getattr(_cls, method_name)
            method_proxy = make_proxy_method(_cls, original_method)
            # Make sure the proxy method is not considered abstract.
            method_proxy.__isabstractmethod__ = False
            setattr(_cls, method_name, method_proxy)

        # Remove the abstract methods in the interface from the decorated class.
        _cls.__abstractmethods__ = frozenset(
            method_name
            for method_name in _cls.__abstractmethods__
            if method_name not in interface.__abstractmethods__
        )

        return cast(C, _cls)

    return _inner_decorator

pydantic_utils

Utilities for pydantic models.

TemplateGenerator

Class to generate templates for pydantic models or classes.

Source code in zenml/utils/pydantic_utils.py
class TemplateGenerator:
    """Class to generate templates for pydantic models or classes."""

    def __init__(
        self, instance_or_class: Union[BaseModel, Type[BaseModel]]
    ) -> None:
        """Initializes the template generator.

        Args:
            instance_or_class: The pydantic model or model class for which to
                generate a template.
        """
        self.instance_or_class = instance_or_class

    def run(self) -> Dict[str, Any]:
        """Generates the template.

        Returns:
            The template dictionary.
        """
        if isinstance(self.instance_or_class, BaseModel):
            template = self._generate_template_for_model(
                self.instance_or_class
            )
        else:
            template = self._generate_template_for_model_class(
                self.instance_or_class
            )

        # Convert to json in an intermediate step so we can leverage Pydantic's
        # encoder to support types like UUID and datetime
        json_string = json.dumps(template, default=pydantic_encoder)
        return cast(Dict[str, Any], json.loads(json_string))

    def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]:
        """Generates a template for a pydantic model.

        Args:
            model: The model for which to generate the template.

        Returns:
            The model template.
        """
        template = self._generate_template_for_model_class(model.__class__)

        for name in model.__fields_set__:
            value = getattr(model, name)
            template[name] = self._generate_template_for_value(value)

        return template

    def _generate_template_for_model_class(
        self,
        model_class: Type[BaseModel],
    ) -> Dict[str, Any]:
        """Generates a template for a pydantic model class.

        Args:
            model_class: The model class for which to generate the template.

        Returns:
            The model class template.
        """
        template: Dict[str, Any] = {}

        for name, field in model_class.__fields__.items():
            if self._is_model_class(field.outer_type_):
                template[name] = self._generate_template_for_model_class(
                    field.outer_type_
                )
            elif field.outer_type_ is Optional and self._is_model_class(
                field.type_
            ):
                template[name] = self._generate_template_for_model_class(
                    field.type_
                )
            else:
                template[name] = field._type_display()

        return template

    def _generate_template_for_value(self, value: Any) -> Any:
        """Generates a template for an arbitrary value.

        Args:
            value: The value for which to generate the template.

        Returns:
            The value template.
        """
        if isinstance(value, Dict):
            return {
                k: self._generate_template_for_value(v)
                for k, v in value.items()
            }
        elif sequence_like(value):
            return [self._generate_template_for_value(v) for v in value]
        elif isinstance(value, BaseModel):
            return self._generate_template_for_model(value)
        else:
            return value

    @staticmethod
    def _is_model_class(value: Any) -> bool:
        """Checks if the given value is a pydantic model class.

        Args:
            value: The value to check.

        Returns:
            If the value is a pydantic model class.
        """
        return isinstance(value, type) and issubclass(value, BaseModel)
__init__(self, instance_or_class) special

Initializes the template generator.

Parameters:

Name Type Description Default
instance_or_class Union[pydantic.main.BaseModel, Type[pydantic.main.BaseModel]]

The pydantic model or model class for which to generate a template.

required
Source code in zenml/utils/pydantic_utils.py
def __init__(
    self, instance_or_class: Union[BaseModel, Type[BaseModel]]
) -> None:
    """Initializes the template generator.

    Args:
        instance_or_class: The pydantic model or model class for which to
            generate a template.
    """
    self.instance_or_class = instance_or_class
run(self)

Generates the template.

Returns:

Type Description
Dict[str, Any]

The template dictionary.

Source code in zenml/utils/pydantic_utils.py
def run(self) -> Dict[str, Any]:
    """Generates the template.

    Returns:
        The template dictionary.
    """
    if isinstance(self.instance_or_class, BaseModel):
        template = self._generate_template_for_model(
            self.instance_or_class
        )
    else:
        template = self._generate_template_for_model_class(
            self.instance_or_class
        )

    # Convert to json in an intermediate step so we can leverage Pydantic's
    # encoder to support types like UUID and datetime
    json_string = json.dumps(template, default=pydantic_encoder)
    return cast(Dict[str, Any], json.loads(json_string))

YAMLSerializationMixin (BaseModel) pydantic-model

Class to serialize/deserialize pydantic models to/from YAML.

Source code in zenml/utils/pydantic_utils.py
class YAMLSerializationMixin(BaseModel):
    """Class to serialize/deserialize pydantic models to/from YAML."""

    def yaml(self, sort_keys: bool = False, **kwargs: Any) -> str:
        """YAML string representation..

        Args:
            sort_keys: Whether to sort the keys in the YAML representation.
            **kwargs: Kwargs to pass to the pydantic json(...) method.

        Returns:
            YAML string representation.
        """
        dict_ = json.loads(self.json(**kwargs, sort_keys=sort_keys))
        return yaml.dump(dict_, sort_keys=sort_keys)

    @classmethod
    def from_yaml(cls: Type[M], path: str) -> M:
        """Creates an instance from a YAML file.

        Args:
            path: Path to a YAML file.

        Returns:
            The model instance.
        """
        dict_ = yaml_utils.read_yaml(path)
        return cls.parse_obj(dict_)
from_yaml(path) classmethod

Creates an instance from a YAML file.

Parameters:

Name Type Description Default
path str

Path to a YAML file.

required

Returns:

Type Description
~M

The model instance.

Source code in zenml/utils/pydantic_utils.py
@classmethod
def from_yaml(cls: Type[M], path: str) -> M:
    """Creates an instance from a YAML file.

    Args:
        path: Path to a YAML file.

    Returns:
        The model instance.
    """
    dict_ = yaml_utils.read_yaml(path)
    return cls.parse_obj(dict_)
yaml(self, sort_keys=False, **kwargs)

YAML string representation..

Parameters:

Name Type Description Default
sort_keys bool

Whether to sort the keys in the YAML representation.

False
**kwargs Any

Kwargs to pass to the pydantic json(...) method.

{}

Returns:

Type Description
str

YAML string representation.

Source code in zenml/utils/pydantic_utils.py
def yaml(self, sort_keys: bool = False, **kwargs: Any) -> str:
    """YAML string representation..

    Args:
        sort_keys: Whether to sort the keys in the YAML representation.
        **kwargs: Kwargs to pass to the pydantic json(...) method.

    Returns:
        YAML string representation.
    """
    dict_ = json.loads(self.json(**kwargs, sort_keys=sort_keys))
    return yaml.dump(dict_, sort_keys=sort_keys)

update_model(original, update, recursive=True, exclude_none=True)

Updates a pydantic model.

Parameters:

Name Type Description Default
original ~M

The model to update.

required
update Union[BaseModel, Dict[str, Any]]

The update values.

required
recursive bool

If True, dictionary values will be updated recursively.

True
exclude_none bool

If True, None values in the update dictionary will be removed.

True

Returns:

Type Description
~M

The updated model.

Source code in zenml/utils/pydantic_utils.py
def update_model(
    original: M,
    update: Union["BaseModel", Dict[str, Any]],
    recursive: bool = True,
    exclude_none: bool = True,
) -> M:
    """Updates a pydantic model.

    Args:
        original: The model to update.
        update: The update values.
        recursive: If `True`, dictionary values will be updated recursively.
        exclude_none: If `True`, `None` values in the update dictionary
            will be removed.

    Returns:
        The updated model.
    """
    if isinstance(update, Dict):
        if exclude_none:
            update_dict = dict_utils.remove_none_values(
                update, recursive=recursive
            )
        else:
            update_dict = update
    else:
        update_dict = update.dict(exclude_unset=True)

    original_dict = original.dict(exclude_unset=True)
    if recursive:
        values = dict_utils.recursive_update(original_dict, update_dict)
    else:
        values = {**original_dict, **update_dict}

    return original.__class__(**values)

validate_function_args(__func, __config, *args, **kwargs)

Validates arguments passed to a function.

This function validates that all arguments to call the function exist and that the types match.

It raises a pydantic.ValidationError if the validation fails.

Parameters:

Name Type Description Default
__func Callable[..., Any]

The function for which the arguments are passed.

required
__config Dict[str, Any]

The pydantic config for the underlying model that is created to validate the types of the arguments.

required
*args Any

Function arguments.

()
**kwargs Any

Function keyword arguments.

{}

Returns:

Type Description
Dict[str, Any]

The validated arguments.

Source code in zenml/utils/pydantic_utils.py
def validate_function_args(
    __func: Callable[..., Any],
    __config: Dict[str, Any],
    *args: Any,
    **kwargs: Any,
) -> Dict[str, Any]:
    """Validates arguments passed to a function.

    This function validates that all arguments to call the function exist and
    that the types match.

    It raises a pydantic.ValidationError if the validation fails.

    Args:
        __func: The function for which the arguments are passed.
        __config: The pydantic config for the underlying model that is created
            to validate the types of the arguments.
        *args: Function arguments.
        **kwargs: Function keyword arguments.

    Returns:
        The validated arguments.
    """
    parameter_prefix = "zenml__"

    signature = inspect.signature(__func)
    parameters = [
        param.replace(name=f"{parameter_prefix}{param.name}")
        for param in signature.parameters.values()
    ]
    signature = signature.replace(parameters=parameters)

    def f() -> None:
        pass

    # We create a dummy function with the original function signature, but
    # add a prefix to all arguments to avoid potential clashes with pydantic
    # BaseModel attributes
    f.__signature__ = signature  # type: ignore[attr-defined]
    f.__annotations__ = {
        f"{parameter_prefix}{key}": annotation
        for key, annotation in __func.__annotations__.items()
    }

    validation_func = ValidatedFunction(f, config=__config)

    kwargs = {
        f"{parameter_prefix}{key}": value for key, value in kwargs.items()
    }
    model = validation_func.init_model_instance(*args, **kwargs)

    validated_args = {
        k[len(parameter_prefix) :]: v
        for k, v in model._iter()
        if k in model.__fields_set__
        or model.__fields__[k].default_factory
        or model.__fields__[k].default
    }

    return validated_args

secret_utils

Utility functions for secrets and secret references.

SecretReference (tuple)

Class representing a secret reference.

Attributes:

Name Type Description
name str

The secret name.

key str

The secret key.

Source code in zenml/utils/secret_utils.py
class SecretReference(NamedTuple):
    """Class representing a secret reference.

    Attributes:
        name: The secret name.
        key: The secret key.
    """

    name: str
    key: str
__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in zenml/utils/secret_utils.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)
__new__(_cls, name, key) special staticmethod

Create new instance of SecretReference(name, key)

__repr__(self) special

Return a nicely formatted representation string

Source code in zenml/utils/secret_utils.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self

ClearTextField(*args, **kwargs)

Marks a pydantic field to prevent secret references.

Parameters:

Name Type Description Default
*args Any

Positional arguments which will be forwarded to pydantic.Field(...).

()
**kwargs Any

Keyword arguments which will be forwarded to pydantic.Field(...).

{}

Returns:

Type Description
Any

Pydantic field info.

Source code in zenml/utils/secret_utils.py
def ClearTextField(*args: Any, **kwargs: Any) -> Any:
    """Marks a pydantic field to prevent secret references.

    Args:
        *args: Positional arguments which will be forwarded
            to `pydantic.Field(...)`.
        **kwargs: Keyword arguments which will be forwarded to
            `pydantic.Field(...)`.

    Returns:
        Pydantic field info.
    """
    kwargs[PYDANTIC_CLEAR_TEXT_FIELD_MARKER] = True
    return Field(*args, **kwargs)  # type: ignore[pydantic-field]

SecretField(*args, **kwargs)

Marks a pydantic field as something containing sensitive information.

Parameters:

Name Type Description Default
*args Any

Positional arguments which will be forwarded to pydantic.Field(...).

()
**kwargs Any

Keyword arguments which will be forwarded to pydantic.Field(...).

{}

Returns:

Type Description
Any

Pydantic field info.

Source code in zenml/utils/secret_utils.py
def SecretField(*args: Any, **kwargs: Any) -> Any:
    """Marks a pydantic field as something containing sensitive information.

    Args:
        *args: Positional arguments which will be forwarded
            to `pydantic.Field(...)`.
        **kwargs: Keyword arguments which will be forwarded to
            `pydantic.Field(...)`.

    Returns:
        Pydantic field info.
    """
    kwargs[PYDANTIC_SENSITIVE_FIELD_MARKER] = True
    return Field(*args, **kwargs)  # type: ignore[pydantic-field]

is_clear_text_field(field)

Returns whether a pydantic field prevents secret references or not.

Parameters:

Name Type Description Default
field ModelField

The field to check.

required

Returns:

Type Description
bool

True if the field prevents secret references, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_clear_text_field(field: "ModelField") -> bool:
    """Returns whether a pydantic field prevents secret references or not.

    Args:
        field: The field to check.

    Returns:
        `True` if the field prevents secret references, `False` otherwise.
    """
    return field.field_info.extra.get(PYDANTIC_CLEAR_TEXT_FIELD_MARKER, False)

is_secret_field(field)

Returns whether a pydantic field contains sensitive information or not.

Parameters:

Name Type Description Default
field ModelField

The field to check.

required

Returns:

Type Description
bool

True if the field contains sensitive information, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_secret_field(field: "ModelField") -> bool:
    """Returns whether a pydantic field contains sensitive information or not.

    Args:
        field: The field to check.

    Returns:
        `True` if the field contains sensitive information, `False` otherwise.
    """
    return field.field_info.extra.get(PYDANTIC_SENSITIVE_FIELD_MARKER, False)

is_secret_reference(value)

Checks whether any value is a secret reference.

Parameters:

Name Type Description Default
value Any

The value to check.

required

Returns:

Type Description
bool

True if the value is a secret reference, False otherwise.

Source code in zenml/utils/secret_utils.py
def is_secret_reference(value: Any) -> bool:
    """Checks whether any value is a secret reference.

    Args:
        value: The value to check.

    Returns:
        `True` if the value is a secret reference, `False` otherwise.
    """
    if not isinstance(value, str):
        return False

    return bool(_secret_reference_expression.fullmatch(value))

parse_secret_reference(reference)

Parses a secret reference.

This function assumes the input string is a valid secret reference and does not perform any additional checks. If you pass an invalid secret reference here, this will most likely crash.

Parameters:

Name Type Description Default
reference str

The string representing a valid secret reference.

required

Returns:

Type Description
SecretReference

The parsed secret reference.

Source code in zenml/utils/secret_utils.py
def parse_secret_reference(reference: str) -> SecretReference:
    """Parses a secret reference.

    This function assumes the input string is a valid secret reference and
    **does not** perform any additional checks. If you pass an invalid secret
    reference here, this will most likely crash.

    Args:
        reference: The string representing a **valid** secret reference.

    Returns:
        The parsed secret reference.
    """
    reference = reference[2:]
    reference = reference[:-2]

    secret_name, secret_key = reference.split(".", 1)
    secret_name, secret_key = secret_name.strip(), secret_key.strip()
    return SecretReference(name=secret_name, key=secret_key)

settings_utils

Utility functions for ZenML settings.

get_flavor_setting_key(flavor)

Gets the setting key for a flavor.

Parameters:

Name Type Description Default
flavor Flavor

The flavor for which to get the key.

required

Returns:

Type Description
str

The setting key for the flavor.

Source code in zenml/utils/settings_utils.py
def get_flavor_setting_key(flavor: "Flavor") -> str:
    """Gets the setting key for a flavor.

    Args:
        flavor: The flavor for which to get the key.

    Returns:
        The setting key for the flavor.
    """
    return f"{flavor.type}.{flavor.name}"

get_general_settings()

Returns all general settings.

Returns:

Type Description
Dict[str, Type[BaseSettings]]

Dictionary mapping general settings keys to their type.

Source code in zenml/utils/settings_utils.py
def get_general_settings() -> Dict[str, Type["BaseSettings"]]:
    """Returns all general settings.

    Returns:
        Dictionary mapping general settings keys to their type.
    """
    from zenml.config import DockerSettings, ResourceSettings

    return {
        DOCKER_SETTINGS_KEY: DockerSettings,
        RESOURCE_SETTINGS_KEY: ResourceSettings,
    }

get_stack_component_for_settings_key(key, stack)

Gets the stack component of a stack for a given settings key.

Parameters:

Name Type Description Default
key str

The settings key for which to get the component.

required
stack Stack

The stack from which to get the component.

required

Exceptions:

Type Description
ValueError

If the key is invalid or the stack does not contain a component of the correct flavor.

Returns:

Type Description
StackComponent

The stack component.

Source code in zenml/utils/settings_utils.py
def get_stack_component_for_settings_key(
    key: str, stack: "Stack"
) -> "StackComponent":
    """Gets the stack component of a stack for a given settings key.

    Args:
        key: The settings key for which to get the component.
        stack: The stack from which to get the component.

    Raises:
        ValueError: If the key is invalid or the stack does not contain a
            component of the correct flavor.

    Returns:
        The stack component.
    """
    if not is_stack_component_setting_key(key):
        raise ValueError(
            f"Settings key {key} does not refer to a stack component."
        )

    component_type, flavor = key.split(".", 1)
    stack_component = stack.components.get(StackComponentType(component_type))
    if not stack_component or stack_component.flavor != flavor:
        raise ValueError(
            f"Component of type {component_type} in stack {stack} is not "
            f"of the flavor {flavor} specified by the settings key {key}."
        )
    return stack_component

get_stack_component_setting_key(stack_component)

Gets the setting key for a stack component.

Parameters:

Name Type Description Default
stack_component StackComponent

The stack component for which to get the key.

required

Returns:

Type Description
str

The setting key for the stack component.

Source code in zenml/utils/settings_utils.py
def get_stack_component_setting_key(stack_component: "StackComponent") -> str:
    """Gets the setting key for a stack component.

    Args:
        stack_component: The stack component for which to get the key.

    Returns:
        The setting key for the stack component.
    """
    return f"{stack_component.type}.{stack_component.flavor}"

is_general_setting_key(key)

Checks whether the key refers to a general setting.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key refers to a general setting.

Source code in zenml/utils/settings_utils.py
def is_general_setting_key(key: str) -> bool:
    """Checks whether the key refers to a general setting.

    Args:
        key: The key to check.

    Returns:
        If the key refers to a general setting.
    """
    return key in get_general_settings()

is_stack_component_setting_key(key)

Checks whether a settings key refers to a stack component.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key refers to a stack component.

Source code in zenml/utils/settings_utils.py
def is_stack_component_setting_key(key: str) -> bool:
    """Checks whether a settings key refers to a stack component.

    Args:
        key: The key to check.

    Returns:
        If the key refers to a stack component.
    """
    return bool(STACK_COMPONENT_REGEX.fullmatch(key))

is_valid_setting_key(key)

Checks whether a settings key is valid.

Parameters:

Name Type Description Default
key str

The key to check.

required

Returns:

Type Description
bool

If the key is valid.

Source code in zenml/utils/settings_utils.py
def is_valid_setting_key(key: str) -> bool:
    """Checks whether a settings key is valid.

    Args:
        key: The key to check.

    Returns:
        If the key is valid.
    """
    return is_general_setting_key(key) or is_stack_component_setting_key(key)

validate_setting_keys(setting_keys)

Validates settings keys.

Parameters:

Name Type Description Default
setting_keys Sequence[str]

The keys to validate.

required

Exceptions:

Type Description
ValueError

If any key is invalid.

Source code in zenml/utils/settings_utils.py
def validate_setting_keys(setting_keys: Sequence[str]) -> None:
    """Validates settings keys.

    Args:
        setting_keys: The keys to validate.

    Raises:
        ValueError: If any key is invalid.
    """
    for key in setting_keys:
        if not is_valid_setting_key(key):
            raise ValueError(
                f"Invalid setting key `{key}`. Setting keys can either refer "
                "to general settings (available keys: "
                f"{set(get_general_settings())}) or stack component specific "
                "settings. Stack component specific keys are of the format "
                "`<STACK_COMPONENT_TYPE>.<STACK_COMPONENT_FLAVOR>`."
            )

singleton

Utility class to turn classes into singleton classes.

SingletonMetaClass (type)

Singleton metaclass.

Use this metaclass to make any class into a singleton class:

class OneRing(metaclass=SingletonMetaClass):
    def __init__(self, owner):
        self._owner = owner

    @property
    def owner(self):
        return self._owner

the_one_ring = OneRing('Sauron')
the_lost_ring = OneRing('Frodo')
print(the_lost_ring.owner)  # Sauron
OneRing._clear() # ring destroyed
Source code in zenml/utils/singleton.py
class SingletonMetaClass(type):
    """Singleton metaclass.

    Use this metaclass to make any class into a singleton class:

    ```python
    class OneRing(metaclass=SingletonMetaClass):
        def __init__(self, owner):
            self._owner = owner

        @property
        def owner(self):
            return self._owner

    the_one_ring = OneRing('Sauron')
    the_lost_ring = OneRing('Frodo')
    print(the_lost_ring.owner)  # Sauron
    OneRing._clear() # ring destroyed
    ```
    """

    def __init__(cls, *args: Any, **kwargs: Any) -> None:
        """Initialize a singleton class.

        Args:
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(*args, **kwargs)
        cls.__singleton_instance: Optional["SingletonMetaClass"] = None

    def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
        """Create or return the singleton instance.

        Args:
            *args: Additional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            The singleton instance.
        """
        if not cls.__singleton_instance:
            cls.__singleton_instance = cast(
                "SingletonMetaClass", super().__call__(*args, **kwargs)
            )

        return cls.__singleton_instance

    def _clear(cls) -> None:
        """Clear the singleton instance."""
        cls.__singleton_instance = None
__call__(cls, *args, **kwargs) special

Create or return the singleton instance.

Parameters:

Name Type Description Default
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
SingletonMetaClass

The singleton instance.

Source code in zenml/utils/singleton.py
def __call__(cls, *args: Any, **kwargs: Any) -> "SingletonMetaClass":
    """Create or return the singleton instance.

    Args:
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        The singleton instance.
    """
    if not cls.__singleton_instance:
        cls.__singleton_instance = cast(
            "SingletonMetaClass", super().__call__(*args, **kwargs)
        )

    return cls.__singleton_instance
__init__(cls, *args, **kwargs) special

Initialize a singleton class.

Parameters:

Name Type Description Default
*args Any

Additional arguments.

()
**kwargs Any

Additional keyword arguments.

{}
Source code in zenml/utils/singleton.py
def __init__(cls, *args: Any, **kwargs: Any) -> None:
    """Initialize a singleton class.

    Args:
        *args: Additional arguments.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(*args, **kwargs)
    cls.__singleton_instance: Optional["SingletonMetaClass"] = None

source_code_utils

Utilities for getting the source code of objects.

get_hashed_source_code(value)

Returns a hash of the objects source code.

Parameters:

Name Type Description Default
value Any

object to get source from.

required

Returns:

Type Description
str

Hash of source code.

Exceptions:

Type Description
TypeError

If unable to compute the hash.

Source code in zenml/utils/source_code_utils.py
def get_hashed_source_code(value: Any) -> str:
    """Returns a hash of the objects source code.

    Args:
        value: object to get source from.

    Returns:
        Hash of source code.

    Raises:
        TypeError: If unable to compute the hash.
    """
    try:
        source_code = get_source_code(value)
    except TypeError:
        raise TypeError(
            f"Unable to compute the hash of source code of object: {value}."
        )
    return hashlib.sha256(source_code.encode("utf-8")).hexdigest()

get_source_code(value)

Returns the source code of an object.

If executing within a IPython kernel environment, then this monkey-patches inspect module temporarily with a workaround to get source from the cell.

Parameters:

Name Type Description Default
value Any

object to get source from.

required

Returns:

Type Description
str

Source code of object.

Source code in zenml/utils/source_code_utils.py
def get_source_code(value: Any) -> str:
    """Returns the source code of an object.

    If executing within a IPython kernel environment, then this monkey-patches
    `inspect` module temporarily with a workaround to get source from the cell.

    Args:
        value: object to get source from.

    Returns:
        Source code of object.
    """
    if Environment.in_notebook():
        # Monkey patch inspect.getfile temporarily to make getsource work.
        # Source: https://stackoverflow.com/questions/51566497/
        def _new_getfile(
            object: Any,
            _old_getfile: Callable[
                [
                    Union[
                        ModuleType,
                        Type[Any],
                        MethodType,
                        FunctionType,
                        TracebackType,
                        FrameType,
                        CodeType,
                        Callable[..., Any],
                    ]
                ],
                str,
            ] = inspect.getfile,
        ) -> Any:
            if not inspect.isclass(object):
                return _old_getfile(object)

            # Lookup by parent module (as in current inspect)
            if hasattr(object, "__module__"):
                object_ = sys.modules.get(object.__module__)
                if hasattr(object_, "__file__"):
                    return object_.__file__  # type: ignore[union-attr]

            # If parent module is __main__, lookup by methods
            for _, member in inspect.getmembers(object):
                if (
                    inspect.isfunction(member)
                    and object.__qualname__ + "." + member.__name__
                    == member.__qualname__
                ):
                    return inspect.getfile(member)
            else:
                raise TypeError(f"Source for {object!r} not found.")

        # Monkey patch, compute source, then revert monkey patch.
        _old_getfile = inspect.getfile
        inspect.getfile = _new_getfile
        try:
            src = inspect.getsource(value)
        finally:
            inspect.getfile = _old_getfile
    else:
        # Use standard inspect if running outside a notebook
        src = inspect.getsource(value)
    return src

source_utils

Utilities for loading/resolving objects.

get_source_root()

Get the source root.

The source root will be determined in the following order: - The manually specified custom source root if it was set. - The ZenML repository directory if one exists in the current working directory or any parent directories. - The parent directory of the main module file.

Returns:

Type Description
str

The source root.

Exceptions:

Type Description
RuntimeError

If the main module file can't be found.

Source code in zenml/utils/source_utils.py
def get_source_root() -> str:
    """Get the source root.

    The source root will be determined in the following order:
    - The manually specified custom source root if it was set.
    - The ZenML repository directory if one exists in the current working
      directory or any parent directories.
    - The parent directory of the main module file.

    Returns:
        The source root.

    Raises:
        RuntimeError: If the main module file can't be found.
    """
    if _CUSTOM_SOURCE_ROOT:
        logger.debug("Using custom source root: %s", _CUSTOM_SOURCE_ROOT)
        return _CUSTOM_SOURCE_ROOT

    from zenml.client import Client

    repo_root = Client.find_repository()
    if repo_root:
        logger.debug("Using repository root as source root: %s", repo_root)
        return str(repo_root.resolve())

    main_module = sys.modules.get("__main__")
    if main_module is None:
        raise RuntimeError(
            "Unable to determine source root because the main module could not "
            "be found."
        )

    if not hasattr(main_module, "__file__") or not main_module.__file__:
        raise RuntimeError(
            "Unable to determine source root because the main module does not "
            "have an associated file. This could be because you're running in "
            "an interactive Python environment."
        )

    path = Path(main_module.__file__).resolve().parent

    logger.debug("Using main module parent directory as source root: %s", path)
    return str(path)

get_source_type(module)

Get the type of a source.

Parameters:

Name Type Description Default
module module

The module for which to get the source type.

required

Returns:

Type Description
SourceType

The source type.

Source code in zenml/utils/source_utils.py
def get_source_type(module: ModuleType) -> SourceType:
    """Get the type of a source.

    Args:
        module: The module for which to get the source type.

    Returns:
        The source type.
    """
    try:
        file_path = inspect.getfile(module)
    except (TypeError, OSError):
        if module.__name__ == "__main__" and Environment.in_notebook():
            return SourceType.USER

        return SourceType.BUILTIN

    if is_internal_module(module_name=module.__name__):
        return SourceType.INTERNAL

    if is_standard_lib_file(file_path=file_path):
        return SourceType.BUILTIN

    if is_user_file(file_path=file_path):
        return SourceType.USER

    if is_distribution_package_file(
        file_path=file_path, module_name=module.__name__
    ):
        return SourceType.DISTRIBUTION_PACKAGE

    return SourceType.UNKNOWN

is_distribution_package_file(file_path, module_name)

Checks if a file/module belongs to a distribution package.

Parameters:

Name Type Description Default
file_path str

The file path to check.

required
module_name str

The module name.

required

Returns:

Type Description
bool

True if the file/module belongs to a distribution package, False otherwise.

Source code in zenml/utils/source_utils.py
def is_distribution_package_file(file_path: str, module_name: str) -> bool:
    """Checks if a file/module belongs to a distribution package.

    Args:
        file_path: The file path to check.
        module_name: The module name.

    Returns:
        True if the file/module belongs to a distribution package, False
        otherwise.
    """
    absolute_file_path = Path(file_path).resolve()

    for path in site.getsitepackages() + [site.getusersitepackages()]:
        if Path(path).resolve() in absolute_file_path.parents:
            return True

    if _get_package_for_module(module_name=module_name):
        return True

    # TODO: Both of the previous checks don't detect editable installs because
    # the site packages dir only contains a reference to the source files,
    # not the actual files, and importlib_metadata doesn't detect it as a valid
    # distribution package. That means currently editable installs get a
    # source type UNKNOWN which might or might not lead to issues.

    return False

is_internal_module(module_name)

Checks if a module is internal (=part of the zenml package).

Parameters:

Name Type Description Default
module_name str

Name of the module to check.

required

Returns:

Type Description
bool

True if the module is internal, False otherwise.

Source code in zenml/utils/source_utils.py
def is_internal_module(module_name: str) -> bool:
    """Checks if a module is internal (=part of the zenml package).

    Args:
        module_name: Name of the module to check.

    Returns:
        True if the module is internal, False otherwise.
    """
    return module_name.split(".", maxsplit=1)[0] == "zenml"

is_standard_lib_file(file_path)

Checks if a file belongs to the Python standard library.

Parameters:

Name Type Description Default
file_path str

The file path to check.

required

Returns:

Type Description
bool

True if the file belongs to the Python standard library, False otherwise.

Source code in zenml/utils/source_utils.py
def is_standard_lib_file(file_path: str) -> bool:
    """Checks if a file belongs to the Python standard library.

    Args:
        file_path: The file path to check.

    Returns:
        True if the file belongs to the Python standard library, False
        otherwise.
    """
    stdlib_root = get_python_lib(standard_lib=True)
    return Path(stdlib_root).resolve() in Path(file_path).resolve().parents

is_user_file(file_path)

Checks if a file is a user file.

Parameters:

Name Type Description Default
file_path str

The file path to check.

required

Returns:

Type Description
bool

True if the file is a user file, False otherwise.

Source code in zenml/utils/source_utils.py
def is_user_file(file_path: str) -> bool:
    """Checks if a file is a user file.

    Args:
        file_path: The file path to check.

    Returns:
        True if the file is a user file, False otherwise.
    """
    source_root = get_source_root()
    return Path(source_root) in Path(file_path).resolve().parents

load(source)

Load a source or import path.

Parameters:

Name Type Description Default
source Union[zenml.config.source.Source, str]

The source to load.

required

Returns:

Type Description
Any

The loaded object.

Source code in zenml/utils/source_utils.py
def load(source: Union[Source, str]) -> Any:
    """Load a source or import path.

    Args:
        source: The source to load.

    Returns:
        The loaded object.
    """
    if isinstance(source, str):
        source = Source.from_import_path(source)

    if source.import_path == NoneTypeSource.import_path:
        # The class of the `None` object doesn't exist in the `builtin` module
        # so we need to manually handle it here
        return NoneType

    import_root = None
    if source.type == SourceType.CODE_REPOSITORY:
        source = CodeRepositorySource.parse_obj(source)
        _warn_about_potential_source_loading_issues(source=source)
        import_root = get_source_root()
    elif source.type == SourceType.DISTRIBUTION_PACKAGE:
        source = DistributionPackageSource.parse_obj(source)
        if source.version:
            current_package_version = _get_package_version(
                package_name=source.package_name
            )
            if current_package_version != source.version:
                logger.warning(
                    "The currently installed version `%s` of package `%s` "
                    "does not match the source version `%s`. This might lead "
                    "to unexpected behavior when using the source object `%s`.",
                    current_package_version,
                    source.package_name,
                    source.version,
                    source.import_path,
                )
    elif source.type in {SourceType.USER, SourceType.UNKNOWN}:
        # Unknown source might also refer to a user file, include source
        # root in python path just to be sure
        import_root = get_source_root()

    module = _load_module(module_name=source.module, import_root=import_root)

    if source.attribute:
        obj = getattr(module, source.attribute)
    else:
        obj = module

    return obj

load_and_validate_class(source, expected_class)

Loads a source class and validates its class.

Parameters:

Name Type Description Default
source Union[str, zenml.config.source.Source]

The source.

required
expected_class Type[Any]

The class that the source should resolve to.

required

Exceptions:

Type Description
TypeError

If the source does not resolve to the expected class.

Returns:

Type Description
Type[Any]

The resolved source class.

Source code in zenml/utils/source_utils.py
def load_and_validate_class(
    source: Union[str, Source], expected_class: Type[Any]
) -> Type[Any]:
    """Loads a source class and validates its class.

    Args:
        source: The source.
        expected_class: The class that the source should resolve to.

    Raises:
        TypeError: If the source does not resolve to the expected class.

    Returns:
        The resolved source class.
    """
    obj = load(source)

    if isinstance(obj, type) and issubclass(obj, expected_class):
        return obj
    else:
        raise TypeError(
            f"Error while loading `{source}`. Expected class "
            f"{expected_class.__name__}, got {obj} instead."
        )

prepend_python_path(path)

Context manager to temporarily prepend a path to the python path.

Parameters:

Name Type Description Default
path str

Path that will be prepended to sys.path for the duration of the context manager.

required

Yields:

Type Description
Iterator[NoneType]

None

Source code in zenml/utils/source_utils.py
@contextlib.contextmanager
def prepend_python_path(path: str) -> Iterator[None]:
    """Context manager to temporarily prepend a path to the python path.

    Args:
        path: Path that will be prepended to sys.path for the duration of
            the context manager.

    Yields:
        None
    """
    try:
        sys.path.insert(0, path)
        yield
    finally:
        sys.path.remove(path)

resolve(obj, skip_validation=False)

Resolve an object.

Parameters:

Name Type Description Default
obj Union[Type[Any], Callable[..., Any], module]

The object to resolve.

required
skip_validation bool

If True, the validation that the object exist in the module is skipped.

False

Exceptions:

Type Description
RuntimeError

If the object can't be resolved.

Returns:

Type Description
Source

The source of the resolved object.

Source code in zenml/utils/source_utils.py
def resolve(
    obj: Union[Type[Any], Callable[..., Any], ModuleType, NoneType],
    skip_validation: bool = False,
) -> Source:
    """Resolve an object.

    Args:
        obj: The object to resolve.
        skip_validation: If True, the validation that the object exist in the
            module is skipped.

    Raises:
        RuntimeError: If the object can't be resolved.

    Returns:
        The source of the resolved object.
    """
    if obj is NoneType:  # type: ignore[comparison-overlap]
        # The class of the `None` object doesn't exist in the `builtin` module
        # so we need to manually handle it here
        return NoneTypeSource
    elif isinstance(obj, ModuleType):
        module = obj
        attribute_name = None
    else:
        module = sys.modules[obj.__module__]
        attribute_name = obj.__name__  # type: ignore[union-attr]

    if (
        not skip_validation
        and attribute_name
        and getattr(module, attribute_name, None) is not obj
    ):
        raise RuntimeError(
            f"Unable to resolve object `{obj}`. For the resolving to work, the "
            "class or function must be defined as top-level code (= it must "
            "get defined when importing the module) and not inside a function/"
            f"if-condition. Please make sure that your `{module.__name__}` "
            f"module has a top-level attribute `{attribute_name}` that "
            "holds the object you want to resolve."
        )

    module_name = module.__name__
    if module_name == "__main__":
        module_name = _resolve_module(module)

    source_type = get_source_type(module=module)

    if source_type == SourceType.USER:
        from zenml.utils import code_repository_utils

        local_repo_context = (
            code_repository_utils.find_active_code_repository()
        )

        if local_repo_context and not local_repo_context.has_local_changes:
            module_name = _resolve_module(module)

            source_root = get_source_root()
            subdir = PurePath(source_root).relative_to(local_repo_context.root)

            return CodeRepositorySource(
                repository_id=local_repo_context.code_repository_id,
                commit=local_repo_context.current_commit,
                subdirectory=subdir.as_posix(),
                module=module_name,
                attribute=attribute_name,
            )

        module_name = _resolve_module(module)
    elif source_type == SourceType.DISTRIBUTION_PACKAGE:
        package_name = _get_package_for_module(module_name=module_name)
        if package_name:
            package_version = _get_package_version(package_name=package_name)
            return DistributionPackageSource(
                module=module_name,
                attribute=attribute_name,
                package_name=package_name,
                version=package_version,
                type=source_type,
            )
        else:
            # Fallback to an unknown source if we can't find the package
            source_type = SourceType.UNKNOWN

    return Source(
        module=module_name, attribute=attribute_name, type=source_type
    )

set_custom_source_root(source_root)

Sets a custom source root.

If set this has the highest priority and will always be used as the source root.

Parameters:

Name Type Description Default
source_root Optional[str]

The source root to use.

required
Source code in zenml/utils/source_utils.py
def set_custom_source_root(source_root: Optional[str]) -> None:
    """Sets a custom source root.

    If set this has the highest priority and will always be used as the source
    root.

    Args:
        source_root: The source root to use.
    """
    logger.debug("Setting custom source root: %s", source_root)
    global _CUSTOM_SOURCE_ROOT
    _CUSTOM_SOURCE_ROOT = source_root

validate_source_class(source, expected_class)

Validates that a source resolves to a certain class.

Parameters:

Name Type Description Default
source Union[zenml.config.source.Source, str]

The source to validate.

required
expected_class Type[Any]

The class that the source should resolve to.

required

Returns:

Type Description
bool

True if the source resolves to the expected class, False otherwise.

Source code in zenml/utils/source_utils.py
def validate_source_class(
    source: Union[Source, str], expected_class: Type[Any]
) -> bool:
    """Validates that a source resolves to a certain class.

    Args:
        source: The source to validate.
        expected_class: The class that the source should resolve to.

    Returns:
        True if the source resolves to the expected class, False otherwise.
    """
    try:
        obj = load(source)
    except Exception:
        return False

    if isinstance(obj, type) and issubclass(obj, expected_class):
        return True
    else:
        return False

string_utils

Utils for strings.

b64_decode(input_)

Returns a decoded string of the base 64 encoded input string.

Parameters:

Name Type Description Default
input_ str

Base64 encoded string.

required

Returns:

Type Description
str

Decoded string.

Source code in zenml/utils/string_utils.py
def b64_decode(input_: str) -> str:
    """Returns a decoded string of the base 64 encoded input string.

    Args:
        input_: Base64 encoded string.

    Returns:
        Decoded string.
    """
    encoded_bytes = input_.encode()
    decoded_bytes = base64.b64decode(encoded_bytes)
    return decoded_bytes.decode()

b64_encode(input_)

Returns a base 64 encoded string of the input string.

Parameters:

Name Type Description Default
input_ str

The input to encode.

required

Returns:

Type Description
str

Base64 encoded string.

Source code in zenml/utils/string_utils.py
def b64_encode(input_: str) -> str:
    """Returns a base 64 encoded string of the input string.

    Args:
        input_: The input to encode.

    Returns:
        Base64 encoded string.
    """
    input_bytes = input_.encode()
    encoded_bytes = base64.b64encode(input_bytes)
    return encoded_bytes.decode()

get_human_readable_filesize(bytes_)

Convert a file size in bytes into a human-readable string.

Parameters:

Name Type Description Default
bytes_ int

The number of bytes to convert.

required

Returns:

Type Description
str

A human-readable string.

Source code in zenml/utils/string_utils.py
def get_human_readable_filesize(bytes_: int) -> str:
    """Convert a file size in bytes into a human-readable string.

    Args:
        bytes_: The number of bytes to convert.

    Returns:
        A human-readable string.
    """
    size = abs(float(bytes_))
    for unit in ["B", "KiB", "MiB", "GiB"]:
        if size < 1024.0 or unit == "GiB":
            break
        size /= 1024.0

    return f"{size:.2f} {unit}"

get_human_readable_time(seconds)

Convert seconds into a human-readable string.

Parameters:

Name Type Description Default
seconds float

The number of seconds to convert.

required

Returns:

Type Description
str

A human-readable string.

Source code in zenml/utils/string_utils.py
def get_human_readable_time(seconds: float) -> str:
    """Convert seconds into a human-readable string.

    Args:
        seconds: The number of seconds to convert.

    Returns:
        A human-readable string.
    """
    prefix = "-" if seconds < 0 else ""
    seconds = abs(seconds)
    int_seconds = int(seconds)
    days, int_seconds = divmod(int_seconds, 86400)
    hours, int_seconds = divmod(int_seconds, 3600)
    minutes, int_seconds = divmod(int_seconds, 60)
    if days > 0:
        time_string = f"{days}d{hours}h{minutes}m{int_seconds}s"
    elif hours > 0:
        time_string = f"{hours}h{minutes}m{int_seconds}s"
    elif minutes > 0:
        time_string = f"{minutes}m{int_seconds}s"
    else:
        time_string = f"{seconds:.3f}s"

    return prefix + time_string

random_str(length)

Generate a random human readable string of given length.

Parameters:

Name Type Description Default
length int

Length of string

required

Returns:

Type Description
str

Random human-readable string.

Source code in zenml/utils/string_utils.py
def random_str(length: int) -> str:
    """Generate a random human readable string of given length.

    Args:
        length: Length of string

    Returns:
        Random human-readable string.
    """
    random.seed()
    return "".join(random.choices(string.ascii_letters, k=length))

typed_model

Utility classes for adding type information to Pydantic models.

BaseTypedModel (BaseModel) pydantic-model

Typed Pydantic model base class.

Use this class as a base class instead of BaseModel to automatically add a type literal attribute to the model that stores the name of the class.

This can be useful when serializing models to JSON and then de-serializing them as part of a submodel union field, e.g.:


class BluePill(BaseTypedModel):
    ...

class RedPill(BaseTypedModel):
    ...

class TheMatrix(BaseTypedModel):
    choice: Union[BluePill, RedPill] = Field(..., discriminator='type')

matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = TheMatrix.parse_obj(d)
assert isinstance(new_matrix.choice, RedPill)

It can also facilitate de-serializing objects when their type isn't known:

matrix = TheMatrix(choice=RedPill())
d = matrix.dict()
new_matrix = BaseTypedModel.from_dict(d)
assert isinstance(new_matrix.choice, RedPill)
Source code in zenml/utils/typed_model.py
class BaseTypedModel(BaseModel, metaclass=BaseTypedModelMeta):
    """Typed Pydantic model base class.

    Use this class as a base class instead of BaseModel to automatically
    add a `type` literal attribute to the model that stores the name of the
    class.

    This can be useful when serializing models to JSON and then de-serializing
    them as part of a submodel union field, e.g.:

    ```python

    class BluePill(BaseTypedModel):
        ...

    class RedPill(BaseTypedModel):
        ...

    class TheMatrix(BaseTypedModel):
        choice: Union[BluePill, RedPill] = Field(..., discriminator='type')

    matrix = TheMatrix(choice=RedPill())
    d = matrix.dict()
    new_matrix = TheMatrix.parse_obj(d)
    assert isinstance(new_matrix.choice, RedPill)
    ```

    It can also facilitate de-serializing objects when their type isn't known:

    ```python
    matrix = TheMatrix(choice=RedPill())
    d = matrix.dict()
    new_matrix = BaseTypedModel.from_dict(d)
    assert isinstance(new_matrix.choice, RedPill)
    ```
    """

    @classmethod
    def from_dict(
        cls,
        model_dict: Dict[str, Any],
    ) -> "BaseTypedModel":
        """Instantiate a Pydantic model from a serialized JSON-able dict representation.

        Args:
            model_dict: the model attributes serialized as JSON-able dict.

        Returns:
            A BaseTypedModel created from the serialized representation.

        Raises:
            RuntimeError: if the model_dict contains an invalid type.
        """
        model_type = model_dict.get("type")
        if not model_type:
            raise RuntimeError(
                "`type` information is missing from the serialized model dict."
            )
        cls = source_utils.load(model_type)
        if not issubclass(cls, BaseTypedModel):
            raise RuntimeError(
                f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
            )

        return cls.parse_obj(model_dict)

    @classmethod
    def from_json(
        cls,
        json_str: str,
    ) -> "BaseTypedModel":
        """Instantiate a Pydantic model from a serialized JSON representation.

        Args:
            json_str: the model attributes serialized as JSON.

        Returns:
            A BaseTypedModel created from the serialized representation.
        """
        model_dict = json.loads(json_str)
        return cls.from_dict(model_dict)
from_dict(model_dict) classmethod

Instantiate a Pydantic model from a serialized JSON-able dict representation.

Parameters:

Name Type Description Default
model_dict Dict[str, Any]

the model attributes serialized as JSON-able dict.

required

Returns:

Type Description
BaseTypedModel

A BaseTypedModel created from the serialized representation.

Exceptions:

Type Description
RuntimeError

if the model_dict contains an invalid type.

Source code in zenml/utils/typed_model.py
@classmethod
def from_dict(
    cls,
    model_dict: Dict[str, Any],
) -> "BaseTypedModel":
    """Instantiate a Pydantic model from a serialized JSON-able dict representation.

    Args:
        model_dict: the model attributes serialized as JSON-able dict.

    Returns:
        A BaseTypedModel created from the serialized representation.

    Raises:
        RuntimeError: if the model_dict contains an invalid type.
    """
    model_type = model_dict.get("type")
    if not model_type:
        raise RuntimeError(
            "`type` information is missing from the serialized model dict."
        )
    cls = source_utils.load(model_type)
    if not issubclass(cls, BaseTypedModel):
        raise RuntimeError(
            f"Class `{cls}` is not a ZenML BaseTypedModel subclass."
        )

    return cls.parse_obj(model_dict)
from_json(json_str) classmethod

Instantiate a Pydantic model from a serialized JSON representation.

Parameters:

Name Type Description Default
json_str str

the model attributes serialized as JSON.

required

Returns:

Type Description
BaseTypedModel

A BaseTypedModel created from the serialized representation.

Source code in zenml/utils/typed_model.py
@classmethod
def from_json(
    cls,
    json_str: str,
) -> "BaseTypedModel":
    """Instantiate a Pydantic model from a serialized JSON representation.

    Args:
        json_str: the model attributes serialized as JSON.

    Returns:
        A BaseTypedModel created from the serialized representation.
    """
    model_dict = json.loads(json_str)
    return cls.from_dict(model_dict)

BaseTypedModelMeta (ModelMetaclass)

Metaclass responsible for adding type information to Pydantic models.

Source code in zenml/utils/typed_model.py
class BaseTypedModelMeta(ModelMetaclass):
    """Metaclass responsible for adding type information to Pydantic models."""

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "BaseTypedModelMeta":
        """Creates a Pydantic BaseModel class.

        This includes a hidden attribute that reflects the full class
        identifier.

        Args:
            name: The name of the class.
            bases: The base classes of the class.
            dct: The class dictionary.

        Returns:
            A Pydantic BaseModel class that includes a hidden attribute that
            reflects the full class identifier.

        Raises:
            TypeError: If the class is not a Pydantic BaseModel class.
        """
        if "type" in dct:
            raise TypeError(
                "`type` is a reserved attribute name for BaseTypedModel "
                "subclasses"
            )
        type_name = f"{dct['__module__']}.{dct['__qualname__']}"
        type_ann = Literal[type_name]  # type: ignore[valid-type]
        type = Field(type_name)
        dct.setdefault("__annotations__", dict())["type"] = type_ann
        dct["type"] = type
        cls = cast(
            Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
        )
        return cls
__new__(mcs, name, bases, dct) special staticmethod

Creates a Pydantic BaseModel class.

This includes a hidden attribute that reflects the full class identifier.

Parameters:

Name Type Description Default
name str

The name of the class.

required
bases Tuple[Type[Any], ...]

The base classes of the class.

required
dct Dict[str, Any]

The class dictionary.

required

Returns:

Type Description
BaseTypedModelMeta

A Pydantic BaseModel class that includes a hidden attribute that reflects the full class identifier.

Exceptions:

Type Description
TypeError

If the class is not a Pydantic BaseModel class.

Source code in zenml/utils/typed_model.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseTypedModelMeta":
    """Creates a Pydantic BaseModel class.

    This includes a hidden attribute that reflects the full class
    identifier.

    Args:
        name: The name of the class.
        bases: The base classes of the class.
        dct: The class dictionary.

    Returns:
        A Pydantic BaseModel class that includes a hidden attribute that
        reflects the full class identifier.

    Raises:
        TypeError: If the class is not a Pydantic BaseModel class.
    """
    if "type" in dct:
        raise TypeError(
            "`type` is a reserved attribute name for BaseTypedModel "
            "subclasses"
        )
    type_name = f"{dct['__module__']}.{dct['__qualname__']}"
    type_ann = Literal[type_name]  # type: ignore[valid-type]
    type = Field(type_name)
    dct.setdefault("__annotations__", dict())["type"] = type_ann
    dct["type"] = type
    cls = cast(
        Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct)
    )
    return cls

uuid_utils

Utility functions for handling UUIDs.

generate_uuid_from_string(value)

Deterministically generates a UUID from a string seed.

Parameters:

Name Type Description Default
value str

The string from which to generate the UUID.

required

Returns:

Type Description
UUID

The generated UUID.

Source code in zenml/utils/uuid_utils.py
def generate_uuid_from_string(value: str) -> UUID:
    """Deterministically generates a UUID from a string seed.

    Args:
        value: The string from which to generate the UUID.

    Returns:
        The generated UUID.
    """
    hash_ = hashlib.md5()
    hash_.update(value.encode("utf-8"))
    return UUID(hex=hash_.hexdigest(), version=4)

is_valid_uuid(value, version=4)

Checks if a string is a valid UUID.

Parameters:

Name Type Description Default
value Any

String to check.

required
version int

Version of UUID to check for.

4

Returns:

Type Description
bool

True if string is a valid UUID, False otherwise.

Source code in zenml/utils/uuid_utils.py
def is_valid_uuid(value: Any, version: int = 4) -> bool:
    """Checks if a string is a valid UUID.

    Args:
        value: String to check.
        version: Version of UUID to check for.

    Returns:
        True if string is a valid UUID, False otherwise.
    """
    if isinstance(value, UUID):
        return True
    if isinstance(value, str):
        try:
            UUID(value, version=version)
            return True
        except ValueError:
            return False
    return False

parse_name_or_uuid(name_or_id)

Convert a "name or id" string value to a string or UUID.

Parameters:

Name Type Description Default
name_or_id Optional[str]

Name or id to convert.

required

Returns:

Type Description
Union[str, uuid.UUID]

A UUID if name_or_id is a UUID, string otherwise.

Source code in zenml/utils/uuid_utils.py
def parse_name_or_uuid(
    name_or_id: Optional[str],
) -> Optional[Union[str, UUID]]:
    """Convert a "name or id" string value to a string or UUID.

    Args:
        name_or_id: Name or id to convert.

    Returns:
        A UUID if name_or_id is a UUID, string otherwise.
    """
    if name_or_id:
        try:
            return UUID(name_or_id)
        except ValueError:
            return name_or_id
    else:
        return name_or_id

visualization_utils

Utility functions for dashboard visualizations.

format_csv_visualization_as_html(csv_visualization, max_rows=10, max_cols=10)

Formats a CSV visualization as an HTML table.

Parameters:

Name Type Description Default
csv_visualization str

CSV visualization as a string.

required
max_rows int

Maximum number of rows to display. Remaining rows will be replaced by an ellipsis in the middle of the table.

10
max_cols int

Maximum number of columns to display. Remaining columns will be replaced by an ellipsis at the end of each row.

10

Returns:

Type Description
str

HTML table as a string.

Source code in zenml/utils/visualization_utils.py
def format_csv_visualization_as_html(
    csv_visualization: str, max_rows: int = 10, max_cols: int = 10
) -> str:
    """Formats a CSV visualization as an HTML table.

    Args:
        csv_visualization: CSV visualization as a string.
        max_rows: Maximum number of rows to display. Remaining rows will be
            replaced by an ellipsis in the middle of the table.
        max_cols: Maximum number of columns to display. Remaining columns will
            be replaced by an ellipsis at the end of each row.

    Returns:
        HTML table as a string.
    """
    rows = csv_visualization.splitlines()
    html = ""

    # If there are fewer rows than the maximum, print all rows
    if len(rows) <= max_rows:
        for row in rows:
            html += _format_csv_row_as_html(row, max_cols=max_cols)

    else:
        # Else, replace middle rows with ellipsis
        half_max_rows = max_rows // 2

        # Print first half of rows
        for row in rows[:half_max_rows]:
            html += _format_csv_row_as_html(row, max_cols=max_cols)

        # Print ellipsis
        if len(rows) > max_rows:
            html += "<tr><td>...</td></tr>"

        # Print last half of rows
        for row in rows[-half_max_rows:]:
            html += _format_csv_row_as_html(row, max_cols=max_cols)

    return "<table>" + html + "</table>"

yaml_utils

Utility functions to help with YAML files and data.

UUIDEncoder (JSONEncoder)

JSON encoder for UUID objects.

Source code in zenml/utils/yaml_utils.py
class UUIDEncoder(json.JSONEncoder):
    """JSON encoder for UUID objects."""

    def default(self, obj: Any) -> Any:
        """Default UUID encoder for JSON.

        Args:
            obj: Object to encode.

        Returns:
            Encoded object.
        """
        if isinstance(obj, UUID):
            # if the obj is uuid, we simply return the value of uuid
            return obj.hex
        return json.JSONEncoder.default(self, obj)
default(self, obj)

Default UUID encoder for JSON.

Parameters:

Name Type Description Default
obj Any

Object to encode.

required

Returns:

Type Description
Any

Encoded object.

Source code in zenml/utils/yaml_utils.py
def default(self, obj: Any) -> Any:
    """Default UUID encoder for JSON.

    Args:
        obj: Object to encode.

    Returns:
        Encoded object.
    """
    if isinstance(obj, UUID):
        # if the obj is uuid, we simply return the value of uuid
        return obj.hex
    return json.JSONEncoder.default(self, obj)

append_yaml(file_path, contents)

Append contents to a YAML file at file_path.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required
contents Dict[Any, Any]

Contents of YAML file as dict.

required

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def append_yaml(file_path: str, contents: Dict[Any, Any]) -> None:
    """Append contents to a YAML file at file_path.

    Args:
        file_path: Path to YAML file.
        contents: Contents of YAML file as dict.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    file_contents = read_yaml(file_path) or {}
    file_contents.update(contents)
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(file_path, yaml.dump(file_contents))

comment_out_yaml(yaml_string)

Comments out a yaml string.

Parameters:

Name Type Description Default
yaml_string str

The yaml string to comment out.

required

Returns:

Type Description
str

The commented out yaml string.

Source code in zenml/utils/yaml_utils.py
def comment_out_yaml(yaml_string: str) -> str:
    """Comments out a yaml string.

    Args:
        yaml_string: The yaml string to comment out.

    Returns:
        The commented out yaml string.
    """
    lines = yaml_string.splitlines(keepends=True)
    lines = ["# " + line for line in lines]
    return "".join(lines)

is_json_serializable(obj)

Checks whether an object is JSON serializable.

Parameters:

Name Type Description Default
obj Any

The object to check.

required

Returns:

Type Description
bool

Whether the object is JSON serializable using pydantics encoder class.

Source code in zenml/utils/yaml_utils.py
def is_json_serializable(obj: Any) -> bool:
    """Checks whether an object is JSON serializable.

    Args:
        obj: The object to check.

    Returns:
        Whether the object is JSON serializable using pydantics encoder class.
    """
    from pydantic.json import pydantic_encoder

    try:
        json.dumps(obj, default=pydantic_encoder)
        return True
    except TypeError:
        return False

is_yaml(file_path)

Returns True if file_path is YAML, else False.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required

Returns:

Type Description
bool

True if is yaml, else False.

Source code in zenml/utils/yaml_utils.py
def is_yaml(file_path: str) -> bool:
    """Returns True if file_path is YAML, else False.

    Args:
        file_path: Path to YAML file.

    Returns:
        True if is yaml, else False.
    """
    if file_path.endswith("yaml") or file_path.endswith("yml"):
        return True
    return False

read_json(file_path)

Read JSON on file path and returns contents as dict.

Parameters:

Name Type Description Default
file_path str

Path to JSON file.

required

Returns:

Type Description
Any

Contents of the file in a dict.

Exceptions:

Type Description
FileNotFoundError

if file does not exist.

Source code in zenml/utils/yaml_utils.py
def read_json(file_path: str) -> Any:
    """Read JSON on file path and returns contents as dict.

    Args:
        file_path: Path to JSON file.

    Returns:
        Contents of the file in a dict.

    Raises:
        FileNotFoundError: if file does not exist.
    """
    if fileio.exists(file_path):
        contents = io_utils.read_file_contents_as_string(file_path)
        return json.loads(contents)
    else:
        raise FileNotFoundError(f"{file_path} does not exist.")

read_yaml(file_path)

Read YAML on file path and returns contents as dict.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required

Returns:

Type Description
Any

Contents of the file in a dict.

Exceptions:

Type Description
FileNotFoundError

if file does not exist.

Source code in zenml/utils/yaml_utils.py
def read_yaml(file_path: str) -> Any:
    """Read YAML on file path and returns contents as dict.

    Args:
        file_path: Path to YAML file.

    Returns:
        Contents of the file in a dict.

    Raises:
        FileNotFoundError: if file does not exist.
    """
    if fileio.exists(file_path):
        contents = io_utils.read_file_contents_as_string(file_path)
        # TODO: [LOW] consider adding a default empty dict to be returned
        #   instead of None
        return yaml.safe_load(contents)
    else:
        raise FileNotFoundError(f"{file_path} does not exist.")

write_json(file_path, contents, encoder=None)

Write contents as JSON format to file_path.

Parameters:

Name Type Description Default
file_path str

Path to JSON file.

required
contents Any

Contents of JSON file.

required
encoder Optional[Type[json.encoder.JSONEncoder]]

Custom JSON encoder to use when saving json.

None

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def write_json(
    file_path: str,
    contents: Any,
    encoder: Optional[Type[json.JSONEncoder]] = None,
) -> None:
    """Write contents as JSON format to file_path.

    Args:
        file_path: Path to JSON file.
        contents: Contents of JSON file.
        encoder: Custom JSON encoder to use when saving json.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            # Check if it is a local path, if it doesn't exist, raise Exception.
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(
        file_path,
        json.dumps(
            contents,
            cls=encoder,
        ),
    )

write_yaml(file_path, contents, sort_keys=True)

Write contents as YAML format to file_path.

Parameters:

Name Type Description Default
file_path str

Path to YAML file.

required
contents Union[Dict[Any, Any], List[Any]]

Contents of YAML file as dict or list.

required
sort_keys bool

If True, keys are sorted alphabetically. If False, the order in which the keys were inserted into the dict will be preserved.

True

Exceptions:

Type Description
FileNotFoundError

if directory does not exist.

Source code in zenml/utils/yaml_utils.py
def write_yaml(
    file_path: str,
    contents: Union[Dict[Any, Any], List[Any]],
    sort_keys: bool = True,
) -> None:
    """Write contents as YAML format to file_path.

    Args:
        file_path: Path to YAML file.
        contents: Contents of YAML file as dict or list.
        sort_keys: If `True`, keys are sorted alphabetically. If `False`,
            the order in which the keys were inserted into the dict will
            be preserved.

    Raises:
        FileNotFoundError: if directory does not exist.
    """
    if not io_utils.is_remote(file_path):
        dir_ = str(Path(file_path).parent)
        if not fileio.isdir(dir_):
            raise FileNotFoundError(f"Directory {dir_} does not exist.")
    io_utils.write_file_contents_as_string(
        file_path, yaml.dump(contents, sort_keys=sort_keys)
    )