Step Operators
zenml.step_operators
special
Step operators allow you to run steps on custom infrastructure.
While an orchestrator defines how and where your entire pipeline runs, a step operator defines how and where an individual step runs. This can be useful in a variety of scenarios. An example could be if one step within a pipeline should run on a separate environment equipped with a GPU (like a trainer step).
base_step_operator
Base class for ZenML step operators.
BaseStepOperator (StackComponent, ABC)
pydantic-model
Base class for all ZenML step operators.
Source code in zenml/step_operators/base_step_operator.py
class BaseStepOperator(StackComponent, ABC):
"""Base class for all ZenML step operators."""
# Class Configuration
TYPE: ClassVar[StackComponentType] = StackComponentType.STEP_OPERATOR
@abstractmethod
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Abstract method to execute a step.
Concrete step operator subclasses must implement the following
functionality in this method:
- Prepare the execution environment and install all the necessary
`requirements`
- Launch a **synchronous** job that executes the `entrypoint_command`
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
launch(self, pipeline_name, run_name, requirements, entrypoint_command)
Abstract method to execute a step.
Concrete step operator subclasses must implement the following
functionality in this method:
- Prepare the execution environment and install all the necessary
requirements
- Launch a synchronous job that executes the entrypoint_command
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
requirements |
List[str] |
List of pip requirements that must be installed inside the step operator environment. |
required |
Source code in zenml/step_operators/base_step_operator.py
@abstractmethod
def launch(
self,
pipeline_name: str,
run_name: str,
requirements: List[str],
entrypoint_command: List[str],
) -> None:
"""Abstract method to execute a step.
Concrete step operator subclasses must implement the following
functionality in this method:
- Prepare the execution environment and install all the necessary
`requirements`
- Launch a **synchronous** job that executes the `entrypoint_command`
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
entrypoint_command: Command that executes the step.
requirements: List of pip requirements that must be installed
inside the step operator environment.
"""
entrypoint
Entrypoint for the step operator.
configure_executor(executor_class, execution_info)
Creates and configures an executor instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
executor_class |
Type[tfx.dsl.components.base.base_executor.BaseExecutor] |
The class of the executor instance. |
required |
execution_info |
ExecutionInfo |
Execution info for the executor. |
required |
Returns:
Type | Description |
---|---|
BaseExecutor |
A configured executor instance. |
Source code in zenml/step_operators/entrypoint.py
def configure_executor(
executor_class: Type[BaseExecutor], execution_info: ExecutionInfo
) -> BaseExecutor:
"""Creates and configures an executor instance.
Args:
executor_class: The class of the executor instance.
execution_info: Execution info for the executor.
Returns:
A configured executor instance.
"""
context = BaseExecutor.Context(
tmp_dir=execution_info.tmp_dir,
unique_id=str(execution_info.execution_id),
executor_output_uri=execution_info.execution_output_uri,
stateful_working_dir=execution_info.stateful_working_dir,
pipeline_node=execution_info.pipeline_node,
pipeline_info=execution_info.pipeline_info,
pipeline_run_id=execution_info.pipeline_run_id,
)
return executor_class(context=context)
create_executor_class(step_source_path, input_artifact_type_mapping)
Creates an executor class for a given step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_source_path |
str |
Import path of the step to run. |
required |
input_artifact_type_mapping |
Dict[str, str] |
A dictionary mapping input names to a string representation of their artifact classes. |
required |
Returns:
Type | Description |
---|---|
Type[zenml.steps.utils._FunctionExecutor] |
A class of an executor instance. |
Source code in zenml/step_operators/entrypoint.py
def create_executor_class(
step_source_path: str,
input_artifact_type_mapping: Dict[str, str],
) -> Type[_FunctionExecutor]:
"""Creates an executor class for a given step.
Args:
step_source_path: Import path of the step to run.
input_artifact_type_mapping: A dictionary mapping input names to
a string representation of their artifact classes.
Returns:
A class of an executor instance.
"""
step_class = cast(
Type[BaseStep], source_utils.load_source_path_class(step_source_path)
)
step_instance = step_class()
materializers = step_instance.get_materializers(ensure_complete=True)
# We don't publish anything to the metadata store inside this environment,
# so the specific artifact classes don't matter
input_spec = {}
for key, value in step_class.INPUT_SIGNATURE.items():
input_spec[key] = BaseArtifact
output_spec = {}
for key, value in step_class.OUTPUT_SIGNATURE.items():
output_spec[key] = type_registry.get_artifact_type(value)[0]
execution_parameters = {
**step_instance.PARAM_SPEC,
**step_instance._internal_execution_parameters,
}
component_class = generate_component_class(
step_name=step_instance.name,
step_module=step_class.__module__,
input_spec=input_spec,
output_spec=output_spec,
execution_parameter_names=set(execution_parameters),
step_function=step_instance.entrypoint,
materializers=materializers,
)
return cast(
Type[_FunctionExecutor], component_class.EXECUTOR_SPEC.executor_class
)
load_execution_info(execution_info_path)
Loads the execution info from the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution_info_path |
str |
Path to the execution info file. |
required |
Returns:
Type | Description |
---|---|
ExecutionInfo |
Execution info. |
Source code in zenml/step_operators/entrypoint.py
def load_execution_info(execution_info_path: str) -> ExecutionInfo:
"""Loads the execution info from the given path.
Args:
execution_info_path: Path to the execution info file.
Returns:
Execution info.
"""
with fileio.open(execution_info_path, "rb") as f:
execution_info_proto = ExecutionInvocation.FromString(f.read())
return ExecutionInfo.from_proto(execution_info_proto)
step_executor_operator
Custom definition of a Step Executor Operator which can be passed into the Step Operator.
StepExecutorOperator (BaseExecutorOperator)
StepExecutorOperator extends TFX's BaseExecutorOperator.
This class can be passed as a custom executor operator during a pipeline run which will then be used to call the step's configured step operator to launch it in some environment.
Source code in zenml/step_operators/step_executor_operator.py
class StepExecutorOperator(BaseExecutorOperator):
"""StepExecutorOperator extends TFX's BaseExecutorOperator.
This class can be passed as a custom executor operator during
a pipeline run which will then be used to call the step's
configured step operator to launch it in some environment.
"""
SUPPORTED_EXECUTOR_SPEC_TYPE = [
executable_spec_pb2.PythonClassExecutableSpec
]
SUPPORTED_PLATFORM_CONFIG_TYPE: List[Any] = []
@staticmethod
def _collect_requirements(
stack: "Stack",
pipeline_node: pipeline_pb2.PipelineNode,
) -> List[str]:
"""Collects all requirements necessary to run a step.
Args:
stack: Stack on which the step is being executed.
pipeline_node: Pipeline node info for a step.
Returns:
Alphabetically sorted list of pip requirements.
"""
requirements = stack.requirements()
# Add pipeline requirements from the corresponding node context
for context in pipeline_node.contexts.contexts:
if context.type.name == "pipeline_requirements":
pipeline_requirements = context.properties[
"pipeline_requirements"
].field_value.string_value.split(" ")
requirements.update(pipeline_requirements)
break
# TODO [ENG-696]: Find a nice way to set this if the running version of
# ZenML is not an official release (e.g. on a development branch)
# Add the current ZenML version as a requirement
requirements.add(f"zenml=={zenml.__version__}")
return sorted(requirements)
@staticmethod
def _resolve_user_modules(
pipeline_node: pipeline_pb2.PipelineNode,
) -> Tuple[str, str]:
"""Resolves the main and step module.
Args:
pipeline_node: Pipeline node info for a step.
Returns:
A tuple containing the path of the resolved main module and step
class.
"""
main_module_path = zenml.constants.USER_MAIN_MODULE
if not main_module_path:
main_module_path = source_utils.get_module_source_from_module(
sys.modules["__main__"]
)
step_type = cast(str, pipeline_node.node_info.type.name)
step_module_path, step_class = step_type.rsplit(".", maxsplit=1)
if step_module_path == "__main__":
step_module_path = main_module_path
step_source_path = f"{step_module_path}.{step_class}"
return main_module_path, step_source_path
@staticmethod
def _get_step_operator(
stack: "Stack", execution_info: data_types.ExecutionInfo
) -> "BaseStepOperator":
"""Fetches the step operator specified in the execution info.
Args:
stack: Stack on which the step is being executed.
execution_info: Execution info needed to run the step.
Returns:
The step operator to run a step.
Raises:
RuntimeError: If no active step operator is found.
"""
step_operator = stack.step_operator
# the two following errors should never happen as the stack gets
# validated before running the pipeline
if not step_operator:
raise RuntimeError(
f"No step operator specified for active stack '{stack.name}'."
)
step_operator_property_name = (
INTERNAL_EXECUTION_PARAMETER_PREFIX + PARAM_CUSTOM_STEP_OPERATOR
)
required_step_operator = json.loads(
execution_info.exec_properties[step_operator_property_name]
)
if required_step_operator != step_operator.name:
raise RuntimeError(
f"No step operator named '{required_step_operator}' in active "
f"stack '{stack.name}'."
)
return step_operator
def run_executor(
self,
execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
"""Invokes the executor with inputs provided by the Launcher.
Args:
execution_info: Necessary information to run the executor.
Returns:
The executor output.
"""
# Pretty sure these attributes will always be not None, assert here so
# mypy doesn't complain
assert execution_info.pipeline_node
assert execution_info.pipeline_info
assert execution_info.pipeline_run_id
assert execution_info.tmp_dir
assert execution_info.execution_output_uri
step_name = execution_info.pipeline_node.node_info.id
stack = Repository().active_stack
step_operator = self._get_step_operator(
stack=stack, execution_info=execution_info
)
requirements = self._collect_requirements(
stack=stack, pipeline_node=execution_info.pipeline_node
)
# Write the execution info to a temporary directory inside the artifact
# store so the step operator entrypoint can load it
execution_info_path = os.path.join(
execution_info.tmp_dir, "zenml_execution_info.pb"
)
_write_execution_info(execution_info, path=execution_info_path)
main_module, step_source_path = self._resolve_user_modules(
pipeline_node=execution_info.pipeline_node
)
input_artifact_types_path = os.path.join(
execution_info.tmp_dir, "input_artifacts.json"
)
input_artifact_type_mapping = {
input_name: source_utils.resolve_class(artifacts[0].__class__)
for input_name, artifacts in execution_info.input_dict.items()
}
yaml_utils.write_json(
input_artifact_types_path, input_artifact_type_mapping
)
entrypoint_command = [
"python",
"-m",
"zenml.step_operators.entrypoint",
"--main_module",
main_module,
"--step_source_path",
step_source_path,
"--execution_info_path",
execution_info_path,
"--input_artifact_types_path",
input_artifact_types_path,
]
logger.info(
"Using step operator `%s` to run step `%s`.",
step_operator.name,
step_name,
)
logger.debug(
"Step operator requirements: %s, entrypoint command: %s.",
requirements,
entrypoint_command,
)
step_operator.launch(
pipeline_name=execution_info.pipeline_info.id,
run_name=execution_info.pipeline_run_id,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
return _read_executor_output(execution_info.execution_output_uri)
run_executor(self, execution_info)
Invokes the executor with inputs provided by the Launcher.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
execution_info |
ExecutionInfo |
Necessary information to run the executor. |
required |
Returns:
Type | Description |
---|---|
ExecutorOutput |
The executor output. |
Source code in zenml/step_operators/step_executor_operator.py
def run_executor(
self,
execution_info: data_types.ExecutionInfo,
) -> execution_result_pb2.ExecutorOutput:
"""Invokes the executor with inputs provided by the Launcher.
Args:
execution_info: Necessary information to run the executor.
Returns:
The executor output.
"""
# Pretty sure these attributes will always be not None, assert here so
# mypy doesn't complain
assert execution_info.pipeline_node
assert execution_info.pipeline_info
assert execution_info.pipeline_run_id
assert execution_info.tmp_dir
assert execution_info.execution_output_uri
step_name = execution_info.pipeline_node.node_info.id
stack = Repository().active_stack
step_operator = self._get_step_operator(
stack=stack, execution_info=execution_info
)
requirements = self._collect_requirements(
stack=stack, pipeline_node=execution_info.pipeline_node
)
# Write the execution info to a temporary directory inside the artifact
# store so the step operator entrypoint can load it
execution_info_path = os.path.join(
execution_info.tmp_dir, "zenml_execution_info.pb"
)
_write_execution_info(execution_info, path=execution_info_path)
main_module, step_source_path = self._resolve_user_modules(
pipeline_node=execution_info.pipeline_node
)
input_artifact_types_path = os.path.join(
execution_info.tmp_dir, "input_artifacts.json"
)
input_artifact_type_mapping = {
input_name: source_utils.resolve_class(artifacts[0].__class__)
for input_name, artifacts in execution_info.input_dict.items()
}
yaml_utils.write_json(
input_artifact_types_path, input_artifact_type_mapping
)
entrypoint_command = [
"python",
"-m",
"zenml.step_operators.entrypoint",
"--main_module",
main_module,
"--step_source_path",
step_source_path,
"--execution_info_path",
execution_info_path,
"--input_artifact_types_path",
input_artifact_types_path,
]
logger.info(
"Using step operator `%s` to run step `%s`.",
step_operator.name,
step_name,
)
logger.debug(
"Step operator requirements: %s, entrypoint command: %s.",
requirements,
entrypoint_command,
)
step_operator.launch(
pipeline_name=execution_info.pipeline_info.id,
run_name=execution_info.pipeline_run_id,
requirements=requirements,
entrypoint_command=entrypoint_command,
)
return _read_executor_output(execution_info.execution_output_uri)