Skip to content

Orchestrators

zenml.orchestrators special

Orchestrator

An orchestrator is a special kind of backend that manages the running of each step of the pipeline. Orchestrators administer the actual pipeline runs. You can think of it as the 'root' of any pipeline job that you run during your experimentation.

ZenML supports a local orchestrator out of the box which allows you to run your pipelines in a local environment. We also support using Apache Airflow as the orchestrator to handle the steps of your pipeline.

base_orchestrator

BaseOrchestrator (StackComponent, ABC) pydantic-model

Base class for all ZenML orchestrators.

Source code in zenml/orchestrators/base_orchestrator.py
class BaseOrchestrator(StackComponent, ABC):
    """Base class for all ZenML orchestrators."""

    # Class Configuration
    TYPE: ClassVar[StackComponentType] = StackComponentType.ORCHESTRATOR

    @abstractmethod
    def run_pipeline(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline.

        Args:
            pipeline: The pipeline to run.
            stack: The stack on which the pipeline is run.
            runtime_configuration: Runtime configuration of the pipeline run.
        """
run_pipeline(self, pipeline, stack, runtime_configuration)

Runs a pipeline.

Parameters:

Name Type Description Default
pipeline BasePipeline

The pipeline to run.

required
stack Stack

The stack on which the pipeline is run.

required
runtime_configuration RuntimeConfiguration

Runtime configuration of the pipeline run.

required
Source code in zenml/orchestrators/base_orchestrator.py
@abstractmethod
def run_pipeline(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline.

    Args:
        pipeline: The pipeline to run.
        stack: The stack on which the pipeline is run.
        runtime_configuration: Runtime configuration of the pipeline run.
    """

context_utils

add_context_to_node(pipeline_node, type_, name, properties)

Add a new context to a TFX protobuf pipeline node.

Parameters:

Name Type Description Default
pipeline_node pipeline_pb2.PipelineNode

A tfx protobuf pipeline node

required
type_ str

The type name for the context to be added

required
name str

Unique key for the context

required
properties Dict[str, str]

dictionary of strings as properties of the context

required
Source code in zenml/orchestrators/context_utils.py
def add_context_to_node(
    pipeline_node: "pipeline_pb2.PipelineNode",
    type_: str,
    name: str,
    properties: Dict[str, str],
) -> None:
    """
    Add a new context to a TFX protobuf pipeline node.

    Args:
        pipeline_node: A tfx protobuf pipeline node
        type_: The type name for the context to be added
        name: Unique key for the context
        properties: dictionary of strings as properties of the context
    """
    # Add a new context to the pipeline
    context: "pipeline_pb2.ContextSpec" = pipeline_node.contexts.contexts.add()
    # Adding the type of context
    context.type.name = type_
    # Setting the name of the context
    context.name.field_value.string_value = name
    # Setting the properties of the context depending on attribute type
    for key, value in properties.items():
        c_property = context.properties[key]
        c_property.field_value.string_value = value

add_runtime_configuration_to_node(pipeline_node, runtime_config)

Add the runtime configuration of a pipeline run to a protobuf pipeline node.

Parameters:

Name Type Description Default
pipeline_node pipeline_pb2.PipelineNode

a tfx protobuf pipeline node

required
runtime_config RuntimeConfiguration

a ZenML RuntimeConfiguration

required
Source code in zenml/orchestrators/context_utils.py
def add_runtime_configuration_to_node(
    pipeline_node: "pipeline_pb2.PipelineNode",
    runtime_config: RuntimeConfiguration,
) -> None:
    """
    Add the runtime configuration of a pipeline run to a protobuf pipeline node.

    Args:
        pipeline_node: a tfx protobuf pipeline node
        runtime_config: a ZenML RuntimeConfiguration
    """
    skip_errors: bool = runtime_config.get(
        "ignore_unserializable_fields", False
    )

    # Determine the name of the context
    def _name(obj: "BaseModel") -> str:
        """Compute a unique context name for a pydantic BaseModel."""
        try:
            return str(hash(obj.json(sort_keys=True)))
        except TypeError as e:
            class_name = obj.__class__.__name__
            logging.info(
                "Cannot convert %s to json, generating uuid instead. Error: %s",
                class_name,
                e,
            )
            return f"{class_name}_{uuid.uuid1()}"

    # iterate over all attributes of runtime context, serializing all pydantic
    # objects to node context.
    for key, obj in runtime_config.items():
        if isinstance(obj, BaseModel):
            logger.debug("Adding %s to context", key)
            add_context_to_node(
                pipeline_node,
                type_=obj.__repr_name__().lower(),
                name=_name(obj),
                properties=serialize_pydantic_object(
                    obj, skip_errors=skip_errors
                ),
            )

serialize_pydantic_object(obj, *, skip_errors=False)

Convert a pydantic object to a dict of strings

Source code in zenml/orchestrators/context_utils.py
def serialize_pydantic_object(
    obj: BaseModel, *, skip_errors: bool = False
) -> Dict[str, str]:
    """Convert a pydantic object to a dict of strings"""

    class PydanticEncoder(json.JSONEncoder):
        def default(self, o: Any) -> Any:
            try:
                return cast(Callable[[Any], str], obj.__json_encoder__)(o)
            except TypeError:
                return super().default(o)

    def _inner_generator(
        dictionary: Dict[str, Any]
    ) -> Iterator[Tuple[str, str]]:
        """Itemwise serialize each element in a dictionary."""
        for key, item in dictionary.items():
            try:
                yield key, json.dumps(item, cls=PydanticEncoder)
            except TypeError as e:
                if skip_errors:
                    logging.info(
                        "Skipping adding field '%s' to metadata context as "
                        "it cannot be serialized due to %s.",
                        key,
                        e,
                    )
                else:
                    raise TypeError(
                        f"Invalid type {type(item)} for key {key} can not be "
                        "serialized."
                    ) from e

    return {key: value for key, value in _inner_generator(obj.dict())}

local special

local_orchestrator

LocalOrchestrator (BaseOrchestrator) pydantic-model

Orchestrator responsible for running pipelines locally.

Source code in zenml/orchestrators/local/local_orchestrator.py
class LocalOrchestrator(BaseOrchestrator):
    """Orchestrator responsible for running pipelines locally."""

    # Class Configuration
    FLAVOR: ClassVar[str] = "local"

    def run_pipeline(
        self,
        pipeline: "BasePipeline",
        stack: "Stack",
        runtime_configuration: "RuntimeConfiguration",
    ) -> Any:
        """Runs a pipeline locally"""

        tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)

        if runtime_configuration is None:
            runtime_configuration = RuntimeConfiguration()

        if runtime_configuration.schedule:
            logger.warning(
                "Local Orchestrator currently does not support the"
                "use of schedules. The `schedule` will be ignored "
                "and the pipeline will be run directly"
            )

        pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
        if not isinstance(pipeline_root, str):
            raise TypeError(
                "TFX Pipeline root may not be a Placeholder, "
                "but must be a specific string."
            )

        for component in tfx_pipeline.components:
            if isinstance(component, base_component.BaseComponent):
                component._resolve_pip_dependencies(pipeline_root)

        pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)

        # Substitute the runtime parameter to be a concrete run_id
        runtime_parameter_utils.substitute_runtime_parameter(
            pb2_pipeline,
            {
                PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
            },
        )

        deployment_config = runner_utils.extract_local_deployment_config(
            pb2_pipeline
        )
        connection_config = (
            Repository().active_stack.metadata_store.get_tfx_metadata_config()
        )

        logger.debug(f"Using deployment config:\n {deployment_config}")
        logger.debug(f"Using connection config:\n {connection_config}")

        # Run each component. Note that the pipeline.components list is in
        # topological order.
        for node in pb2_pipeline.nodes:
            pipeline_node: PipelineNode = node.pipeline_node

            # fill out that context
            context_utils.add_context_to_node(
                pipeline_node,
                type_=MetadataContextTypes.STACK.value,
                name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
                properties=stack.dict(),
            )

            # Add all pydantic objects from runtime_configuration to the context
            context_utils.add_runtime_configuration_to_node(
                pipeline_node, runtime_configuration
            )

            # Add pipeline requirements as a context
            requirements = " ".join(sorted(pipeline.requirements))
            context_utils.add_context_to_node(
                pipeline_node,
                type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
                name=str(hash(requirements)),
                properties={"pipeline_requirements": requirements},
            )

            node_id = pipeline_node.node_info.id
            executor_spec = runner_utils.extract_executor_spec(
                deployment_config, node_id
            )
            custom_driver_spec = runner_utils.extract_custom_driver_spec(
                deployment_config, node_id
            )

            p_info = pb2_pipeline.pipeline_info
            r_spec = pb2_pipeline.runtime_spec

            # set custom executor operator to allow custom execution logic for
            # each step
            step = get_step_for_node(
                pipeline_node, steps=list(pipeline.steps.values())
            )
            custom_executor_operators = {
                executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
            }

            component_launcher = launcher.Launcher(
                pipeline_node=pipeline_node,
                mlmd_connection=metadata.Metadata(connection_config),
                pipeline_info=p_info,
                pipeline_runtime_spec=r_spec,
                executor_spec=executor_spec,
                custom_driver_spec=custom_driver_spec,
                custom_executor_operators=custom_executor_operators,
            )
            stack.prepare_step_run()
            execute_step(component_launcher)
            stack.cleanup_step_run()
