Skip to content

Step Operators

zenml.step_operators special

Step Operator

While an orchestrator defines how and where your entire pipeline runs, a step operator defines how and where an individual step runs. This can be useful in a variety of scenarios. An example could be if one step within a pipeline should run on a separate environment equipped with a GPU (like a trainer step).

base_step_operator

BaseStepOperator (StackComponent, ABC) pydantic-model

Base class for all ZenML step operators.

Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperator(StackComponent, ABC):
    """Base class for all ZenML step operators."""

    # Class Configuration
    TYPE: ClassVar[StackComponentType] = StackComponentType.STEP_OPERATOR

    @abstractmethod
    def launch(
        self,
        pipeline_name: str,
        run_name: str,
        requirements: List[str],
        entrypoint_command: List[str],
    ) -> None:
        """Abstract method to execute a step.

        Concrete step operator subclasses must implement the following
        functionality in this method:
        - Prepare the execution environment and install all the necessary
          `requirements`
        - Launch a **synchronous** job that executes the `entrypoint_command`

        Args:
            pipeline_name: Name of the pipeline which the step to be executed
                is part of.
            run_name: Name of the pipeline run which the step to be executed
                is part of.
            entrypoint_command: Command that executes the step.
            requirements: List of pip requirements that must be installed
                inside the step operator environment.
        """
launch(self, pipeline_name, run_name, requirements, entrypoint_command)

Abstract method to execute a step.

Concrete step operator subclasses must implement the following functionality in this method: - Prepare the execution environment and install all the necessary requirements - Launch a synchronous job that executes the entrypoint_command

Parameters:

Name Type Description Default
pipeline_name str

Name of the pipeline which the step to be executed is part of.

required
run_name str

Name of the pipeline run which the step to be executed is part of.

required
entrypoint_command List[str]

Command that executes the step.

required
requirements List[str]

List of pip requirements that must be installed inside the step operator environment.

required
Source code in zenml/step_operators/base_step_operator.py
@abstractmethod
def launch(
    self,
    pipeline_name: str,
    run_name: str,
    requirements: List[str],
    entrypoint_command: List[str],
) -> None:
    """Abstract method to execute a step.

    Concrete step operator subclasses must implement the following
    functionality in this method:
    - Prepare the execution environment and install all the necessary
      `requirements`
    - Launch a **synchronous** job that executes the `entrypoint_command`

    Args:
        pipeline_name: Name of the pipeline which the step to be executed
            is part of.
        run_name: Name of the pipeline run which the step to be executed
            is part of.
        entrypoint_command: Command that executes the step.
        requirements: List of pip requirements that must be installed
            inside the step operator environment.
    """

entrypoint

configure_executor(executor_class, execution_info)

Creates and configures an executor instance.

Parameters:

Name Type Description Default
executor_class Type[tfx.dsl.components.base.base_executor.BaseExecutor]

The class of the executor instance.

required
execution_info ExecutionInfo

Execution info for the executor.

required

Returns:

Type Description
BaseExecutor

A configured executor instance.

Source code in zenml/step_operators/entrypoint.py
def configure_executor(
    executor_class: Type[BaseExecutor], execution_info: ExecutionInfo
) -> BaseExecutor:
    """Creates and configures an executor instance.

    Args:
        executor_class: The class of the executor instance.
        execution_info: Execution info for the executor.

    Returns:
        A configured executor instance.
    """
    context = BaseExecutor.Context(
        tmp_dir=execution_info.tmp_dir,
        unique_id=str(execution_info.execution_id),
        executor_output_uri=execution_info.execution_output_uri,
        stateful_working_dir=execution_info.stateful_working_dir,
        pipeline_node=execution_info.pipeline_node,
        pipeline_info=execution_info.pipeline_info,
        pipeline_run_id=execution_info.pipeline_run_id,
    )

    return executor_class(context=context)

create_executor_class(step_source_path, input_artifact_type_mapping)

Creates an executor class for a given step.

Parameters:

Name Type Description Default
step_source_path str

Import path of the step to run.

required
input_artifact_type_mapping Dict[str, str]

A dictionary mapping input names to a string representation of their artifact classes.

required
Source code in zenml/step_operators/entrypoint.py
def create_executor_class(
    step_source_path: str,
    input_artifact_type_mapping: Dict[str, str],
) -> Type[_FunctionExecutor]:
    """Creates an executor class for a given step.

    Args:
        step_source_path: Import path of the step to run.
        input_artifact_type_mapping: A dictionary mapping input names to
            a string representation of their artifact classes.
    """
    step_class = cast(
        Type[BaseStep], source_utils.load_source_path_class(step_source_path)
    )
    step_instance = step_class()

    materializers = step_instance.get_materializers(ensure_complete=True)

    # We don't publish anything to the metadata store inside this environment,
    # so the specific artifact classes don't matter
    input_spec = {}
    for key, value in step_class.INPUT_SIGNATURE.items():
        input_spec[key] = BaseArtifact

    output_spec = {}
    for key, value in step_class.OUTPUT_SIGNATURE.items():
        output_spec[key] = type_registry.get_artifact_type(value)[0]

    execution_parameters = {
        **step_instance.PARAM_SPEC,
        **step_instance._internal_execution_parameters,
    }

    component_class = generate_component_class(
        step_name=step_instance.name,
        step_module=step_class.__module__,
        input_spec=input_spec,
        output_spec=output_spec,
        execution_parameter_names=set(execution_parameters),
        step_function=step_instance.entrypoint,
        materializers=materializers,
    )

    return cast(
        Type[_FunctionExecutor], component_class.EXECUTOR_SPEC.executor_class
    )

