Skip to content

Orchestrators

zenml.orchestrators special

An orchestrator is a special kind of backend that manages the running of each step of the pipeline. Orchestrators administer the actual pipeline runs. You can think of it as the 'root' of any pipeline job that you run during your experimentation.

ZenML supports a local orchestrator out of the box which allows you to run your pipelines in a local environment. We also support using Apache Airflow as the orchestrator to handle the steps of your pipeline.

base_orchestrator

BaseOrchestrator (StackComponent, ABC) pydantic-model

Base class for all ZenML orchestrators.

Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestrator(StackComponent, ABC):
    """Base class for all ZenML orchestrators."""

    @property
    def type(self) -> StackComponentType:
        """The component type."""
        return StackComponentType.ORCHESTRATOR

    @property
    @abstractmethod
    def flavor(self) -> OrchestratorFlavor:
        """The orchestrator flavor."""

    @abstractmethod
    def run_pipeline(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline.

        Args:
            pipeline: The pipeline to run.
            stack: The stack on which the pipeline is run.
            runtime_configuration: Runtime configuration of the pipeline run.
        """
flavor: OrchestratorFlavor property readonly

The orchestrator flavor.

type: StackComponentType property readonly

The component type.

run_pipeline(self, pipeline, stack, runtime_configuration)

Runs a pipeline.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to run.

required
stack Stack

The stack on which the pipeline is run.

required
runtime_configuration RuntimeConfiguration

Runtime configuration of the pipeline run.

required
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def run_pipeline(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline.

    Args:
        pipeline: The pipeline to run.
        stack: The stack on which the pipeline is run.
        runtime_configuration: Runtime configuration of the pipeline run.
    """

context_utils

add_pydantic_object_as_metadata_context(obj, context)

Parameters:

Name Type Description Default
obj BaseModel

an instance of a pydantic object

required
context pipeline_pb2.ContextSpec

a context proto message within a pipeline node

required
Source code in zenml/orchestrators/context_utils.py
def add_pydantic_object_as_metadata_context(
    obj: "BaseModel",
    context: "pipeline_pb2.ContextSpec",  # type: ignore[valid-type]
) -> None:
    """

    Args:
        obj: an instance of a pydantic object
        context: a context proto message within a pipeline node
    """
    context.type.name = (  # type: ignore[attr-defined]
        obj.__repr_name__().lower()
    )
    # Setting the name of the context
    name = str(hash(obj.json(sort_keys=True)))
    context.name.field_value.string_value = name  # type:ignore[attr-defined]

    # Setting the properties of the context
    for k, v in obj.dict().items():
        c_property = context.properties[k]  # type:ignore[attr-defined]
        if isinstance(v, int):
            c_property.field_value.int_value = v
        elif isinstance(v, float):
            c_property.field_value.double_value = v
        elif isinstance(v, str):
            c_property.field_value.string_value = v
        else:
            c_property.field_value.string_value = str(v)

add_stack_as_metadata_context(stack, context)

Given an instance of a stack object, the function adds it to the context of a pipeline node in proper format

Parameters:

Name Type Description Default
stack Stack

an instance of a Zenml Stack object

required
context pipeline_pb2.ContextSpec

a context proto message within a pipeline node

required
Source code in zenml/orchestrators/context_utils.py
def add_stack_as_metadata_context(
    stack: "Stack",
    context: "pipeline_pb2.ContextSpec",  # type: ignore[valid-type]
) -> None:
    """Given an instance of a stack object, the function adds it to the context
    of a pipeline node in proper format

    Args:
        stack: an instance of a Zenml Stack object
        context: a context proto message within a pipeline node
    """
    # Adding the type of context
    context.type.name = (  # type:ignore[attr-defined]
        MetadataContextTypes.STACK.value
    )

    # Converting the stack into a dict to prepare for hashing
    stack_dict = stack.dict()

    # Setting the name of the context
    name = str(hash(json.dumps(stack_dict, sort_keys=True)))
    context.name.field_value.string_value = name  # type:ignore[attr-defined]

    # Setting the properties of the context
    for k, v in stack_dict.items():
        c_property = context.properties[k]  # type:ignore[attr-defined]
        c_property.field_value.string_value = v

local special

local_orchestrator

LocalOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines locally.

Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines locally."""

    supports_local_execution = True
    supports_remote_execution = False

    @property
    def flavor(self) -> OrchestratorFlavor:
        """The orchestrator flavor."""
        return OrchestratorFlavor.LOCAL

    def run_pipeline(
        self,
        pipeline_proto: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline locally"""

        tfx_pipeline = create_tfx_pipeline(pipeline_proto, stack=stack)

        if runtime_configuration is None:
            runtime_configuration = RuntimeConfiguration()

        if runtime_configuration.schedule:
            logger.warning(
                "Local Orchestrator currently does not support the"
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run directly"
            )

        for component in tfx_pipeline.components:
            if isinstance(component, base_component.BaseComponent):
                component._resolve_pip_dependencies(
                    tfx_pipeline.pipeline_info.pipeline_root
                )

        c = compiler.Compiler()
        pipeline_proto = c.compile(tfx_pipeline)

        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pipeline_proto,
            {
                PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
            },
        )

        deployment_config = runner_utils.extract_local_deployment_config(
            pipeline_proto
        )
        connection_config = (
            Repository().active_stack.metadata_store.get_tfx_metadata_config()
        )

        logger.debug(f"Using deployment config:\n {deployment_config}")
        logger.debug(f"Using connection config:\n {connection_config}")

        # Run each component. Note that the pipeline.components list is in
        # topological order.
        for node in pipeline_proto.nodes:
            context = node.pipeline_node.contexts.contexts.add()
            context_utils.add_stack_as_metadata_context(
                context=context, stack=stack
            )

            # Add all pydantic objects from runtime_configuration to the
            # context
            for k, v in runtime_configuration.items():
                if v and issubclass(type(v), BaseModel):
                    context = node.pipeline_node.contexts.contexts.add()
                    logger.debug("Adding %s to context", k)
                    context_utils.add_pydantic_object_as_metadata_context(
                        context=context, obj=v
                    )

            pipeline_node = node.pipeline_node
            node_id = pipeline_node.node_info.id
            executor_spec = runner_utils.extract_executor_spec(
                deployment_config, node_id
            )
            custom_driver_spec = runner_utils.extract_custom_driver_spec(
                deployment_config, node_id
            )

            p_info = pipeline_proto.pipeline_info
            r_spec = pipeline_proto.runtime_spec

            component_launcher = launcher.Launcher(
                pipeline_node=pipeline_node,
                mlmd_connection=metadata.Metadata(connection_config),
                pipeline_info=p_info,
                pipeline_runtime_spec=r_spec,
                executor_spec=executor_spec,
                custom_driver_spec=custom_driver_spec,
            )
            execute_step(component_launcher)
