Skip to content

Step Operators

zenml.step_operators special

Step operators allow you to run steps on custom infrastructure.

While an orchestrator defines how and where your entire pipeline runs, a step operator defines how and where an individual step runs. This can be useful in a variety of scenarios. An example could be if one step within a pipeline should run on a separate environment equipped with a GPU (like a trainer step).

base_step_operator

Base class for ZenML step operators.

BaseStepOperator (StackComponent, ABC)

Base class for all ZenML step operators.

Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperator(StackComponent, ABC):
    """Base class for all ZenML step operators."""

    @property
    def config(self) -> BaseStepOperatorConfig:
        """Returns the config of the step operator.

        Returns:
            The config of the step operator.
        """
        return cast(BaseStepOperatorConfig, self._config)

    @abstractmethod
    def launch(
        self,
        info: "StepRunInfo",
        entrypoint_command: List[str],
    ) -> None:
        """Abstract method to execute a step.

        Subclasses must implement this method and launch a **synchronous**
        job that executes the `entrypoint_command`.

        Args:
            info: Information about the step run.
            entrypoint_command: Command that executes the step.
        """
config: BaseStepOperatorConfig property readonly

Returns the config of the step operator.

Returns:

Type Description
BaseStepOperatorConfig

The config of the step operator.

launch(self, info, entrypoint_command)

Abstract method to execute a step.

Subclasses must implement this method and launch a synchronous job that executes the entrypoint_command.

Parameters:

Name Type Description Default
info StepRunInfo

Information about the step run.

required
entrypoint_command List[str]

Command that executes the step.

required
Source code in zenml/step_operators/base_step_operator.py
@abstractmethod
def launch(
    self,
    info: "StepRunInfo",
    entrypoint_command: List[str],
) -> None:
    """Abstract method to execute a step.

    Subclasses must implement this method and launch a **synchronous**
    job that executes the `entrypoint_command`.

    Args:
        info: Information about the step run.
        entrypoint_command: Command that executes the step.
    """

BaseStepOperatorConfig (StackComponentConfig) pydantic-model