run_pipeline(self, pipeline, stack, runtime_configuration)

Runs a pipeline locally

Source code in zenml/orchestrators/local/local_orchestrator.py
def run_pipeline(
    self,
    pipeline: "BasePipeline",
    stack: "Stack",
    runtime_configuration: "RuntimeConfiguration",
) -> Any:
    """Runs a pipeline locally"""

    tfx_pipeline: TfxPipeline = create_tfx_pipeline(pipeline, stack=stack)

    if runtime_configuration is None:
        runtime_configuration = RuntimeConfiguration()

    if runtime_configuration.schedule:
        logger.warning(
            "Local Orchestrator currently does not support the"
            "use of schedules. The `schedule` will be ignored "
            "and the pipeline will be run directly"
        )

    pipeline_root = tfx_pipeline.pipeline_info.pipeline_root
    if not isinstance(pipeline_root, str):
        raise TypeError(
            "TFX Pipeline root may not be a Placeholder, "
            "but must be a specific string."
        )

    for component in tfx_pipeline.components:
        if isinstance(component, base_component.BaseComponent):
            component._resolve_pip_dependencies(pipeline_root)

    pb2_pipeline: Pb2Pipeline = Compiler().compile(tfx_pipeline)

    # Substitute the runtime parameter to be a concrete run_id
    runtime_parameter_utils.substitute_runtime_parameter(
        pb2_pipeline,
        {
            PIPELINE_RUN_ID_PARAMETER_NAME: runtime_configuration.run_name,
        },
    )

    deployment_config = runner_utils.extract_local_deployment_config(
        pb2_pipeline
    )
    connection_config = (
        Repository().active_stack.metadata_store.get_tfx_metadata_config()
    )

    logger.debug(f"Using deployment config:\n {deployment_config}")
    logger.debug(f"Using connection config:\n {connection_config}")

    # Run each component. Note that the pipeline.components list is in
    # topological order.
    for node in pb2_pipeline.nodes:
        pipeline_node: PipelineNode = node.pipeline_node

        # fill out that context
        context_utils.add_context_to_node(
            pipeline_node,
            type_=MetadataContextTypes.STACK.value,
            name=str(hash(json.dumps(stack.dict(), sort_keys=True))),
            properties=stack.dict(),
        )

        # Add all pydantic objects from runtime_configuration to the context
        context_utils.add_runtime_configuration_to_node(
            pipeline_node, runtime_configuration
        )

        # Add pipeline requirements as a context
        requirements = " ".join(sorted(pipeline.requirements))
        context_utils.add_context_to_node(
            pipeline_node,
            type_=MetadataContextTypes.PIPELINE_REQUIREMENTS.value,
            name=str(hash(requirements)),
            properties={"pipeline_requirements": requirements},
        )

        node_id = pipeline_node.node_info.id
        executor_spec = runner_utils.extract_executor_spec(
            deployment_config, node_id
        )
        custom_driver_spec = runner_utils.extract_custom_driver_spec(
            deployment_config, node_id
        )

        p_info = pb2_pipeline.pipeline_info
        r_spec = pb2_pipeline.runtime_spec

        # set custom executor operator to allow custom execution logic for
        # each step
        step = get_step_for_node(
            pipeline_node, steps=list(pipeline.steps.values())
        )
        custom_executor_operators = {
            executable_spec_pb2.PythonClassExecutableSpec: step.executor_operator
        }

        component_launcher = launcher.Launcher(
            pipeline_node=pipeline_node,
            mlmd_connection=metadata.Metadata(connection_config),
            pipeline_info=p_info,
            pipeline_runtime_spec=r_spec,
            executor_spec=executor_spec,
            custom_driver_spec=custom_driver_spec,
            custom_executor_operators=custom_executor_operators,
        )
        stack.prepare_step_run()
        execute_step(component_launcher)
        stack.cleanup_step_run()