load_execution_info(execution_info_path)

Loads the execution info from the given path.

Source code in zenml/step_operators/entrypoint.py
def load_execution_info(execution_info_path: str) -> ExecutionInfo:
    """Loads the execution info from the given path."""
    with fileio.open(execution_info_path, "rb") as f:
        execution_info_proto = ExecutionInvocation.FromString(f.read())

    return ExecutionInfo.from_proto(execution_info_proto)

step_executor_operator

StepExecutorOperator (BaseExecutorOperator)

StepExecutorOperator extends TFX's BaseExecutorOperator.

This class can be passed as a custom executor operator during a pipeline run which will then be used to call the step's configured step operator to launch it in some environment.

Source code in zenml/step_operators/step_executor_operator.py
class StepExecutorOperator(BaseExecutorOperator):
    """StepExecutorOperator extends TFX's BaseExecutorOperator.

    This class can be passed as a custom executor operator during
    a pipeline run which will then be used to call the step's
    configured step operator to launch it in some environment.
    """

    SUPPORTED_EXECUTOR_SPEC_TYPE = [
        executable_spec_pb2.PythonClassExecutableSpec
    ]
    SUPPORTED_PLATFORM_CONFIG_TYPE: List[Any] = []

    @staticmethod
    def _collect_requirements(
        stack: "Stack",
        pipeline_node: pipeline_pb2.PipelineNode,
    ) -> List[str]:
        """Collects all requirements necessary to run a step.

        Args:
            stack: Stack on which the step is being executed.
            pipeline_node: Pipeline node info for a step.

        Returns:
            Alphabetically sorted list of pip requirements.
        """
        requirements = stack.requirements()

        # Add pipeline requirements from the corresponding node context
        for context in pipeline_node.contexts.contexts:
            if context.type.name == "pipeline_requirements":
                pipeline_requirements = context.properties[
                    "pipeline_requirements"
                ].field_value.string_value.split(" ")
                requirements.update(pipeline_requirements)
                break

        # TODO [ENG-696]: Find a nice way to set this if the running version of
        #  ZenML is not an official release (e.g. on a development branch)
        # Add the current ZenML version as a requirement
        requirements.add(f"zenml=={zenml.__version__}")

        return sorted(requirements)

    @staticmethod
    def _resolve_user_modules(
        pipeline_node: pipeline_pb2.PipelineNode,
    ) -> Tuple[str, str]:
        """Resolves the main and step module.

        Args:
            pipeline_node: Pipeline node info for a step.

        Returns:
            A tuple containing the path of the resolved main module and step
            class.
        """
        main_module_path = zenml.constants.USER_MAIN_MODULE
        if not main_module_path:
            main_module_path = source_utils.get_module_source_from_module(
                sys.modules["__main__"]
            )

        step_type = cast(str, pipeline_node.node_info.type.name)
        step_module_path, step_class = step_type.rsplit(".", maxsplit=1)
        if step_module_path == "__main__":
            step_module_path = main_module_path

        step_source_path = f"{step_module_path}.{step_class}"

        return main_module_path, step_source_path

    @staticmethod
    def _get_step_operator(
        stack: "Stack", execution_info: data_types.ExecutionInfo
    ) -> "BaseStepOperator":
        """Fetches the step operator specified in the execution info.

        Args:
            stack: Stack on which the step is being executed.
            execution_info: Execution info needed to run the step.

        Returns:
            The step operator to run a step.
        """
        step_operator = stack.step_operator

        # the two following errors should never happen as the stack gets
        # validated before running the pipeline
        if not step_operator:
            raise RuntimeError(
                f"No step operator specified for active stack '{stack.name}'."
            )

        step_operator_property_name = (
            INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_CUSTOM_STEP_OPERATOR
        )
        required_step_operator = json.loads(
            execution_info.exec_properties[step_operator_property_name]
        )
        if required_step_operator != step_operator.name:
            raise RuntimeError(
                f"No step operator named '{required_step_operator}' in active "
                f"stack '{stack.name}'."
            )

        return step_operator

    def run_executor(
        self,
        execution_info: data_types.ExecutionInfo,
    ) -> execution_result_pb2.ExecutorOutput:
        """Invokes the executor with inputs provided by the Launcher.

        Args:
            execution_info: Necessary information to run the executor.

        Returns:
            The executor output.
        """
        # Pretty sure these attributes will always be not None, assert here so
        # mypy doesn't complain
        assert execution_info.pipeline_node
        assert execution_info.pipeline_info
        assert execution_info.pipeline_run_id
        assert execution_info.tmp_dir
        assert execution_info.execution_output_uri

        step_name = execution_info.pipeline_node.node_info.id
        stack = Repository().active_stack
        step_operator = self._get_step_operator(
            stack=stack, execution_info=execution_info
        )

        requirements = self._collect_requirements(
            stack=stack, pipeline_node=execution_info.pipeline_node
        )

        # Write the execution info to a temporary directory inside the artifact
        # store so the step operator entrypoint can load it
        execution_info_path = os.path.join(
            execution_info.tmp_dir, "zenml_execution_info.pb"
        )
        _write_execution_info(execution_info, path=execution_info_path)

        main_module, step_source_path = self._resolve_user_modules(
            pipeline_node=execution_info.pipeline_node
        )

        input_artifact_types_path = os.path.join(
            execution_info.tmp_dir, "input_artifacts.json"
        )
        input_artifact_type_mapping = {
            input_name: source_utils.resolve_class(artifacts[0].__class__)
            for input_name, artifacts in execution_info.input_dict.items()
        }
        yaml_utils.write_json(
            input_artifact_types_path, input_artifact_type_mapping
        )
        entrypoint_command = [
            "python",
            "-m",
            "zenml.step_operators.entrypoint",
            "--main_module",
            main_module,
            "--step_source_path",
            step_source_path,
            "--execution_info_path",
            execution_info_path,
            "--input_artifact_types_path",
            input_artifact_types_path,
        ]

        logger.info(
            "Using step operator `%s` to run step `%s`.",
            step_operator.name,
            step_name,
        )
        logger.debug(
            "Step operator requirements: %s, entrypoint command: %s.",
            requirements,
            entrypoint_command,
        )
        step_operator.launch(
            pipeline_name=execution_info.pipeline_info.id,
            run_name=execution_info.pipeline_run_id,
            requirements=requirements,
            entrypoint_command=entrypoint_command,
        )

        return _read_executor_output(execution_info.execution_output_uri)
