Skip to content

Steps

zenml.steps special

Initializer for ZenML steps.

A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.

Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.

Steps can be subclassed from the BaseStep class, or used via our @step decorator.

base_parameters

Step parameters.

BaseParameters (BaseModel) pydantic-model

Base class to pass parameters into a step.

Source code in zenml/steps/base_parameters.py
class BaseParameters(BaseModel):
    """Base class to pass parameters into a step."""

base_step

Base Step for ZenML.

BaseStep

Abstract base class for all ZenML steps.

Source code in zenml/steps/base_step.py
class BaseStep(metaclass=BaseStepMeta):
    """Abstract base class for all ZenML steps."""

    def __init__(
        self,
        *args: Any,
        name: Optional[str] = None,
        enable_cache: Optional[bool] = None,
        enable_artifact_metadata: Optional[bool] = None,
        enable_artifact_visualization: Optional[bool] = None,
        enable_step_logs: Optional[bool] = None,
        experiment_tracker: Optional[str] = None,
        step_operator: Optional[str] = None,
        parameters: Optional["ParametersOrDict"] = None,
        output_materializers: Optional[
            "OutputMaterializersSpecification"
        ] = None,
        settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
        extra: Optional[Dict[str, Any]] = None,
        on_failure: Optional["HookSpecification"] = None,
        on_success: Optional["HookSpecification"] = None,
        **kwargs: Any,
    ) -> None:
        """Initializes a step.

        Args:
            *args: Positional arguments passed to the step.
            name: The name of the step.
            enable_cache: If caching should be enabled for this step.
            enable_artifact_metadata: If artifact metadata should be enabled
                for this step.
            enable_artifact_visualization: If artifact visualization should be
                enabled for this step.
            enable_step_logs: Enable step logs for this step.
            experiment_tracker: The experiment tracker to use for this step.
            step_operator: The step operator to use for this step.
            parameters: Function parameters for this step
            output_materializers: Output materializers for this step. If
                given as a dict, the keys must be a subset of the output names
                of this step. If a single value (type or string) is given, the
                materializer will be used for all outputs.
            settings: settings for this step.
            extra: Extra configurations for this step.
            on_failure: Callback function in event of failure of the step. Can
                be a function with a single argument of type `BaseException`, or
                a source path to such a function (e.g. `module.my_function`).
            on_success: Callback function in event of success of the step. Can
                be a function with no arguments, or a source path to such a
                function (e.g. `module.my_function`).
            **kwargs: Keyword arguments passed to the step.
        """
        self._upstream_steps: Set["BaseStep"] = set()
        self.entrypoint_definition = validate_entrypoint_function(
            self.entrypoint, reserved_arguments=["after", "id"]
        )

        name = name or self.__class__.__name__

        requires_context = self.entrypoint_definition.context is not None
        if enable_cache is None:
            if requires_context:
                # Using the StepContext inside a step provides access to
                # external resources which might influence the step execution.
                # We therefore disable caching unless it is explicitly enabled
                enable_cache = False
                logger.debug(
                    "Step '%s': Step context required and caching not "
                    "explicitly enabled.",
                    name,
                )

        logger.debug(
            "Step '%s': Caching %s.",
            name,
            "enabled" if enable_cache is not False else "disabled",
        )
        logger.debug(
            "Step '%s': Artifact metadata %s.",
            name,
            "enabled" if enable_artifact_metadata is not False else "disabled",
        )
        logger.debug(
            "Step '%s': Artifact visualization %s.",
            name,
            "enabled"
            if enable_artifact_visualization is not False
            else "disabled",
        )

        logger.debug(
            "Step '%s': logs %s.",
            name,
            "enabled" if enable_step_logs is not False else "disabled",
        )

        self._configuration = PartialStepConfiguration(
            name=name,
            enable_cache=enable_cache,
            enable_artifact_metadata=enable_artifact_metadata,
            enable_artifact_visualization=enable_artifact_visualization,
            enable_step_logs=enable_step_logs,
        )
        self.configure(
            experiment_tracker=experiment_tracker,
            step_operator=step_operator,
            output_materializers=output_materializers,
            parameters=parameters,
            settings=settings,
            extra=extra,
            on_failure=on_failure,
            on_success=on_success,
        )
        self._verify_and_apply_init_params(*args, **kwargs)

    @abstractmethod
    def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
        """Abstract method for core step logic.

        Args:
            *args: Positional arguments passed to the step.
            **kwargs: Keyword arguments passed to the step.

        Returns:
            The output of the step.
        """

    @classmethod
    def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
        """Loads a step from source.

        Args:
            source: The path to the step source.

        Returns:
            The loaded step.

        Raises:
            ValueError: If the source is not a valid step source.
        """
        obj = source_utils.load(source)

        if isinstance(obj, BaseStep):
            return obj
        elif isinstance(obj, type) and issubclass(obj, BaseStep):
            return obj()
        else:
            raise ValueError("Invalid step source.")

    def resolve(self) -> Source:
        """Resolves the step.

        Returns:
            The step source.
        """
        return source_utils.resolve(self.__class__)

    @property
    def upstream_steps(self) -> Set["BaseStep"]:
        """Names of the upstream steps of this step.

        This property will only contain the full set of upstream steps once
        it's parent pipeline `connect(...)` method was called.

        Returns:
            Set of upstream step names.
        """
        return self._upstream_steps

    def after(self, step: "BaseStep") -> None:
        """Adds an upstream step to this step.

        Calling this method makes sure this step only starts running once the
        given step has successfully finished executing.

        **Note**: This can only be called inside the pipeline connect function
        which is decorated with the `@pipeline` decorator. Any calls outside
        this function will be ignored.

        Example:
        The following pipeline will run its steps sequentially in the following
        order: step_2 -> step_1 -> step_3

        ```python
        @pipeline
        def example_pipeline(step_1, step_2, step_3):
            step_1.after(step_2)
            step_3(step_1(), step_2())
        ```

        Args:
            step: A step which should finish executing before this step is
                started.
        """
        self._upstream_steps.add(step)

    @property
    def source_object(self) -> Any:
        """The source object of this step.

        Returns:
            The source object of this step.
        """
        return self.__class__

    @property
    def source_code(self) -> str:
        """The source code of this step.

        Returns:
            The source code of this step.
        """
        return inspect.getsource(self.source_object)

    @property
    def docstring(self) -> Optional[str]:
        """The docstring of this step.

        Returns:
            The docstring of this step.
        """
        return self.__doc__

    @property
    def caching_parameters(self) -> Dict[str, Any]:
        """Caching parameters for this step.

        Returns:
            A dictionary containing the caching parameters
        """
        parameters = {}
        parameters[
            STEP_SOURCE_PARAMETER_NAME
        ] = source_code_utils.get_hashed_source_code(self.source_object)

        for name, output in self.configuration.outputs.items():
            if output.materializer_source:
                key = f"{name}_materializer_source"
                hash_ = hashlib.md5()

                for source in output.materializer_source:
                    materializer_class = source_utils.load(source)
                    code_hash = source_code_utils.get_hashed_source_code(
                        materializer_class
                    )
                    hash_.update(code_hash.encode())

                parameters[key] = hash_.hexdigest()

        return parameters

    def _verify_and_apply_init_params(self, *args: Any, **kwargs: Any) -> None:
        """Verifies the initialization args and kwargs of this step.

        This method makes sure that there is only one parameters object passed
        at initialization and that it was passed using the correct name and
        type specified in the step declaration.

        Args:
            *args: The args passed to the init method of this step.
            **kwargs: The kwargs passed to the init method of this step.

        Raises:
            StepInterfaceError: If there are too many arguments or arguments
                with a wrong name/type.
        """
        maximum_arg_count = (
            1 if self.entrypoint_definition.legacy_params else 0
        )
        arg_count = len(args) + len(kwargs)
        if arg_count > maximum_arg_count:
            raise StepInterfaceError(
                f"Too many arguments ({arg_count}, expected: "
                f"{maximum_arg_count}) passed when creating a "
                f"'{self.name}' step."
            )

        if self.entrypoint_definition.legacy_params:
            if args:
                config = args[0]
            elif kwargs:
                key, config = kwargs.popitem()

                if key != self.entrypoint_definition.legacy_params.name:
                    raise StepInterfaceError(
                        f"Unknown keyword argument '{key}' when creating a "
                        f"'{self.name}' step, only expected a single "
                        "argument with key "
                        f"'{self.entrypoint_definition.legacy_params.name}'."
                    )
            else:
                # This step requires configuration parameters but no parameters
                # object was passed as an argument. The parameters might be
                # set via default values in the parameters class or in a
                # configuration file, so we continue for now and verify
                # that all parameters are set before running the step
                return

            if not isinstance(
                config, self.entrypoint_definition.legacy_params.annotation
            ):
                raise StepInterfaceError(
                    f"`{config}` object passed when creating a "
                    f"'{self.name}' step is not a "
                    f"`{self.entrypoint_definition.legacy_params.annotation.__name__} "
                    "` instance."
                )

            self.configure(parameters=config)

    def _parse_call_args(
        self, *args: Any, **kwargs: Any
    ) -> Tuple[
        Dict[str, "StepArtifact"],
        Dict[str, "ExternalArtifact"],
        Dict[str, Any],
    ]:
        """Parses the call args for the step entrypoint.

        Args:
            *args: Entrypoint function arguments.
            **kwargs: Entrypoint function keyword arguments.

        Raises:
            StepInterfaceError: If invalid function arguments were passed.

        Returns:
            The artifacts, external artifacts and parameters for the step.
        """
        signature = get_step_entrypoint_signature(step=self)

        try:
            bound_args = signature.bind_partial(*args, **kwargs)
        except TypeError as e:
            raise StepInterfaceError(
                f"Wrong arguments when calling step '{self.name}': {e}"
            ) from e

        artifacts = {}
        external_artifacts = {}
        parameters = {}

        for key, value in bound_args.arguments.items():
            self.entrypoint_definition.validate_input(key=key, value=value)

            if isinstance(value, StepArtifact):
                artifacts[key] = value
                if key in self.configuration.parameters:
                    logger.warning(
                        "Got duplicate value for step input %s, using value "
                        "provided as artifact.",
                        key,
                    )
            elif isinstance(value, ExternalArtifact):
                external_artifacts[key] = value
                if not value._id:
                    # If the external artifact references a fixed artifact by
                    # ID, caching behaves as expected.
                    logger.warning(
                        "Using an external artifact as step input currently "
                        "invalidates caching for the step and all downstream "
                        "steps. Future releases will introduce hashing of "
                        "artifacts which will improve this behavior."
                    )
            else:
                parameters[key] = value

        # Above we iterated over the provided arguments which should overwrite
        # any parameters previously defined on the step instance. Now we apply
        # the default values on the entrypoint function and add those as
        # parameters for any argument that has no value yet. If we were to do
        # that in the above loop, we would overwrite previously configured
        # parameters with the default values.
        bound_args.apply_defaults()
        for key, value in bound_args.arguments.items():
            self.entrypoint_definition.validate_input(key=key, value=value)
            if (
                key not in artifacts
                and key not in external_artifacts
                and key not in self.configuration.parameters
            ):
                parameters[key] = value

        return artifacts, external_artifacts, parameters

    def __call__(
        self,
        *args: Any,
        id: Optional[str] = None,
        after: Union[str, Sequence[str], None] = None,
        **kwargs: Any,
    ) -> Any:
        """Handle a call of the step.

        This method does one of two things:
        * If there is an active pipeline context, it adds an invocation of the
          step instance to the pipeline.
        * If no pipeline is active, it calls the step entrypoint function.

        Args:
            *args: Entrypoint function arguments.
            id: Invocation ID to use.
            after: Upstream steps for the invocation.
            **kwargs: Entrypoint function keyword arguments.

        Returns:
            The outputs of the entrypoint function call.
        """
        from zenml.new.pipelines.pipeline import Pipeline

        if not Pipeline.ACTIVE_PIPELINE:
            # The step is being called outside of the context of a pipeline,
            # we simply call the entrypoint
            return self.call_entrypoint(*args, **kwargs)

        (
            input_artifacts,
            external_artifacts,
            parameters,
        ) = self._parse_call_args(*args, **kwargs)

        upstream_steps = {
            artifact.invocation_id for artifact in input_artifacts.values()
        }
        if isinstance(after, str):
            upstream_steps.add(after)
        elif isinstance(after, Sequence):
            upstream_steps = upstream_steps.union(after)

        invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
            step=self,
            input_artifacts=input_artifacts,
            external_artifacts=external_artifacts,
            parameters=parameters,
            upstream_steps=upstream_steps,
            custom_id=id,
            allow_id_suffix=not id,
        )

        outputs = []
        for key, annotation in self.entrypoint_definition.outputs.items():
            output = StepArtifact(
                invocation_id=invocation_id,
                output_name=key,
                annotation=annotation,
                pipeline=Pipeline.ACTIVE_PIPELINE,
            )
            outputs.append(output)

        if len(outputs) == 1:
            return outputs[0]
        else:
            return outputs

    def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
        """Calls the entrypoint function of the step.

        Args:
            *args: Entrypoint function arguments.
            **kwargs: Entrypoint function keyword arguments.

        Returns:
            The return value of the entrypoint function.

        Raises:
            StepInterfaceError: If the arguments to the entrypoint function are
                invalid.
        """
        try:
            validated_args = pydantic_utils.validate_function_args(
                self.entrypoint,
                {"arbitrary_types_allowed": True, "smart_union": True},
                *args,
                **kwargs,
            )
        except ValidationError as e:
            raise StepInterfaceError("Invalid entrypoint arguments.") from e

        return self.entrypoint(**validated_args)

    @property
    def name(self) -> str:
        """The name of the step.

        Returns:
            The name of the step.
        """
        return self.configuration.name

    @property
    def enable_cache(self) -> Optional[bool]:
        """If caching is enabled for the step.

        Returns:
            If caching is enabled for the step.
        """
        return self.configuration.enable_cache

    @property
    def configuration(self) -> PartialStepConfiguration:
        """The configuration of the step.

        Returns:
            The configuration of the step.
        """
        return self._configuration

    def configure(
        self: T,
        name: Optional[str] = None,
        enable_cache: Optional[bool] = None,
        enable_artifact_metadata: Optional[bool] = None,
        enable_artifact_visualization: Optional[bool] = None,
        enable_step_logs: Optional[bool] = None,
        experiment_tracker: Optional[str] = None,
        step_operator: Optional[str] = None,
        parameters: Optional["ParametersOrDict"] = None,
        output_materializers: Optional[
            "OutputMaterializersSpecification"
        ] = None,
        settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
        extra: Optional[Dict[str, Any]] = None,
        on_failure: Optional["HookSpecification"] = None,
        on_success: Optional["HookSpecification"] = None,
        merge: bool = True,
    ) -> T:
        """Configures the step.

        Configuration merging example:
        * `merge==True`:
            step.configure(extra={"key1": 1})
            step.configure(extra={"key2": 2}, merge=True)
            step.configuration.extra # {"key1": 1, "key2": 2}
        * `merge==False`:
            step.configure(extra={"key1": 1})
            step.configure(extra={"key2": 2}, merge=False)
            step.configuration.extra # {"key2": 2}

        Args:
            name: DEPRECATED: The name of the step.
            enable_cache: If caching should be enabled for this step.
            enable_artifact_metadata: If artifact metadata should be enabled
                for this step.
            enable_artifact_visualization: If artifact visualization should be
                enabled for this step.
            enable_step_logs: If step logs should be enabled for this step.
            experiment_tracker: The experiment tracker to use for this step.
            step_operator: The step operator to use for this step.
            parameters: Function parameters for this step
            output_materializers: Output materializers for this step. If
                given as a dict, the keys must be a subset of the output names
                of this step. If a single value (type or string) is given, the
                materializer will be used for all outputs.
            settings: settings for this step.
            extra: Extra configurations for this step.
            on_failure: Callback function in event of failure of the step. Can
                be a function with a single argument of type `BaseException`, or
                a source path to such a function (e.g. `module.my_function`).
            on_success: Callback function in event of success of the step. Can
                be a function with no arguments, or a source path to such a
                function (e.g. `module.my_function`).
            merge: If `True`, will merge the given dictionary configurations
                like `parameters` and `settings` with existing
                configurations. If `False` the given configurations will
                overwrite all existing ones. See the general description of this
                method for an example.

        Returns:
            The step instance that this method was called on.
        """
        from zenml.hooks.hook_validators import resolve_and_validate_hook

        if name:
            logger.warning("Configuring the name of a step is deprecated.")

        def _resolve_if_necessary(
            value: Union[str, Source, Type[Any]]
        ) -> Source:
            if isinstance(value, str):
                return Source.from_import_path(value)
            elif isinstance(value, Source):
                return value
            else:
                return source_utils.resolve(value)

        def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
            if isinstance(value, str) or not isinstance(value, Sequence):
                return (_resolve_if_necessary(value),)
            else:
                return tuple(_resolve_if_necessary(v) for v in value)

        outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
        allowed_output_names = set(self.entrypoint_definition.outputs)

        if output_materializers:
            if not isinstance(output_materializers, Mapping):
                sources = _convert_to_tuple(output_materializers)
                output_materializers = {
                    output_name: sources
                    for output_name in allowed_output_names
                }

            for output_name, materializer in output_materializers.items():
                sources = _convert_to_tuple(materializer)
                outputs[output_name]["materializer_source"] = sources

        failure_hook_source = None
        if on_failure:
            # string of on_failure hook function to be used for this step
            failure_hook_source = resolve_and_validate_hook(on_failure)

        success_hook_source = None
        if on_success:
            # string of on_success hook function to be used for this step
            success_hook_source = resolve_and_validate_hook(on_success)

        if isinstance(parameters, BaseParameters):
            parameters = parameters.dict()

        values = dict_utils.remove_none_values(
            {
                "enable_cache": enable_cache,
                "enable_artifact_metadata": enable_artifact_metadata,
                "enable_artifact_visualization": enable_artifact_visualization,
                "enable_step_logs": enable_step_logs,
                "experiment_tracker": experiment_tracker,
                "step_operator": step_operator,
                "parameters": parameters,
                "settings": settings,
                "outputs": outputs or None,
                "extra": extra,
                "failure_hook_source": failure_hook_source,
                "success_hook_source": success_hook_source,
            }
        )
        config = StepConfigurationUpdate(**values)
        self._apply_configuration(config, merge=merge)
        return self

    def with_options(
        self,
        enable_cache: Optional[bool] = None,
        enable_artifact_metadata: Optional[bool] = None,
        enable_artifact_visualization: Optional[bool] = None,
        enable_step_logs: Optional[bool] = None,
        experiment_tracker: Optional[str] = None,
        step_operator: Optional[str] = None,
        parameters: Optional["ParametersOrDict"] = None,
        output_materializers: Optional[
            "OutputMaterializersSpecification"
        ] = None,
        settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
        extra: Optional[Dict[str, Any]] = None,
        on_failure: Optional["HookSpecification"] = None,
        on_success: Optional["HookSpecification"] = None,
        merge: bool = True,
    ) -> "BaseStep":
        """Copies the step and applies the given configurations.

        Args:
            enable_cache: If caching should be enabled for this step.
            enable_artifact_metadata: If artifact metadata should be enabled
                for this step.
            enable_artifact_visualization: If artifact visualization should be
                enabled for this step.
            enable_step_logs: If step logs should be enabled for this step.
            experiment_tracker: The experiment tracker to use for this step.
            step_operator: The step operator to use for this step.
            parameters: Function parameters for this step
            output_materializers: Output materializers for this step. If
                given as a dict, the keys must be a subset of the output names
                of this step. If a single value (type or string) is given, the
                materializer will be used for all outputs.
            settings: settings for this step.
            extra: Extra configurations for this step.
            on_failure: Callback function in event of failure of the step. Can
                be a function with a single argument of type `BaseException`, or
                a source path to such a function (e.g. `module.my_function`).
            on_success: Callback function in event of success of the step. Can
                be a function with no arguments, or a source path to such a
                function (e.g. `module.my_function`).
            merge: If `True`, will merge the given dictionary configurations
                like `parameters` and `settings` with existing
                configurations. If `False` the given configurations will
                overwrite all existing ones. See the general description of this
                method for an example.

        Returns:
            The copied step instance.
        """
        step_copy = self.copy()
        step_copy.configure(
            enable_cache=enable_cache,
            enable_artifact_metadata=enable_artifact_metadata,
            enable_artifact_visualization=enable_artifact_visualization,
            enable_step_logs=enable_step_logs,
            experiment_tracker=experiment_tracker,
            step_operator=step_operator,
            parameters=parameters,
            output_materializers=output_materializers,
            settings=settings,
            extra=extra,
            on_failure=on_failure,
            on_success=on_success,
            merge=merge,
        )
        return step_copy

    def copy(self) -> "BaseStep":
        """Copies the step.

        Returns:
            The step copy.
        """
        return copy.deepcopy(self)

    def _apply_configuration(
        self,
        config: StepConfigurationUpdate,
        merge: bool = True,
    ) -> None:
        """Applies an update to the step configuration.

        Args:
            config: The configuration update.
            merge: Whether to merge the updates with the existing configuration
                or not. See the `BaseStep.configure(...)` method for a detailed
                explanation.
        """
        self._validate_configuration(config)

        self._configuration = pydantic_utils.update_model(
            self._configuration, update=config, recursive=merge
        )

        logger.debug("Updated step configuration:")
        logger.debug(self._configuration)

    def _validate_configuration(self, config: StepConfigurationUpdate) -> None:
        """Validates a configuration update.

        Args:
            config: The configuration update to validate.
        """
        settings_utils.validate_setting_keys(list(config.settings))
        self._validate_function_parameters(parameters=config.parameters)
        self._validate_outputs(outputs=config.outputs)

    def _validate_function_parameters(
        self, parameters: Dict[str, Any]
    ) -> None:
        """Validates step function parameters.

        Args:
            parameters: The parameters to validate.

        Raises:
            StepInterfaceError: If the step requires no function parameters but
                parameters were configured.
        """
        if not parameters:
            return

        for key, value in parameters.items():
            if key in self.entrypoint_definition.inputs:
                self.entrypoint_definition.validate_input(key=key, value=value)

            elif not self.entrypoint_definition.legacy_params:
                raise StepInterfaceError(
                    "Can't set parameter without param class."
                )

    def _validate_outputs(
        self, outputs: Mapping[str, PartialArtifactConfiguration]
    ) -> None:
        """Validates the step output configuration.

        Args:
            outputs: The configured step outputs.

        Raises:
            StepInterfaceError: If an output for a non-existent name is
                configured of an output artifact/materializer source does not
                resolve to the correct class.
        """
        allowed_output_names = set(self.entrypoint_definition.outputs)
        for output_name, output in outputs.items():
            if output_name not in allowed_output_names:
                raise StepInterfaceError(
                    f"Got unexpected materializers for non-existent "
                    f"output '{output_name}' in step '{self.name}'. "
                    f"Only materializers for the outputs "
                    f"{allowed_output_names} of this step can"
                    f" be registered."
                )

            if output.materializer_source:
                for source in output.materializer_source:
                    if not source_utils.validate_source_class(
                        source, expected_class=BaseMaterializer
                    ):
                        raise StepInterfaceError(
                            f"Materializer source `{source}` "
                            f"for output '{output_name}' of step '{self.name}' "
                            "does not resolve to a `BaseMaterializer` subclass."
                        )

    def _validate_inputs(
        self,
        input_artifacts: Dict[str, "StepArtifact"],
        external_artifacts: Dict[str, UUID],
    ) -> None:
        """Validates the step inputs.

        This method makes sure that all inputs are provided either as an
        artifact or parameter.

        Args:
            input_artifacts: The input artifacts.
            external_artifacts: The external input artifacts.

        Raises:
            StepInterfaceError: If an entrypoint input is missing.
        """
        for key in self.entrypoint_definition.inputs.keys():
            if (
                key in input_artifacts
                or key in self.configuration.parameters
                or key in external_artifacts
            ):
                continue
            raise StepInterfaceError(f"Missing entrypoint input {key}.")

    def _finalize_configuration(
        self,
        input_artifacts: Dict[str, "StepArtifact"],
        external_artifacts: Dict[str, UUID],
    ) -> StepConfiguration:
        """Finalizes the configuration after the step was called.

        Once the step was called, we know the outputs of previous steps
        and that no additional user configurations will be made. That means
        we can now collect the remaining artifact and materializer types
        as well as check for the completeness of the step function parameters.

        Args:
            input_artifacts: The input artifacts of this step.
            external_artifacts: The external artifacts of this step.

        Returns:
            The finalized step configuration.
        """
        outputs: Dict[
            str, Dict[str, Union[Source, Tuple[Source, ...]]]
        ] = defaultdict(dict)

        for (
            output_name,
            output_annotation,
        ) in self.entrypoint_definition.outputs.items():
            output = self._configuration.outputs.get(
                output_name, PartialArtifactConfiguration()
            )

            from pydantic.typing import (
                get_origin,
                is_none_type,
                is_union,
            )

            from zenml.steps.utils import get_args

            if not output.materializer_source:
                if output_annotation is Any:
                    outputs[output_name]["materializer_source"] = ()
                    outputs[output_name][
                        "default_materializer_source"
                    ] = source_utils.resolve(
                        materializer_registry.get_default_materializer()
                    )
                    continue

                if is_union(
                    get_origin(output_annotation) or output_annotation
                ):
                    output_types = tuple(
                        type(None)
                        if is_none_type(output_type)
                        else output_type
                        for output_type in get_args(output_annotation)
                    )
                else:
                    output_types = (output_annotation,)

                materializer_sources = []

                for output_type in output_types:
                    materializer_class = materializer_registry[output_type]
                    materializer_sources.append(
                        source_utils.resolve(materializer_class)
                    )

                outputs[output_name]["materializer_source"] = tuple(
                    materializer_sources
                )

        parameters = self._finalize_parameters()
        self.configure(parameters=parameters, merge=False)
        self._validate_inputs(
            input_artifacts=input_artifacts,
            external_artifacts=external_artifacts,
        )

        values = dict_utils.remove_none_values({"outputs": outputs or None})
        config = StepConfigurationUpdate(**values)
        self._apply_configuration(config)

        self._configuration = self._configuration.copy(
            update={
                "caching_parameters": self.caching_parameters,
                "external_input_artifacts": external_artifacts,
            }
        )

        complete_configuration = StepConfiguration.parse_obj(
            self._configuration
        )
        return complete_configuration

    def _finalize_parameters(self) -> Dict[str, Any]:
        """Finalizes the config parameters for running this step.

        Returns:
            All parameter values for running this step.
        """
        params = {}
        for key, value in self.configuration.parameters.items():
            if key not in self.entrypoint_definition.inputs:
                continue

            annotation = self.entrypoint_definition.inputs[key].annotation
            annotation = resolve_type_annotation(annotation)
            if inspect.isclass(annotation) and issubclass(
                annotation, BaseModel
            ):
                # Make sure we have all necessary values to instantiate the
                # pydantic model later
                model = annotation(**value)
                params[key] = model.dict()
            else:
                params[key] = value

        if self.entrypoint_definition.legacy_params:
            legacy_params = self._finalize_legacy_parameters()
            params[
                self.entrypoint_definition.legacy_params.name
            ] = legacy_params

        return params

    def _finalize_legacy_parameters(self) -> Dict[str, Any]:
        """Verifies and prepares the config parameters for running this step.

        When the step requires config parameters, this method:
            - checks if config parameters were set via a config object or file
            - tries to set missing config parameters from default values of the
              config class

        Returns:
            Values for the previously unconfigured function parameters.

        Raises:
            MissingStepParameterError: If no value could be found for one or
                more config parameters.
            StepInterfaceError: If the parameter class validation failed.
        """
        if not self.entrypoint_definition.legacy_params:
            return {}

        logger.warning(
            "The `BaseParameters` class to define step parameters is "
            "deprecated. Check out our docs "
            "https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines "
            "for information on how to parameterize your steps. As a quick "
            "fix to get rid of this warning, make sure your parameter class "
            "inherits from `pydantic.BaseModel` instead of the "
            "`BaseParameters` class."
        )

        # parameters for the `BaseParameters` class specified in the "new" way
        # by specifying a dict of parameters for the corresponding key
        params_defined_in_new_way = (
            self.configuration.parameters.get(
                self.entrypoint_definition.legacy_params.name
            )
            or {}
        )

        values = {}
        missing_keys = []
        for (
            name,
            field,
        ) in (
            self.entrypoint_definition.legacy_params.annotation.__fields__.items()
        ):
            if name in self.configuration.parameters:
                # a value for this parameter has been set already
                values[name] = self.configuration.parameters[name]
            elif name in params_defined_in_new_way:
                # a value for this parameter has been set in the "new" way
                # already
                values[name] = params_defined_in_new_way[name]
            elif field.required:
                # this field has no default value set and therefore needs
                # to be passed via an initialized config object
                missing_keys.append(name)
            else:
                # use default value from the pydantic config class
                values[name] = field.default

        if missing_keys:
            raise MissingStepParameterError(
                self.name,
                missing_keys,
                self.entrypoint_definition.legacy_params.annotation,
            )

        if (
            self.entrypoint_definition.legacy_params.annotation.__config__.extra
            == Extra.allow
        ):
            # Add all parameters for the config class for backwards
            # compatibility if the config class allows extra attributes
            values.update(self.configuration.parameters)

        try:
            self.entrypoint_definition.legacy_params.annotation(**values)
        except ValidationError:
            raise StepInterfaceError("Failed to validate function parameters.")

        return values
