Steps
zenml.steps
special
A step is a single piece or stage of a ZenML pipeline. Think of each step as being one of the nodes of a Directed Acyclic Graph (or DAG). Steps are responsible for one aspect of processing or interacting with the data / artifacts in the pipeline.
ZenML currently implements a basic step interface, but there will be other more customized interfaces (layered in a hierarchy) for specialized implementations. Conceptually, a Step is a discrete and independent part of a pipeline that is responsible for one particular aspect of data manipulation inside a ZenML pipeline.
Steps can be subclassed from the BaseStep
class, or used via our @step
decorator.
base_step
BaseStep
Abstract base class for all ZenML steps.
Attributes:
Name | Type | Description |
---|---|---|
name |
The name of this step. |
|
pipeline_parameter_name |
Optional[str] |
The name of the pipeline parameter for which this step was passed as an argument. |
enable_cache |
A boolean indicating if caching is enabled for this step. |
|
custom_step_operator |
Optional name of a custom step operator to use for this step. |
|
requires_context |
A boolean indicating if this step requires a
|
Source code in zenml/steps/base_step.py
class BaseStep(metaclass=BaseStepMeta):
"""Abstract base class for all ZenML steps.
Attributes:
name: The name of this step.
pipeline_parameter_name: The name of the pipeline parameter for which
this step was passed as an argument.
enable_cache: A boolean indicating if caching is enabled for this step.
custom_step_operator: Optional name of a custom step operator to use
for this step.
requires_context: A boolean indicating if this step requires a
`StepContext` object during execution.
"""
# TODO [ENG-156]: Ensure these are ordered
INPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
OUTPUT_SIGNATURE: ClassVar[Dict[str, Type[Any]]] = None # type: ignore[assignment] # noqa
CONFIG_PARAMETER_NAME: ClassVar[Optional[str]] = None
CONFIG_CLASS: ClassVar[Optional[Type[BaseStepConfig]]] = None
CONTEXT_PARAMETER_NAME: ClassVar[Optional[str]] = None
PARAM_SPEC: Dict[str, Any] = {}
INPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
OUTPUT_SPEC: Dict[str, Type[BaseArtifact]] = {}
INSTANCE_CONFIGURATION: Dict[str, Any] = {}
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.name = self.__class__.__name__
self.pipeline_parameter_name: Optional[str] = None
kwargs.update(getattr(self, INSTANCE_CONFIGURATION))
self.requires_context = bool(self.CONTEXT_PARAMETER_NAME)
self._created_by_functional_api = kwargs.pop(
PARAM_CREATED_BY_FUNCTIONAL_API, False
)
self.custom_step_operator = kwargs.pop(PARAM_CUSTOM_STEP_OPERATOR, None)
enable_cache = kwargs.pop(PARAM_ENABLE_CACHE, None)
if enable_cache is None:
if self.requires_context:
# Using the StepContext inside a step provides access to
# external resources which might influence the step execution.
# We therefore disable caching unless it is explicitly enabled
enable_cache = False
logger.debug(
"Step '%s': Step context required and caching not "
"explicitly enabled.",
self.name,
)
else:
# Default to cache enabled if not explicitly set
enable_cache = True
logger.debug(
"Step '%s': Caching %s.",
self.name,
"enabled" if enable_cache else "disabled",
)
self.enable_cache = enable_cache
self._explicit_materializers: Dict[str, Type[BaseMaterializer]] = {}
self._component: Optional[_ZenMLSimpleComponent] = None
self._has_been_called = False
self._verify_init_arguments(*args, **kwargs)
self._verify_output_spec()
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic."""
def get_materializers(
self, ensure_complete: bool = False
) -> Dict[str, Type[BaseMaterializer]]:
"""Returns available materializers for the outputs of this step.
Args:
ensure_complete: If set to `True`, this method will raise a
`StepInterfaceError` if no materializer can be found for an
output.
Returns:
A dictionary mapping output names to `BaseMaterializer` subclasses.
If no explicit materializer was set using
`step.with_return_materializers(...)`, this checks the
default materializer registry to find a materializer for the
type of the output. If no materializer is registered, the
output of this method will not contain an entry for this output.
Raises:
StepInterfaceError: (Only if `ensure_complete` is set to `True`)
If an output does not have an explicit materializer assigned
to it and there is no default materializer registered for
the output type.
"""
materializers = self._explicit_materializers.copy()
for output_name, output_type in self.OUTPUT_SIGNATURE.items():
if output_name in materializers:
# Materializer for this output was set explicitly
pass
elif default_materializer_registry.is_registered(output_type):
materializer = default_materializer_registry[output_type]
materializers[output_name] = materializer
else:
if ensure_complete:
raise StepInterfaceError(
f"Unable to find materializer for output "
f"'{output_name}' of type `{output_type}` in step "
f"'{self.name}'. Please make sure to either "
f"explicitly set a materializer for step outputs "
f"using `step.with_return_materializers(...)` or "
f"registering a default materializer for specific "
f"types by subclassing `BaseMaterializer` and setting "
f"its `ASSOCIATED_TYPES` class variable.",
url="https://docs.zenml.io/guides/index/custom-materializer",
)
return materializers
@property
def _internal_execution_parameters(self) -> Dict[str, Any]:
"""ZenML internal execution parameters for this step."""
parameters = {
PARAM_PIPELINE_PARAMETER_NAME: self.pipeline_parameter_name,
PARAM_CUSTOM_STEP_OPERATOR: self.custom_step_operator,
}
if self.enable_cache:
# Caching is enabled so we compute a hash of the step function code
# and materializers to catch changes in the step behavior
# If the step was defined using the functional api, only track
# changes to the entrypoint function. Otherwise track changes to
# the entire step class.
source_object = (
self.entrypoint
if self._created_by_functional_api
else self.__class__
)
parameters["step_source"] = get_hashed_source(source_object)
for name, materializer in self.get_materializers().items():
key = f"{name}_materializer_source"
parameters[key] = get_hashed_source(materializer)
else:
# Add a random string to the execution properties to disable caching
random_string = f"{random.getrandbits(128):032x}"
parameters["disable_cache"] = random_string
return {
INTERNAL_EXECUTION_PARAMETER_PREFIX + key: value
for key, value in parameters.items()
}
def _verify_init_arguments(self, *args: Any, **kwargs: Any) -> None:
"""Verifies the initialization args and kwargs of this step.
This method makes sure that there is only a config object passed at
initialization and that it was passed using the correct name and
type specified in the step declaration.
If the correct config object was found, additionally saves the
config parameters to `self.PARAM_SPEC`.
Args:
*args: The args passed to the init method of this step.
**kwargs: The kwargs passed to the init method of this step.
Raises:
StepInterfaceError: If there are too many arguments or arguments
with a wrong name/type.
"""
maximum_arg_count = 1 if self.CONFIG_CLASS else 0
arg_count = len(args) + len(kwargs)
if arg_count > maximum_arg_count:
raise StepInterfaceError(
f"Too many arguments ({arg_count}, expected: "
f"{maximum_arg_count}) passed when creating a "
f"'{self.name}' step."
)
if self.CONFIG_PARAMETER_NAME and self.CONFIG_CLASS:
if args:
config = args[0]
elif kwargs:
key, config = kwargs.popitem()
if key != self.CONFIG_PARAMETER_NAME:
raise StepInterfaceError(
f"Unknown keyword argument '{key}' when creating a "
f"'{self.name}' step, only expected a single "
f"argument with key '{self.CONFIG_PARAMETER_NAME}'."
)
else:
# This step requires configuration parameters but no config
# object was passed as an argument. The parameters might be
# set via default values in the config class or in a
# configuration file, so we continue for now and verify
# that all parameters are set before running the step
return
if not isinstance(config, self.CONFIG_CLASS):
raise StepInterfaceError(
f"`{config}` object passed when creating a "
f"'{self.name}' step is not a "
f"`{self.CONFIG_CLASS.__name__}` instance."
)
self.PARAM_SPEC = config.dict()
def _verify_output_spec(self) -> None:
"""Verifies the explicitly set output artifact types of this step.
Raises:
StepInterfaceError: If an output artifact type is specified for a
non-existent step output or the artifact type is not allowed
for the corresponding output type.
"""
for output_name, artifact_type in self.OUTPUT_SPEC.items():
if output_name not in self.OUTPUT_SIGNATURE:
raise StepInterfaceError(
f"Found explicit artifact type for unrecognized output "
f"'{output_name}' in step '{self.name}'. Output "
f"artifact types can only be specified for the outputs "
f"of this step: {set(self.OUTPUT_SIGNATURE)}."
)
if not issubclass(artifact_type, BaseArtifact):
raise StepInterfaceError(
f"Invalid artifact type ({artifact_type}) for output "
f"'{output_name}' of step '{self.name}'. Only "
f"`BaseArtifact` subclasses are allowed as artifact types."
)
output_type = self.OUTPUT_SIGNATURE[output_name]
allowed_artifact_types = set(
type_registry.get_artifact_type(output_type)
)
if artifact_type not in allowed_artifact_types:
raise StepInterfaceError(
f"Artifact type `{artifact_type}` for output "
f"'{output_name}' of step '{self.name}' is not an "
f"allowed artifact type for the defined output type "
f"`{output_type}`. Allowed artifact types: "
f"{allowed_artifact_types}. If you want to extend the "
f"allowed artifact types, implement a custom "
f"`BaseMaterializer` subclass and set its "
f"`ASSOCIATED_ARTIFACT_TYPES` and `ASSOCIATED_TYPES` "
f"accordingly."
)
def _update_and_verify_parameter_spec(self) -> None:
"""Verifies and prepares the config parameters for running this step.
When the step requires config parameters, this method:
- checks if config parameters were set via a config object or file
- tries to set missing config parameters from default values of the
config class
Raises:
MissingStepParameterError: If no value could be found for one or
more config parameters.
StepInterfaceError: If a config parameter value couldn't be
serialized to json.
"""
if self.CONFIG_CLASS:
# we need to store a value for all config keys inside the
# metadata store to make sure caching works as expected
missing_keys = []
for name, field in self.CONFIG_CLASS.__fields__.items():
if name in self.PARAM_SPEC:
# a value for this parameter has been set already
continue
if field.required:
# this field has no default value set and therefore needs
# to be passed via an initialized config object
missing_keys.append(name)
else:
# use default value from the pydantic config class
self.PARAM_SPEC[name] = field.default
if missing_keys:
raise MissingStepParameterError(
self.name, missing_keys, self.CONFIG_CLASS
)
def _prepare_input_artifacts(
self, *artifacts: Channel, **kw_artifacts: Channel
) -> Dict[str, Channel]:
"""Verifies and prepares the input artifacts for running this step.
Args:
*artifacts: Positional input artifacts passed to
the __call__ method.
**kw_artifacts: Keyword input artifacts passed to
the __call__ method.
Returns:
Dictionary containing both the positional and keyword input
artifacts.
Raises:
StepInterfaceError: If there are too many or too few artifacts.
"""
input_artifact_keys = list(self.INPUT_SIGNATURE.keys())
if len(artifacts) > len(input_artifact_keys):
raise StepInterfaceError(
f"Too many input artifacts for step '{self.name}'. "
f"This step expects {len(input_artifact_keys)} artifact(s) "
f"but got {len(artifacts) + len(kw_artifacts)}."
)
combined_artifacts = {}
for i, artifact in enumerate(artifacts):
if not isinstance(artifact, Channel):
raise StepInterfaceError(
f"Wrong argument type (`{type(artifact)}`) for positional "
f"argument {i} of step '{self.name}'. Only outputs "
f"from previous steps can be used as arguments when "
f"connecting steps."
)
key = input_artifact_keys[i]
combined_artifacts[key] = artifact
for key, artifact in kw_artifacts.items():
if key in combined_artifacts:
# an artifact for this key was already set by
# the positional input artifacts
raise StepInterfaceError(
f"Unexpected keyword argument '{key}' for step "
f"'{self.name}'. An artifact for this key was "
f"already passed as a positional argument."
)
if not isinstance(artifact, Channel):
raise StepInterfaceError(
f"Wrong argument type (`{type(artifact)}`) for argument "
f"'{key}' of step '{self.name}'. Only outputs from "
f"previous steps can be used as arguments when "
f"connecting steps."
)
combined_artifacts[key] = artifact
# check if there are any missing or unexpected artifacts
expected_artifacts = set(self.INPUT_SIGNATURE.keys())
actual_artifacts = set(combined_artifacts.keys())
missing_artifacts = expected_artifacts - actual_artifacts
unexpected_artifacts = actual_artifacts - expected_artifacts
if missing_artifacts:
raise StepInterfaceError(
f"Missing input artifact(s) for step "
f"'{self.name}': {missing_artifacts}."
)
if unexpected_artifacts:
raise StepInterfaceError(
f"Unexpected input artifact(s) for step "
f"'{self.name}': {unexpected_artifacts}. This step "
f"only requires the following artifacts: {expected_artifacts}."
)
return combined_artifacts
# TODO [ENG-157]: replaces Channels with ZenML class (BaseArtifact?)
def __call__(
self, *artifacts: Channel, **kw_artifacts: Channel
) -> Union[Channel, List[Channel]]:
"""Generates a component when called."""
if self._has_been_called:
raise StepInterfaceError(
f"Step {self.name} has already been called. A ZenML step "
f"instance can only be called once per pipeline run."
)
self._has_been_called = True
self._update_and_verify_parameter_spec()
# Prepare the input artifacts and spec
input_artifacts = self._prepare_input_artifacts(
*artifacts, **kw_artifacts
)
self.INPUT_SPEC = {
arg_name: artifact_type.type
for arg_name, artifact_type in input_artifacts.items()
}
# make sure we have registered materializers for each output
materializers = self.get_materializers(ensure_complete=True)
# Prepare the output artifacts and spec
for key, value in self.OUTPUT_SIGNATURE.items():
verified_types = type_registry.get_artifact_type(value)
if key not in self.OUTPUT_SPEC:
self.OUTPUT_SPEC[key] = verified_types[0]
execution_parameters = {
**self.PARAM_SPEC,
**self._internal_execution_parameters,
}
# Convert execution parameter values to strings
try:
execution_parameters = {
k: json.dumps(v) for k, v in execution_parameters.items()
}
except TypeError as e:
raise StepInterfaceError(
f"Failed to serialize execution parameters for step "
f"'{self.name}'. Please make sure to only use "
f"json serializable parameter values."
) from e
component_class = generate_component_class(
step_name=self.name,
step_module=self.__module__,
input_spec=self.INPUT_SPEC,
output_spec=self.OUTPUT_SPEC,
execution_parameter_names=set(execution_parameters),
step_function=self.entrypoint,
materializers=materializers,
)
self._component = component_class(
**input_artifacts, **execution_parameters
)
# Resolve the returns in the right order.
returns = [self.component.outputs[key] for key in self.OUTPUT_SIGNATURE]
# If its one return we just return the one channel not as a list
if len(returns) == 1:
return returns[0]
else:
return returns
@property
def component(self) -> _ZenMLSimpleComponent:
"""Returns a TFX component."""
if not self._component:
raise StepInterfaceError(
"Trying to access the step component "
"before creating it via calling the step."
)
return self._component
@property
def executor_operator(self) -> Type[BaseExecutorOperator]:
"""Executor operator class that should be used to run this step."""
if self.custom_step_operator:
return StepExecutorOperator
else:
return PythonExecutorOperator
def with_return_materializers(
self: T,
materializers: Union[
Type[BaseMaterializer], Dict[str, Type[BaseMaterializer]]
],
) -> T:
"""Register materializers for step outputs.
If a single materializer is passed, it will be used for all step
outputs. Otherwise, the dictionary keys specify the output names
for which the materializers will be used.
Args:
materializers: The materializers for the outputs of this step.
Returns:
The object that this method was called on.
Raises:
StepInterfaceError: If a materializer is not a `BaseMaterializer`
subclass or a materializer for a non-existent output is given.
"""
def _is_materializer_class(value: Any) -> bool:
"""Checks whether the given object is a `BaseMaterializer`
subclass."""
is_class = isinstance(value, type)
return is_class and issubclass(value, BaseMaterializer)
if isinstance(materializers, dict):
allowed_output_names = set(self.OUTPUT_SIGNATURE)
for output_name, materializer in materializers.items():
if output_name not in allowed_output_names:
raise StepInterfaceError(
f"Got unexpected materializers for non-existent "
f"output '{output_name}' in step '{self.name}'. "
f"Only materializers for the outputs "
f"{allowed_output_names} of this step can"
f" be registered."
)
if not _is_materializer_class(materializer):
raise StepInterfaceError(
f"Got unexpected object `{materializer}` as "
f"materializer for output '{output_name}' of step "
f"'{self.name}'. Only `BaseMaterializer` "
f"subclasses are allowed."
)
self._explicit_materializers[output_name] = materializer
elif _is_materializer_class(materializers):
# Set the materializer for all outputs of this step
self._explicit_materializers = {
key: materializers for key in self.OUTPUT_SIGNATURE
}
else:
raise StepInterfaceError(
f"Got unexpected object `{materializers}` as output "
f"materializer for step '{self.name}'. Only "
f"`BaseMaterializer` subclasses or dictionaries mapping "
f"output names to `BaseMaterializer` subclasses are allowed "
f"as input when specifying return materializers."
)
return self
component: _ZenMLSimpleComponent
property
readonly
Returns a TFX component.
executor_operator: Type[tfx.orchestration.portable.base_executor_operator.BaseExecutorOperator]
property
readonly
Executor operator class that should be used to run this step.
__call__(self, *artifacts, **kw_artifacts)
special
Generates a component when called.
Source code in zenml/steps/base_step.py
def __call__(
self, *artifacts: Channel, **kw_artifacts: Channel
) -> Union[Channel, List[Channel]]:
"""Generates a component when called."""
if self._has_been_called:
raise StepInterfaceError(
f"Step {self.name} has already been called. A ZenML step "
f"instance can only be called once per pipeline run."
)
self._has_been_called = True
self._update_and_verify_parameter_spec()
# Prepare the input artifacts and spec
input_artifacts = self._prepare_input_artifacts(
*artifacts, **kw_artifacts
)
self.INPUT_SPEC = {
arg_name: artifact_type.type
for arg_name, artifact_type in input_artifacts.items()
}
# make sure we have registered materializers for each output
materializers = self.get_materializers(ensure_complete=True)
# Prepare the output artifacts and spec
for key, value in self.OUTPUT_SIGNATURE.items():
verified_types = type_registry.get_artifact_type(value)
if key not in self.OUTPUT_SPEC:
self.OUTPUT_SPEC[key] = verified_types[0]
execution_parameters = {
**self.PARAM_SPEC,
**self._internal_execution_parameters,
}
# Convert execution parameter values to strings
try:
execution_parameters = {
k: json.dumps(v) for k, v in execution_parameters.items()
}
except TypeError as e:
raise StepInterfaceError(
f"Failed to serialize execution parameters for step "
f"'{self.name}'. Please make sure to only use "
f"json serializable parameter values."
) from e
component_class = generate_component_class(
step_name=self.name,
step_module=self.__module__,
input_spec=self.INPUT_SPEC,
output_spec=self.OUTPUT_SPEC,
execution_parameter_names=set(execution_parameters),
step_function=self.entrypoint,
materializers=materializers,
)
self._component = component_class(
**input_artifacts, **execution_parameters
)
# Resolve the returns in the right order.
returns = [self.component.outputs[key] for key in self.OUTPUT_SIGNATURE]
# If its one return we just return the one channel not as a list
if len(returns) == 1:
return returns[0]
else:
return returns
entrypoint(self, *args, **kwargs)
Abstract method for core step logic.
Source code in zenml/steps/base_step.py
@abstractmethod
def entrypoint(self, *args: Any, **kwargs: Any) -> Any:
"""Abstract method for core step logic."""
get_materializers(self, ensure_complete=False)
Returns available materializers for the outputs of this step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ensure_complete |
bool |
If set to |
False |
Returns:
Type | Description |
---|---|
Dict[str, Type[zenml.materializers.base_materializer.BaseMaterializer]] |
A dictionary mapping output names to |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
(Only if |
Source code in zenml/steps/base_step.py
def get_materializers(
self, ensure_complete: bool = False
) -> Dict[str, Type[BaseMaterializer]]:
"""Returns available materializers for the outputs of this step.
Args:
ensure_complete: If set to `True`, this method will raise a
`StepInterfaceError` if no materializer can be found for an
output.
Returns:
A dictionary mapping output names to `BaseMaterializer` subclasses.
If no explicit materializer was set using
`step.with_return_materializers(...)`, this checks the
default materializer registry to find a materializer for the
type of the output. If no materializer is registered, the
output of this method will not contain an entry for this output.
Raises:
StepInterfaceError: (Only if `ensure_complete` is set to `True`)
If an output does not have an explicit materializer assigned
to it and there is no default materializer registered for
the output type.
"""
materializers = self._explicit_materializers.copy()
for output_name, output_type in self.OUTPUT_SIGNATURE.items():
if output_name in materializers:
# Materializer for this output was set explicitly
pass
elif default_materializer_registry.is_registered(output_type):
materializer = default_materializer_registry[output_type]
materializers[output_name] = materializer
else:
if ensure_complete:
raise StepInterfaceError(
f"Unable to find materializer for output "
f"'{output_name}' of type `{output_type}` in step "
f"'{self.name}'. Please make sure to either "
f"explicitly set a materializer for step outputs "
f"using `step.with_return_materializers(...)` or "
f"registering a default materializer for specific "
f"types by subclassing `BaseMaterializer` and setting "
f"its `ASSOCIATED_TYPES` class variable.",
url="https://docs.zenml.io/guides/index/custom-materializer",
)
return materializers
with_return_materializers(self, materializers)
Register materializers for step outputs.
If a single materializer is passed, it will be used for all step outputs. Otherwise, the dictionary keys specify the output names for which the materializers will be used.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
materializers |
Union[Type[zenml.materializers.base_materializer.BaseMaterializer], Dict[str, Type[zenml.materializers.base_materializer.BaseMaterializer]]] |
The materializers for the outputs of this step. |
required |
Returns:
Type | Description |
---|---|
~T |
The object that this method was called on. |
Exceptions:
Type | Description |
---|---|
StepInterfaceError |
If a materializer is not a |
Source code in zenml/steps/base_step.py
def with_return_materializers(
self: T,
materializers: Union[
Type[BaseMaterializer], Dict[str, Type[BaseMaterializer]]
],
) -> T:
"""Register materializers for step outputs.
If a single materializer is passed, it will be used for all step
outputs. Otherwise, the dictionary keys specify the output names
for which the materializers will be used.
Args:
materializers: The materializers for the outputs of this step.
Returns:
The object that this method was called on.
Raises:
StepInterfaceError: If a materializer is not a `BaseMaterializer`
subclass or a materializer for a non-existent output is given.
"""
def _is_materializer_class(value: Any) -> bool:
"""Checks whether the given object is a `BaseMaterializer`
subclass."""
is_class = isinstance(value, type)
return is_class and issubclass(value, BaseMaterializer)
if isinstance(materializers, dict):
allowed_output_names = set(self.OUTPUT_SIGNATURE)
for output_name, materializer in materializers.items():
if output_name not in allowed_output_names:
raise StepInterfaceError(
f"Got unexpected materializers for non-existent "
f"output '{output_name}' in step '{self.name}'. "
f"Only materializers for the outputs "
f"{allowed_output_names} of this step can"
f" be registered."
)
if not _is_materializer_class(materializer):
raise StepInterfaceError(
f"Got unexpected object `{materializer}` as "
f"materializer for output '{output_name}' of step "
f"'{self.name}'. Only `BaseMaterializer` "
f"subclasses are allowed."
)
self._explicit_materializers[output_name] = materializer
elif _is_materializer_class(materializers):
# Set the materializer for all outputs of this step
self._explicit_materializers = {
key: materializers for key in self.OUTPUT_SIGNATURE
}
else:
raise StepInterfaceError(
f"Got unexpected object `{materializers}` as output "
f"materializer for step '{self.name}'. Only "
f"`BaseMaterializer` subclasses or dictionaries mapping "
f"output names to `BaseMaterializer` subclasses are allowed "
f"as input when specifying return materializers."
)
return self
BaseStepMeta (type)
Metaclass for BaseStep
.
Checks whether everything passed in: * Has a matching materializer. * Is a subclass of the Config class
Source code in zenml/steps/base_step.py
class BaseStepMeta(type):
"""Metaclass for `BaseStep`.
Checks whether everything passed in:
* Has a matching materializer.
* Is a subclass of the Config class
"""
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec."""
dct.setdefault("PARAM_SPEC", {})
dct.setdefault("INPUT_SPEC", {})
dct.setdefault("OUTPUT_SPEC", {})
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
cls.INPUT_SIGNATURE = {}
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None
# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
inspect.unwrap(cls.entrypoint)
)
if bases:
# We're not creating the abstract `BaseStep` class
# but a concrete implementation. Make sure the step function
# signature does not contain variable *args or **kwargs
variable_arguments = None
if step_function_signature.varargs:
variable_arguments = f"*{step_function_signature.varargs}"
elif step_function_signature.varkw:
variable_arguments = f"**{step_function_signature.varkw}"
if variable_arguments:
raise StepInterfaceError(
f"Unable to create step '{name}' with variable arguments "
f"'{variable_arguments}'. Please make sure your step "
f"functions are defined with a fixed amount of arguments."
)
step_function_args = (
step_function_signature.args + step_function_signature.kwonlyargs
)
# Remove 'self' from the signature if it exists
if step_function_args and step_function_args[0] == "self":
step_function_args.pop(0)
# Verify the input arguments of the step function
for arg in step_function_args:
arg_type = step_function_signature.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
if not arg_type:
raise StepInterfaceError(
f"Missing type annotation for argument '{arg}' when "
f"trying to create step '{name}'. Please make sure to "
f"include type annotations for all your step inputs "
f"and outputs."
)
if issubclass(arg_type, BaseStepConfig):
# Raise an error if we already found a config in the signature
if cls.CONFIG_CLASS is not None:
raise StepInterfaceError(
f"Found multiple configuration arguments "
f"('{cls.CONFIG_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `BaseStepConfig` subclass as input "
f"argument for a step."
)
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get be defined later, so we simply store the
# argument name and type for later use.
cls.INPUT_SIGNATURE.update({arg: arg_type})
# Parse the returns of the step function
return_type = step_function_signature.annotations.get("return", None)
if return_type is not None:
if isinstance(return_type, Output):
cls.OUTPUT_SIGNATURE = {
name: resolve_type_annotation(type_)
for (name, type_) in return_type.items()
}
else:
cls.OUTPUT_SIGNATURE[
SINGLE_RETURN_OUT_NAME
] = resolve_type_annotation(return_type)
# Raise an exception if input and output names of a step overlap as
# tfx requires them to be unique
# TODO [ENG-155]: Can we prefix inputs and outputs to avoid this
# restriction?
counter: Counter[str] = collections.Counter()
counter.update(list(cls.INPUT_SIGNATURE))
counter.update(list(cls.OUTPUT_SIGNATURE))
if cls.CONFIG_CLASS:
counter.update(list(cls.CONFIG_CLASS.__fields__.keys()))
shared_keys = {k for k in counter.elements() if counter[k] > 1}
if shared_keys:
raise StepInterfaceError(
f"The following keys are overlapping in the input, output and "
f"config parameter names of step '{name}': {shared_keys}. "
f"Please make sure that your input, output and config "
f"parameter names are unique."
)
return cls
__new__(mcs, name, bases, dct)
special
staticmethod
Set up a new class with a qualified spec.
Source code in zenml/steps/base_step.py
def __new__(
mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]
) -> "BaseStepMeta":
"""Set up a new class with a qualified spec."""
dct.setdefault("PARAM_SPEC", {})
dct.setdefault("INPUT_SPEC", {})
dct.setdefault("OUTPUT_SPEC", {})
cls = cast(Type["BaseStep"], super().__new__(mcs, name, bases, dct))
cls.INPUT_SIGNATURE = {}
cls.OUTPUT_SIGNATURE = {}
cls.CONFIG_PARAMETER_NAME = None
cls.CONFIG_CLASS = None
cls.CONTEXT_PARAMETER_NAME = None
# Get the signature of the step function
step_function_signature = inspect.getfullargspec(
inspect.unwrap(cls.entrypoint)
)
if bases:
# We're not creating the abstract `BaseStep` class
# but a concrete implementation. Make sure the step function
# signature does not contain variable *args or **kwargs
variable_arguments = None
if step_function_signature.varargs:
variable_arguments = f"*{step_function_signature.varargs}"
elif step_function_signature.varkw:
variable_arguments = f"**{step_function_signature.varkw}"
if variable_arguments:
raise StepInterfaceError(
f"Unable to create step '{name}' with variable arguments "
f"'{variable_arguments}'. Please make sure your step "
f"functions are defined with a fixed amount of arguments."
)
step_function_args = (
step_function_signature.args + step_function_signature.kwonlyargs
)
# Remove 'self' from the signature if it exists
if step_function_args and step_function_args[0] == "self":
step_function_args.pop(0)
# Verify the input arguments of the step function
for arg in step_function_args:
arg_type = step_function_signature.annotations.get(arg, None)
arg_type = resolve_type_annotation(arg_type)
if not arg_type:
raise StepInterfaceError(
f"Missing type annotation for argument '{arg}' when "
f"trying to create step '{name}'. Please make sure to "
f"include type annotations for all your step inputs "
f"and outputs."
)
if issubclass(arg_type, BaseStepConfig):
# Raise an error if we already found a config in the signature
if cls.CONFIG_CLASS is not None:
raise StepInterfaceError(
f"Found multiple configuration arguments "
f"('{cls.CONFIG_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `BaseStepConfig` subclass as input "
f"argument for a step."
)
cls.CONFIG_PARAMETER_NAME = arg
cls.CONFIG_CLASS = arg_type
elif issubclass(arg_type, StepContext):
if cls.CONTEXT_PARAMETER_NAME is not None:
raise StepInterfaceError(
f"Found multiple context arguments "
f"('{cls.CONTEXT_PARAMETER_NAME}' and '{arg}') when "
f"trying to create step '{name}'. Please make sure to "
f"only have one `StepContext` as input "
f"argument for a step."
)
cls.CONTEXT_PARAMETER_NAME = arg
else:
# Can't do any check for existing materializers right now
# as they might get be defined later, so we simply store the
# argument name and type for later use.
cls.INPUT_SIGNATURE.update({arg: arg_type})
# Parse the returns of the step function
return_type = step_function_signature.annotations.get("return", None)
if return_type is not None:
if isinstance(return_type, Output):
cls.OUTPUT_SIGNATURE = {
name: resolve_type_annotation(type_)
for (name, type_) in return_type.items()
}
else:
cls.OUTPUT_SIGNATURE[
SINGLE_RETURN_OUT_NAME
] = resolve_type_annotation(return_type)
# Raise an exception if input and output names of a step overlap as
# tfx requires them to be unique
# TODO [ENG-155]: Can we prefix inputs and outputs to avoid this
# restriction?
counter: Counter[str] = collections.Counter()
counter.update(list(cls.INPUT_SIGNATURE))
counter.update(list(cls.OUTPUT_SIGNATURE))
if cls.CONFIG_CLASS:
counter.update(list(cls.CONFIG_CLASS.__fields__.keys()))
shared_keys = {k for k in counter.elements() if counter[k] > 1}
if shared_keys:
raise StepInterfaceError(
f"The following keys are overlapping in the input, output and "
f"config parameter names of step '{name}': {shared_keys}. "
f"Please make sure that your input, output and config "
f"parameter names are unique."
)
return cls
base_step_config
BaseStepConfig (BaseModel)
pydantic-model
Base configuration class to pass execution params into a step.
Source code in zenml/steps/base_step_config.py
class BaseStepConfig(BaseModel):
"""Base configuration class to pass execution params into a step."""
builtin_steps
special
pandas_analyzer
PandasAnalyzer (BaseAnalyzerStep)
Simple step implementation which analyzes a given pd.DataFrame
Source code in zenml/steps/builtin_steps/pandas_analyzer.py
class PandasAnalyzer(BaseAnalyzerStep):
"""Simple step implementation which analyzes a given pd.DataFrame"""
# Manually defining the type of the output artifacts
OUTPUT_SPEC = {"statistics": StatisticsArtifact, "schema": SchemaArtifact}
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: PandasAnalyzerConfig,
) -> Output( # type:ignore[valid-type]
statistics=pd.DataFrame, schema=pd.DataFrame
):
"""Main entrypoint function for the pandas analyzer
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
Returns:
the statistics and the schema of the given dataframe
"""
statistics = dataset.describe(
percentiles=config.percentiles,
include=config.include,
exclude=config.exclude,
).T
schema = dataset.dtypes.to_frame().T.astype(str)
return statistics, schema
CONFIG_CLASS (BaseAnalyzerConfig)
pydantic-model
Config class for the PandasAnalyzer Config
Source code in zenml/steps/builtin_steps/pandas_analyzer.py
class PandasAnalyzerConfig(BaseAnalyzerConfig):
"""Config class for the PandasAnalyzer Config"""
percentiles: List[float] = [0.25, 0.5, 0.75]
include: Optional[Union[str, List[Type[Any]]]] = None
exclude: Optional[Union[str, List[Type[Any]]]] = None
entrypoint(self, dataset, config)
Main entrypoint function for the pandas analyzer
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
pd.DataFrame, the given dataset |
required |
config |
PandasAnalyzerConfig |
the configuration of the step |
required |
Returns:
Type | Description |
---|---|
<zenml.steps.step_output.Output object at 0x7f1e759de880> |
the statistics and the schema of the given dataframe |
Source code in zenml/steps/builtin_steps/pandas_analyzer.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: PandasAnalyzerConfig,
) -> Output( # type:ignore[valid-type]
statistics=pd.DataFrame, schema=pd.DataFrame
):
"""Main entrypoint function for the pandas analyzer
Args:
dataset: pd.DataFrame, the given dataset
config: the configuration of the step
Returns:
the statistics and the schema of the given dataframe
"""
statistics = dataset.describe(
percentiles=config.percentiles,
include=config.include,
exclude=config.exclude,
).T
schema = dataset.dtypes.to_frame().T.astype(str)
return statistics, schema
PandasAnalyzerConfig (BaseAnalyzerConfig)
pydantic-model
Config class for the PandasAnalyzer Config
Source code in zenml/steps/builtin_steps/pandas_analyzer.py
class PandasAnalyzerConfig(BaseAnalyzerConfig):
"""Config class for the PandasAnalyzer Config"""
percentiles: List[float] = [0.25, 0.5, 0.75]
include: Optional[Union[str, List[Type[Any]]]] = None
exclude: Optional[Union[str, List[Type[Any]]]] = None
pandas_datasource
PandasDatasource (BaseDatasourceStep)
Simple step implementation to ingest from a csv file using pandas
Source code in zenml/steps/builtin_steps/pandas_datasource.py
class PandasDatasource(BaseDatasourceStep):
"""Simple step implementation to ingest from a csv file using pandas"""
def entrypoint( # type: ignore[override]
self,
config: PandasDatasourceConfig,
) -> pd.DataFrame:
"""Main entrypoint method for the PandasDatasource
Args:
config: the configuration of the step
Returns:
the resulting dataframe
"""
return pd.read_csv(
filepath_or_buffer=config.path,
sep=config.sep,
header=config.header,
names=config.names,
index_col=config.index_col,
)
CONFIG_CLASS (BaseDatasourceConfig)
pydantic-model
Config class for the pandas csv datasource
Source code in zenml/steps/builtin_steps/pandas_datasource.py
class PandasDatasourceConfig(BaseDatasourceConfig):
"""Config class for the pandas csv datasource"""
path: str
sep: str = ","
header: Union[int, List[int], str] = "infer"
names: Optional[List[str]] = None
index_col: Optional[Union[int, str, List[Union[int, str]], bool]] = None
entrypoint(self, config)
Main entrypoint method for the PandasDatasource
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
PandasDatasourceConfig |
the configuration of the step |
required |
Returns:
Type | Description |
---|---|
DataFrame |
the resulting dataframe |
Source code in zenml/steps/builtin_steps/pandas_datasource.py
def entrypoint( # type: ignore[override]
self,
config: PandasDatasourceConfig,
) -> pd.DataFrame:
"""Main entrypoint method for the PandasDatasource
Args:
config: the configuration of the step
Returns:
the resulting dataframe
"""
return pd.read_csv(
filepath_or_buffer=config.path,
sep=config.sep,
header=config.header,
names=config.names,
index_col=config.index_col,
)
PandasDatasourceConfig (BaseDatasourceConfig)
pydantic-model
Config class for the pandas csv datasource
Source code in zenml/steps/builtin_steps/pandas_datasource.py
class PandasDatasourceConfig(BaseDatasourceConfig):
"""Config class for the pandas csv datasource"""
path: str
sep: str = ","
header: Union[int, List[int], str] = "infer"
names: Optional[List[str]] = None
index_col: Optional[Union[int, str, List[Union[int, str]], bool]] = None
restrict_step_access_decorator
restrict_step_access(_func)
Decorator to restrict this function from running inside a step.
Apply this decorator to a ZenML function to prevent it from being run inside the context of a step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_func |
Any |
The function to restrict access to, inside a step. |
required |
Returns:
Type | Description |
---|---|
Any |
The same function without any enhancements but checks the Environment to see if it's running inside a step. |
Exceptions:
Type | Description |
---|---|
ForbiddenRepositoryAccessError |
If trying to create a |
Source code in zenml/steps/restrict_step_access_decorator.py
def restrict_step_access(_func: Any) -> Any:
"""Decorator to restrict this function from running inside a step.
Apply this decorator to a ZenML function to prevent it from being run
inside the context of a step.
Args:
_func: The function to restrict access to, inside a step.
Returns:
The same function without any enhancements but checks the Environment
to see if it's running inside a step.
Raises:
ForbiddenRepositoryAccessError: If trying to create a `Repository`
instance while a ZenML step is being executed.
"""
if Environment().step_is_running:
raise ForbiddenRepositoryAccessError(
"Unable to access repository during step execution. If you "
"require access to the artifact or metadata store, please use "
"a `StepContext` inside your step instead.",
url="https://docs.zenml.io/features/step-fixtures#using-the-stepcontext",
)
return _func
step_context
StepContext
Provides additional context inside a step function.
This class is used to access the metadata store, materializers and
artifacts inside a step function. To use it, add a StepContext
object
to the signature of your step function like this:
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
You do not need to create a StepContext
object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the StepContext
and automatically pass it when executing your
step.
Note: When using a StepContext
inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the @step
decorator or when
initializing your custom step class.
Source code in zenml/steps/step_context.py
class StepContext:
"""Provides additional context inside a step function.
This class is used to access the metadata store, materializers and
artifacts inside a step function. To use it, add a `StepContext` object
to the signature of your step function like this:
```python
@step
def my_step(context: StepContext, ...)
context.get_output_materializer(...)
```
You do not need to create a `StepContext` object yourself and pass it
when creating the step, as long as you specify it in the signature ZenML
will create the `StepContext` and automatically pass it when executing your
step.
**Note**: When using a `StepContext` inside a step, ZenML disables caching
for this step by default as the context provides access to external
resources which might influence the result of your step execution. To
enable caching anyway, explicitly enable it in the `@step` decorator or when
initializing your custom step class.
"""
def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifacts: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifacts.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact "
f"keys: {set(output_artifacts)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
)
for key in output_materializers.keys()
}
self._metadata_store = Repository().active_stack.metadata_store
self._stack = Repository().active_stack
def _get_output(
self, output_name: Optional[str] = None
) -> StepContextOutput:
"""Returns the materializer and artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer and URI.
Returns:
Tuple containing the materializer and artifact URI for the
given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name`
was given but the step has multiple outputs.
"""
output_count = len(self._outputs)
if output_count == 0:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step does not have any outputs."
)
if not output_name and output_count > 1:
raise StepContextError(
f"Unable to get step output for step '{self.step_name}': "
f"This step has multiple outputs ({set(self._outputs)}), "
f"please specify which output to return."
)
if output_name:
if output_name not in self._outputs:
raise StepContextError(
f"Unable to get step output '{output_name}' for "
f"step '{self.step_name}'. This step does not have an "
f"output with the given name, please specify one of the "
f"available outputs: {set(self._outputs)}."
)
return self._outputs[output_name]
else:
return next(iter(self._outputs.values()))
@property
def metadata_store(self) -> "BaseMetadataStore":
"""
Returns an instance of the metadata store that is used to store
metadata about the step (and the corresponding pipeline) which is
being executed.
"""
return self._metadata_store
@property
def stack(self) -> Optional["Stack"]:
"""Returns the current active stack."""
return self._stack
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name`
was given but the step has multiple outputs.
"""
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)
def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name`
was given but the step has multiple outputs.
"""
return cast(str, self._get_output(output_name).artifact.uri)
metadata_store: BaseMetadataStore
property
readonly
Returns an instance of the metadata store that is used to store metadata about the step (and the corresponding pipeline) which is being executed.
stack: Optional[Stack]
property
readonly
Returns the current active stack.
__init__(self, step_name, output_materializers, output_artifacts)
special
Initializes a StepContext instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step that this context is used in. |
required |
output_materializers |
Dict[str, Type[BaseMaterializer]] |
The output materializers of the step that this context is used in. |
required |
output_artifacts |
Dict[str, BaseArtifact] |
The output artifacts of the step that this context is used in. |
required |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the keys of the output materializers and output artifacts do not match. |
Source code in zenml/steps/step_context.py
def __init__(
self,
step_name: str,
output_materializers: Dict[str, Type["BaseMaterializer"]],
output_artifacts: Dict[str, "BaseArtifact"],
):
"""Initializes a StepContext instance.
Args:
step_name: The name of the step that this context is used in.
output_materializers: The output materializers of the step that
this context is used in.
output_artifacts: The output artifacts of the step that this
context is used in.
Raises:
StepContextError: If the keys of the output materializers and
output artifacts do not match.
"""
if output_materializers.keys() != output_artifacts.keys():
raise StepContextError(
f"Mismatched keys in output materializers and output "
f"artifacts for step '{step_name}'. Output materializer "
f"keys: {set(output_materializers)}, output artifact "
f"keys: {set(output_artifacts)}"
)
self.step_name = step_name
self._outputs = {
key: StepContextOutput(
output_materializers[key], output_artifacts[key]
)
for key in output_materializers.keys()
}
self._metadata_store = Repository().active_stack.metadata_store
self._stack = Repository().active_stack
get_output_artifact_uri(self, output_name=None)
Returns the artifact URI for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the URI. If no name is given and the step only has a single output, the URI of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
Returns:
Type | Description |
---|---|
str |
Artifact URI for the given output. |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the step has no outputs, no output for
the given |
Source code in zenml/steps/step_context.py
def get_output_artifact_uri(self, output_name: Optional[str] = None) -> str:
"""Returns the artifact URI for a given step output.
Args:
output_name: Optional name of the output for which to get the URI.
If no name is given and the step only has a single output,
the URI of this output will be returned. If the step has
multiple outputs, an exception will be raised.
Returns:
Artifact URI for the given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name`
was given but the step has multiple outputs.
"""
return cast(str, self._get_output(output_name).artifact.uri)
get_output_materializer(self, output_name=None, custom_materializer_class=None)
Returns a materializer for a given step output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output_name |
Optional[str] |
Optional name of the output for which to get the materializer. If no name is given and the step only has a single output, the materializer of this output will be returned. If the step has multiple outputs, an exception will be raised. |
None |
custom_materializer_class |
Optional[Type[BaseMaterializer]] |
If given, this |
None |
Returns:
Type | Description |
---|---|
BaseMaterializer |
A materializer initialized with the output artifact for the given output. |
Exceptions:
Type | Description |
---|---|
StepContextError |
If the step has no outputs, no output for
the given |
Source code in zenml/steps/step_context.py
def get_output_materializer(
self,
output_name: Optional[str] = None,
custom_materializer_class: Optional[Type["BaseMaterializer"]] = None,
) -> "BaseMaterializer":
"""Returns a materializer for a given step output.
Args:
output_name: Optional name of the output for which to get the
materializer. If no name is given and the step only has a
single output, the materializer of this output will be
returned. If the step has multiple outputs, an exception
will be raised.
custom_materializer_class: If given, this `BaseMaterializer`
subclass will be initialized with the output artifact instead
of the materializer that was registered for this step output.
Returns:
A materializer initialized with the output artifact for
the given output.
Raises:
StepContextError: If the step has no outputs, no output for
the given `output_name` or if no `output_name`
was given but the step has multiple outputs.
"""
materializer_class, artifact = self._get_output(output_name)
# use custom materializer class if provided or fallback to default
# materializer for output
materializer_class = custom_materializer_class or materializer_class
return materializer_class(artifact)
StepContextOutput (tuple)
Tuple containing materializer class and artifact for a step output.
Source code in zenml/steps/step_context.py
class StepContextOutput(NamedTuple):
"""Tuple containing materializer class and artifact for a step output."""
materializer_class: Type["BaseMaterializer"]
artifact: "BaseArtifact"
__getnewargs__(self)
special
Return self as a plain tuple. Used by copy and pickle.
Source code in zenml/steps/step_context.py
def __getnewargs__(self):
'Return self as a plain tuple. Used by copy and pickle.'
return _tuple(self)
__new__(_cls, materializer_class, artifact)
special
staticmethod
Create new instance of StepContextOutput(materializer_class, artifact)
__repr__(self)
special
Return a nicely formatted representation string
Source code in zenml/steps/step_context.py
def __repr__(self):
'Return a nicely formatted representation string'
return self.__class__.__name__ + repr_fmt % self
step_decorator
step(_func=None, *, name=None, enable_cache=None, output_types=None, custom_step_operator=None)
Outer decorator function for the creation of a ZenML step
In order to be able to work with parameters such as name
, it features a
nested decorator structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
_func |
Optional[~F] |
The decorated function. |
None |
name |
Optional[str] |
The name of the step. If left empty, the name of the decorated function will be used as a fallback. |
None |
enable_cache |
Optional[bool] |
Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a |
None |
output_types |
Optional[Dict[str, Type[BaseArtifact]]] |
A dictionary which sets different outputs to non-default artifact types |
None |
custom_step_operator |
Optional[str] |
Optional name of a
|
None |
Returns:
Type | Description |
---|---|
Union[Type[zenml.steps.base_step.BaseStep], Callable[[~F], Type[zenml.steps.base_step.BaseStep]]] |
the inner decorator which creates the step class based on the ZenML BaseStep |
Source code in zenml/steps/step_decorator.py
def step(
_func: Optional[F] = None,
*,
name: Optional[str] = None,
enable_cache: Optional[bool] = None,
output_types: Optional[Dict[str, Type["BaseArtifact"]]] = None,
custom_step_operator: Optional[str] = None
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
"""Outer decorator function for the creation of a ZenML step
In order to be able to work with parameters such as `name`, it features a
nested decorator structure.
Args:
_func: The decorated function.
name: The name of the step. If left empty, the name of the decorated
function will be used as a fallback.
enable_cache: Specify whether caching is enabled for this step. If no
value is passed, caching is enabled by default unless the step
requires a `StepContext` (see
:class:`zenml.steps.step_context.StepContext` for more information).
output_types: A dictionary which sets different outputs to non-default
artifact types
custom_step_operator: Optional name of a
`zenml.step_operators.BaseStepOperator` to use for this step.
Returns:
the inner decorator which creates the step class based on the
ZenML BaseStep
"""
def inner_decorator(func: F) -> Type[BaseStep]:
"""Inner decorator function for the creation of a ZenML Step
Args:
func: types.FunctionType, this function will be used as the
"process" method of the generated Step
Returns:
The class of a newly generated ZenML Step.
"""
step_name = name or func.__name__
output_spec = output_types or {}
return type( # noqa
step_name,
(BaseStep,),
{
STEP_INNER_FUNC_NAME: staticmethod(func),
INSTANCE_CONFIGURATION: {
PARAM_ENABLE_CACHE: enable_cache,
PARAM_CREATED_BY_FUNCTIONAL_API: True,
PARAM_CUSTOM_STEP_OPERATOR: custom_step_operator,
},
OUTPUT_SPEC: output_spec,
"__module__": func.__module__,
},
)
if _func is None:
return inner_decorator
else:
return inner_decorator(_func)
step_environment
StepEnvironment (BaseEnvironmentComponent)
Provides additional information about a step runtime inside a step function in the form of an Environment component.
This class can be used from within a pipeline step implementation to access additional information about the runtime parameters of a pipeline step, such as the pipeline name, pipeline run ID and other pipeline runtime information. To use it, access it inside your step function like this:
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.pipeline_run_id, env.step_name)
Source code in zenml/steps/step_environment.py
class StepEnvironment(BaseEnvironmentComponent):
"""Provides additional information about a step runtime inside a step
function in the form of an Environment component.
This class can be used from within a pipeline step implementation to
access additional information about the runtime parameters of a pipeline
step, such as the pipeline name, pipeline run ID and other pipeline runtime
information. To use it, access it inside your step function like this:
```python
from zenml.environment import Environment
@step
def my_step(...)
env = Environment().step_environment
do_something_with(env.pipeline_name, env.pipeline_run_id, env.step_name)
```
"""
NAME = STEP_ENVIRONMENT_NAME
def __init__(
self,
pipeline_name: str,
pipeline_run_id: str,
step_name: str,
):
"""Initialize the environment of the currently running
step.
Args:
pipeline_name: the name of the currently running pipeline
pipeline_run_id: the ID of the currently running pipeline
step_name: the name of the currently running step
"""
super().__init__()
self._pipeline_name = pipeline_name
self._pipeline_run_id = pipeline_run_id
self._step_name = step_name
@property
def pipeline_name(self) -> str:
"""The name of the currently running pipeline."""
return self._pipeline_name
@property
def pipeline_run_id(self) -> str:
"""The ID of the current pipeline run."""
return self._pipeline_run_id
@property
def step_name(self) -> str:
"""The name of the currently running step."""
return self._step_name
pipeline_name: str
property
readonly
The name of the currently running pipeline.
pipeline_run_id: str
property
readonly
The ID of the current pipeline run.
step_name: str
property
readonly
The name of the currently running step.
__init__(self, pipeline_name, pipeline_run_id, step_name)
special
Initialize the environment of the currently running step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
the name of the currently running pipeline |
required |
pipeline_run_id |
str |
the ID of the currently running pipeline |
required |
step_name |
str |
the name of the currently running step |
required |
Source code in zenml/steps/step_environment.py
def __init__(
self,
pipeline_name: str,
pipeline_run_id: str,
step_name: str,
):
"""Initialize the environment of the currently running
step.
Args:
pipeline_name: the name of the currently running pipeline
pipeline_run_id: the ID of the currently running pipeline
step_name: the name of the currently running step
"""
super().__init__()
self._pipeline_name = pipeline_name
self._pipeline_run_id = pipeline_run_id
self._step_name = step_name
step_interfaces
special
base_analyzer_step
BaseAnalyzerConfig (BaseStepConfig)
pydantic-model
Base class for analyzer step configurations
Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerConfig(BaseStepConfig):
"""Base class for analyzer step configurations"""
BaseAnalyzerStep (BaseStep)
Base step implementation for any analyzer step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerStep(BaseStep):
"""Base step implementation for any analyzer step implementation on ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
config: BaseAnalyzerConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
statistics=StatisticsArtifact, schema=SchemaArtifact
):
"""Base entrypoint for any analyzer implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for analyzer step configurations
Source code in zenml/steps/step_interfaces/base_analyzer_step.py
class BaseAnalyzerConfig(BaseStepConfig):
"""Base class for analyzer step configurations"""
entrypoint(self, dataset, config, context)
Base entrypoint for any analyzer implementation
Source code in zenml/steps/step_interfaces/base_analyzer_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
config: BaseAnalyzerConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
statistics=StatisticsArtifact, schema=SchemaArtifact
):
"""Base entrypoint for any analyzer implementation"""
base_datasource_step
BaseDatasourceConfig (BaseStepConfig)
pydantic-model
Base class for datasource configs to inherit from
Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceConfig(BaseStepConfig):
"""Base class for datasource configs to inherit from"""
BaseDatasourceStep (BaseStep)
Base step implementation for any datasource step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceStep(BaseStep):
"""Base step implementation for any datasource step implementation on ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
config: BaseDatasourceConfig,
context: StepContext,
) -> DataArtifact:
"""Base entrypoint for any datasource implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for datasource configs to inherit from
Source code in zenml/steps/step_interfaces/base_datasource_step.py
class BaseDatasourceConfig(BaseStepConfig):
"""Base class for datasource configs to inherit from"""
entrypoint(self, config, context)
Base entrypoint for any datasource implementation
Source code in zenml/steps/step_interfaces/base_datasource_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
config: BaseDatasourceConfig,
context: StepContext,
) -> DataArtifact:
"""Base entrypoint for any datasource implementation"""
base_drift_detection_step
BaseDriftDetectionConfig (BaseStepConfig)
pydantic-model
Base class for drift detection step configurations
Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionConfig(BaseStepConfig):
"""Base class for drift detection step configurations"""
BaseDriftDetectionStep (BaseStep)
Base step implementation for any drift detection step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionStep(BaseStep):
"""Base step implementation for any drift detection step implementation
on ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
reference_dataset: DataArtifact,
comparison_dataset: DataArtifact,
config: BaseDriftDetectionConfig,
context: StepContext,
) -> Any:
"""Base entrypoint for any drift detection implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for drift detection step configurations
Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
class BaseDriftDetectionConfig(BaseStepConfig):
"""Base class for drift detection step configurations"""
entrypoint(self, reference_dataset, comparison_dataset, config, context)
Base entrypoint for any drift detection implementation
Source code in zenml/steps/step_interfaces/base_drift_detection_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
reference_dataset: DataArtifact,
comparison_dataset: DataArtifact,
config: BaseDriftDetectionConfig,
context: StepContext,
) -> Any:
"""Base entrypoint for any drift detection implementation"""
base_evaluator_step
BaseEvaluatorConfig (BaseStepConfig)
pydantic-model
Base class for evaluator step configurations
Source code in zenml/steps/step_interfaces/base_evaluator_step.py
class BaseEvaluatorConfig(BaseStepConfig):
"""Base class for evaluator step configurations"""
BaseEvaluatorStep (BaseStep)
Base step implementation for any evaluator step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_evaluator_step.py
class BaseEvaluatorStep(BaseStep):
"""Base step implementation for any evaluator step implementation on ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
model: ModelArtifact,
config: BaseEvaluatorConfig,
context: StepContext,
) -> DataArtifact:
"""Base entrypoint for any evaluator implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for evaluator step configurations
Source code in zenml/steps/step_interfaces/base_evaluator_step.py
class BaseEvaluatorConfig(BaseStepConfig):
"""Base class for evaluator step configurations"""
entrypoint(self, dataset, model, config, context)
Base entrypoint for any evaluator implementation
Source code in zenml/steps/step_interfaces/base_evaluator_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
model: ModelArtifact,
config: BaseEvaluatorConfig,
context: StepContext,
) -> DataArtifact:
"""Base entrypoint for any evaluator implementation"""
base_preprocessor_step
BasePreprocessorConfig (BaseStepConfig)
pydantic-model
Base class for Preprocessor step configurations
Source code in zenml/steps/step_interfaces/base_preprocessor_step.py
class BasePreprocessorConfig(BaseStepConfig):
"""Base class for Preprocessor step configurations"""
BasePreprocessorStep (BaseStep)
Base step implementation for any Preprocessor step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_preprocessor_step.py
class BasePreprocessorStep(BaseStep):
"""Base step implementation for any Preprocessor step implementation on
ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
train_dataset: DataArtifact,
test_dataset: DataArtifact,
validation_dataset: DataArtifact,
statistics: StatisticsArtifact,
schema: SchemaArtifact,
config: BasePreprocessorConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
train_transformed=DataArtifact,
test_transformed=DataArtifact,
validation_transformed=DataArtifact,
):
"""Base entrypoint for any Preprocessor implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for Preprocessor step configurations
Source code in zenml/steps/step_interfaces/base_preprocessor_step.py
class BasePreprocessorConfig(BaseStepConfig):
"""Base class for Preprocessor step configurations"""
entrypoint(self, train_dataset, test_dataset, validation_dataset, statistics, schema, config, context)
Base entrypoint for any Preprocessor implementation
Source code in zenml/steps/step_interfaces/base_preprocessor_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
train_dataset: DataArtifact,
test_dataset: DataArtifact,
validation_dataset: DataArtifact,
statistics: StatisticsArtifact,
schema: SchemaArtifact,
config: BasePreprocessorConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
train_transformed=DataArtifact,
test_transformed=DataArtifact,
validation_transformed=DataArtifact,
):
"""Base entrypoint for any Preprocessor implementation"""
base_split_step
BaseSplitStep (BaseStep)
Base step implementation for any split step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_split_step.py
class BaseSplitStep(BaseStep):
"""Base step implementation for any split step implementation on ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
config: BaseSplitStepConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
train=DataArtifact, test=DataArtifact, validation=DataArtifact
):
"""Entrypoint for a function for the split steps to run"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for split configs to inherit from
Source code in zenml/steps/step_interfaces/base_split_step.py
class BaseSplitStepConfig(BaseStepConfig):
"""Base class for split configs to inherit from"""
entrypoint(self, dataset, config, context)
Entrypoint for a function for the split steps to run
Source code in zenml/steps/step_interfaces/base_split_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
dataset: DataArtifact,
config: BaseSplitStepConfig,
context: StepContext,
) -> Output( # type:ignore[valid-type]
train=DataArtifact, test=DataArtifact, validation=DataArtifact
):
"""Entrypoint for a function for the split steps to run"""
BaseSplitStepConfig (BaseStepConfig)
pydantic-model
Base class for split configs to inherit from
Source code in zenml/steps/step_interfaces/base_split_step.py
class BaseSplitStepConfig(BaseStepConfig):
"""Base class for split configs to inherit from"""
base_trainer_step
BaseTrainerConfig (BaseStepConfig)
pydantic-model
Base class for Trainer step configurations
Source code in zenml/steps/step_interfaces/base_trainer_step.py
class BaseTrainerConfig(BaseStepConfig):
"""Base class for Trainer step configurations"""
BaseTrainerStep (BaseStep)
Base step implementation for any Trainer step implementation on ZenML
Source code in zenml/steps/step_interfaces/base_trainer_step.py
class BaseTrainerStep(BaseStep):
"""Base step implementation for any Trainer step implementation on
ZenML"""
@abstractmethod
def entrypoint( # type: ignore[override]
self,
train_dataset: DataArtifact,
validation_dataset: DataArtifact,
config: BaseTrainerConfig,
context: StepContext,
) -> ModelArtifact:
"""Base entrypoint for any Trainer implementation"""
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Base class for Trainer step configurations
Source code in zenml/steps/step_interfaces/base_trainer_step.py
class BaseTrainerConfig(BaseStepConfig):
"""Base class for Trainer step configurations"""
entrypoint(self, train_dataset, validation_dataset, config, context)
Base entrypoint for any Trainer implementation
Source code in zenml/steps/step_interfaces/base_trainer_step.py
@abstractmethod
def entrypoint( # type: ignore[override]
self,
train_dataset: DataArtifact,
validation_dataset: DataArtifact,
config: BaseTrainerConfig,
context: StepContext,
) -> ModelArtifact:
"""Base entrypoint for any Trainer implementation"""
step_output
Output
A named tuple with a default name that cannot be overridden.
Source code in zenml/steps/step_output.py
class Output(object):
"""A named tuple with a default name that cannot be overridden."""
def __init__(self, **kwargs: Type[Any]):
# TODO [ENG-161]: do we even need the named tuple here or is
# a list of tuples (name, Type) sufficient?
self.outputs = NamedTuple("ZenOutput", **kwargs) # type: ignore[misc]
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type)."""
yield from self.outputs.__annotations__.items()
items(self)
Yields a tuple of type (output_name, output_type).
Source code in zenml/steps/step_output.py
def items(self) -> Iterator[Tuple[str, Type[Any]]]:
"""Yields a tuple of type (output_name, output_type)."""
yield from self.outputs.__annotations__.items()
utils
The collection of utility functions/classes are inspired by their original implementation of the Tensorflow Extended team, which can be found here:
https://github.com/tensorflow/tfx/blob/master/tfx/dsl/component/experimental /decorators.py
This version is heavily adjusted to work with the Pipeline-Step paradigm which is proposed by ZenML.
do_types_match(type_a, type_b)
Check whether type_a and type_b match.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
type_a |
Type[Any] |
First Type to check. |
required |
type_b |
Type[Any] |
Second Type to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if types match, otherwise False. |
Source code in zenml/steps/utils.py
def do_types_match(type_a: Type[Any], type_b: Type[Any]) -> bool:
"""Check whether type_a and type_b match.
Args:
type_a: First Type to check.
type_b: Second Type to check.
Returns:
True if types match, otherwise False.
"""
# TODO [ENG-158]: Check more complicated cases where type_a can be a sub-type
# of type_b
return type_a == type_b
generate_component_class(step_name, step_module, input_spec, output_spec, execution_parameter_names, step_function, materializers)
Generates a TFX component class for a ZenML step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
Name of the step for which the component will be created. |
required |
step_module |
str |
Module in which the step class is defined. |
required |
input_spec |
Dict[str, Type[zenml.artifacts.base_artifact.BaseArtifact]] |
Input artifacts of the step. |
required |
output_spec |
Dict[str, Type[zenml.artifacts.base_artifact.BaseArtifact]] |
Output artifacts of the step |
required |
execution_parameter_names |
Set[str] |
Execution parameter names of the step. |
required |
step_function |
Callable[..., Any] |
The actual function to execute when running the step. |
required |
materializers |
Dict[str, Type[zenml.materializers.base_materializer.BaseMaterializer]] |
Materializer classes for all outputs of the step. |
required |
Returns:
Type | Description |
---|---|
Type[_ZenMLSimpleComponent] |
A TFX component class. |
Source code in zenml/steps/utils.py
def generate_component_class(
step_name: str,
step_module: str,
input_spec: Dict[str, Type[BaseArtifact]],
output_spec: Dict[str, Type[BaseArtifact]],
execution_parameter_names: Set[str],
step_function: Callable[..., Any],
materializers: Dict[str, Type[BaseMaterializer]],
) -> Type["_ZenMLSimpleComponent"]:
"""Generates a TFX component class for a ZenML step.
Args:
step_name: Name of the step for which the component will be created.
step_module: Module in which the step class is defined.
input_spec: Input artifacts of the step.
output_spec: Output artifacts of the step
execution_parameter_names: Execution parameter names of the step.
step_function: The actual function to execute when running the step.
materializers: Materializer classes for all outputs of the step.
Returns:
A TFX component class.
"""
component_spec_class = generate_component_spec_class(
step_name=step_name,
input_spec=input_spec,
output_spec=output_spec,
execution_parameter_names=execution_parameter_names,
)
# Create executor class
executor_class_name = f"{step_name}_Executor"
executor_class = type(
executor_class_name,
(_FunctionExecutor,),
{
"_FUNCTION": staticmethod(step_function),
"__module__": step_module,
"materializers": materializers,
PARAM_STEP_NAME: step_name,
},
)
# Add the executor class to the module in which the step was defined
module = sys.modules[step_module]
setattr(module, executor_class_name, executor_class)
return type(
step_name,
(_ZenMLSimpleComponent,),
{
"SPEC_CLASS": component_spec_class,
"EXECUTOR_SPEC": ExecutorClassSpec(executor_class=executor_class),
"__module__": step_module,
},
)
generate_component_spec_class(step_name, input_spec, output_spec, execution_parameter_names)
Generates a TFX component spec class for a ZenML step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
Name of the step for which the component will be created. |
required |
input_spec |
Dict[str, Type[zenml.artifacts.base_artifact.BaseArtifact]] |
Input artifacts of the step. |
required |
output_spec |
Dict[str, Type[zenml.artifacts.base_artifact.BaseArtifact]] |
Output artifacts of the step |
required |
execution_parameter_names |
Set[str] |
Execution parameter names of the step. |
required |
Returns:
Type | Description |
---|---|
Type[tfx.types.component_spec.ComponentSpec] |
A TFX component spec class. |
Source code in zenml/steps/utils.py
def generate_component_spec_class(
step_name: str,
input_spec: Dict[str, Type[BaseArtifact]],
output_spec: Dict[str, Type[BaseArtifact]],
execution_parameter_names: Set[str],
) -> Type[component_spec.ComponentSpec]:
"""Generates a TFX component spec class for a ZenML step.
Args:
step_name: Name of the step for which the component will be created.
input_spec: Input artifacts of the step.
output_spec: Output artifacts of the step
execution_parameter_names: Execution parameter names of the step.
Returns:
A TFX component spec class.
"""
inputs = {
key: component_spec.ChannelParameter(type=artifact_type)
for key, artifact_type in input_spec.items()
}
outputs = {
key: component_spec.ChannelParameter(type=artifact_type)
for key, artifact_type in output_spec.items()
}
parameters = {
key: component_spec.ExecutionParameter(type=str) # type: ignore[no-untyped-call] # noqa
for key in execution_parameter_names
}
return type(
f"{step_name}_Spec",
(component_spec.ComponentSpec,),
{
"INPUTS": inputs,
"OUTPUTS": outputs,
"PARAMETERS": parameters,
},
)
resolve_type_annotation(obj)
Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is typing.Dict
, this method will return the
concrete class dict
.
Source code in zenml/steps/utils.py
def resolve_type_annotation(obj: Any) -> Any:
"""Returns the non-generic class for generic aliases of the typing module.
If the input is no generic typing alias, the input itself is returned.
Example: if the input object is `typing.Dict`, this method will return the
concrete class `dict`.
"""
if isinstance(obj, typing._GenericAlias): # type: ignore[attr-defined]
return obj.__origin__
else:
return obj