utils

create_tfx_pipeline(zenml_pipeline, stack)

Creates a tfx pipeline from a ZenML pipeline.

Source code in zenml/orchestrators/utils.py
def create_tfx_pipeline(
    zenml_pipeline: "BasePipeline", stack: "Stack"
) -> tfx_pipeline.Pipeline:
    """Creates a tfx pipeline from a ZenML pipeline."""
    # Connect the inputs/outputs of all steps in the pipeline
    zenml_pipeline.connect(**zenml_pipeline.steps)

    tfx_components = [step.component for step in zenml_pipeline.steps.values()]

    artifact_store = stack.artifact_store
    metadata_store = stack.metadata_store

    return tfx_pipeline.Pipeline(
        pipeline_name=zenml_pipeline.name,
        components=tfx_components,  # type: ignore[arg-type]
        pipeline_root=artifact_store.path,
        metadata_connection_config=metadata_store.get_tfx_metadata_config(),
        enable_cache=zenml_pipeline.enable_cache,
    )

execute_step(tfx_launcher)

Executes a tfx component.

Parameters:

Name Type Description Default
tfx_launcher Launcher

A tfx launcher to execute the component.

required

Returns:

Type Description
Optional[tfx.orchestration.portable.data_types.ExecutionInfo]

Optional execution info returned by the launcher.