caching_parameters: Dict[str, Any] property readonly

Caching parameters for this step.

Returns:

Type Description
Dict[str, Any]

A dictionary containing the caching parameters

configuration: PartialStepConfiguration property readonly

The configuration of the step.

Returns:

Type Description
PartialStepConfiguration

The configuration of the step.

docstring: Optional[str] property readonly

The docstring of this step.

Returns:

Type Description
Optional[str]

The docstring of this step.

enable_cache: Optional[bool] property readonly

If caching is enabled for the step.

Returns:

Type Description
Optional[bool]

If caching is enabled for the step.

name: str property readonly

The name of the step.

Returns:

Type Description
str

The name of the step.

source_code: str property readonly

The source code of this step.

Returns:

Type Description
str

The source code of this step.

source_object: Any property readonly

The source object of this step.

Returns:

Type Description
Any

The source object of this step.

upstream_steps: Set[BaseStep] property readonly

Names of the upstream steps of this step.

This property will only contain the full set of upstream steps once it's parent pipeline connect(...) method was called.

Returns:

Type Description
Set[BaseStep]

Set of upstream step names.

__call__(self, *args, *, id=None, after=None, **kwargs) special

Handle a call of the step.

This method does one of two things: * If there is an active pipeline context, it adds an invocation of the step instance to the pipeline. * If no pipeline is active, it calls the step entrypoint function.