run_executor(self, execution_info)

Invokes the executor with inputs provided by the Launcher.

Parameters:

Name Type Description Default
execution_info ExecutionInfo

Necessary information to run the executor.

required

Returns:

Type Description
ExecutorOutput

The executor output.

Source code in zenml/step_operators/step_executor_operator.py
def run_executor(
    self,
    execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
    """Invokes the executor with inputs provided by the Launcher.

    Args:
        execution_info: Necessary information to run the executor.

    Returns:
        The executor output.
    """
    # Pretty sure these attributes will always be not None, assert here so
    # mypy doesn't complain
    assert execution_info.pipeline_node
    assert execution_info.pipeline_info
    assert execution_info.pipeline_run_id
    assert execution_info.tmp_dir
    assert execution_info.execution_output_uri

    step_name = execution_info.pipeline_node.node_info.id
    stack = Repository().active_stack
    step_operator = self._get_step_operator(
        stack=stack, execution_info=execution_info
    )

    requirements = self._collect_requirements(
        stack=stack, pipeline_node=execution_info.pipeline_node
    )

    # Write the execution info to a temporary directory inside the artifact
    # store so the step operator entrypoint can load it
    execution_info_path = os.path.join(
        execution_info.tmp_dir, "zenml_execution_info.pb"
    )
    _write_execution_info(execution_info, path=execution_info_path)

    main_module, step_source_path = self._resolve_user_modules(
        pipeline_node=execution_info.pipeline_node
    )

    input_artifact_types_path = os.path.join(
        execution_info.tmp_dir, "input_artifacts.json"
    )
    input_artifact_type_mapping = {
        input_name: source_utils.resolve_class(artifacts[0].__class__)
        for input_name, artifacts in execution_info.input_dict.items()
    }
    yaml_utils.write_json(
        input_artifact_types_path, input_artifact_type_mapping
    )
    entrypoint_command = [
        "python",
        "-m",
        "zenml.step_operators.entrypoint",
        "--main_module",
        main_module,
        "--step_source_path",
        step_source_path,
        "--execution_info_path",
        execution_info_path,
        "--input_artifact_types_path",
        input_artifact_types_path,
    ]

    logger.info(
        "Using step operator `%s` to run step `%s`.",
        step_operator.name,
        step_name,
    )
    logger.debug(
        "Step operator requirements: %s, entrypoint command: %s.",
        requirements,
        entrypoint_command,
    )
    step_operator.launch(
        pipeline_name=execution_info.pipeline_info.id,
        run_name=execution_info.pipeline_run_id,
        requirements=requirements,
        entrypoint_command=entrypoint_command,
    )

    return _read_executor_output(execution_info.execution_output_uri)