Base config for step operators.

Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperatorConfig(StackComponentConfig):
    """Base config for step operators."""

    @root_validator(pre=True)
    def _deprecations(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate and/or remove deprecated fields.

        Args:
            values: The values to validate.

        Returns:
            The validated values.
        """
        if "base_image" in values:
            image_name = values.pop("base_image", None)
            if image_name:
                logger.warning(
                    "The 'base_image' field has been deprecated. To use a "
                    "custom base container image with your "
                    "step operators, please use the DockerSettings in your "
                    "pipeline (see https://docs.zenml.io/advanced-guide/pipelines/containerization)."
                )

        return values

BaseStepOperatorFlavor (Flavor)

Base class for all ZenML step operator flavors.

Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperatorFlavor(Flavor):
    """Base class for all ZenML step operator flavors."""

    @property
    def type(self) -> StackComponentType:
        """Returns the flavor type.

        Returns:
            The type of the flavor.
        """
        return StackComponentType.STEP_OPERATOR

    @property
    def config_class(self) -> Type[BaseStepOperatorConfig]:
        """Returns the config class for this flavor.

        Returns:
            The config class for this flavor.
        """
        return BaseStepOperatorConfig

    @property
    @abstractmethod
    def implementation_class(self) -> Type[BaseStepOperator]:
        """Returns the implementation class for this flavor.

        Returns:
            The implementation class for this flavor.
        """
config_class: Type[zenml.step_operators.base_step_operator.BaseStepOperatorConfig] property readonly

Returns the config class for this flavor.

Returns:

Type Description
Type[zenml.step_operators.base_step_operator.BaseStepOperatorConfig]

The config class for this flavor.

implementation_class: Type[zenml.step_operators.base_step_operator.BaseStepOperator] property readonly

Returns the implementation class for this flavor.

Returns:

Type Description
Type[zenml.step_operators.base_step_operator.BaseStepOperator]

The implementation class for this flavor.

type: StackComponentType property readonly

Returns the flavor type.

Returns:

Type Description
StackComponentType

The type of the flavor.

step_executor_operator

Custom StepExecutorOperator which can be passed to the step operator.

StepExecutorOperator (BaseExecutorOperator)

StepExecutorOperator extends TFX's BaseExecutorOperator.

This class can be passed as a custom executor operator during a pipeline run which will then be used to call the step's configured step operator to launch it in some environment.

Source code in zenml/step_operators/step_executor_operator.py
class StepExecutorOperator(BaseExecutorOperator):
    """StepExecutorOperator extends TFX's BaseExecutorOperator.

    This class can be passed as a custom executor operator during
    a pipeline run which will then be used to call the step's
    configured step operator to launch it in some environment.
    """

    SUPPORTED_EXECUTOR_SPEC_TYPE = [
        executable_spec_pb2.PythonClassExecutableSpec
    ]
    SUPPORTED_PLATFORM_CONFIG_TYPE: List[Any] = []

    @staticmethod
    def _get_step_operator(
        stack: "Stack", step_operator_name: str
    ) -> "BaseStepOperator":
        """Fetches the step operator specified in the execution info.

        Args:
            stack: Stack on which the step is being executed.
            step_operator_name: Name of the step operator to get.

        Returns:
            The step operator to run a step.

        Raises:
            RuntimeError: If no active step operator is found.
        """
        step_operator = stack.step_operator

        # the two following errors should never happen as the stack gets
        # validated before running the pipeline
        if not step_operator:
            raise RuntimeError(
                f"No step operator specified for active stack '{stack.name}'."
            )

        if step_operator_name != step_operator.name:
            raise RuntimeError(
                f"No step operator named '{step_operator_name}' in active "
                f"stack '{stack.name}'."
            )

        return step_operator

    @staticmethod
    def _get_step_name_in_pipeline(
        execution_info: data_types.ExecutionInfo,
    ) -> str:
        """Gets the name of a step inside its pipeline.

        Args:
            execution_info: The step execution info.

        Returns:
            The name of the step in the pipeline.
        """
        property_name = (
            INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_PIPELINE_PARAMETER_NAME
        )
        return cast(
            str, json.loads(execution_info.exec_properties[property_name])
        )

    def run_executor(
        self,
        execution_info: data_types.ExecutionInfo,
    ) -> execution_result_pb2.ExecutorOutput:
        """Invokes the executor with inputs provided by the Launcher.

        Args:
            execution_info: Necessary information to run the executor.

        Returns:
            The executor output.
        """
        # Pretty sure these attributes will always be not None, assert here so
        # mypy doesn't complain
        assert execution_info.pipeline_node
        assert execution_info.pipeline_info
        assert execution_info.pipeline_run_id
        assert execution_info.tmp_dir
        assert execution_info.execution_output_uri

        step = proto_utils.get_step(pipeline_node=execution_info.pipeline_node)
        pipeline_config = proto_utils.get_pipeline_config(
            pipeline_node=execution_info.pipeline_node
        )
        assert step.config.step_operator

        stack = Client().active_stack
        step_operator = self._get_step_operator(
            stack=stack, step_operator_name=step.config.step_operator
        )

        # Write the execution info to a temporary directory inside the artifact
        # store so the step operator entrypoint can load it
        execution_info_path = os.path.join(
            execution_info.tmp_dir, "zenml_execution_info.pb"
        )
        _write_execution_info(execution_info, path=execution_info_path)

        step_name_in_pipeline = self._get_step_name_in_pipeline(execution_info)

        entrypoint_command = (
            StepOperatorEntrypointConfiguration.get_entrypoint_command()
            + StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
                step_name=step_name_in_pipeline,
                execution_info_path=execution_info_path,
            )
        )

        logger.info(
            "Using step operator `%s` to run step `%s`.",
            step_operator.name,
            step_name_in_pipeline,
        )
        step_run_info = StepRunInfo(
            config=step.config,
            pipeline=pipeline_config,
            run_name=execution_info.pipeline_run_id,
        )
        step_operator.launch(
            info=step_run_info,
            entrypoint_command=entrypoint_command,
        )

        return _read_executor_output(execution_info.execution_output_uri)
run_executor(self, execution_info)

Invokes the executor with inputs provided by the Launcher.

Parameters:

Name Type Description Default
execution_info ExecutionInfo

Necessary information to run the executor.

required

Returns:

Type Description
ExecutorOutput

The executor output.

Source code in zenml/step_operators/step_executor_operator.py
def run_executor(
    self,
    execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
    """Invokes the executor with inputs provided by the Launcher.

    Args:
        execution_info: Necessary information to run the executor.

    Returns:
        The executor output.
    """
    # Pretty sure these attributes will always be not None, assert here so
    # mypy doesn't complain
    assert execution_info.pipeline_node
    assert execution_info.pipeline_info
    assert execution_info.pipeline_run_id
    assert execution_info.tmp_dir
    assert execution_info.execution_output_uri

    step = proto_utils.get_step(pipeline_node=execution_info.pipeline_node)
    pipeline_config = proto_utils.get_pipeline_config(
        pipeline_node=execution_info.pipeline_node
    )
    assert step.config.step_operator

    stack = Client().active_stack
    step_operator = self._get_step_operator(
        stack=stack, step_operator_name=step.config.step_operator
    )

    # Write the execution info to a temporary directory inside the artifact
    # store so the step operator entrypoint can load it
    execution_info_path = os.path.join(
        execution_info.tmp_dir, "zenml_execution_info.pb"
    )
    _write_execution_info(execution_info, path=execution_info_path)

    step_name_in_pipeline = self._get_step_name_in_pipeline(execution_info)

    entrypoint_command = (
        StepOperatorEntrypointConfiguration.get_entrypoint_command()
        + StepOperatorEntrypointConfiguration.get_entrypoint_arguments(
            step_name=step_name_in_pipeline,
            execution_info_path=execution_info_path,
        )
    )

    logger.info(
        "Using step operator `%s` to run step `%s`.",
        step_operator.name,
        step_name_in_pipeline,
    )
    step_run_info = StepRunInfo(
        config=step.config,
        pipeline=pipeline_config,
        run_name=execution_info.pipeline_run_id,
    )
    step_operator.launch(
        info=step_run_info,
        entrypoint_command=entrypoint_command,
    )

    return _read_executor_output(execution_info.execution_output_uri)

step_operator_entrypoint_configuration

Abstract base class for entrypoint configurations that run a single step.

StepOperatorEntrypointConfiguration (StepEntrypointConfiguration)

Base class for step operator entrypoint configurations.

Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration):
    """Base class for step operator entrypoint configurations."""

    @classmethod
    def get_entrypoint_options(cls) -> Set[str]:
        """Gets all options required for running with this configuration.

        Returns:
            The superclass options as well as an option for the path to the
            execution info.
        """
        return super().get_entrypoint_options() | {EXECUTION_INFO_PATH_OPTION}

    @classmethod
    def get_entrypoint_arguments(
        cls,
        **kwargs: Any,
    ) -> List[str]:
        """Gets all arguments that the entrypoint command should be called with.

        Args:
            **kwargs: Kwargs, must include the execution info path.

        Returns:
            The superclass arguments as well as arguments for the path to the
            execution info.
        """
        return super().get_entrypoint_arguments(**kwargs) + [
            f"--{EXECUTION_INFO_PATH_OPTION}",
            kwargs[EXECUTION_INFO_PATH_OPTION],
        ]

    def _run_step(
        self,
        step: "Step",
        deployment: "PipelineDeployment",
    ) -> Optional[data_types.ExecutionInfo]:
        """Runs a single step.

        Args:
            step: The step to run.
            deployment: The deployment configuration.

        Raises:
            RuntimeError: If the step executor class does not exist.

        Returns:
            Step execution info.
        """
        # Make sure the artifact store is loaded before we load the execution
        # info
        stack = Client().active_stack

        execution_info_path = self.entrypoint_args[EXECUTION_INFO_PATH_OPTION]
        execution_info = self._load_execution_info(execution_info_path)
        executor_class = step_utils.get_executor_class(step.config.name)
        if not executor_class:
            raise RuntimeError(
                f"Unable to find executor class for step {step.config.name}."
            )

        executor = self._configure_executor(
            executor_class=executor_class, execution_info=execution_info
        )

        stack.orchestrator._ensure_artifact_classes_loaded(step.config)
        step_run_info = StepRunInfo(
            config=step.config,
            pipeline=deployment.pipeline,
            run_name=execution_info.pipeline_run_id,
        )

        stack.prepare_step_run(info=step_run_info)
        try:
            run_with_executor(execution_info=execution_info, executor=executor)
        finally:
            stack.cleanup_step_run(info=step_run_info)

        return execution_info

    @staticmethod
    def _load_execution_info(execution_info_path: str) -> ExecutionInfo:
        """Loads the execution info from the given path.

        Args:
            execution_info_path: Path to the execution info file.

        Returns:
            Execution info.
        """
        with fileio.open(execution_info_path, "rb") as f:
            execution_info_proto = ExecutionInvocation.FromString(f.read())

        return ExecutionInfo.from_proto(execution_info_proto)

    @staticmethod
    def _configure_executor(
        executor_class: Type[BaseExecutor], execution_info: ExecutionInfo
    ) -> BaseExecutor:
        """Creates and configures an executor instance.

        Args:
            executor_class: The class of the executor instance.
            execution_info: Execution info for the executor.

        Returns:
            A configured executor instance.
        """
        context = BaseExecutor.Context(
            tmp_dir=execution_info.tmp_dir,
            unique_id=str(execution_info.execution_id),
            executor_output_uri=execution_info.execution_output_uri,
            stateful_working_dir=execution_info.stateful_working_dir,
            pipeline_node=execution_info.pipeline_node,
            pipeline_info=execution_info.pipeline_info,
            pipeline_run_id=execution_info.pipeline_run_id,
        )

        return executor_class(context=context)
get_entrypoint_arguments(**kwargs) classmethod

Gets all arguments that the entrypoint command should be called with.

Parameters:

Name Type Description Default
**kwargs Any

Kwargs, must include the execution info path.

{}

Returns:

Type Description
List[str]

The superclass arguments as well as arguments for the path to the execution info.

Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
@classmethod
def get_entrypoint_arguments(
    cls,
    **kwargs: Any,
) -> List[str]:
    """Gets all arguments that the entrypoint command should be called with.

    Args:
        **kwargs: Kwargs, must include the execution info path.

    Returns:
        The superclass arguments as well as arguments for the path to the
        execution info.
    """
    return super().get_entrypoint_arguments(**kwargs) + [
        f"--{EXECUTION_INFO_PATH_OPTION}",
        kwargs[EXECUTION_INFO_PATH_OPTION],
    ]
get_entrypoint_options() classmethod

Gets all options required for running with this configuration.

Returns:

Type Description
Set[str]

The superclass options as well as an option for the path to the execution info.

Source code in zenml/step_operators/step_operator_entrypoint_configuration.py
@classmethod
def get_entrypoint_options(cls) -> Set[str]:
    """Gets all options required for running with this configuration.

    Returns:
        The superclass options as well as an option for the path to the
        execution info.
    """
    return super().get_entrypoint_options() | {EXECUTION_INFO_PATH_OPTION}