Parameters:

Name Type Description Default
*args Any

Entrypoint function arguments.

()
id Optional[str]

Invocation ID to use.

None
after Union[str, Sequence[str]]

Upstream steps for the invocation.

None
**kwargs Any

Entrypoint function keyword arguments.

{}

Returns:

Type Description
Any

The outputs of the entrypoint function call.

Source code in zenml/steps/base_step.py
def __call__(
    self,
    *args: Any,
    id: Optional[str] = None,
    after: Union[str, Sequence[str], None] = None,
    **kwargs: Any,
) -> Any:
    """Handle a call of the step.

    This method does one of two things:
    * If there is an active pipeline context, it adds an invocation of the
      step instance to the pipeline.
    * If no pipeline is active, it calls the step entrypoint function.

    Args:
        *args: Entrypoint function arguments.
        id: Invocation ID to use.
        after: Upstream steps for the invocation.
        **kwargs: Entrypoint function keyword arguments.

    Returns:
        The outputs of the entrypoint function call.
    """
    from zenml.new.pipelines.pipeline import Pipeline

    if not Pipeline.ACTIVE_PIPELINE:
        # The step is being called outside of the context of a pipeline,
        # we simply call the entrypoint
        return self.call_entrypoint(*args, **kwargs)

    (
        input_artifacts,
        external_artifacts,
        parameters,
    ) = self._parse_call_args(*args, **kwargs)

    upstream_steps = {
        artifact.invocation_id for artifact in input_artifacts.values()
    }
    if isinstance(after, str):
        upstream_steps.add(after)
    elif isinstance(after, Sequence):
        upstream_steps = upstream_steps.union(after)

    invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation(
        step=self,
        input_artifacts=input_artifacts,
        external_artifacts=external_artifacts,
        parameters=parameters,
        upstream_steps=upstream_steps,
        custom_id=id,
        allow_id_suffix=not id,
    )

    outputs = []
    for key, annotation in self.entrypoint_definition.outputs.items():
        output = StepArtifact(
            invocation_id=invocation_id,
            output_name=key,
            annotation=annotation,
            pipeline=Pipeline.ACTIVE_PIPELINE,
        )
        outputs.append(output)

    if len(outputs) == 1:
        return outputs[0]
    else:
        return outputs
__init__(self, *args, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, **kwargs) special

Initializes a step.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to the step.

()
name Optional[str]

The name of the step.

None
enable_cache Optional[bool]

If caching should be enabled for this step.

None
enable_artifact_metadata Optional[bool]

If artifact metadata should be enabled for this step.

None
enable_artifact_visualization Optional[bool]

If artifact visualization should be enabled for this step.

None
enable_step_logs Optional[bool]

Enable step logs for this step.

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
parameters Optional[ParametersOrDict]

Function parameters for this step

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
settings Optional[Mapping[str, SettingsOrDict]]

settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None
on_failure Optional[HookSpecification]

Callback function in event of failure of the step. Can be a function with a single argument of type BaseException, or a source path to such a function (e.g. module.my_function).

None
on_success Optional[HookSpecification]

Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. module.my_function).

None
**kwargs Any

Keyword arguments passed to the step.

{}
Source code in zenml/steps/base_step.py
def __init__(
    self,
    *args: Any,
    name: Optional[str] = None,
    enable_cache: Optional[bool] = None,
    enable_artifact_metadata: Optional[bool] = None,
    enable_artifact_visualization: Optional[bool] = None,
    enable_step_logs: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    parameters: Optional["ParametersOrDict"] = None,
    output_materializers: Optional[
        "OutputMaterializersSpecification"
    ] = None,
    settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
    on_failure: Optional["HookSpecification"] = None,
    on_success: Optional["HookSpecification"] = None,
    **kwargs: Any,
) -> None:
    """Initializes a step.

    Args:
        *args: Positional arguments passed to the step.
        name: The name of the step.
        enable_cache: If caching should be enabled for this step.
        enable_artifact_metadata: If artifact metadata should be enabled
            for this step.
        enable_artifact_visualization: If artifact visualization should be
            enabled for this step.
        enable_step_logs: Enable step logs for this step.
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        parameters: Function parameters for this step
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        settings: settings for this step.
        extra: Extra configurations for this step.
        on_failure: Callback function in event of failure of the step. Can
            be a function with a single argument of type `BaseException`, or
            a source path to such a function (e.g. `module.my_function`).
        on_success: Callback function in event of success of the step. Can
            be a function with no arguments, or a source path to such a
            function (e.g. `module.my_function`).
        **kwargs: Keyword arguments passed to the step.
    """
    self._upstream_steps: Set["BaseStep"] = set()
    self.entrypoint_definition = validate_entrypoint_function(
        self.entrypoint, reserved_arguments=["after", "id"]
    )

    name = name or self.__class__.__name__

    requires_context = self.entrypoint_definition.context is not None
    if enable_cache is None:
        if requires_context:
            # Using the StepContext inside a step provides access to
            # external resources which might influence the step execution.
            # We therefore disable caching unless it is explicitly enabled
            enable_cache = False
            logger.debug(
                "Step '%s': Step context required and caching not "
                "explicitly enabled.",
                name,
            )

    logger.debug(
        "Step '%s': Caching %s.",
        name,
        "enabled" if enable_cache is not False else "disabled",
    )
    logger.debug(
        "Step '%s': Artifact metadata %s.",
        name,
        "enabled" if enable_artifact_metadata is not False else "disabled",
    )
    logger.debug(
        "Step '%s': Artifact visualization %s.",
        name,
        "enabled"
        if enable_artifact_visualization is not False
        else "disabled",
    )

    logger.debug(
        "Step '%s': logs %s.",
        name,
        "enabled" if enable_step_logs is not False else "disabled",
    )

    self._configuration = PartialStepConfiguration(
        name=name,
        enable_cache=enable_cache,
        enable_artifact_metadata=enable_artifact_metadata,
        enable_artifact_visualization=enable_artifact_visualization,
        enable_step_logs=enable_step_logs,
    )
    self.configure(
        experiment_tracker=experiment_tracker,
        step_operator=step_operator,
        output_materializers=output_materializers,
        parameters=parameters,
        settings=settings,
        extra=extra,
        on_failure=on_failure,
        on_success=on_success,
    )
    self._verify_and_apply_init_params(*args, **kwargs)
after(self, step)

Adds an upstream step to this step.

Calling this method makes sure this step only starts running once the given step has successfully finished executing.

Note: This can only be called inside the pipeline connect function which is decorated with the @pipeline decorator. Any calls outside this function will be ignored.

Examples:

The following pipeline will run its steps sequentially in the following order: step_2 -> step_1 -> step_3

@pipeline
def example_pipeline(step_1, step_2, step_3):
    step_1.after(step_2)
    step_3(step_1(), step_2())

Parameters:

Name Type Description Default
step BaseStep

A step which should finish executing before this step is started.

required
Source code in zenml/steps/base_step.py
def after(self, step: "BaseStep") -> None:
    """Adds an upstream step to this step.

    Calling this method makes sure this step only starts running once the
    given step has successfully finished executing.

    **Note**: This can only be called inside the pipeline connect function
    which is decorated with the `@pipeline` decorator. Any calls outside
    this function will be ignored.

    Example:
    The following pipeline will run its steps sequentially in the following
    order: step_2 -> step_1 -> step_3

    ```python
    @pipeline
    def example_pipeline(step_1, step_2, step_3):
        step_1.after(step_2)
        step_3(step_1(), step_2())
    ```

    Args:
        step: A step which should finish executing before this step is
            started.
    """
    self._upstream_steps.add(step)
call_entrypoint(self, *args, **kwargs)

Calls the entrypoint function of the step.

Parameters:

Name Type Description Default
*args Any

Entrypoint function arguments.

()
**kwargs Any

Entrypoint function keyword arguments.

{}

Returns:

Type Description
Any

The return value of the entrypoint function.

Exceptions:

Type Description
StepInterfaceError

If the arguments to the entrypoint function are invalid.

Source code in zenml/steps/base_step.py
def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any:
    """Calls the entrypoint function of the step.

    Args:
        *args: Entrypoint function arguments.
        **kwargs: Entrypoint function keyword arguments.

    Returns:
        The return value of the entrypoint function.

    Raises:
        StepInterfaceError: If the arguments to the entrypoint function are
            invalid.
    """
    try:
        validated_args = pydantic_utils.validate_function_args(
            self.entrypoint,
            {"arbitrary_types_allowed": True, "smart_union": True},
            *args,
            **kwargs,
        )
    except ValidationError as e:
        raise StepInterfaceError("Invalid entrypoint arguments.") from e

    return self.entrypoint(**validated_args)