Source code in zenml/orchestrators/utils.py
def execute_step(
    tfx_launcher: launcher.Launcher,
) -> Optional[data_types.ExecutionInfo]:
    """Executes a tfx component.

    Args:
        tfx_launcher: A tfx launcher to execute the component.

    Returns:
        Optional execution info returned by the launcher.
    """
    step_name_param = (
        INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
    )
    pipeline_step_name = tfx_launcher._pipeline_node.node_info.id
    start_time = time.time()
    logger.info(f"Step `{pipeline_step_name}` has started.")
    try:
        execution_info = tfx_launcher.launch()
        if execution_info and get_cache_status(execution_info):
            if execution_info.exec_properties:
                step_name = json.loads(
                    execution_info.exec_properties[step_name_param]
                )
                logger.info(
                    f"Using cached version of `{pipeline_step_name}` "
                    f"[`{step_name}`].",
                )
            else:
                logger.error(
                    f"No execution properties found for step "
                    f"`{pipeline_step_name}`."
                )
    except RuntimeError as e:
        if "execution has already succeeded" in str(e):
            # Hacky workaround to catch the error that a pipeline run with
            # this name already exists. Raise an error with a more descriptive
            # message instead.
            raise DuplicateRunNameError()
        else:
            raise

    run_duration = time.time() - start_time
    logger.info(
        f"Step `{pipeline_step_name}` has finished in "
        f"{string_utils.get_human_readable_time(run_duration)}."
    )
    return execution_info

get_cache_status(execution_info)

Returns the caching status of a step.

Parameters:

Name Type Description Default
execution_info ExecutionInfo

The execution info of a tfx step.

required

Exceptions:

Type Description
AttributeError

If the execution info is None.

KeyError

If no pipeline info is found in the execution_info.

Returns:

Type Description
bool

The caching status of a tfx step as a boolean value.

Source code in zenml/orchestrators/utils.py
def get_cache_status(
    execution_info: data_types.ExecutionInfo,
) -> bool:
    """Returns the caching status of a step.

    Args:
        execution_info: The execution info of a `tfx` step.

    Raises:
        AttributeError: If the execution info is `None`.
        KeyError: If no pipeline info is found in the `execution_info`.

    Returns:
        The caching status of a `tfx` step as a boolean value.
    """
    if execution_info is None:
        logger.warning("No execution info found when checking cache status.")
        return False

    status = False
    repository = Repository()
    # TODO [ENG-706]: Get the current running stack instead of just the active
    #   stack
    active_stack = repository.active_stack
    if not active_stack:
        raise RuntimeError(
            "No active stack is configured for the repository. Run "
            "`zenml stack set STACK_NAME` to update the active stack."
        )

    metadata_store = active_stack.metadata_store

    step_name_param = (
        INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
    )
    step_name = json.loads(execution_info.exec_properties[step_name_param])
    if execution_info.pipeline_info:
        pipeline_name = execution_info.pipeline_info.id
    else:
        raise KeyError(f"No pipeline info found for step `{step_name}`.")
    pipeline_run_name = cast(str, execution_info.pipeline_run_id)
    pipeline = metadata_store.get_pipeline(pipeline_name)
    if pipeline is None:
        logger.error(f"Pipeline {pipeline_name} not found in Metadata Store.")
    else:
        status = (
            pipeline.get_run(pipeline_run_name).get_step(step_name).is_cached
        )
    return status

get_step_for_node(node, steps)

Finds the matching step for a tfx pipeline node.

Source code in zenml/orchestrators/utils.py
def get_step_for_node(node: PipelineNode, steps: List[BaseStep]) -> BaseStep:
    """Finds the matching step for a tfx pipeline node."""
    step_name = node.node_info.id
    try:
        return next(step for step in steps if step.name == step_name)
    except StopIteration:
        raise RuntimeError(f"Unable to find step with name '{step_name}'.")