flavor: OrchestratorFlavor property readonly

The orchestrator flavor.

run_pipeline(self, pipeline_proto, stack, runtime_configuration)

Runs a pipeline locally

Source code in zenml/orchestrators/local/local_orchestrator.py
def run_pipeline(
    self,
    pipeline_proto: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline locally"""

    tfx_pipeline = create_tfx_pipeline(pipeline_proto, stack=stack)

    if runtime_configuration is None:
        runtime_configuration = RuntimeConfiguration()

    if runtime_configuration.schedule:
        logger.warning(
            "Local Orchestrator currently does not support the"
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run directly"
        )

    for component in tfx_pipeline.components:
        if isinstance(component, base_component.BaseComponent):
            component._resolve_pip_dependencies(
                tfx_pipeline.pipeline_info.pipeline_root
            )

    c = compiler.Compiler()
    pipeline_proto = c.compile(tfx_pipeline)

    # Substitute the runtime parameter to be a concrete run_id
    runtime_parameter_utils.substitute_runtime_parameter(
        pipeline_proto,
        {
            PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
        },
    )

    deployment_config = runner_utils.extract_local_deployment_config(
        pipeline_proto
    )
    connection_config = (
        Repository().active_stack.metadata_store.get_tfx_metadata_config()
    )

    logger.debug(f"Using deployment config:\n {deployment_config}")
    logger.debug(f"Using connection config:\n {connection_config}")

    # Run each component. Note that the pipeline.components list is in
    # topological order.
    for node in pipeline_proto.nodes:
        context = node.pipeline_node.contexts.contexts.add()
        context_utils.add_stack_as_metadata_context(
            context=context, stack=stack
        )

        # Add all pydantic objects from runtime_configuration to the
        # context
        for k, v in runtime_configuration.items():
            if v and issubclass(type(v), BaseModel):
                context = node.pipeline_node.contexts.contexts.add()
                logger.debug("Adding %s to context", k)
                context_utils.add_pydantic_object_as_metadata_context(
                    context=context, obj=v
                )

        pipeline_node = node.pipeline_node
        node_id = pipeline_node.node_info.id
        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id
        )
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id
        )

        p_info = pipeline_proto.pipeline_info
        r_spec = pipeline_proto.runtime_spec

        component_launcher = launcher.Launcher(
            pipeline_node=pipeline_node,
            mlmd_connection=metadata.Metadata(connection_config),
            pipeline_info=p_info,
            pipeline_runtime_spec=r_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
        )
        execute_step(component_launcher)

utils

create_tfx_pipeline(zenml_pipeline, stack)

Creates a tfx pipeline from a ZenML pipeline.

Source code in zenml/orchestrators/utils.py
def create_tfx_pipeline(
    zenml_pipeline: "BasePipeline", stack: "Stack"
) -> tfx_pipeline.Pipeline:
    """Creates a tfx pipeline from a ZenML pipeline."""
    # Connect the inputs/outputs of all steps in the pipeline
    zenml_pipeline.connect(**zenml_pipeline.steps)

    tfx_components = [step.component for step in zenml_pipeline.steps.values()]

    artifact_store = stack.artifact_store
    metadata_store = stack.metadata_store

    return tfx_pipeline.Pipeline(
        pipeline_name=zenml_pipeline.name,
        components=tfx_components,  # type: ignore[arg-type]
        pipeline_root=artifact_store.path,
        metadata_connection_config=metadata_store.get_tfx_metadata_config(),
        enable_cache=zenml_pipeline.enable_cache,
    )

execute_step(tfx_launcher)

Executes a tfx component.

Parameters:

Name Type Description Default
tfx_launcher Launcher

A tfx launcher to execute the component.

required

Returns:

Type Description
Optional[tfx.orchestration.portable.data_types.ExecutionInfo]

Optional execution info returned by the launcher.

Source code in zenml/orchestrators/utils.py
def execute_step(
    tfx_launcher: launcher.Launcher,
) -> Optional[data_types.ExecutionInfo]:
    """Executes a tfx component.

    Args:
        tfx_launcher: A tfx launcher to execute the component.

    Returns:
        Optional execution info returned by the launcher.
    """
    step_name = tfx_launcher._pipeline_node.node_info.id  # type: ignore[attr-defined] # noqa
    start_time = time.time()
    logger.info(f"Step `{step_name}` has started.")
    try:
        execution_info = tfx_launcher.launch()
    except RuntimeError as e:
        if "execution has already succeeded" in str(e):
            # Hacky workaround to catch the error that a pipeline run with
            # this name already exists. Raise an error with a more descriptive
            # message instead.
            raise DuplicateRunNameError()
        else:
            raise

    run_duration = time.time() - start_time
    logger.info(
        "Step `%s` has finished in %s.",
        step_name,
        string_utils.get_human_readable_time(run_duration),
    )
    return execution_info