configure(self, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, merge=True)

Configures the step.

Configuration merging example: * merge==True: step.configure(extra={"key1": 1}) step.configure(extra={"key2": 2}, merge=True) step.configuration.extra # {"key1": 1, "key2": 2} * merge==False: step.configure(extra={"key1": 1}) step.configure(extra={"key2": 2}, merge=False) step.configuration.extra # {"key2": 2}

Parameters:

Name Type Description Default
name Optional[str]

DEPRECATED: The name of the step.

None
enable_cache Optional[bool]

If caching should be enabled for this step.

None
enable_artifact_metadata Optional[bool]

If artifact metadata should be enabled for this step.

None
enable_artifact_visualization Optional[bool]

If artifact visualization should be enabled for this step.

None
enable_step_logs Optional[bool]

If step logs should be enabled for this step.

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
parameters Optional[ParametersOrDict]

Function parameters for this step

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
settings Optional[Mapping[str, SettingsOrDict]]

settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None
on_failure Optional[HookSpecification]

Callback function in event of failure of the step. Can be a function with a single argument of type BaseException, or a source path to such a function (e.g. module.my_function).

None
on_success Optional[HookSpecification]

Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. module.my_function).

None
merge bool

If True, will merge the given dictionary configurations like parameters and settings with existing configurations. If False the given configurations will overwrite all existing ones. See the general description of this method for an example.

True

Returns:

Type Description
~T

The step instance that this method was called on.

Source code in zenml/steps/base_step.py
def configure(
    self: T,
    name: Optional[str] = None,
    enable_cache: Optional[bool] = None,
    enable_artifact_metadata: Optional[bool] = None,
    enable_artifact_visualization: Optional[bool] = None,
    enable_step_logs: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    parameters: Optional["ParametersOrDict"] = None,
    output_materializers: Optional[
        "OutputMaterializersSpecification"
    ] = None,
    settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
    on_failure: Optional["HookSpecification"] = None,
    on_success: Optional["HookSpecification"] = None,
    merge: bool = True,
) -> T:
    """Configures the step.

    Configuration merging example:
    * `merge==True`:
        step.configure(extra={"key1": 1})
        step.configure(extra={"key2": 2}, merge=True)
        step.configuration.extra # {"key1": 1, "key2": 2}
    * `merge==False`:
        step.configure(extra={"key1": 1})
        step.configure(extra={"key2": 2}, merge=False)
        step.configuration.extra # {"key2": 2}

    Args:
        name: DEPRECATED: The name of the step.
        enable_cache: If caching should be enabled for this step.
        enable_artifact_metadata: If artifact metadata should be enabled
            for this step.
        enable_artifact_visualization: If artifact visualization should be
            enabled for this step.
        enable_step_logs: If step logs should be enabled for this step.
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        parameters: Function parameters for this step
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        settings: settings for this step.
        extra: Extra configurations for this step.
        on_failure: Callback function in event of failure of the step. Can
            be a function with a single argument of type `BaseException`, or
            a source path to such a function (e.g. `module.my_function`).
        on_success: Callback function in event of success of the step. Can
            be a function with no arguments, or a source path to such a
            function (e.g. `module.my_function`).
        merge: If `True`, will merge the given dictionary configurations
            like `parameters` and `settings` with existing
            configurations. If `False` the given configurations will
            overwrite all existing ones. See the general description of this
            method for an example.

    Returns:
        The step instance that this method was called on.
    """
    from zenml.hooks.hook_validators import resolve_and_validate_hook

    if name:
        logger.warning("Configuring the name of a step is deprecated.")

    def _resolve_if_necessary(
        value: Union[str, Source, Type[Any]]
    ) -> Source:
        if isinstance(value, str):
            return Source.from_import_path(value)
        elif isinstance(value, Source):
            return value
        else:
            return source_utils.resolve(value)

    def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
        if isinstance(value, str) or not isinstance(value, Sequence):
            return (_resolve_if_necessary(value),)
        else:
            return tuple(_resolve_if_necessary(v) for v in value)

    outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict)
    allowed_output_names = set(self.entrypoint_definition.outputs)

    if output_materializers:
        if not isinstance(output_materializers, Mapping):
            sources = _convert_to_tuple(output_materializers)
            output_materializers = {
                output_name: sources
                for output_name in allowed_output_names
            }

        for output_name, materializer in output_materializers.items():
            sources = _convert_to_tuple(materializer)
            outputs[output_name]["materializer_source"] = sources

    failure_hook_source = None
    if on_failure:
        # string of on_failure hook function to be used for this step
        failure_hook_source = resolve_and_validate_hook(on_failure)

    success_hook_source = None
    if on_success:
        # string of on_success hook function to be used for this step
        success_hook_source = resolve_and_validate_hook(on_success)

    if isinstance(parameters, BaseParameters):
        parameters = parameters.dict()

    values = dict_utils.remove_none_values(
        {
            "enable_cache": enable_cache,
            "enable_artifact_metadata": enable_artifact_metadata,
            "enable_artifact_visualization": enable_artifact_visualization,
            "enable_step_logs": enable_step_logs,
            "experiment_tracker": experiment_tracker,
            "step_operator": step_operator,
            "parameters": parameters,
            "settings": settings,
            "outputs": outputs or None,
            "extra": extra,
            "failure_hook_source": failure_hook_source,
            "success_hook_source": success_hook_source,
        }
    )
    config = StepConfigurationUpdate(**values)
    self._apply_configuration(config, merge=merge)
    return self
copy(self)

Copies the step.

Returns:

Type Description
BaseStep

The step copy.

Source code in zenml/steps/base_step.py
def copy(self) -> "BaseStep":
    """Copies the step.

    Returns:
        The step copy.
    """
    return copy.deepcopy(self)
entrypoint(self, *args, **kwargs)

Abstract method for core step logic.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to the step.

()
**kwargs Any

Keyword arguments passed to the step.

{}

Returns:

Type Description
Any

The output of the step.

Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
    """Abstract method for core step logic.

    Args:
        *args: Positional arguments passed to the step.
        **kwargs: Keyword arguments passed to the step.

    Returns:
        The output of the step.
    """
load_from_source(source) classmethod

Loads a step from source.

Parameters:

Name Type Description Default
source Union[zenml.config.source.Source, str]

The path to the step source.

required

Returns:

Type Description
BaseStep

The loaded step.

Exceptions:

Type Description
ValueError

If the source is not a valid step source.

Source code in zenml/steps/base_step.py
@classmethod
def load_from_source(cls, source: Union[Source, str]) -> "BaseStep":
    """Loads a step from source.

    Args:
        source: The path to the step source.

    Returns:
        The loaded step.

    Raises:
        ValueError: If the source is not a valid step source.
    """
    obj = source_utils.load(source)

    if isinstance(obj, BaseStep):
        return obj
    elif isinstance(obj, type) and issubclass(obj, BaseStep):
        return obj()
    else:
        raise ValueError("Invalid step source.")
resolve(self)

Resolves the step.

Returns:

Type Description
Source

The step source.

Source code in zenml/steps/base_step.py
def resolve(self) -> Source:
    """Resolves the step.

    Returns:
        The step source.
    """
    return source_utils.resolve(self.__class__)
with_options(self, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, parameters=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None, merge=True)

Copies the step and applies the given configurations.

Parameters:

Name Type Description Default
enable_cache Optional[bool]

If caching should be enabled for this step.

None
enable_artifact_metadata Optional[bool]

If artifact metadata should be enabled for this step.

None
enable_artifact_visualization Optional[bool]

If artifact visualization should be enabled for this step.

None
enable_step_logs Optional[bool]

If step logs should be enabled for this step.

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
parameters Optional[ParametersOrDict]

Function parameters for this step

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
settings Optional[Mapping[str, SettingsOrDict]]

settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None
on_failure Optional[HookSpecification]

Callback function in event of failure of the step. Can be a function with a single argument of type BaseException, or a source path to such a function (e.g. module.my_function).

None
on_success Optional[HookSpecification]

Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. module.my_function).

None
merge bool

If True, will merge the given dictionary configurations like parameters and settings with existing configurations. If False the given configurations will overwrite all existing ones. See the general description of this method for an example.

True

Returns:

Type Description
BaseStep

The copied step instance.

Source code in zenml/steps/base_step.py
def with_options(
    self,
    enable_cache: Optional[bool] = None,
    enable_artifact_metadata: Optional[bool] = None,
    enable_artifact_visualization: Optional[bool] = None,
    enable_step_logs: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    parameters: Optional["ParametersOrDict"] = None,
    output_materializers: Optional[
        "OutputMaterializersSpecification"
    ] = None,
    settings: Optional[Mapping[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
    on_failure: Optional["HookSpecification"] = None,
    on_success: Optional["HookSpecification"] = None,
    merge: bool = True,
) -> "BaseStep":
    """Copies the step and applies the given configurations.

    Args:
        enable_cache: If caching should be enabled for this step.
        enable_artifact_metadata: If artifact metadata should be enabled
            for this step.
        enable_artifact_visualization: If artifact visualization should be
            enabled for this step.
        enable_step_logs: If step logs should be enabled for this step.
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        parameters: Function parameters for this step
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        settings: settings for this step.
        extra: Extra configurations for this step.
        on_failure: Callback function in event of failure of the step. Can
            be a function with a single argument of type `BaseException`, or
            a source path to such a function (e.g. `module.my_function`).
        on_success: Callback function in event of success of the step. Can
            be a function with no arguments, or a source path to such a
            function (e.g. `module.my_function`).
        merge: If `True`, will merge the given dictionary configurations
            like `parameters` and `settings` with existing
            configurations. If `False` the given configurations will
            overwrite all existing ones. See the general description of this
            method for an example.

    Returns:
        The copied step instance.
    """
    step_copy = self.copy()
    step_copy.configure(
        enable_cache=enable_cache,
        enable_artifact_metadata=enable_artifact_metadata,
        enable_artifact_visualization=enable_artifact_visualization,
        enable_step_logs=enable_step_logs,
        experiment_tracker=experiment_tracker,
        step_operator=step_operator,
        parameters=parameters,
        output_materializers=output_materializers,
        settings=settings,
        extra=extra,
        on_failure=on_failure,
        on_success=on_success,
        merge=merge,
    )
    return step_copy

BaseStepMeta (type)

Metaclass for BaseStep.

Makes sure that the entrypoint function has valid parameters and type annotations.

Source code in zenml/steps/base_step.py
class BaseStepMeta(type):
    """Metaclass for `BaseStep`.

    Makes sure that the entrypoint function has valid parameters and type
    annotations.
    """

    def __new__(
        mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
    ) -> "BaseStepMeta":
        """Set up a new class with a qualified spec.

        Args:
            name: The name of the class.
            bases: The base classes of the class.
            dct: The attributes of the class.

        Returns:
            The new class.
        """
        cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
        if name not in {"BaseStep", "_DecoratedStep"}:
            validate_entrypoint_function(cls.entrypoint)

        return cls
__new__(mcs, name, bases, dct) special staticmethod

Set up a new class with a qualified spec.

Parameters:

Name Type Description Default
name str

The name of the class.

required
bases Tuple[Type[Any], ...]

The base classes of the class.

required
dct Dict[str, Any]

The attributes of the class.

required

Returns:

Type Description
BaseStepMeta

The new class.

Source code in zenml/steps/base_step.py
def __new__(
    mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
    """Set up a new class with a qualified spec.

    Args:
        name: The name of the class.
        bases: The base classes of the class.
        dct: The attributes of the class.

    Returns:
        The new class.
    """
    cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
    if name not in {"BaseStep", "_DecoratedStep"}:
        validate_entrypoint_function(cls.entrypoint)

    return cls

entrypoint_function_utils

Util functions for step and pipeline entrypoint functions.

EntrypointFunctionDefinition (tuple)

Class representing a step entrypoint function.

Attributes:

Name Type Description
inputs Dict[str, inspect.Parameter]

The entrypoint function inputs.

outputs Dict[str, Any]

The entrypoint function outputs. This dictionary maps output names to output annotations.

context Optional[inspect.Parameter]

Optional parameter representing the StepContext input.

legacy_params Optional[inspect.Parameter]

Optional parameter representing the BaseParameters input.

Source code in zenml/steps/entrypoint_function_utils.py
class EntrypointFunctionDefinition(NamedTuple):
    """Class representing a step entrypoint function.

    Attributes:
        inputs: The entrypoint function inputs.
        outputs: The entrypoint function outputs. This dictionary maps output
            names to output annotations.
        context: Optional parameter representing the `StepContext` input.
        legacy_params: Optional parameter representing the `BaseParameters`
            input.
    """

    inputs: Dict[str, inspect.Parameter]
    outputs: Dict[str, Any]
    context: Optional[inspect.Parameter]
    legacy_params: Optional[inspect.Parameter]

    def validate_input(self, key: str, value: Any) -> None:
        """Validates an input to the step entrypoint function.

        Args:
            key: The key for which the input was passed
            value: The input value.

        Raises:
            KeyError: If the function has no input for the given key.
            RuntimeError: If a parameter is passed for an input that is
                annotated as an `UnmaterializedArtifact`.
            StepInterfaceError: If the input is a parameter and not JSON
                serializable.
        """
        from zenml.materializers import UnmaterializedArtifact

        if key not in self.inputs:
            raise KeyError(
                f"Received step entrypoint input for invalid key {key}."
            )

        parameter = self.inputs[key]

        if isinstance(value, (StepArtifact, ExternalArtifact)):
            # If we were to do any type validation for artifacts here, we
            # would not be able to leverage pydantics type coercion (e.g.
            # providing an `int` artifact for a `float` input)
            return

        # Not an artifact -> This is a parameter
        if parameter.annotation is UnmaterializedArtifact:
            raise RuntimeError(
                "Passing parameter for input of type `UnmaterializedArtifact` "
                "is not allowed."
            )

        self._validate_input_value(parameter=parameter, value=value)

        if not yaml_utils.is_json_serializable(value):
            raise StepInterfaceError(
                f"Argument type (`{type(value)}`) for argument "
                f"'{key}' is not JSON "
                "serializable."
            )

    def _validate_input_value(
        self, parameter: inspect.Parameter, value: Any
    ) -> None:
        """Validates an input value to the step entrypoint function.

        Args:
            parameter: The function parameter for which the value was provided.
            value: The input value.

        Raises:
            RuntimeError: If the input value is not valid for the type
                annotation provided for the function parameter.
        """

        class ModelConfig(BaseConfig):
            arbitrary_types_allowed = False

        # Create a pydantic model with just a single required field with the
        # type annotation of the parameter to verify the input type including
        # pydantics type coercion
        validation_model_class = create_model(
            "input_validation_model",
            __config__=ModelConfig,
            value=(parameter.annotation, ...),
        )

        try:
            validation_model_class(value=value)
        except ValidationError as e:
            raise RuntimeError("Input validation failed.") from e
__getnewargs__(self) special

Return self as a plain tuple. Used by copy and pickle.

Source code in zenml/steps/entrypoint_function_utils.py
def __getnewargs__(self):
    'Return self as a plain tuple.  Used by copy and pickle.'
    return _tuple(self)
__new__(_cls, inputs, outputs, context, legacy_params) special staticmethod

Create new instance of EntrypointFunctionDefinition(inputs, outputs, context, legacy_params)

__repr__(self) special

Return a nicely formatted representation string

Source code in zenml/steps/entrypoint_function_utils.py
def __repr__(self):
    'Return a nicely formatted representation string'
    return self.__class__.__name__ + repr_fmt % self
validate_input(self, key, value)

Validates an input to the step entrypoint function.

Parameters:

Name Type Description Default
key str

The key for which the input was passed

required
value Any

The input value.

required

Exceptions:

Type Description
KeyError

If the function has no input for the given key.

RuntimeError

If a parameter is passed for an input that is annotated as an UnmaterializedArtifact.

StepInterfaceError

If the input is a parameter and not JSON serializable.

Source code in zenml/steps/entrypoint_function_utils.py
def validate_input(self, key: str, value: Any) -> None:
    """Validates an input to the step entrypoint function.

    Args:
        key: The key for which the input was passed
        value: The input value.

    Raises:
        KeyError: If the function has no input for the given key.
        RuntimeError: If a parameter is passed for an input that is
            annotated as an `UnmaterializedArtifact`.
        StepInterfaceError: If the input is a parameter and not JSON
            serializable.
    """
    from zenml.materializers import UnmaterializedArtifact

    if key not in self.inputs:
        raise KeyError(
            f"Received step entrypoint input for invalid key {key}."
        )

    parameter = self.inputs[key]

    if isinstance(value, (StepArtifact, ExternalArtifact)):
        # If we were to do any type validation for artifacts here, we
        # would not be able to leverage pydantics type coercion (e.g.
        # providing an `int` artifact for a `float` input)
        return

    # Not an artifact -> This is a parameter
    if parameter.annotation is UnmaterializedArtifact:
        raise RuntimeError(
            "Passing parameter for input of type `UnmaterializedArtifact` "
            "is not allowed."
        )

    self._validate_input_value(parameter=parameter, value=value)

    if not yaml_utils.is_json_serializable(value):
        raise StepInterfaceError(
            f"Argument type (`{type(value)}`) for argument "
            f"'{key}' is not JSON "
            "serializable."
        )

StepArtifact

Class to represent step output artifacts.

Source code in zenml/steps/entrypoint_function_utils.py
class StepArtifact:
    """Class to represent step output artifacts."""

    def __init__(
        self,
        invocation_id: str,
        output_name: str,
        annotation: Any,
        pipeline: "Pipeline",
    ) -> None:
        """Initialize a step artifact.

        Args:
            invocation_id: The ID of the invocation that produces this artifact.
            output_name: The name of the output that produces this artifact.
            annotation: The output type annotation.
            pipeline: The pipeline which the invocation is part of.
        """
        self.invocation_id = invocation_id
        self.output_name = output_name
        self.annotation = annotation
        self.pipeline = pipeline
__init__(self, invocation_id, output_name, annotation, pipeline) special

Initialize a step artifact.

Parameters:

Name Type Description Default
invocation_id str

The ID of the invocation that produces this artifact.

required
output_name str

The name of the output that produces this artifact.

required
annotation Any

The output type annotation.

required
pipeline Pipeline

The pipeline which the invocation is part of.

required
Source code in zenml/steps/entrypoint_function_utils.py
def __init__(
    self,
    invocation_id: str,
    output_name: str,
    annotation: Any,
    pipeline: "Pipeline",
) -> None:
    """Initialize a step artifact.

    Args:
        invocation_id: The ID of the invocation that produces this artifact.
        output_name: The name of the output that produces this artifact.
        annotation: The output type annotation.
        pipeline: The pipeline which the invocation is part of.
    """
    self.invocation_id = invocation_id
    self.output_name = output_name
    self.annotation = annotation
    self.pipeline = pipeline

get_step_entrypoint_signature(step)

Get the entrypoint signature of a step.

Parameters:

Name Type Description Default
step BaseStep

The step for which to get the entrypoint signature.

required

Returns:

Type Description
Signature

The entrypoint function signature.

Source code in zenml/steps/entrypoint_function_utils.py
def get_step_entrypoint_signature(step: "BaseStep") -> inspect.Signature:
    """Get the entrypoint signature of a step.

    Args:
        step: The step for which to get the entrypoint signature.

    Returns:
        The entrypoint function signature.
    """
    from zenml.steps import BaseParameters, StepContext

    signature = inspect.signature(step.entrypoint, follow_wrapped=True)

    def _is_param_of_class(annotation: Any, class_: Type[Any]) -> bool:
        annotation = resolve_type_annotation(annotation)
        return inspect.isclass(annotation) and issubclass(annotation, class_)

    parameters = list(signature.parameters.values())

    # Filter out deprecated args: step context and legacy parameters
    parameters = [
        param
        for param in parameters
        if not _is_param_of_class(param.annotation, class_=BaseParameters)
        and not _is_param_of_class(param.annotation, class_=StepContext)
    ]

    signature = signature.replace(parameters=parameters)
    return signature

validate_entrypoint_function(func, reserved_arguments=())

Validates a step entrypoint function.

Parameters:

Name Type Description Default
func Callable[..., Any]

The step entrypoint function to validate.

required
reserved_arguments Sequence[str]

The reserved arguments for the entrypoint function.

()

Exceptions:

Type Description
StepInterfaceError

If the entrypoint function has variable arguments or keyword arguments.

StepInterfaceError

If the entrypoint function has multiple BaseParameter arguments.

StepInterfaceError

If the entrypoint function has multiple StepContext arguments.

RuntimeError

If type annotations should be enforced and a type annotation is missing.

Returns:

Type Description
EntrypointFunctionDefinition

A validated definition of the entrypoint function.

Source code in zenml/steps/entrypoint_function_utils.py
def validate_entrypoint_function(
    func: Callable[..., Any], reserved_arguments: Sequence[str] = ()
) -> EntrypointFunctionDefinition:
    """Validates a step entrypoint function.

    Args:
        func: The step entrypoint function to validate.
        reserved_arguments: The reserved arguments for the entrypoint function.

    Raises:
        StepInterfaceError: If the entrypoint function has variable arguments
            or keyword arguments.
        StepInterfaceError: If the entrypoint function has multiple
            `BaseParameter` arguments.
        StepInterfaceError: If the entrypoint function has multiple
            `StepContext` arguments.
        RuntimeError: If type annotations should be enforced and a type
            annotation is missing.

    Returns:
        A validated definition of the entrypoint function.
    """
    from zenml.steps import BaseParameters, StepContext

    signature = inspect.signature(func, follow_wrapped=True)
    validate_reserved_arguments(
        signature=signature, reserved_arguments=reserved_arguments
    )

    inputs = {}
    context: Optional[inspect.Parameter] = None
    legacy_params: Optional[inspect.Parameter] = None

    signature_parameters = list(signature.parameters.items())
    if signature_parameters and signature_parameters[0][0] == "self":
        # TODO: Once we get rid of the old step decorator, we can also remove
        # the `BaseStepMeta` class which right now calls this function on an
        # unbound instance method when using the class-based API. If we get rid
        # of that, this check and removal of the `self` parameter is not
        # necessary anymore
        signature_parameters = signature_parameters[1:]

    for key, parameter in signature_parameters:
        if parameter.kind in {parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD}:
            raise StepInterfaceError(
                f"Variable args or kwargs not allowed for function "
                f"{func.__name__}."
            )

        annotation = parameter.annotation
        if annotation is parameter.empty:
            if ENFORCE_TYPE_ANNOTATIONS:
                raise RuntimeError(
                    f"Missing type annotation for input '{key}' of step "
                    f"function '{func.__name__}'."
                )

            # If a type annotation is missing, use `Any` instead
            parameter = parameter.replace(annotation=Any)

        annotation = resolve_type_annotation(annotation)
        if inspect.isclass(annotation) and issubclass(
            annotation, BaseParameters
        ):
            if legacy_params is not None:
                raise StepInterfaceError(
                    f"Found multiple parameter arguments "
                    f"('{legacy_params.name}' and '{key}') "
                    f"for function {func.__name__}."
                )
            legacy_params = parameter

        elif inspect.isclass(annotation) and issubclass(
            annotation, StepContext
        ):
            if context is not None:
                raise StepInterfaceError(
                    f"Found multiple context arguments "
                    f"('{context.name}' and '{key}') "
                    f"for function {func.__name__}."
                )
            context = parameter
        else:
            inputs[key] = parameter

    outputs = parse_return_type_annotations(
        func=func, enforce_type_annotations=ENFORCE_TYPE_ANNOTATIONS
    )

    return EntrypointFunctionDefinition(
        inputs=inputs,
        outputs=outputs,
        context=context,
        legacy_params=legacy_params,
    )

validate_reserved_arguments(signature, reserved_arguments)

Validates that the signature does not contain any reserved arguments.

Parameters:

Name Type Description Default
signature Signature

The signature to validate.

required
reserved_arguments Sequence[str]

The reserved arguments for the signature.

required

Exceptions:

Type Description
RuntimeError

If the signature contains a reserved argument.

Source code in zenml/steps/entrypoint_function_utils.py
def validate_reserved_arguments(
    signature: inspect.Signature, reserved_arguments: Sequence[str]
) -> None:
    """Validates that the signature does not contain any reserved arguments.

    Args:
        signature: The signature to validate.
        reserved_arguments: The reserved arguments for the signature.

    Raises:
        RuntimeError: If the signature contains a reserved argument.
    """
    for arg in reserved_arguments:
        if arg in signature.parameters:
            raise RuntimeError(f"Reserved argument name '{arg}'.")

external_artifact

External artifact definition.

ExternalArtifact

External artifacts can be used to provide values as input to ZenML steps.

ZenML steps accept either artifacts (=outputs of other steps), parameters (raw, JSON serializable values) or external artifacts. External artifacts can be used to provide any value as input to a step without needing to write an additional step that returns this value.

Examples:

from zenml import step, pipeline, ExternalArtifact
import numpy as np

@step
def my_step(value: np.ndarray) -> None:
  print(value)

my_array = np.array([1, 2, 3])

@pipeline
def my_pipeline():
  my_step(value=ExternalArtifact(my_array))
Source code in zenml/steps/external_artifact.py
class ExternalArtifact:
    """External artifacts can be used to provide values as input to ZenML steps.

    ZenML steps accept either artifacts (=outputs of other steps), parameters
    (raw, JSON serializable values) or external artifacts. External artifacts
    can be used to provide any value as input to a step without needing to
    write an additional step that returns this value.

    Example:
    ```
    from zenml import step, pipeline, ExternalArtifact
    import numpy as np

    @step
    def my_step(value: np.ndarray) -> None:
      print(value)

    my_array = np.array([1, 2, 3])

    @pipeline
    def my_pipeline():
      my_step(value=ExternalArtifact(my_array))
    ```
    """

    def __init__(
        self,
        value: Any = None,
        id: Optional[UUID] = None,
        materializer: Optional["MaterializerClassOrSource"] = None,
        store_artifact_metadata: bool = True,
        store_artifact_visualizations: bool = True,
    ) -> None:
        """Initializes an external artifact instance.

        The external artifact needs to have either a value associated with it
        that will be uploaded to the artifact store, or reference an artifact
        that is already registered in ZenML. This could be either from a
        previous pipeline run or a previously uploaded external artifact.

        Args:
            value: The artifact value. Either this or an artifact ID must be
                provided.
            id: The ID of an artifact that should be referenced by this external
                artifact. Either this or an artifact value must be provided.
            materializer: The materializer to use for saving the artifact value
                to the artifact store. Only used when `value` is provided.
            store_artifact_metadata: Whether metadata for the artifact should
                be stored. Only used when `value` is provided.
            store_artifact_visualizations: Whether visualizations for the
                artifact should be stored. Only used when `value` is provided.

        Raises:
            ValueError: If no/multiple values are provided for the `value` and
                `id` arguments.
        """
        if value is not None and id is not None:
            raise ValueError(
                "Only a value or an ID can be provided when creating an "
                "external artifact."
            )
        if value is None and id is None:
            raise ValueError(
                "Either a value or an ID must be provided when creating an "
                "external artifact."
            )

        self._value = value
        self._id = id
        self._materializer = materializer
        self._store_artifact_metadata = store_artifact_metadata
        self._store_artifact_visualizations = store_artifact_visualizations

    def upload_if_necessary(self) -> UUID:
        """Uploads the artifact if necessary.

        This method does one of two things:
        - If an artifact is referenced by ID, it will verify that the artifact
          exists and is in the correct artifact store.
        - Otherwise, the artifact value will be uploaded and published.

        Raises:
            RuntimeError: If the artifact store of the referenced artifact
                is not the same as the one in the active stack.
            RuntimeError: If the URI of the artifact already exists.

        Returns:
            The artifact ID.
        """
        artifact_store_id = Client().active_stack.artifact_store.id

        if self._id:
            response = Client().get_artifact(artifact_id=self._id)
            if response.artifact_store_id != artifact_store_id:
                raise RuntimeError(
                    f"The artifact {response.name} (ID: {response.id}) "
                    "referenced by an external artifact is not stored in the "
                    "artifact store of the active stack. This will lead to "
                    "issues loading the artifact. Please make sure to only "
                    "reference artifacts stored in your active artifact store."
                )
        else:
            assert self._value is not None

            logger.info("Uploading external artifact...")
            artifact_name = f"external_{uuid4()}"
            materializer_class = self._get_materializer_class(
                value=self._value
            )

            uri = os.path.join(
                Client().active_stack.artifact_store.path,
                "external_artifacts",
                artifact_name,
            )
            if fileio.exists(uri):
                raise RuntimeError(f"Artifact URI '{uri}' already exists.")
            fileio.makedirs(uri)

            materializer = materializer_class(uri)

            artifact_id = artifact_utils.upload_artifact(
                name=artifact_name,
                data=self._value,
                materializer=materializer,
                artifact_store_id=artifact_store_id,
                extract_metadata=self._store_artifact_metadata,
                include_visualizations=self._store_artifact_visualizations,
            )

            # To avoid duplicate uploads, switch to referencing the uploaded
            # artifact by ID
            self._id = artifact_id
            logger.info(
                "Finished uploading external artifact %s.", artifact_id
            )

        return self._id

    def _get_materializer_class(self, value: Any) -> Type["BaseMaterializer"]:
        """Gets a materializer class for a value.

        If a custom materializer is defined for this artifact it will be
        returned. Otherwise it will get the materializer class from the
        registry, falling back to the Cloudpickle materializer if no concrete
        materializer is registered for the type of value.

        Args:
            value: The value for which to get the materializer class.

        Returns:
            The materializer class.
        """
        if isinstance(self._materializer, type):
            return self._materializer
        elif self._materializer:
            return source_utils.load_and_validate_class(
                self._materializer, expected_class=BaseMaterializer
            )
        else:
            return materializer_registry[type(value)]
__init__(self, value=None, id=None, materializer=None, store_artifact_metadata=True, store_artifact_visualizations=True) special

Initializes an external artifact instance.

The external artifact needs to have either a value associated with it that will be uploaded to the artifact store, or reference an artifact that is already registered in ZenML. This could be either from a previous pipeline run or a previously uploaded external artifact.

Parameters:

Name Type Description Default
value Any

The artifact value. Either this or an artifact ID must be provided.

None
id Optional[uuid.UUID]

The ID of an artifact that should be referenced by this external artifact. Either this or an artifact value must be provided.

None
materializer Optional[MaterializerClassOrSource]

The materializer to use for saving the artifact value to the artifact store. Only used when value is provided.

None
store_artifact_metadata bool

Whether metadata for the artifact should be stored. Only used when value is provided.

True
store_artifact_visualizations bool

Whether visualizations for the artifact should be stored. Only used when value is provided.

True

Exceptions:

Type Description
ValueError

If no/multiple values are provided for the value and id arguments.

Source code in zenml/steps/external_artifact.py
def __init__(
    self,
    value: Any = None,
    id: Optional[UUID] = None,
    materializer: Optional["MaterializerClassOrSource"] = None,
    store_artifact_metadata: bool = True,
    store_artifact_visualizations: bool = True,
) -> None:
    """Initializes an external artifact instance.

    The external artifact needs to have either a value associated with it
    that will be uploaded to the artifact store, or reference an artifact
    that is already registered in ZenML. This could be either from a
    previous pipeline run or a previously uploaded external artifact.

    Args:
        value: The artifact value. Either this or an artifact ID must be
            provided.
        id: The ID of an artifact that should be referenced by this external
            artifact. Either this or an artifact value must be provided.
        materializer: The materializer to use for saving the artifact value
            to the artifact store. Only used when `value` is provided.
        store_artifact_metadata: Whether metadata for the artifact should
            be stored. Only used when `value` is provided.
        store_artifact_visualizations: Whether visualizations for the
            artifact should be stored. Only used when `value` is provided.

    Raises:
        ValueError: If no/multiple values are provided for the `value` and
            `id` arguments.
    """
    if value is not None and id is not None:
        raise ValueError(
            "Only a value or an ID can be provided when creating an "
            "external artifact."
        )
    if value is None and id is None:
        raise ValueError(
            "Either a value or an ID must be provided when creating an "
            "external artifact."
        )

    self._value = value
    self._id = id
    self._materializer = materializer
    self._store_artifact_metadata = store_artifact_metadata
    self._store_artifact_visualizations = store_artifact_visualizations
upload_if_necessary(self)

Uploads the artifact if necessary.

This method does one of two things: - If an artifact is referenced by ID, it will verify that the artifact exists and is in the correct artifact store. - Otherwise, the artifact value will be uploaded and published.

Exceptions:

Type Description
RuntimeError

If the artifact store of the referenced artifact is not the same as the one in the active stack.

RuntimeError

If the URI of the artifact already exists.

Returns:

Type Description
UUID

The artifact ID.

Source code in zenml/steps/external_artifact.py
def upload_if_necessary(self) -> UUID:
    """Uploads the artifact if necessary.

    This method does one of two things:
    - If an artifact is referenced by ID, it will verify that the artifact
      exists and is in the correct artifact store.
    - Otherwise, the artifact value will be uploaded and published.

    Raises:
        RuntimeError: If the artifact store of the referenced artifact
            is not the same as the one in the active stack.
        RuntimeError: If the URI of the artifact already exists.

    Returns:
        The artifact ID.
    """
    artifact_store_id = Client().active_stack.artifact_store.id

    if self._id:
        response = Client().get_artifact(artifact_id=self._id)
        if response.artifact_store_id != artifact_store_id:
            raise RuntimeError(
                f"The artifact {response.name} (ID: {response.id}) "
                "referenced by an external artifact is not stored in the "
                "artifact store of the active stack. This will lead to "
                "issues loading the artifact. Please make sure to only "
                "reference artifacts stored in your active artifact store."
            )
    else:
        assert self._value is not None

        logger.info("Uploading external artifact...")
        artifact_name = f"external_{uuid4()}"
        materializer_class = self._get_materializer_class(
            value=self._value
        )

        uri = os.path.join(
            Client().active_stack.artifact_store.path,
            "external_artifacts",
            artifact_name,
        )
        if fileio.exists(uri):
            raise RuntimeError(f"Artifact URI '{uri}' already exists.")
        fileio.makedirs(uri)

        materializer = materializer_class(uri)

        artifact_id = artifact_utils.upload_artifact(
            name=artifact_name,
            data=self._value,
            materializer=materializer,
            artifact_store_id=artifact_store_id,
            extract_metadata=self._store_artifact_metadata,
            include_visualizations=self._store_artifact_visualizations,
        )

        # To avoid duplicate uploads, switch to referencing the uploaded
        # artifact by ID
        self._id = artifact_id
        logger.info(
            "Finished uploading external artifact %s.", artifact_id
        )

    return self._id

step_decorator

Step decorator function.

step(_func=None, *, name=None, enable_cache=None, enable_artifact_metadata=None, enable_artifact_visualization=None, enable_step_logs=None, experiment_tracker=None, step_operator=None, output_materializers=None, settings=None, extra=None, on_failure=None, on_success=None)

Outer decorator function for the creation of a ZenML step.

In order to be able to work with parameters such as name, it features a nested decorator structure.

Parameters:

Name Type Description Default
_func Optional[~F]

The decorated function.

None
name Optional[str]

The name of the step. If left empty, the name of the decorated function will be used as a fallback.

None
enable_cache Optional[bool]

Specify whether caching is enabled for this step. If no value is passed, caching is enabled by default.

None
enable_artifact_metadata Optional[bool]

Specify whether metadata is enabled for this step. If no value is passed, metadata is enabled by default.

None
enable_artifact_visualization Optional[bool]

Specify whether visualization is enabled for this step. If no value is passed, visualization is enabled by default.

None
enable_step_logs Optional[bool]

Specify whether step logs are enabled for this step.

None
experiment_tracker Optional[str]

The experiment tracker to use for this step.

None
step_operator Optional[str]

The step operator to use for this step.

None
output_materializers Optional[OutputMaterializersSpecification]

Output materializers for this step. If given as a dict, the keys must be a subset of the output names of this step. If a single value (type or string) is given, the materializer will be used for all outputs.

None
settings Optional[Dict[str, SettingsOrDict]]

Settings for this step.

None
extra Optional[Dict[str, Any]]

Extra configurations for this step.

None
on_failure Optional[HookSpecification]

Callback function in event of failure of the step. Can be a function with a single argument of type BaseException, or a source path to such a function (e.g. module.my_function).

None
on_success Optional[HookSpecification]

Callback function in event of success of the step. Can be a function with no arguments, or a source path to such a function (e.g. module.my_function).

None

Returns:

Type Description
Union[Type[zenml.steps.base_step.BaseStep], Callable[[~F], Type[zenml.steps.base_step.BaseStep]]]

The inner decorator which creates the step class based on the ZenML BaseStep

Source code in zenml/steps/step_decorator.py
def step(
    _func: Optional[F] = None,
    *,
    name: Optional[str] = None,
    enable_cache: Optional[bool] = None,
    enable_artifact_metadata: Optional[bool] = None,
    enable_artifact_visualization: Optional[bool] = None,
    enable_step_logs: Optional[bool] = None,
    experiment_tracker: Optional[str] = None,
    step_operator: Optional[str] = None,
    output_materializers: Optional["OutputMaterializersSpecification"] = None,
    settings: Optional[Dict[str, "SettingsOrDict"]] = None,
    extra: Optional[Dict[str, Any]] = None,
    on_failure: Optional["HookSpecification"] = None,
    on_success: Optional["HookSpecification"] = None,
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
    """Outer decorator function for the creation of a ZenML step.

    In order to be able to work with parameters such as `name`, it features a
    nested decorator structure.

    Args:
        _func: The decorated function.
        name: The name of the step. If left empty, the name of the decorated
            function will be used as a fallback.
        enable_cache: Specify whether caching is enabled for this step. If no
            value is passed, caching is enabled by default.
        enable_artifact_metadata: Specify whether metadata is enabled for this
            step. If no value is passed, metadata is enabled by default.
        enable_artifact_visualization: Specify whether visualization is enabled
            for this step. If no value is passed, visualization is enabled by
            default.
        enable_step_logs: Specify whether step logs are enabled for this step.
        experiment_tracker: The experiment tracker to use for this step.
        step_operator: The step operator to use for this step.
        output_materializers: Output materializers for this step. If
            given as a dict, the keys must be a subset of the output names
            of this step. If a single value (type or string) is given, the
            materializer will be used for all outputs.
        settings: Settings for this step.
        extra: Extra configurations for this step.
        on_failure: Callback function in event of failure of the step. Can be a
            function with a single argument of type `BaseException`, or a source
            path to such a function (e.g. `module.my_function`).
        on_success: Callback function in event of success of the step. Can be a
            function with no arguments, or a source path to such a function
            (e.g. `module.my_function`).

    Returns:
        The inner decorator which creates the step class based on the
        ZenML BaseStep
    """

    def inner_decorator(func: F) -> Type[BaseStep]:
        """Inner decorator function for the creation of a ZenML Step.

        Args:
            func: types.FunctionType, this function will be used as the
                "process" method of the generated Step.

        Returns:
            The class of a newly generated ZenML Step.
        """
        step_name = name or func.__name__
        logger.warning(
            "The `@step` decorator that you used to define your "
            f"{step_name} step is deprecated. Check out our docs "
            "https://docs.zenml.io/user-guide/advanced-guide/migrate-your-old-pipelines-and-steps "
            "for information on how to migrate your steps to the new syntax."
        )

        return type(  # noqa
            func.__name__,
            (_DecoratedStep,),
            {
                STEP_INNER_FUNC_NAME: staticmethod(func),
                CLASS_CONFIGURATION: {
                    PARAM_STEP_NAME: name,
                    PARAM_ENABLE_CACHE: enable_cache,
                    PARAM_ENABLE_ARTIFACT_METADATA: enable_artifact_metadata,
                    PARAM_ENABLE_ARTIFACT_VISUALIZATION: enable_artifact_visualization,
                    PARAM_ENABLE_STEP_LOGS: enable_step_logs,
                    PARAM_EXPERIMENT_TRACKER: experiment_tracker,
                    PARAM_STEP_OPERATOR: step_operator,
                    PARAM_OUTPUT_MATERIALIZERS: output_materializers,
                    PARAM_SETTINGS: settings,
                    PARAM_EXTRA_OPTIONS: extra,
                    PARAM_ON_FAILURE: on_failure,
                    PARAM_ON_SUCCESS: on_success,
                },
                "__module__": func.__module__,
                "__doc__": func.__doc__,
            },
        )

    if _func is None:
        return inner_decorator
    else:
        return inner_decorator(_func)

step_environment

Step environment class.

StepEnvironment (BaseEnvironmentComponent)

(Deprecated) Added information about a run inside a step function.

This takes the form of an Environment component. This class can be used from within a pipeline step implementation to access additional information about the runtime parameters of a pipeline step, such as the pipeline name, pipeline run ID and other pipeline runtime information. To use it, access it inside your step function like this:

from zenml.environment import Environment

@step
def my_step(...)
    env = Environment().step_environment
    do_something_with(env.pipeline_name, env.run_name, env.step_name)
Source code in zenml/steps/step_environment.py
class StepEnvironment(BaseEnvironmentComponent):
    """(Deprecated) Added information about a run inside a step function.

    This takes the form of an Environment component. This class can be used from
    within a pipeline step implementation to access additional information about
    the runtime parameters of a pipeline step, such as the pipeline name,
    pipeline run ID and other pipeline runtime information. To use it, access it
    inside your step function like this:

    ```python
    from zenml.environment import Environment

    @step
    def my_step(...)
        env = Environment().step_environment
        do_something_with(env.pipeline_name, env.run_name, env.step_name)
    ```
    """

    NAME = STEP_ENVIRONMENT_NAME

    def __init__(
        self,
        step_run_info: "StepRunInfo",
        cache_enabled: bool,
    ):
        """Initialize the environment of the currently running step.

        Args:
            step_run_info: Info about the currently running step.
            cache_enabled: Whether caching is enabled for the current step run.
        """
        super().__init__()
        self._step_run_info = step_run_info
        self._cache_enabled = cache_enabled

    @property
    def pipeline_name(self) -> str:
        """The name of the currently running pipeline.

        Returns:
            The name of the currently running pipeline.
        """
        return self._step_run_info.pipeline.name

    @property
    def run_name(self) -> str:
        """The name of the current pipeline run.

        Returns:
            The name of the current pipeline run.
        """
        return self._step_run_info.run_name

    @property
    def step_name(self) -> str:
        """The name of the currently running step.

        Returns:
            The name of the currently running step.
        """
        return self._step_run_info.pipeline_step_name

    @property
    def step_run_info(self) -> "StepRunInfo":
        """Info about the currently running step.

        Returns:
            Info about the currently running step.
        """
        return self._step_run_info

    @property
    def cache_enabled(self) -> bool:
        """Returns whether cache is enabled for the step.

        Returns:
            True if cache is enabled for the step, otherwise False.
        """
        return self._cache_enabled
cache_enabled: bool property readonly

Returns whether cache is enabled for the step.

Returns:

Type Description
bool

True if cache is enabled for the step, otherwise False.

pipeline_name: str property readonly

The name of the currently running pipeline.

Returns:

Type Description
str

The name of the currently running pipeline.

run_name: str property readonly

The name of the current pipeline run.

Returns:

Type Description
str

The name of the current pipeline run.

step_name: str property readonly

The name of the currently running step.

Returns:

Type Description
str

The name of the currently running step.

step_run_info: StepRunInfo property readonly

Info about the currently running step.

Returns:

Type Description
StepRunInfo

Info about the currently running step.

__init__(self, step_run_info, cache_enabled) special

Initialize the environment of the currently running step.

Parameters:

Name Type Description Default
step_run_info StepRunInfo

Info about the currently running step.

required
cache_enabled bool

Whether caching is enabled for the current step run.

required
Source code in zenml/steps/step_environment.py
def __init__(
    self,
    step_run_info: "StepRunInfo",
    cache_enabled: bool,
):
    """Initialize the environment of the currently running step.

    Args:
        step_run_info: Info about the currently running step.
        cache_enabled: Whether caching is enabled for the current step run.
    """
    super().__init__()
    self._step_run_info = step_run_info
    self._cache_enabled = cache_enabled

step_invocation

Step invocation class definition.

StepInvocation

Step invocation class.

Source code in zenml/steps/step_invocation.py
class StepInvocation:
    """Step invocation class."""

    def __init__(
        self,
        id: str,
        step: "BaseStep",
        input_artifacts: Dict[str, "StepArtifact"],
        external_artifacts: Dict[str, "ExternalArtifact"],
        parameters: Dict[str, Any],
        upstream_steps: Set[str],
        pipeline: "Pipeline",
    ) -> None:
        """Initialize a step invocation.

        Args:
            id: The invocation ID.
            step: The step that is represented by the invocation.
            input_artifacts: The input artifacts for the invocation.
            external_artifacts: The external artifacts for the invocation.
            parameters: The parameters for the invocation.
            upstream_steps: The upstream steps for the invocation.
            pipeline: The parent pipeline of the invocation.
        """
        self.id = id
        self.step = step
        self.input_artifacts = input_artifacts
        self.external_artifacts = external_artifacts
        self.parameters = parameters
        self.invocation_upstream_steps = upstream_steps
        self.pipeline = pipeline

    @property
    def upstream_steps(self) -> Set[str]:
        """The upstream steps of the invocation.

        Returns:
            The upstream steps of the invocation.
        """
        return self.invocation_upstream_steps.union(
            self._get_and_validate_step_upstream_steps()
        )

    def _get_and_validate_step_upstream_steps(self) -> Set[str]:
        """Validates the upstream steps defined on the step instance.

        This is only allowed in legacy pipelines when calling `step.after(...)`
        and we need to make sure that both the upstream and downstream steps
        of such a relationship are only invoked once inside a pipeline.

        Returns:
            The upstream steps defined on the step instance.
        """

        def _verify_single_invocation(step: "BaseStep") -> str:
            invocations = {
                invocation
                for invocation in self.pipeline.invocations.values()
                if invocation.step is step
            }
            if len(invocations) > 1:
                raise RuntimeError(
                    "Setting upstream steps for a step using "
                    "`step_1.after(step_2)` is not allowed in combination "
                    "with calling one of the two steps multiple times."
                )
            return invocations.pop().id

        if self.step.upstream_steps:
            # If the step has upstream steps, make sure it only got invoked once
            _verify_single_invocation(step=self.step)

        upstream_steps = set()

        for upstream_step in self.step.upstream_steps:
            upstream_step_invocation_id = _verify_single_invocation(
                step=upstream_step
            )
            upstream_steps.add(upstream_step_invocation_id)

        return upstream_steps

    def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration":
        """Finalizes a step invocation.

        The will validate the upstream steps and run final configurations on the
        step that is represented by the invocation.

        Args:
            parameters_to_ignore: Set of parameters that should not be applied
                to the step instance.

        Returns:
            The finalized step configuration.
        """
        # Validate the upstream steps for legacy .after() calls
        self._get_and_validate_step_upstream_steps()

        parameters_to_apply = {
            key: value
            for key, value in self.parameters.items()
            if key not in parameters_to_ignore
        }
        self.step.configure(parameters=parameters_to_apply)

        external_artifact_ids = {}
        for key, artifact in self.external_artifacts.items():
            external_artifact_ids[key] = artifact.upload_if_necessary()

        return self.step._finalize_configuration(
            input_artifacts=self.input_artifacts,
            external_artifacts=external_artifact_ids,
        )
upstream_steps: Set[str] property readonly

The upstream steps of the invocation.

Returns:

Type Description
Set[str]

The upstream steps of the invocation.

__init__(self, id, step, input_artifacts, external_artifacts, parameters, upstream_steps, pipeline) special

Initialize a step invocation.

Parameters:

Name Type Description Default
id str

The invocation ID.

required
step BaseStep

The step that is represented by the invocation.

required
input_artifacts Dict[str, StepArtifact]

The input artifacts for the invocation.

required
external_artifacts Dict[str, ExternalArtifact]

The external artifacts for the invocation.

required
parameters Dict[str, Any]

The parameters for the invocation.

required
upstream_steps Set[str]

The upstream steps for the invocation.

required
pipeline Pipeline

The parent pipeline of the invocation.

required
Source code in zenml/steps/step_invocation.py
def __init__(
    self,
    id: str,
    step: "BaseStep",
    input_artifacts: Dict[str, "StepArtifact"],
    external_artifacts: Dict[str, "ExternalArtifact"],
    parameters: Dict[str, Any],
    upstream_steps: Set[str],
    pipeline: "Pipeline",
) -> None:
    """Initialize a step invocation.

    Args:
        id: The invocation ID.
        step: The step that is represented by the invocation.
        input_artifacts: The input artifacts for the invocation.
        external_artifacts: The external artifacts for the invocation.
        parameters: The parameters for the invocation.
        upstream_steps: The upstream steps for the invocation.
        pipeline: The parent pipeline of the invocation.
    """
    self.id = id
    self.step = step
    self.input_artifacts = input_artifacts
    self.external_artifacts = external_artifacts
    self.parameters = parameters
    self.invocation_upstream_steps = upstream_steps
    self.pipeline = pipeline
finalize(self, parameters_to_ignore)

Finalizes a step invocation.

The will validate the upstream steps and run final configurations on the step that is represented by the invocation.

Parameters:

Name Type Description Default
parameters_to_ignore Set[str]

Set of parameters that should not be applied to the step instance.

required

Returns:

Type Description
StepConfiguration

The finalized step configuration.

Source code in zenml/steps/step_invocation.py
def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration":
    """Finalizes a step invocation.

    The will validate the upstream steps and run final configurations on the
    step that is represented by the invocation.

    Args:
        parameters_to_ignore: Set of parameters that should not be applied
            to the step instance.

    Returns:
        The finalized step configuration.
    """
    # Validate the upstream steps for legacy .after() calls
    self._get_and_validate_step_upstream_steps()

    parameters_to_apply = {
        key: value
        for key, value in self.parameters.items()
        if key not in parameters_to_ignore
    }
    self.step.configure(parameters=parameters_to_apply)

    external_artifact_ids = {}
    for key, artifact in self.external_artifacts.items():
        external_artifact_ids[key] = artifact.upload_if_necessary()

    return self.step._finalize_configuration(
        input_artifacts=self.input_artifacts,
        external_artifacts=external_artifact_ids,
    )

step_output

Step output class.

Output

A named tuple with a default name that cannot be overridden.

Source code in zenml/steps/step_output.py
class Output(object):
    """A named tuple with a default name that cannot be overridden."""

    def __init__(self, **kwargs: Type[Any]):
        """Initializes the output.

        Args:
            **kwargs: The output values.
        """
        self.outputs = NamedTuple("ZenOutput", **kwargs)  # type: ignore[misc]

    def items(self) -> Iterator[Tuple[str, Type[Any]]]:
        """Yields a tuple of type (output_name, output_type).

        Yields:
            A tuple of type (output_name, output_type).
        """
        yield from self.outputs.__annotations__.items()
__init__(self, **kwargs) special

Initializes the output.

Parameters:

Name Type Description Default
**kwargs Type[Any]

The output values.

{}
Source code in zenml/steps/step_output.py
def __init__(self, **kwargs: Type[Any]):
    """Initializes the output.

    Args:
        **kwargs: The output values.
    """
    self.outputs = NamedTuple("ZenOutput", **kwargs)  # type: ignore[misc]
items(self)

Yields a tuple of type (output_name, output_type).

Yields:

Type Description
Iterator[Tuple[str, Type[Any]]]

A tuple of type (output_name, output_type).

Source code in zenml/steps/step_output.py
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
    """Yields a tuple of type (output_name, output_type).

    Yields:
        A tuple of type (output_name, output_type).
    """
    yield from self.outputs.__annotations__.items()

utils

Utility functions and classes to run ZenML steps.

OnlyNoneReturnsVisitor (ReturnVisitor)

Checks whether a function AST contains only None returns.

Source code in zenml/steps/utils.py
class OnlyNoneReturnsVisitor(ReturnVisitor):
    """Checks whether a function AST contains only `None` returns."""

    def __init__(self) -> None:
        """Initializes a visitor instance."""
        super().__init__()
        self.has_only_none_returns = True

    def visit_Return(self, node: ast.Return) -> None:
        """Visit a return statement.

        Args:
            node: The return statement to visit.
        """
        if node.value is not None:
            if isinstance(node.value, (ast.Constant, ast.NameConstant)):
                if node.value.value is None:
                    return

            self.has_only_none_returns = False
__init__(self) special

Initializes a visitor instance.

Source code in zenml/steps/utils.py
def __init__(self) -> None:
    """Initializes a visitor instance."""
    super().__init__()
    self.has_only_none_returns = True
visit_Return(self, node)

Visit a return statement.

Parameters:

Name Type Description Default
node Return

The return statement to visit.

required
Source code in zenml/steps/utils.py
def visit_Return(self, node: ast.Return) -> None:
    """Visit a return statement.

    Args:
        node: The return statement to visit.
    """
    if node.value is not None:
        if isinstance(node.value, (ast.Constant, ast.NameConstant)):
            if node.value.value is None:
                return

        self.has_only_none_returns = False

ReturnVisitor (NodeVisitor)

AST visitor class that can be subclassed to visit function returns.

Source code in zenml/steps/utils.py
class ReturnVisitor(ast.NodeVisitor):
    """AST visitor class that can be subclassed to visit function returns."""

    def __init__(self, ignore_nested_functions: bool = True) -> None:
        """Initializes a return visitor instance.

        Args:
            ignore_nested_functions: If `True`, will skip visiting nested
                functions.
        """
        self._ignore_nested_functions = ignore_nested_functions
        self._inside_function = False

    def _visit_function(
        self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
    ) -> None:
        """Visit a (async) function definition node.

        Args:
            node: The node to visit.
        """
        if self._ignore_nested_functions and self._inside_function:
            # We're already inside a function definition and should ignore
            # nested functions so we don't want to recurse any further
            return

        self._inside_function = True
        self.generic_visit(node)

    visit_FunctionDef = _visit_function
    visit_AsyncFunctionDef = _visit_function
__init__(self, ignore_nested_functions=True) special

Initializes a return visitor instance.

Parameters:

Name Type Description Default
ignore_nested_functions bool

If True, will skip visiting nested functions.

True
Source code in zenml/steps/utils.py
def __init__(self, ignore_nested_functions: bool = True) -> None:
    """Initializes a return visitor instance.

    Args:
        ignore_nested_functions: If `True`, will skip visiting nested
            functions.
    """
    self._ignore_nested_functions = ignore_nested_functions
    self._inside_function = False
visit_AsyncFunctionDef(self, node)

Visit a (async) function definition node.

Parameters:

Name Type Description Default
node Union[_ast.FunctionDef, _ast.AsyncFunctionDef]

The node to visit.

required
Source code in zenml/steps/utils.py
def _visit_function(
    self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
    """Visit a (async) function definition node.

    Args:
        node: The node to visit.
    """
    if self._ignore_nested_functions and self._inside_function:
        # We're already inside a function definition and should ignore
        # nested functions so we don't want to recurse any further
        return

    self._inside_function = True
    self.generic_visit(node)
visit_FunctionDef(self, node)

Visit a (async) function definition node.

Parameters:

Name Type Description Default
node Union[_ast.FunctionDef, _ast.AsyncFunctionDef]

The node to visit.

required
Source code in zenml/steps/utils.py
def _visit_function(
    self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
    """Visit a (async) function definition node.

    Args:
        node: The node to visit.
    """
    if self._ignore_nested_functions and self._inside_function:
        # We're already inside a function definition and should ignore
        # nested functions so we don't want to recurse any further
        return

    self._inside_function = True
    self.generic_visit(node)

TupleReturnVisitor (ReturnVisitor)

Checks whether a function AST contains tuple returns.

Source code in zenml/steps/utils.py
class TupleReturnVisitor(ReturnVisitor):
    """Checks whether a function AST contains tuple returns."""

    def __init__(self) -> None:
        """Initializes a visitor instance."""
        super().__init__()
        self.has_tuple_return = False

    def visit_Return(self, node: ast.Return) -> None:
        """Visit a return statement.

        Args:
            node: The return statement to visit.
        """
        if isinstance(node.value, ast.Tuple) and len(node.value.elts) > 1:
            self.has_tuple_return = True
__init__(self) special

Initializes a visitor instance.

Source code in zenml/steps/utils.py
def __init__(self) -> None:
    """Initializes a visitor instance."""
    super().__init__()
    self.has_tuple_return = False
visit_Return(self, node)

Visit a return statement.

Parameters:

Name Type Description Default
node Return

The return statement to visit.

required
Source code in zenml/steps/utils.py
def visit_Return(self, node: ast.Return) -> None:
    """Visit a return statement.

    Args:
        node: The return statement to visit.
    """
    if isinstance(node.value, ast.Tuple) and len(node.value.elts) > 1:
        self.has_tuple_return = True

get_args(obj)

Get arguments of a type annotation.

Examples:

get_args(Union[int, str]) == (int, str)

Parameters:

Name Type Description Default
obj Any

The annotation.

required

Returns:

Type Description
Tuple[Any, ...]

The args of the annotation.

Source code in zenml/steps/utils.py
def get_args(obj: Any) -> Tuple[Any, ...]:
    """Get arguments of a type annotation.

    Example:
        `get_args(Union[int, str]) == (int, str)`

    Args:
        obj: The annotation.

    Returns:
        The args of the annotation.
    """
    return tuple(
        pydantic_typing.get_origin(v) or v
        for v in pydantic_typing.get_args(obj)
    )

get_output_name_from_annotation_metadata(annotation)

Get the output name from a type annotation.

Examples:

get_output_name_from_annotation_metadata(int)  # None
get_output_name_from_annotation_metadata(Annotated[int, "name"]  # name

Parameters:

Name Type Description Default
annotation Any

The type annotation.

required

Exceptions:

Type Description
ValueError

If the annotation contains multiple metadata fields or a single non-string metadata field.

Returns:

Type Description
Optional[str]

The annotation metadata.

Source code in zenml/steps/utils.py
def get_output_name_from_annotation_metadata(annotation: Any) -> Optional[str]:
    """Get the output name from a type annotation.

    Example:
    ```python
    get_output_name_from_annotation_metadata(int)  # None
    get_output_name_from_annotation_metadata(Annotated[int, "name"]  # name
    ```

    Args:
        annotation: The type annotation.

    Raises:
        ValueError: If the annotation contains multiple metadata fields or a
            single non-string metadata field.

    Returns:
        The annotation metadata.
    """
    if (pydantic_typing.get_origin(annotation) or annotation) is not Annotated:
        return None

    annotation, *metadata = pydantic_typing.get_args(annotation)

    if len(metadata) != 1:
        raise ValueError(
            "Annotation metadata can only contain a single element which must "
            "be the output name."
        )

    output_name = metadata[0]

    if not isinstance(output_name, str):
        raise ValueError(
            "Annotation metadata must be a string which will be used as the "
            "output name."
        )

    return output_name

has_only_none_returns(func)

Checks whether a function contains only None returns.

A None return could be either an explicit return None or an empty return statement.

Examples:

def f1():
  return None

def f2():
  return

def f3(condition):
  if condition:
    return None
  else:
    return 1

has_only_none_returns(f1)  # True
has_only_none_returns(f2)  # True
has_only_none_returns(f3)  # False

Parameters:

Name Type Description Default
func Callable[..., Any]

The function to check.

required

Returns:

Type Description
bool

Whether the function contains only None returns.

Source code in zenml/steps/utils.py
def has_only_none_returns(func: Callable[..., Any]) -> bool:
    """Checks whether a function contains only `None` returns.

    A `None` return could be either an explicit `return None` or an empty
    `return` statement.

    Example:
    ```python
    def f1():
      return None

    def f2():
      return

    def f3(condition):
      if condition:
        return None
      else:
        return 1

    has_only_none_returns(f1)  # True
    has_only_none_returns(f2)  # True
    has_only_none_returns(f3)  # False
    ```

    Args:
        func: The function to check.

    Returns:
        Whether the function contains only `None` returns.
    """
    source = textwrap.dedent(source_code_utils.get_source_code(func))
    tree = ast.parse(source)

    visitor = OnlyNoneReturnsVisitor()
    visitor.visit(tree)

    return visitor.has_only_none_returns

has_tuple_return(func)

Checks whether a function returns multiple values.

Multiple values means that the return statement is followed by a tuple (with or without brackets).

Examples:

def f1():
  return 1, 2

def f2():
  return (1, 2)

def f3():
  var = (1, 2)
  return var

has_tuple_return(f1)  # True
has_tuple_return(f2)  # True
has_tuple_return(f3)  # False

Parameters:

Name Type Description Default
func Callable[..., Any]

The function to check.

required

Returns:

Type Description
bool

Whether the function returns multiple values.

Source code in zenml/steps/utils.py
def has_tuple_return(func: Callable[..., Any]) -> bool:
    """Checks whether a function returns multiple values.

    Multiple values means that the `return` statement is followed by a tuple
    (with or without brackets).

    Example:
    ```python
    def f1():
      return 1, 2

    def f2():
      return (1, 2)

    def f3():
      var = (1, 2)
      return var

    has_tuple_return(f1)  # True
    has_tuple_return(f2)  # True
    has_tuple_return(f3)  # False
    ```

    Args:
        func: The function to check.

    Returns:
        Whether the function returns multiple values.
    """
    source = textwrap.dedent(source_code_utils.get_source_code(func))
    tree = ast.parse(source)

    visitor = TupleReturnVisitor()
    visitor.visit(tree)

    return visitor.has_tuple_return

parse_return_type_annotations(func, enforce_type_annotations=False)

Parse the return type annotation of a step function.

Parameters:

Name Type Description Default
func Callable[..., Any]

The step function.

required
enforce_type_annotations bool

If True, raises an exception if a type annotation is missing.

False

Exceptions:

Type Description
RuntimeError

If the output annotation has variable length or contains duplicate output names.

RuntimeError

If type annotations should be enforced and a type annotation is missing.

Returns:

Type Description
Dict[str, Any]

The function output artifacts.

Source code in zenml/steps/utils.py
def parse_return_type_annotations(
    func: Callable[..., Any], enforce_type_annotations: bool = False
) -> Dict[str, Any]:
    """Parse the return type annotation of a step function.

    Args:
        func: The step function.
        enforce_type_annotations: If `True`, raises an exception if a type
            annotation is missing.

    Raises:
        RuntimeError: If the output annotation has variable length or contains
            duplicate output names.
        RuntimeError: If type annotations should be enforced and a type
            annotation is missing.

    Returns:
        The function output artifacts.
    """
    signature = inspect.signature(func, follow_wrapped=True)
    return_annotation = signature.return_annotation

    if return_annotation is None:
        return {}

    if return_annotation is signature.empty:
        if enforce_type_annotations:
            raise RuntimeError(
                "Missing return type annotation for step function "
                f"'{func.__name__}'."
            )
        elif has_only_none_returns(func):
            return {}
        else:
            return_annotation = Any

    if isinstance(return_annotation, Output):
        logger.warning(
            "Using the `Output` class to define the outputs of your steps is "
            "deprecated. You should instead use the standard Python way of "
            "type annotating your functions. Check out our documentation "
            "https://docs.zenml.io/user-guide/advanced-guide/configure-steps-pipelines#step-output-names"
            "for more information on how to assign custom names to your step "
            "outputs."
        )
        return {
            output_name: resolve_type_annotation(output_type)
            for output_name, output_type in return_annotation.items()
        }
    elif pydantic_typing.get_origin(return_annotation) is tuple:
        requires_multiple_artifacts = has_tuple_return(func)

        if requires_multiple_artifacts:
            output_signature = {}

            args = pydantic_typing.get_args(return_annotation)
            if args[-1] is Ellipsis:
                raise RuntimeError(
                    "Variable length output annotations are not allowed."
                )

            for i, annotation in enumerate(args):
                resolved_annotation = resolve_type_annotation(annotation)
                output_name = (
                    get_output_name_from_annotation_metadata(annotation)
                    or f"output_{i}"
                )
                if output_name in output_signature:
                    raise RuntimeError(f"Duplicate output name {output_name}.")

                output_signature[output_name] = resolved_annotation

            return output_signature

    resolved_annotation = resolve_type_annotation(return_annotation)
    output_name = (
        get_output_name_from_annotation_metadata(return_annotation)
        or SINGLE_RETURN_OUT_NAME
    )

    output_signature = {output_name: resolved_annotation}

    return output_signature

resolve_type_annotation(obj)

Returns the non-generic class for generic aliases of the typing module.

Example: if the input object is typing.Dict, this method will return the concrete class dict.

Parameters:

Name Type Description Default
obj Any

The object to resolve.

required

Returns:

Type Description
Any

The non-generic class for generic aliases of the typing module.

Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
    """Returns the non-generic class for generic aliases of the typing module.

    Example: if the input object is `typing.Dict`, this method will return the
    concrete class `dict`.

    Args:
        obj: The object to resolve.

    Returns:
        The non-generic class for generic aliases of the typing module.
    """
    origin = pydantic_typing.get_origin(obj) or obj

    if origin is Annotated:
        annotation, *_ = pydantic_typing.get_args(obj)
        return resolve_type_annotation(annotation)
    elif pydantic_typing.is_union(origin):
        return obj

    return origin