Integrations
zenml.integrations
special
ZenML integrations module.
The ZenML integrations module contains sub-modules for each integration that we
support. This includes orchestrators like Apache Airflow, visualization tools
like the facets
library, as well as deep learning libraries like PyTorch.
airflow
special
Airflow integration for ZenML.
The Airflow integration sub-module powers an alternative to the local
orchestrator. You can enable it by registering the Airflow orchestrator with
the CLI tool, then bootstrap using the zenml orchestrator up
command.
AirflowIntegration (Integration)
Definition of Airflow Integration for ZenML.
Source code in zenml/integrations/airflow/__init__.py
class AirflowIntegration(Integration):
"""Definition of Airflow Integration for ZenML."""
NAME = AIRFLOW
REQUIREMENTS = ["apache-airflow==2.2.0"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AIRFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
)
]
flavors()
classmethod
Declare the stack component flavors for the Airflow integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/airflow/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Airflow integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AIRFLOW_ORCHESTRATOR_FLAVOR,
source="zenml.integrations.airflow.orchestrators.AirflowOrchestrator",
type=StackComponentType.ORCHESTRATOR,
integration=cls.NAME,
)
]
orchestrators
special
The Airflow integration enables the use of Airflow as a pipeline orchestrator.
airflow_orchestrator
Implementation of Airflow orchestrator integration.
AirflowOrchestrator (BaseOrchestrator)
pydantic-model
Orchestrator responsible for running pipelines using Airflow.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
class AirflowOrchestrator(BaseOrchestrator):
"""Orchestrator responsible for running pipelines using Airflow."""
airflow_home: str = ""
# Class Configuration
FLAVOR: ClassVar[str] = AIRFLOW_ORCHESTRATOR_FLAVOR
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self._set_env()
@staticmethod
def _translate_schedule(
schedule: Optional[Schedule] = None,
) -> Dict[str, Any]:
"""Convert ZenML schedule into Airflow schedule.
The Airflow schedule uses slightly different naming and needs some
default entries for execution without a schedule.
Args:
schedule: Containing the interval, start and end date and
a boolean flag that defines if past runs should be caught up
on
Returns:
Airflow configuration dict.
"""
if schedule:
if schedule.cron_expression:
return {
"schedule_interval": schedule.cron_expression,
}
else:
return {
"schedule_interval": schedule.interval_second,
"start_date": schedule.start_time,
"end_date": schedule.end_time,
"catchup": schedule.catchup,
}
return {
"schedule_interval": "@once",
# set the a start time in the past and disable catchup so airflow runs the dag immediately
"start_date": datetime.datetime.now() - datetime.timedelta(7),
"catchup": False,
}
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
sorted_steps: List of steps in the pipeline.
pipeline: The pipeline to be executed.
pb2_pipeline: The pipeline as a protobuf message.
stack: The stack on which the pipeline will be deployed.
runtime_configuration: The runtime configuration.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(runtime_configuration.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in sorted_steps:
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "BaseStep", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self.run_step(
step=step_instance,
run_name=run_name,
pb2_pipeline=pb2_pipeline,
)
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.name] = airflow_operator
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets Airflow home according to orchestrator UUID.
Args:
values: Dictionary containing all orchestrator attributes values.
Returns:
Dictionary containing all orchestrator attributes values and the airflow home.
Raises:
ValueError: If the orchestrator UUID is not set.
"""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
@property
def dags_directory(self) -> str:
"""Returns path to the airflow dags directory.
Returns:
Path to the airflow dags directory.
"""
return os.path.join(self.airflow_home, "dags")
@property
def pid_file(self) -> str:
"""Returns path to the daemon PID file.
Returns:
Path to the daemon PID file.
"""
return os.path.join(self.airflow_home, "airflow_daemon.pid")
@property
def log_file(self) -> str:
"""Returns path to the airflow log file.
Returns:
str: Path to the airflow log file.
"""
return os.path.join(self.airflow_home, "airflow_orchestrator.log")
@property
def password_file(self) -> str:
"""Returns path to the webserver password file.
Returns:
Path to the webserver password file.
"""
return os.path.join(self.airflow_home, "standalone_admin_password.txt")
def _set_env(self) -> None:
"""Sets environment variables to configure airflow."""
os.environ["AIRFLOW_HOME"] = self.airflow_home
os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = self.dags_directory
os.environ["AIRFLOW__CORE__DAG_DISCOVERY_SAFE_MODE"] = "false"
os.environ["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false"
# check the DAG folder every 10 seconds for new files
os.environ["AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL"] = "10"
def _copy_to_dag_directory_if_necessary(self, dag_filepath: str) -> None:
"""Copies DAG module to the Airflow DAGs directory if not already present.
Args:
dag_filepath: Path to the file in which the DAG is defined.
"""
dags_directory = io_utils.resolve_relative_path(self.dags_directory)
if dags_directory == os.path.dirname(dag_filepath):
logger.debug("File is already in airflow DAGs directory.")
else:
logger.debug(
"Copying dag file '%s' to DAGs directory.", dag_filepath
)
destination_path = os.path.join(
dags_directory, os.path.basename(dag_filepath)
)
if fileio.exists(destination_path):
logger.info(
"File '%s' already exists, overwriting with new DAG file",
destination_path,
)
fileio.copy(dag_filepath, destination_path, overwrite=True)
def _log_webserver_credentials(self) -> None:
"""Logs URL and credentials to log in to the airflow webserver.
Raises:
FileNotFoundError: If the password file does not exist.
"""
if fileio.exists(self.password_file):
with open(self.password_file) as file:
password = file.read().strip()
else:
raise FileNotFoundError(
f"Can't find password file '{self.password_file}'"
)
logger.info(
"To inspect your DAGs, login to http://0.0.0.0:8080 "
"with username: admin password: %s",
password,
)
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator.
Returns:
Runtime options dictionary.
"""
return {DAG_FILEPATH_OPTION_KEY: None}
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Args:
pipeline: Pipeline to be deployed.
stack: Stack to be deployed.
runtime_configuration: Runtime configuration for the pipeline.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
@property
def is_running(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the daemon is running, False otherwise.
Raises:
RuntimeError: If port 8080 is occupied.
"""
from airflow.cli.commands.standalone_command import StandaloneCommand
from airflow.jobs.triggerer_job import TriggererJob
daemon_running = daemon.check_if_daemon_is_running(self.pid_file)
command = StandaloneCommand()
webserver_port_open = command.port_open(8080)
if not daemon_running:
if webserver_port_open:
raise RuntimeError(
"The airflow daemon does not seem to be running but "
"local port 8080 is occupied. Make sure the port is "
"available and try again."
)
# exit early so we don't check non-existing airflow databases
return False
# we can't use StandaloneCommand().is_ready() here as the
# Airflow SequentialExecutor apparently does not send a heartbeat
# while running a task which would result in this returning `False`
# even if Airflow is running.
airflow_running = webserver_port_open and command.job_running(
TriggererJob
)
return airflow_running
@property
def is_provisioned(self) -> bool:
"""Returns whether the airflow daemon is currently running.
Returns:
True if the airflow daemon is running, False otherwise.
"""
return self.is_running
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
dags_directory: str
property
readonly
Returns path to the airflow dags directory.
Returns:
Type | Description |
---|---|
str |
Path to the airflow dags directory. |
is_provisioned: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the airflow daemon is running, False otherwise. |
is_running: bool
property
readonly
Returns whether the airflow daemon is currently running.
Returns:
Type | Description |
---|---|
bool |
True if the daemon is running, False otherwise. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If port 8080 is occupied. |
log_file: str
property
readonly
Returns path to the airflow log file.
Returns:
Type | Description |
---|---|
str |
Path to the airflow log file. |
password_file: str
property
readonly
Returns path to the webserver password file.
Returns:
Type | Description |
---|---|
str |
Path to the webserver password file. |
pid_file: str
property
readonly
Returns path to the daemon PID file.
Returns:
Type | Description |
---|---|
str |
Path to the daemon PID file. |
__init__(self, **values)
special
Sets environment variables to configure airflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**values |
Any |
Values to set in the orchestrator. |
{} |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def __init__(self, **values: Any):
"""Sets environment variables to configure airflow.
Args:
**values: Values to set in the orchestrator.
"""
super().__init__(**values)
self._set_env()
deprovision(self)
Stops the airflow daemon if necessary and tears down resources.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def deprovision(self) -> None:
"""Stops the airflow daemon if necessary and tears down resources."""
if self.is_running:
daemon.stop_daemon(self.pid_file)
fileio.rmtree(self.airflow_home)
logger.info("Airflow spun down.")
prepare_or_run_pipeline(self, sorted_steps, pipeline, pb2_pipeline, stack, runtime_configuration)
Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment and used for orchestration of the pipeline.
How it works:
A new airflow_dag is instantiated with the pipeline name and among others things the run schedule.
For each step of the pipeline a callable is created. This callable uses the run_step() method to execute the step. The parameters of this callable are pre-filled and an airflow step_operator is created within the dag. The dependencies to upstream steps are then configured.
Finally, the dag is fully complete and can be returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sorted_steps |
List[zenml.steps.base_step.BaseStep] |
List of steps in the pipeline. |
required |
pipeline |
BasePipeline |
The pipeline to be executed. |
required |
pb2_pipeline |
Pipeline |
The pipeline as a protobuf message. |
required |
stack |
Stack |
The stack on which the pipeline will be deployed. |
required |
runtime_configuration |
RuntimeConfiguration |
The runtime configuration. |
required |
Returns:
Type | Description |
---|---|
Any |
The Airflow DAG. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_or_run_pipeline(
self,
sorted_steps: List[BaseStep],
pipeline: "BasePipeline",
pb2_pipeline: Pb2Pipeline,
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> Any:
"""Creates an Airflow DAG as the intermediate representation for the pipeline.
This DAG will be loaded by airflow in the target environment
and used for orchestration of the pipeline.
How it works:
-------------
A new airflow_dag is instantiated with the pipeline name and among
others things the run schedule.
For each step of the pipeline a callable is created. This callable
uses the run_step() method to execute the step. The parameters of
this callable are pre-filled and an airflow step_operator is created
within the dag. The dependencies to upstream steps are then
configured.
Finally, the dag is fully complete and can be returned.
Args:
sorted_steps: List of steps in the pipeline.
pipeline: The pipeline to be executed.
pb2_pipeline: The pipeline as a protobuf message.
stack: The stack on which the pipeline will be deployed.
runtime_configuration: The runtime configuration.
Returns:
The Airflow DAG.
"""
import airflow
from airflow.operators import python as airflow_python
# Instantiate and configure airflow Dag with name and schedule
airflow_dag = airflow.DAG(
dag_id=pipeline.name,
is_paused_upon_creation=False,
**self._translate_schedule(runtime_configuration.schedule),
)
# Dictionary mapping step names to airflow_operators. This will be needed
# to configure airflow operator dependencies
step_name_to_airflow_operator = {}
for step in sorted_steps:
# Create callable that will be used by airflow to execute the step
# within the orchestrated environment
def _step_callable(step_instance: "BaseStep", **kwargs):
if self.requires_resources_in_orchestration_environment(step):
logger.warning(
"Specifying step resources is not yet supported for "
"the Airflow orchestrator, ignoring resource "
"configuration for step %s.",
step.name,
)
# Extract run name for the kwargs that will be passed to the
# callable
run_name = kwargs["ti"].get_dagrun().run_id
self.run_step(
step=step_instance,
run_name=run_name,
pb2_pipeline=pb2_pipeline,
)
# Create airflow python operator that contains the step callable
airflow_operator = airflow_python.PythonOperator(
dag=airflow_dag,
task_id=step.name,
provide_context=True,
python_callable=functools.partial(
_step_callable, step_instance=step
),
)
# Configure the current airflow operator to run after all upstream
# operators finished executing
step_name_to_airflow_operator[step.name] = airflow_operator
upstream_step_names = self.get_upstream_step_names(
step=step, pb2_pipeline=pb2_pipeline
)
for upstream_step_name in upstream_step_names:
airflow_operator.set_upstream(
step_name_to_airflow_operator[upstream_step_name]
)
# Return the finished airflow dag
return airflow_dag
prepare_pipeline_deployment(self, pipeline, stack, runtime_configuration)
Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline |
BasePipeline |
Pipeline to be deployed. |
required |
stack |
Stack |
Stack to be deployed. |
required |
runtime_configuration |
RuntimeConfiguration |
Runtime configuration for the pipeline. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If Airflow is not running or no DAG filepath runtime option is provided. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def prepare_pipeline_deployment(
self,
pipeline: "BasePipeline",
stack: "Stack",
runtime_configuration: "RuntimeConfiguration",
) -> None:
"""Checks Airflow is running and copies DAG file to the Airflow DAGs directory.
Args:
pipeline: Pipeline to be deployed.
stack: Stack to be deployed.
runtime_configuration: Runtime configuration for the pipeline.
Raises:
RuntimeError: If Airflow is not running or no DAG filepath runtime
option is provided.
"""
if not self.is_running:
raise RuntimeError(
"Airflow orchestrator is currently not running. Run `zenml "
"stack up` to provision resources for the active stack."
)
if Environment.in_notebook():
raise RuntimeError(
"Unable to run the Airflow orchestrator from within a "
"notebook. Airflow requires a python file which contains a "
"global Airflow DAG object and therefore does not work with "
"notebooks. Please copy your ZenML pipeline code in a python "
"file and try again."
)
try:
dag_filepath = runtime_configuration[DAG_FILEPATH_OPTION_KEY]
except KeyError:
raise RuntimeError(
f"No DAG filepath found in runtime configuration. Make sure "
f"to add the filepath to your airflow DAG file as a runtime "
f"option (key: '{DAG_FILEPATH_OPTION_KEY}')."
)
self._copy_to_dag_directory_if_necessary(dag_filepath=dag_filepath)
provision(self)
Ensures that Airflow is running.
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def provision(self) -> None:
"""Ensures that Airflow is running."""
if self.is_running:
logger.info("Airflow is already running.")
self._log_webserver_credentials()
return
if not fileio.exists(self.dags_directory):
io_utils.create_dir_recursive_if_not_exists(self.dags_directory)
from airflow.cli.commands.standalone_command import StandaloneCommand
try:
command = StandaloneCommand()
# Run the daemon with a working directory inside the current
# zenml repo so the same repo will be used to run the DAGs
daemon.run_as_daemon(
command.run,
pid_file=self.pid_file,
log_file=self.log_file,
working_directory=get_source_root_path(),
)
while not self.is_running:
# Wait until the daemon started all the relevant airflow
# processes
time.sleep(0.1)
self._log_webserver_credentials()
except Exception as e:
logger.error(e)
logger.error(
"An error occurred while starting the Airflow daemon. If you "
"want to start it manually, use the commands described in the "
"official Airflow quickstart guide for running Airflow locally."
)
self.deprovision()
runtime_options(self)
Runtime options for the airflow orchestrator.
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Runtime options dictionary. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
def runtime_options(self) -> Dict[str, Any]:
"""Runtime options for the airflow orchestrator.
Returns:
Runtime options dictionary.
"""
return {DAG_FILEPATH_OPTION_KEY: None}
set_airflow_home(values)
classmethod
Sets Airflow home according to orchestrator UUID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
Dict[str, Any] |
Dictionary containing all orchestrator attributes values. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Dictionary containing all orchestrator attributes values and the airflow home. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the orchestrator UUID is not set. |
Source code in zenml/integrations/airflow/orchestrators/airflow_orchestrator.py
@root_validator(skip_on_failure=True)
def set_airflow_home(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Sets Airflow home according to orchestrator UUID.
Args:
values: Dictionary containing all orchestrator attributes values.
Returns:
Dictionary containing all orchestrator attributes values and the airflow home.
Raises:
ValueError: If the orchestrator UUID is not set.
"""
if "uuid" not in values:
raise ValueError("`uuid` needs to exist for AirflowOrchestrator.")
values["airflow_home"] = os.path.join(
io_utils.get_global_config_directory(),
AIRFLOW_ROOT_DIR,
str(values["uuid"]),
)
return values
aws
special
Integrates multiple AWS Tools as Stack Components.
The AWS integration provides a way for our users to manage their secrets through AWS, a way to use the aws container registry. Additionally, the Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker.
AWSIntegration (Integration)
Definition of AWS integration for ZenML.
Source code in zenml/integrations/aws/__init__.py
class AWSIntegration(Integration):
"""Definition of AWS integration for ZenML."""
NAME = AWS
REQUIREMENTS = ["boto3==1.21.0", "sagemaker==2.82.2"]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AWS_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.aws.secrets_managers"
".AWSSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_CONTAINER_REGISTRY_FLAVOR,
source="zenml.integrations.aws.container_registries"
".AWSContainerRegistry",
type=StackComponentType.CONTAINER_REGISTRY,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.aws.step_operators"
".SagemakerStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declare the stack component flavors for the AWS integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/aws/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the AWS integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AWS_SECRET_MANAGER_FLAVOR,
source="zenml.integrations.aws.secrets_managers"
".AWSSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_CONTAINER_REGISTRY_FLAVOR,
source="zenml.integrations.aws.container_registries"
".AWSContainerRegistry",
type=StackComponentType.CONTAINER_REGISTRY,
integration=cls.NAME,
),
FlavorWrapper(
name=AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.aws.step_operators"
".SagemakerStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
container_registries
special
Initialization of AWS Container Registry integration.
aws_container_registry
Implementation of the AWS container registry integration.
AWSContainerRegistry (BaseContainerRegistry)
pydantic-model
Class for AWS Container Registry.
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
class AWSContainerRegistry(BaseContainerRegistry):
"""Class for AWS Container Registry."""
# Class Configuration
FLAVOR: ClassVar[str] = AWS_CONTAINER_REGISTRY_FLAVOR
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
def _get_region(self) -> str:
"""Parses the AWS region from the registry URI.
Raises:
RuntimeError: If the region parsing fails due to an invalid URI.
Returns:
The region string.
"""
match = re.fullmatch(r".*\.dkr\.ecr\.(.*)\.amazonaws\.com", self.uri)
if not match:
raise RuntimeError(
f"Unable to parse region from ECR URI {self.uri}."
)
return match.group(1)
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
@property
def post_registration_message(self) -> Optional[str]:
"""Optional message printed after the stack component is registered.
Returns:
Info message regarding docker repositories in AWS.
"""
return (
"Amazon ECR requires you to create a repository before you can "
"push an image to it. If you want to for example run a pipeline "
"using our Kubeflow orchestrator, ZenML will automatically build a "
f"docker image called `{self.uri}/zenml-kubeflow:<PIPELINE_NAME>` "
f"and try to push it. This will fail unless you create the "
f"repository `zenml-kubeflow` inside your amazon registry."
)
post_registration_message: Optional[str]
property
readonly
Optional message printed after the stack component is registered.
Returns:
Type | Description |
---|---|
Optional[str] |
Info message regarding docker repositories in AWS. |
prepare_image_push(self, image_name)
Logs warning message if trying to push an image for which no repository exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
image_name |
str |
Name of the docker image that will be pushed. |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
If the docker image name is invalid. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
def prepare_image_push(self, image_name: str) -> None:
"""Logs warning message if trying to push an image for which no repository exists.
Args:
image_name: Name of the docker image that will be pushed.
Raises:
ValueError: If the docker image name is invalid.
"""
response = boto3.client(
"ecr", region_name=self._get_region()
).describe_repositories()
try:
repo_uris: List[str] = [
repository["repositoryUri"]
for repository in response["repositories"]
]
except (KeyError, ClientError) as e:
# invalid boto response, let's hope for the best and just push
logger.debug("Error while trying to fetch ECR repositories: %s", e)
return
repo_exists = any(image_name.startswith(f"{uri}:") for uri in repo_uris)
if not repo_exists:
match = re.search(f"{self.uri}/(.*):.*", image_name)
if not match:
raise ValueError(f"Invalid docker image name '{image_name}'.")
repo_name = match.group(1)
logger.warning(
"Amazon ECR requires you to create a repository before you can "
f"push an image to it. ZenML is trying to push the image "
f"{image_name} but could only detect the following "
f"repositories: {repo_uris}. We will try to push anyway, but "
f"in case it fails you need to create a repository named "
f"`{repo_name}`."
)
validate_aws_uri(uri)
classmethod
Validates that the URI is in the correct format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI to validate. |
required |
Returns:
Type | Description |
---|---|
str |
URI in the correct format. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the URI contains a slash character. |
Source code in zenml/integrations/aws/container_registries/aws_container_registry.py
@validator("uri")
def validate_aws_uri(cls, uri: str) -> str:
"""Validates that the URI is in the correct format.
Args:
uri: URI to validate.
Returns:
URI in the correct format.
Raises:
ValueError: If the URI contains a slash character.
"""
if "/" in uri:
raise ValueError(
"Property `uri` can not contain a `/`. An example of a valid "
"URI is: `715803424592.dkr.ecr.us-east-1.amazonaws.com`"
)
return uri
secrets_managers
special
AWS Secrets Manager.
aws_secrets_manager
Implementation of the AWS Secrets Manager integration.
AWSSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the AWS secrets manager.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
class AWSSecretsManager(BaseSecretsManager):
"""Class to interact with the AWS secrets manager."""
region_name: str
# Class configuration
FLAVOR: ClassVar[str] = AWS_SECRET_MANAGER_FLAVOR
SUPPORTS_SCOPING: ClassVar[bool] = True
CLIENT: ClassVar[Any] = None
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
cls.validate_secret_name_or_namespace(namespace)
@classmethod
def _ensure_client_connected(cls, region_name: str) -> None:
"""Ensure that the client is connected to the AWS secrets manager.
Args:
region_name: the AWS region name
"""
if cls.CLIENT is None:
# Create a Secrets Manager client
session = boto3.session.Session()
cls.CLIENT = session.client(
service_name="secretsmanager", region_name=region_name
)
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
def _get_secret_tags(
self, secret: BaseSecretSchema
) -> List[Dict[str, str]]:
"""Return a list of AWS secret tag values for a given secret.
Args:
secret: the secret object
Returns:
A list of AWS secret tag values
"""
metadata = self._get_secret_metadata(secret)
return [{"Key": k, "Value": v} for k, v in metadata.items()]
def _get_secret_scope_filters(
self,
secret_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return a list of AWS filters for the entire scope or just a scoped secret.
These filters can be used when querying the AWS Secrets Manager
for all secrets or for a single secret available in the configured
scope. For more information see: https://docs.aws.amazon.com/secretsmanager/latest/userguide/manage_search-secret.html
Example AWS filters for all secrets in the current (namespace) scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Example AWS filters for a particular secret in the current (namespace)
scope:
```python
[
{
"Key: "tag-key",
"Values": ["zenml_secret_name"],
},
{
"Key: "tag-value",
"Values": ["my_secret"],
},
{
"Key: "tag-key",
"Values": ["zenml_scope"],
},
{
"Key: "tag-value",
"Values": ["namespace"],
},
{
"Key: "tag-key",
"Values": ["zenml_namespace"],
},
{
"Key: "tag-value",
"Values": ["my_namespace"],
},
]
```
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of AWS filters uniquely identifying all secrets
or a named secret within the configured scope.
"""
metadata = self._get_secret_scope_metadata(secret_name)
filters: List[Dict[str, Any]] = []
for k, v in metadata.items():
filters.append(
{
"Key": "tag-key",
"Values": [
k,
],
}
)
filters.append(
{
"Key": "tag-value",
"Values": [
str(v),
],
}
)
return filters
def _list_secrets(self, secret_name: Optional[str] = None) -> List[str]:
"""List all secrets matching a name.
This method lists all the secrets in the current scope without loading
their contents. An optional secret name can be supplied to filter out
all but a single secret identified by name.
Args:
secret_name: Optional secret name to filter for.
Returns:
A list of secret names in the current scope and the optional
secret name.
"""
self._ensure_client_connected(self.region_name)
filters: List[Dict[str, Any]] = []
prefix: Optional[str] = None
if self.scope == SecretsManagerScope.NONE:
# unscoped (legacy) secrets don't have tags. We want to filter out
# non-legacy secrets
filters = [
{
"Key": "tag-key",
"Values": [
"!zenml_scope",
],
},
]
if secret_name:
prefix = secret_name
else:
filters = self._get_secret_scope_filters()
if secret_name:
prefix = self._get_scoped_secret_name(secret_name)
else:
# add the name prefix to the filters to account for the fact
# that AWS does not do exact matching but prefix-matching on the
# filters
prefix = self._get_scoped_secret_name_prefix()
if prefix:
filters.append(
{
"Key": "name",
"Values": [
f"{prefix}",
],
}
)
# TODO [ENG-720]: Deal with pagination in the aws secret manager when
# listing all secrets
# TODO [ENG-721]: take out this magic maxresults number
response = self.CLIENT.list_secrets(MaxResults=100, Filters=filters)
results = []
for secret in response["SecretList"]:
name = self._get_unscoped_secret_name(secret["Name"])
# keep only the names that are in scope and filter by secret name,
# if one was given
if name and (not secret_name or secret_name == name):
results.append(name)
return results
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_all_secrets(self)
Delete all existing secrets.
This method will force delete all your secrets. You will not be able to recover them once this method is called.
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets.
This method will force delete all your secrets. You will not be able to
recover them once this method is called.
"""
self._ensure_client_connected(self.region_name)
for secret_name in self._list_secrets():
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
delete_secret(self, secret_name)
Delete an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret does not exist
"""
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
self.CLIENT.delete_secret(
SecretId=self._get_scoped_secret_name(secret_name),
ForceDeleteWithoutRecovery=True,
)
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
return self._list_secrets()
get_secret(self, secret_name)
Gets a secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Gets a secret.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret_name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret_name):
raise KeyError(f"Can't find the specified secret '{secret_name}'")
get_secret_value_response = self.CLIENT.get_secret_value(
SecretId=self._get_scoped_secret_name(secret_name)
)
if "SecretString" not in get_secret_value_response:
get_secret_value_response = None
return secret_from_dict(
json.loads(get_secret_value_response["SecretString"]),
secret_name=secret_name,
decode=False,
)
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if self._list_secrets(secret.name):
raise SecretExistsError(
f"A Secret with the name {secret.name} already exists"
)
secret_value = json.dumps(secret_to_dict(secret, encode=False))
kwargs: Dict[str, Any] = {
"Name": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
"Tags": self._get_secret_tags(secret),
}
self.CLIENT.create_secret(**kwargs)
logger.debug("Created AWS secret: %s", kwargs["Name"])
update_secret(self, secret)
Update an existing secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name_or_namespace(secret.name)
self._ensure_client_connected(self.region_name)
if not self._list_secrets(secret.name):
raise KeyError(f"Can't find the specified secret '{secret.name}'")
secret_value = json.dumps(secret_to_dict(secret))
kwargs = {
"SecretId": self._get_scoped_secret_name(secret.name),
"SecretString": secret_value,
}
self.CLIENT.put_secret_value(**kwargs)
validate_secret_name_or_namespace(name)
classmethod
Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The /
character is only used internally to delimit
scopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/aws/secrets_managers/aws_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(cls, name: str) -> None:
"""Validate a secret name or namespace.
AWS secret names must contain only alphanumeric characters and the
characters /_+=.@-. The `/` character is only used internally to delimit
scopes.
Args:
name: the secret name or namespace
Raises:
ValueError: if the secret name or namespace is invalid
"""
if not re.fullmatch(r"[a-zA-Z0-9_+=\.@\-]*", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the characters _+=.@-."
)
step_operators
special
Initialization of the Sagemaker Step Operator.
sagemaker_step_operator
Implementation of the Sagemaker Step Operator.
SagemakerStepOperator (BaseStepOperator, PipelineDockerImageBuilder)
pydantic-model
Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint to run using Sagemaker's Estimator.
Attributes:
Name | Type | Description |
---|---|---|
role |
str |
The role that has to be assigned to the jobs which are running in Sagemaker. |
instance_type |
str |
The type of the compute instance where jobs will run. |
base_image |
Optional[str] |
The base image to use for building the docker image that will be executed. |
bucket |
Optional[str] |
Name of the S3 bucket to use for storing artifacts from the job run. If not provided, a default bucket will be created based on the following format: "sagemaker-{region}-{aws-account-id}". |
experiment_name |
Optional[str] |
The name for the experiment to which the job will be associated. If not provided, the job runs would be independent. |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
class SagemakerStepOperator(BaseStepOperator, PipelineDockerImageBuilder):
"""Step operator to run a step on Sagemaker.
This class defines code that builds an image with the ZenML entrypoint
to run using Sagemaker's Estimator.
Attributes:
role: The role that has to be assigned to the jobs which are
running in Sagemaker.
instance_type: The type of the compute instance where jobs will run.
base_image: The base image to use for building the docker
image that will be executed.
bucket: Name of the S3 bucket to use for storing artifacts
from the job run. If not provided, a default bucket will be created
based on the following format: "sagemaker-{region}-{aws-account-id}".
experiment_name: The name for the experiment to which the job
will be associated. If not provided, the job runs would be
independent.
"""
role: str
instance_type: str
base_image: Optional[str] = None
bucket: Optional[str] = None
experiment_name: Optional[str] = None
# Class Configuration
FLAVOR: ClassVar[str] = AWS_SAGEMAKER_STEP_OPERATOR_FLAVOR
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
("base_image", "docker_parent_image")
)
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains a container registry.
Returns:
A validator that checks that the stack contains a container registry.
"""
def _ensure_local_orchestrator(stack: Stack) -> Tuple[bool, str]:
return (
stack.orchestrator.FLAVOR == "local",
"Local orchestrator is required",
)
return StackValidator(
required_components={StackComponentType.CONTAINER_REGISTRY},
custom_validation_function=_ensure_local_orchestrator,
)
def launch(
self,
pipeline_name: str,
run_name: str,
docker_configuration: "DockerConfiguration",
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
docker_configuration: The Docker configuration for this step.
entrypoint_command: Command that executes the step.
resource_configuration: The resource configuration for this step.
"""
image_name = self.build_and_push_docker_image(
pipeline_name=pipeline_name,
docker_configuration=docker_configuration,
stack=Repository().active_stack,
runtime_configuration=RuntimeConfiguration(),
entrypoint=" ".join(entrypoint_command),
)
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains a container registry.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A validator that checks that the stack contains a container registry. |
launch(self, pipeline_name, run_name, docker_configuration, entrypoint_command, resource_configuration)
Launches a step on Sagemaker.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
docker_configuration |
DockerConfiguration |
The Docker configuration for this step. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
resource_configuration |
ResourceConfiguration |
The resource configuration for this step. |
required |
Source code in zenml/integrations/aws/step_operators/sagemaker_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
docker_configuration: "DockerConfiguration",
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on Sagemaker.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
docker_configuration: The Docker configuration for this step.
entrypoint_command: Command that executes the step.
resource_configuration: The resource configuration for this step.
"""
image_name = self.build_and_push_docker_image(
pipeline_name=pipeline_name,
docker_configuration=docker_configuration,
stack=Repository().active_stack,
runtime_configuration=RuntimeConfiguration(),
entrypoint=" ".join(entrypoint_command),
)
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the SageMaker step operator. If you want to run this step "
"operator on specific resources, you can do so by configuring "
"a different instance type like this: "
"`zenml step-operator update %s "
"--instance_type=<INSTANCE_TYPE>`",
self.name,
)
session = sagemaker.Session(default_bucket=self.bucket)
estimator = sagemaker.estimator.Estimator(
image_name,
self.role,
instance_count=1,
instance_type=self.instance_type,
sagemaker_session=session,
)
# Sagemaker doesn't allow any underscores in job/experiment/trial names
sanitized_run_name = run_name.replace("_", "-")
experiment_config = {}
if self.experiment_name:
experiment_config = {
"ExperimentName": self.experiment_name,
"TrialName": sanitized_run_name,
}
estimator.fit(
wait=True,
experiment_config=experiment_config,
job_name=sanitized_run_name,
)
azure
special
Initialization of the ZenML Azure integration.
The Azure integration submodule provides a way to run ZenML pipelines in a cloud
environment. Specifically, it allows the use of cloud artifact stores,
and an io
module to handle file operations on Azure Blob Storage.
The Azure Step Operator integration submodule provides a way to run ZenML steps
in AzureML.
AzureIntegration (Integration)
Definition of Azure integration for ZenML.
Source code in zenml/integrations/azure/__init__.py
class AzureIntegration(Integration):
"""Definition of Azure integration for ZenML."""
NAME = AZURE
REQUIREMENTS = [
"adlfs==2021.10.0",
"azure-keyvault-keys",
"azure-keyvault-secrets",
"azure-identity",
"azureml-core==1.42.0.post1",
]
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AZURE_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.azure.artifact_stores"
".AzureArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=AZURE_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.azure.secrets_managers"
".AzureSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AZUREML_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.azure.step_operators"
".AzureMLStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
flavors()
classmethod
Declares the flavors for the integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/azure/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declares the flavors for the integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=AZURE_ARTIFACT_STORE_FLAVOR,
source="zenml.integrations.azure.artifact_stores"
".AzureArtifactStore",
type=StackComponentType.ARTIFACT_STORE,
integration=cls.NAME,
),
FlavorWrapper(
name=AZURE_SECRETS_MANAGER_FLAVOR,
source="zenml.integrations.azure.secrets_managers"
".AzureSecretsManager",
type=StackComponentType.SECRETS_MANAGER,
integration=cls.NAME,
),
FlavorWrapper(
name=AZUREML_STEP_OPERATOR_FLAVOR,
source="zenml.integrations.azure.step_operators"
".AzureMLStepOperator",
type=StackComponentType.STEP_OPERATOR,
integration=cls.NAME,
),
]
artifact_stores
special
Initialization of the Azure Artifact Store integration.
azure_artifact_store
Implementation of the Azure Artifact Store integration.
AzureArtifactStore (BaseArtifactStore, AuthenticationMixin)
pydantic-model
Artifact Store for Microsoft Azure based artifacts.
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin):
"""Artifact Store for Microsoft Azure based artifacts."""
_filesystem: Optional[adlfs.AzureBlobFileSystem] = None
# Class Configuration
FLAVOR: ClassVar[str] = AZURE_ARTIFACT_STORE_FLAVOR
SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"}
@property
def filesystem(self) -> adlfs.AzureBlobFileSystem:
"""The adlfs filesystem to access this artifact store.
Returns:
The adlfs filesystem to access this artifact store.
"""
if not self._filesystem:
secret = self.get_authentication_secret(
expected_schema_type=AzureSecretSchema
)
credentials = secret.content if secret else {}
self._filesystem = adlfs.AzureBlobFileSystem(
**credentials,
anon=False,
use_listings_cache=False,
)
return self._filesystem
@classmethod
def _split_path(cls, path: PathType) -> Tuple[str, str]:
"""Splits a path into the filesystem prefix and remainder.
Example:
```python
prefix, remainder = ZenAzure._split_path("az://my_container/test.txt")
print(prefix, remainder) # "az://" "my_container/test.txt"
```
Args:
path: The path to split.
Returns:
A tuple of the filesystem prefix and the remainder.
"""
path = convert_to_str(path)
prefix = ""
for potential_prefix in cls.SUPPORTED_SCHEMES:
if path.startswith(potential_prefix):
prefix = potential_prefix
path = path[len(potential_prefix) :]
break
return prefix, path
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
filesystem: AzureBlobFileSystem
property
readonly
The adlfs filesystem to access this artifact store.
Returns:
Type | Description |
---|---|
AzureBlobFileSystem |
The adlfs filesystem to access this artifact store. |
copyfile(self, src, dst, overwrite=False)
Copy a file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path to copy from. |
required |
dst |
Union[bytes, str] |
The path to copy to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def copyfile(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Copy a file.
Args:
src: The path to copy from.
dst: The path to copy to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to copy to destination '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to copy anyway."
)
# TODO [ENG-151]: Check if it works with overwrite=True or if we need to
# manually remove it first
self.filesystem.copy(path1=src, path2=dst)
exists(self, path)
Check whether a path exists.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path exists, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def exists(self, path: PathType) -> bool:
"""Check whether a path exists.
Args:
path: The path to check.
Returns:
True if the path exists, False otherwise.
"""
return self.filesystem.exists(path=path) # type: ignore[no-any-return]
glob(self, pattern)
Return all paths that match the given glob pattern.
The glob pattern may include: - '' to match any number of characters - '?' to match a single character - '[...]' to match one of the characters inside the brackets - '' as the full name of a path component to match to search in subdirectories of any depth (e.g. '/some_dir/*/some_file)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pattern |
Union[bytes, str] |
The glob pattern to match, see details above. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of paths that match the given glob pattern. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def glob(self, pattern: PathType) -> List[PathType]:
"""Return all paths that match the given glob pattern.
The glob pattern may include:
- '*' to match any number of characters
- '?' to match a single character
- '[...]' to match one of the characters inside the brackets
- '**' as the full name of a path component to match to search
in subdirectories of any depth (e.g. '/some_dir/**/some_file)
Args:
pattern: The glob pattern to match, see details above.
Returns:
A list of paths that match the given glob pattern.
"""
prefix, _ = self._split_path(pattern)
return [
f"{prefix}{path}" for path in self.filesystem.glob(path=pattern)
]
isdir(self, path)
Check whether a path is a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the path is a directory, False otherwise. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def isdir(self, path: PathType) -> bool:
"""Check whether a path is a directory.
Args:
path: The path to check.
Returns:
True if the path is a directory, False otherwise.
"""
return self.filesystem.isdir(path=path) # type: ignore[no-any-return]
listdir(self, path)
Return a list of files in a directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to list. |
required |
Returns:
Type | Description |
---|---|
List[Union[bytes, str]] |
A list of files in the given directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def listdir(self, path: PathType) -> List[PathType]:
"""Return a list of files in a directory.
Args:
path: The path to list.
Returns:
A list of files in the given directory.
"""
_, path = self._split_path(path)
def _extract_basename(file_dict: Dict[str, Any]) -> str:
"""Extracts the basename from a dictionary returned by the Azure filesystem.
Args:
file_dict: A dictionary returned by the Azure filesystem.
Returns:
The basename of the file.
"""
file_path = cast(str, file_dict["name"])
base_name = file_path[len(path) :]
return base_name.lstrip("/")
return [
_extract_basename(dict_)
for dict_ in self.filesystem.listdir(path=path)
]
makedirs(self, path)
Create a directory at the given path.
If needed also create missing parent directories.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def makedirs(self, path: PathType) -> None:
"""Create a directory at the given path.
If needed also create missing parent directories.
Args:
path: The path to create.
"""
self.filesystem.makedirs(path=path, exist_ok=True)
mkdir(self, path)
Create a directory at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to create. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def mkdir(self, path: PathType) -> None:
"""Create a directory at the given path.
Args:
path: The path to create.
"""
self.filesystem.makedir(path=path)
open(self, path, mode='r')
Open a file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
Path of the file to open. |
required |
mode |
str |
Mode in which to open the file. Currently, only 'rb' and 'wb' to read and write binary files are supported. |
'r' |
Returns:
Type | Description |
---|---|
Any |
A file-like object. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def open(self, path: PathType, mode: str = "r") -> Any:
"""Open a file at the given path.
Args:
path: Path of the file to open.
mode: Mode in which to open the file. Currently, only
'rb' and 'wb' to read and write binary files are supported.
Returns:
A file-like object.
"""
return self.filesystem.open(path=path, mode=mode)
remove(self, path)
Remove the file at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def remove(self, path: PathType) -> None:
"""Remove the file at the given path.
Args:
path: The path to remove.
"""
self.filesystem.rm_file(path=path)
rename(self, src, dst, overwrite=False)
Rename source file to destination file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Union[bytes, str] |
The path of the file to rename. |
required |
dst |
Union[bytes, str] |
The path to rename the source file to. |
required |
overwrite |
bool |
If a file already exists at the destination, this
method will overwrite it if overwrite= |
False |
Exceptions:
Type | Description |
---|---|
FileExistsError |
If a file already exists at the destination
and overwrite is not set to |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rename(
self, src: PathType, dst: PathType, overwrite: bool = False
) -> None:
"""Rename source file to destination file.
Args:
src: The path of the file to rename.
dst: The path to rename the source file to.
overwrite: If a file already exists at the destination, this
method will overwrite it if overwrite=`True` and
raise a FileExistsError otherwise.
Raises:
FileExistsError: If a file already exists at the destination
and overwrite is not set to `True`.
"""
if not overwrite and self.filesystem.exists(dst):
raise FileExistsError(
f"Unable to rename file to '{convert_to_str(dst)}', "
f"file already exists. Set `overwrite=True` to rename anyway."
)
# TODO [ENG-152]: Check if it works with overwrite=True or if we need
# to manually remove it first
self.filesystem.rename(path1=src, path2=dst)
rmtree(self, path)
Remove the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path of the directory to remove. |
required |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def rmtree(self, path: PathType) -> None:
"""Remove the given directory.
Args:
path: The path of the directory to remove.
"""
self.filesystem.delete(path=path, recursive=True)
stat(self, path)
Return stat info for the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[bytes, str] |
The path to get stat info for. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
Stat info. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def stat(self, path: PathType) -> Dict[str, Any]:
"""Return stat info for the given path.
Args:
path: The path to get stat info for.
Returns:
Stat info.
"""
return self.filesystem.stat(path=path) # type: ignore[no-any-return]
walk(self, top, topdown=True, onerror=None)
Return an iterator that walks the contents of the given directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
top |
Union[bytes, str] |
Path of directory to walk. |
required |
topdown |
bool |
Unused argument to conform to interface. |
True |
onerror |
Optional[Callable[..., NoneType]] |
Unused argument to conform to interface. |
None |
Yields:
Type | Description |
---|---|
Iterable[Tuple[Union[bytes, str], List[Union[bytes, str]], List[Union[bytes, str]]]] |
An Iterable of Tuples, each of which contain the path of the current directory path, a list of directories inside the current directory and a list of files inside the current directory. |
Source code in zenml/integrations/azure/artifact_stores/azure_artifact_store.py
def walk(
self,
top: PathType,
topdown: bool = True,
onerror: Optional[Callable[..., None]] = None,
) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]:
"""Return an iterator that walks the contents of the given directory.
Args:
top: Path of directory to walk.
topdown: Unused argument to conform to interface.
onerror: Unused argument to conform to interface.
Yields:
An Iterable of Tuples, each of which contain the path of the current
directory path, a list of directories inside the current directory
and a list of files inside the current directory.
"""
# TODO [ENG-153]: Additional params
prefix, _ = self._split_path(top)
for (
directory,
subdirectories,
files,
) in self.filesystem.walk(path=top):
yield f"{prefix}{directory}", subdirectories, files
secrets_managers
special
Initialization of the Azure Secrets Manager integration.
azure_secrets_manager
Implementation of the Azure Secrets Manager integration.
AzureSecretsManager (BaseSecretsManager)
pydantic-model
Class to interact with the Azure secrets manager.
Attributes:
Name | Type | Description |
---|---|---|
key_vault_name |
str |
Name of an Azure Key Vault that this secrets manager will use to store secrets. |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
class AzureSecretsManager(BaseSecretsManager):
"""Class to interact with the Azure secrets manager.
Attributes:
key_vault_name: Name of an Azure Key Vault that this secrets manager
will use to store secrets.
"""
key_vault_name: str
# Class configuration
FLAVOR: ClassVar[str] = AZURE_SECRETS_MANAGER_FLAVOR
SUPPORTS_SCOPING: ClassVar[bool] = True
CLIENT: ClassVar[Any] = None
@classmethod
def _ensure_client_connected(cls, vault_name: str) -> None:
if cls.CLIENT is None:
KVUri = f"https://{vault_name}.vault.azure.net"
credential = DefaultAzureCredential()
cls.CLIENT = SecretClient(vault_url=KVUri, credential=credential)
@classmethod
def _validate_scope(
cls,
scope: SecretsManagerScope,
namespace: Optional[str],
) -> None:
"""Validate the scope and namespace value.
Args:
scope: Scope value.
namespace: Optional namespace value.
"""
if namespace:
cls.validate_secret_name_or_namespace(namespace, scope)
@classmethod
def validate_secret_name_or_namespace(
cls,
name: str,
scope: SecretsManagerScope,
) -> None:
"""Validate a secret name or namespace.
Azure secret names must contain only alphanumeric characters and the
character `-`.
Given that we also save secret names and namespaces as labels, we are
also limited by the 256 maximum size limitation that Azure imposes on
label values. An arbitrary length of 100 characters is used here for
the maximum size for the secret name and namespace.
Args:
name: the secret name or namespace
scope: the current scope
Raises:
ValueError: if the secret name or namespace is invalid
"""
if scope == SecretsManagerScope.NONE:
# to preserve backwards compatibility, we don't validate the
# secret name for unscoped secrets.
return
if not re.fullmatch(r"[0-9a-zA-Z-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the character -."
)
if len(name) > 100:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 100 characters."
)
def validate_secret_name(self, name: str) -> None:
"""Validate a secret name.
Args:
name: the secret name
"""
self.validate_secret_name_or_namespace(name, self.scope)
def _create_or_update_secret(self, secret: BaseSecretSchema) -> None:
"""Creates a new secret or updated an existing one.
Args:
secret: the secret to register or update
"""
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
for key, value in secret.content.items():
encoded_key = base64.b64encode(
f"{secret.name}-{key}".encode()
).hex()
azure_secret_name = f"zenml-{encoded_key}"
self.CLIENT.set_secret(azure_secret_name, value)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags={
ZENML_GROUP_KEY: secret.name,
ZENML_KEY_NAME: key,
ZENML_SCHEMA_NAME: secret.TYPE,
},
)
logger.debug(
"Secret `%s` written to the Azure Key Vault.",
azure_secret_name,
)
else:
azure_secret_name = self._get_scoped_secret_name(
secret.name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
)
self.CLIENT.set_secret(
azure_secret_name,
json.dumps(secret_to_dict(secret)),
)
self.CLIENT.update_secret_properties(
azure_secret_name,
tags=self._get_secret_metadata(secret),
)
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self._create_or_update_secret(secret)
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.key_vault_name)
zenml_secret: Optional[BaseSecretSchema] = None
if self.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Azure secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if secret_contents:
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Azure secrets
try:
response = self.CLIENT.get_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
)
scope_tags = self._get_secret_scope_metadata(secret_name)
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope,
# even if it has the same name
if scope_tags.items() <= response.properties.tags.items():
zenml_secret = secret_from_dict(
json.loads(response.value), secret_name=secret_name
)
except ResourceNotFoundError:
pass
if not zenml_secret:
raise KeyError(f"Can't find the specified secret '{secret_name}'")
return zenml_secret
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))
return list(set_of_secrets)
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.key_vault_name)
if secret.name not in self.get_all_secret_keys():
raise KeyError(f"Can't find the specified secret '{secret.name}'")
self._create_or_update_secret(secret)
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.key_vault_name)
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
# Go through all Azure secrets and delete the ones with the
# secret_name as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
else:
if secret_name not in self.get_all_secret_keys():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
self.CLIENT.begin_delete_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
).result()
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret "
"`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_all_secrets(self)
Delete all existing secrets.
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_all_secrets(self) -> None:
"""Delete all existing secrets."""
self._ensure_client_connected(self.key_vault_name)
# List all secrets.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
logger.info(
"Deleted key-value pair {`%s`, `***`} from secret "
"`%s`",
secret_property.name,
tags.get(ZENML_GROUP_KEY),
)
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
self.CLIENT.begin_delete_secret(secret_property.name).result()
delete_secret(self, secret_name)
Delete an existing secret. by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to delete |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret no longer exists |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def delete_secret(self, secret_name: str) -> None:
"""Delete an existing secret. by name.
Args:
secret_name: the name of the secret to delete
Raises:
KeyError: if the secret no longer exists
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.key_vault_name)
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
# Go through all Azure secrets and delete the ones with the
# secret_name as label.
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
self.CLIENT.begin_delete_secret(
secret_property.name
).result()
else:
if secret_name not in self.get_all_secret_keys():
raise KeyError(
f"Can't find the specified secret '{secret_name}'"
)
self.CLIENT.begin_delete_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
).result()
get_all_secret_keys(self)
Get all secret keys.
Returns:
Type | Description |
---|---|
List[str] |
A list of all secret keys |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_all_secret_keys(self) -> List[str]:
"""Get all secret keys.
Returns:
A list of all secret keys
"""
self._ensure_client_connected(self.key_vault_name)
set_of_secrets = set()
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if not tags:
continue
if self.scope == SecretsManagerScope.NONE:
# legacy, non-scoped secrets
if ZENML_GROUP_KEY in tags:
set_of_secrets.add(tags.get(ZENML_GROUP_KEY))
continue
scope_tags = self._get_secret_scope_metadata()
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope
if scope_tags.items() <= tags.items():
set_of_secrets.add(tags.get(ZENML_SECRET_NAME_LABEL))
return list(set_of_secrets)
get_secret(self, secret_name)
Get a secret by its name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret_name |
str |
the name of the secret to get |
required |
Returns:
Type | Description |
---|---|
BaseSecretSchema |
The secret. |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
ValueError |
if the secret is named 'name' |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def get_secret(self, secret_name: str) -> BaseSecretSchema:
"""Get a secret by its name.
Args:
secret_name: the name of the secret to get
Returns:
The secret.
Raises:
KeyError: if the secret does not exist
ValueError: if the secret is named 'name'
"""
self.validate_secret_name(secret_name)
self._ensure_client_connected(self.key_vault_name)
zenml_secret: Optional[BaseSecretSchema] = None
if self.scope == SecretsManagerScope.NONE:
# Legacy secrets are mapped to multiple Azure secrets, one for
# each secret key
secret_contents = {}
zenml_schema_name = ""
for secret_property in self.CLIENT.list_properties_of_secrets():
tags = secret_property.tags
if tags and tags.get(ZENML_GROUP_KEY) == secret_name:
secret_key = tags.get(ZENML_KEY_NAME)
if not secret_key:
raise ValueError("Missing secret key tag.")
if secret_key == "name":
raise ValueError("The secret's key cannot be 'name'.")
response = self.CLIENT.get_secret(secret_property.name)
secret_contents[secret_key] = response.value
zenml_schema_name = tags.get(ZENML_SCHEMA_NAME)
if secret_contents:
secret_contents["name"] = secret_name
secret_schema = SecretSchemaClassRegistry.get_class(
secret_schema=zenml_schema_name
)
zenml_secret = secret_schema(**secret_contents)
else:
# Scoped secrets are mapped 1-to-1 with Azure secrets
try:
response = self.CLIENT.get_secret(
self._get_scoped_secret_name(
secret_name,
separator=ZENML_AZURE_SECRET_SCOPE_PATH_SEPARATOR,
),
)
scope_tags = self._get_secret_scope_metadata(secret_name)
# all scope tags need to be included in the Azure secret tags,
# otherwise the secret does not belong to the current scope,
# even if it has the same name
if scope_tags.items() <= response.properties.tags.items():
zenml_secret = secret_from_dict(
json.loads(response.value), secret_name=secret_name
)
except ResourceNotFoundError:
pass
if not zenml_secret:
raise KeyError(f"Can't find the specified secret '{secret_name}'")
return zenml_secret
register_secret(self, secret)
Registers a new secret.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to register |
required |
Exceptions:
Type | Description |
---|---|
SecretExistsError |
if the secret already exists |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def register_secret(self, secret: BaseSecretSchema) -> None:
"""Registers a new secret.
Args:
secret: the secret to register
Raises:
SecretExistsError: if the secret already exists
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.key_vault_name)
if secret.name in self.get_all_secret_keys():
raise SecretExistsError(
f"A Secret with the name '{secret.name}' already exists."
)
self._create_or_update_secret(secret)
update_secret(self, secret)
Update an existing secret by creating new versions of the existing secrets.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
secret |
BaseSecretSchema |
the secret to update |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
if the secret does not exist |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def update_secret(self, secret: BaseSecretSchema) -> None:
"""Update an existing secret by creating new versions of the existing secrets.
Args:
secret: the secret to update
Raises:
KeyError: if the secret does not exist
"""
self.validate_secret_name(secret.name)
self._ensure_client_connected(self.key_vault_name)
if secret.name not in self.get_all_secret_keys():
raise KeyError(f"Can't find the specified secret '{secret.name}'")
self._create_or_update_secret(secret)
validate_secret_name(self, name)
Validate a secret name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name |
required |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
def validate_secret_name(self, name: str) -> None:
"""Validate a secret name.
Args:
name: the secret name
"""
self.validate_secret_name_or_namespace(name, self.scope)
validate_secret_name_or_namespace(name, scope)
classmethod
Validate a secret name or namespace.
Azure secret names must contain only alphanumeric characters and the
character -
.
Given that we also save secret names and namespaces as labels, we are also limited by the 256 maximum size limitation that Azure imposes on label values. An arbitrary length of 100 characters is used here for the maximum size for the secret name and namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the secret name or namespace |
required |
scope |
SecretsManagerScope |
the current scope |
required |
Exceptions:
Type | Description |
---|---|
ValueError |
if the secret name or namespace is invalid |
Source code in zenml/integrations/azure/secrets_managers/azure_secrets_manager.py
@classmethod
def validate_secret_name_or_namespace(
cls,
name: str,
scope: SecretsManagerScope,
) -> None:
"""Validate a secret name or namespace.
Azure secret names must contain only alphanumeric characters and the
character `-`.
Given that we also save secret names and namespaces as labels, we are
also limited by the 256 maximum size limitation that Azure imposes on
label values. An arbitrary length of 100 characters is used here for
the maximum size for the secret name and namespace.
Args:
name: the secret name or namespace
scope: the current scope
Raises:
ValueError: if the secret name or namespace is invalid
"""
if scope == SecretsManagerScope.NONE:
# to preserve backwards compatibility, we don't validate the
# secret name for unscoped secrets.
return
if not re.fullmatch(r"[0-9a-zA-Z-]+", name):
raise ValueError(
f"Invalid secret name or namespace '{name}'. Must contain "
f"only alphanumeric characters and the character -."
)
if len(name) > 100:
raise ValueError(
f"Invalid secret name or namespace '{name}'. The length is "
f"limited to maximum 100 characters."
)
step_operators
special
Initialization of AzureML Step Operator integration.
azureml_step_operator
Implementation of the ZenML AzureML Step Operator.
AzureMLStepOperator (BaseStepOperator, PipelineDockerImageBuilder)
pydantic-model
Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the ZenML entrypoint command in it.
Attributes:
Name | Type | Description |
---|---|---|
subscription_id |
str |
The Azure account's subscription ID |
resource_group |
str |
The resource group to which the AzureML workspace is deployed. |
workspace_name |
str |
The name of the AzureML Workspace. |
compute_target_name |
str |
The name of the configured ComputeTarget. An instance of it has to be created on the portal if it doesn't exist already. |
environment_name |
Optional[str] |
The name of the environment if there already exists one. |
docker_base_image |
Optional[str] |
The custom docker base image that the environment should use. |
tenant_id |
Optional[str] |
The Azure Tenant ID. |
service_principal_id |
Optional[str] |
The ID for the service principal that is created to allow apps to access secure resources. |
service_principal_password |
Optional[str] |
Password for the service principal. |
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
class AzureMLStepOperator(BaseStepOperator, PipelineDockerImageBuilder):
"""Step operator to run a step on AzureML.
This class defines code that can set up an AzureML environment and run the
ZenML entrypoint command in it.
Attributes:
subscription_id: The Azure account's subscription ID
resource_group: The resource group to which the AzureML workspace
is deployed.
workspace_name: The name of the AzureML Workspace.
compute_target_name: The name of the configured ComputeTarget.
An instance of it has to be created on the portal if it doesn't
exist already.
environment_name: The name of the environment if there
already exists one.
docker_base_image: The custom docker base image that the
environment should use.
tenant_id: The Azure Tenant ID.
service_principal_id: The ID for the service principal that is created
to allow apps to access secure resources.
service_principal_password: Password for the service principal.
"""
subscription_id: str
resource_group: str
workspace_name: str
compute_target_name: str
# Environment
environment_name: Optional[str] = None
docker_base_image: Optional[str] = None
# Service principal authentication
# https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal
tenant_id: Optional[str] = SecretField()
service_principal_id: Optional[str] = SecretField()
service_principal_password: Optional[str] = SecretField()
# Class Configuration
FLAVOR: ClassVar[str] = AZUREML_STEP_OPERATOR_FLAVOR
_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
("docker_base_image", "docker_parent_image")
)
def _get_authentication(self) -> Optional[AbstractAuthentication]:
"""Returns the authentication object for the AzureML environment.
Returns:
The authentication object for the AzureML environment.
"""
if (
self.tenant_id
and self.service_principal_id
and self.service_principal_password
):
return ServicePrincipalAuthentication(
tenant_id=self.tenant_id,
service_principal_id=self.service_principal_id,
service_principal_password=self.service_principal_password,
)
return None
def _prepare_environment(
self,
workspace: Workspace,
docker_configuration: "DockerConfiguration",
run_name: str,
) -> Environment:
"""Prepares the environment in which Azure will run all jobs.
Args:
workspace: The AzureML Workspace that has configuration
for a storage account, container registry among other
things.
docker_configuration: The Docker configuration for this step.
run_name: The name of the pipeline run that can be used
for naming environments and runs.
Returns:
The AzureML Environment object.
"""
requirements_files = self._gather_requirements_files(
docker_configuration=docker_configuration,
stack=Repository().active_stack,
)
requirements = list(
itertools.chain.from_iterable(
r[1].split("\n") for r in requirements_files
)
)
requirements.append(f"zenml=={zenml.__version__}")
logger.info(
"Using requirements for AzureML step operator environment: %s",
requirements,
)
if self.environment_name:
environment = Environment.get(
workspace=workspace, name=self.environment_name
)
if not environment.python.conda_dependencies:
environment.python.conda_dependencies = (
CondaDependencies.create(
python_version=ZenMLEnvironment.python_version()
)
)
for requirement in requirements:
environment.python.conda_dependencies.add_pip_package(
requirement
)
else:
environment = Environment(name=f"zenml-{run_name}")
environment.python.conda_dependencies = CondaDependencies.create(
pip_packages=requirements,
python_version=ZenMLEnvironment.python_version(),
)
parent_image = (
docker_configuration.parent_image or self.docker_parent_image
)
if parent_image:
# replace the default azure base image
environment.docker.base_image = parent_image
environment_variables = {
"ENV_ZENML_PREVENT_PIPELINE_EXECUTION": "True",
}
# set credentials to access azure storage
for key in [
"AZURE_STORAGE_ACCOUNT_KEY",
"AZURE_STORAGE_ACCOUNT_NAME",
"AZURE_STORAGE_CONNECTION_STRING",
"AZURE_STORAGE_SAS_TOKEN",
]:
value = os.getenv(key)
if value:
environment_variables[key] = value
environment_variables[
ENV_ZENML_CONFIG_PATH
] = f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
environment_variables.update(docker_configuration.environment)
environment.environment_variables = environment_variables
return environment
def launch(
self,
pipeline_name: str,
run_name: str,
docker_configuration: "DockerConfiguration",
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
docker_configuration: The Docker configuration for this step.
entrypoint_command: Command that executes the step.
resource_configuration: The resource configuration for this step.
"""
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
unused_docker_fields = [
"dockerfile",
"build_context_root",
"build_options",
"docker_target_repository",
"dockerignore",
"copy_files",
"copy_profile",
]
ignored_docker_fields = (
docker_configuration.__fields_set__.intersection(
unused_docker_fields
)
)
if ignored_docker_fields:
logger.warning(
"The AzureML step operator currently does not support all "
"options defined in your Docker configuration. Ignoring all "
"values set for the attributes: %s",
ignored_docker_fields,
)
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
with _include_active_profile(
build_context_root=source_directory,
load_config_path=PurePosixPath(
f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
),
):
environment = self._prepare_environment(
workspace=workspace,
docker_configuration=docker_configuration,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
run.display_name = run_name
run.wait_for_completion(show_output=True)
launch(self, pipeline_name, run_name, docker_configuration, entrypoint_command, resource_configuration)
Launches a step on AzureML.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline_name |
str |
Name of the pipeline which the step to be executed is part of. |
required |
run_name |
str |
Name of the pipeline run which the step to be executed is part of. |
required |
docker_configuration |
DockerConfiguration |
The Docker configuration for this step. |
required |
entrypoint_command |
List[str] |
Command that executes the step. |
required |
resource_configuration |
ResourceConfiguration |
The resource configuration for this step. |
required |
Source code in zenml/integrations/azure/step_operators/azureml_step_operator.py
def launch(
self,
pipeline_name: str,
run_name: str,
docker_configuration: "DockerConfiguration",
entrypoint_command: List[str],
resource_configuration: "ResourceConfiguration",
) -> None:
"""Launches a step on AzureML.
Args:
pipeline_name: Name of the pipeline which the step to be executed
is part of.
run_name: Name of the pipeline run which the step to be executed
is part of.
docker_configuration: The Docker configuration for this step.
entrypoint_command: Command that executes the step.
resource_configuration: The resource configuration for this step.
"""
if not resource_configuration.empty:
logger.warning(
"Specifying custom step resources is not supported for "
"the AzureML step operator. If you want to run this step "
"operator on specific resources, you can do so by creating an "
"Azure compute target (https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target) "
"with a specific machine type and then updating this step "
"operator: `zenml step-operator update %s "
"--compute_target_name=<COMPUTE_TARGET_NAME>`",
self.name,
)
unused_docker_fields = [
"dockerfile",
"build_context_root",
"build_options",
"docker_target_repository",
"dockerignore",
"copy_files",
"copy_profile",
]
ignored_docker_fields = (
docker_configuration.__fields_set__.intersection(
unused_docker_fields
)
)
if ignored_docker_fields:
logger.warning(
"The AzureML step operator currently does not support all "
"options defined in your Docker configuration. Ignoring all "
"values set for the attributes: %s",
ignored_docker_fields,
)
workspace = Workspace.get(
subscription_id=self.subscription_id,
resource_group=self.resource_group,
name=self.workspace_name,
auth=self._get_authentication(),
)
source_directory = get_source_root_path()
with _include_active_profile(
build_context_root=source_directory,
load_config_path=PurePosixPath(
f"./{DOCKER_IMAGE_ZENML_CONFIG_DIR}"
),
):
environment = self._prepare_environment(
workspace=workspace,
docker_configuration=docker_configuration,
run_name=run_name,
)
compute_target = ComputeTarget(
workspace=workspace, name=self.compute_target_name
)
run_config = ScriptRunConfig(
source_directory=source_directory,
environment=environment,
compute_target=compute_target,
command=entrypoint_command,
)
experiment = Experiment(workspace=workspace, name=pipeline_name)
run = experiment.submit(config=run_config)
run.display_name = run_name
run.wait_for_completion(show_output=True)
constants
Constants for ZenML integrations.
dash
special
Initialization of the Dash integration.
DashIntegration (Integration)
Definition of Dash integration for ZenML.
Source code in zenml/integrations/dash/__init__.py
class DashIntegration(Integration):
"""Definition of Dash integration for ZenML."""
NAME = DASH
REQUIREMENTS = [
"dash>=2.0.0",
"dash-cytoscape>=0.3.0",
"dash-bootstrap-components>=1.0.1",
"jupyter-dash>=0.4.2",
]
visualizers
special
Initialization of the Pipeline Run Visualizer.
pipeline_run_lineage_visualizer
Implementation of the pipeline run lineage visualizer.
PipelineRunLineageVisualizer (BaseVisualizer)
Implementation of a lineage diagram via the dash and dash-cytoscape libraries.
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
class PipelineRunLineageVisualizer(BaseVisualizer):
"""Implementation of a lineage diagram via the dash and dash-cytoscape libraries."""
ARTIFACT_PREFIX = "artifact_"
STEP_PREFIX = "step_"
STATUS_CLASS_MAPPING = {
ExecutionStatus.CACHED: "green",
ExecutionStatus.FAILED: "red",
ExecutionStatus.RUNNING: "yellow",
ExecutionStatus.COMPLETED: "blue",
}
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
visualize(self, object, magic=False, *args, **kwargs)
Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
object |
PipelineRunView |
The pipeline run to visualize. |
required |
magic |
bool |
If True, the visualization is rendered in a magic mode. |
False |
*args |
Any |
Additional positional arguments. |
() |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Dash |
The Dash application. |
Source code in zenml/integrations/dash/visualizers/pipeline_run_lineage_visualizer.py
def visualize(
self,
object: PipelineRunView,
magic: bool = False,
*args: Any,
**kwargs: Any,
) -> dash.Dash:
"""Method to visualize pipeline runs via the Dash library.
The layout puts every layer of the dag in a column.
Args:
object: The pipeline run to visualize.
magic: If True, the visualization is rendered in a magic mode.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
The Dash application.
"""
external_stylesheets = [
dbc.themes.BOOTSTRAP,
dbc.icons.BOOTSTRAP,
]
if magic:
if Environment.in_notebook:
# Only import jupyter_dash in this case
from jupyter_dash import JupyterDash # noqa
JupyterDash.infer_jupyter_proxy_config()
app = JupyterDash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = "inline"
else:
cli_utils.warning(
"Cannot set magic flag in non-notebook environments."
)
else:
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
mode = None
nodes, edges, first_step_id = [], [], None
first_step_id = None
for step in object.steps:
step_output_artifacts = list(step.outputs.values())
execution_id = (
step_output_artifacts[0].producer_step.id
if step_output_artifacts
else step.id
)
step_id = self.STEP_PREFIX + str(step.id)
if first_step_id is None:
first_step_id = step_id
nodes.append(
{
"data": {
"id": step_id,
"execution_id": execution_id,
"label": f"{execution_id} / {step.entrypoint_name}",
"entrypoint_name": step.entrypoint_name, # redundant for consistency
"name": step.name, # redundant for consistency
"type": "step",
"parameters": step.parameters,
"inputs": {k: v.uri for k, v in step.inputs.items()},
"outputs": {k: v.uri for k, v in step.outputs.items()},
},
"classes": self.STATUS_CLASS_MAPPING[step.status],
}
)
for artifact_name, artifact in step.outputs.items():
nodes.append(
{
"data": {
"id": self.ARTIFACT_PREFIX + str(artifact.id),
"execution_id": artifact.id,
"label": f"{artifact.id} / {artifact_name} ("
f"{artifact.data_type})",
"type": "artifact",
"name": artifact_name,
"is_cached": artifact.is_cached,
"artifact_type": artifact.type,
"artifact_data_type": artifact.data_type,
"parent_step_id": artifact.parent_step_id,
"producer_step_id": artifact.producer_step.id,
"uri": artifact.uri,
},
"classes": f"rectangle "
f"{self.STATUS_CLASS_MAPPING[step.status]}",
}
)
edges.append(
{
"data": {
"source": self.STEP_PREFIX + str(step.id),
"target": self.ARTIFACT_PREFIX + str(artifact.id),
},
"classes": f"edge-arrow "
f"{self.STATUS_CLASS_MAPPING[step.status]}"
+ (" dashed" if artifact.is_cached else " solid"),
}
)
for artifact_name, artifact in step.inputs.items():
edges.append(
{
"data": {
"source": self.ARTIFACT_PREFIX + str(artifact.id),
"target": self.STEP_PREFIX + str(step.id),
},
"classes": "edge-arrow "
+ (
f"{self.STATUS_CLASS_MAPPING[ExecutionStatus.CACHED]} dashed"
if artifact.is_cached
else f"{self.STATUS_CLASS_MAPPING[step.status]} solid"
),
}
)
app.layout = dbc.Row(
[
dbc.Container(f"Run: {object.name}", class_name="h1"),
dbc.Row(
[
dbc.Col(
[
dbc.Row(
[
html.Span(
[
html.Span(
[
html.I(
className="bi bi-circle-fill me-1"
),
"Step",
],
className="me-2",
),
html.Span(
[
html.I(
className="bi bi-square-fill me-1"
),
"Artifact",
],
className="me-4",
),
dbc.Badge(
"Completed",
color=COLOR_BLUE,
className="me-1",
),
dbc.Badge(
"Cached",
color=COLOR_GREEN,
className="me-1",
),
dbc.Badge(
"Running",
color=COLOR_YELLOW,
className="me-1",
),
dbc.Badge(
"Failed",
color=COLOR_RED,
className="me-1",
),
]
),
]
),
dbc.Row(
[
cyto.Cytoscape(
id="cytoscape",
layout={
"name": "breadthfirst",
"roots": f'[id = "{first_step_id}"]',
},
elements=edges + nodes,
stylesheet=STYLESHEET,
style={
"width": "100%",
"height": "800px",
},
zoom=1,
)
]
),
dbc.Row(
[
dbc.Button(
"Reset",
id="bt-reset",
color="primary",
className="me-1",
)
]
),
]
),
dbc.Col(
[
dcc.Markdown(id="markdown-selected-node-data"),
]
),
]
),
],
className="p-5",
)
@app.callback( # type: ignore[misc]
Output("markdown-selected-node-data", "children"),
Input("cytoscape", "selectedNodeData"),
)
def display_data(data_list: List[Dict[str, Any]]) -> str:
"""Callback for the text area below the graph.
Args:
data_list: The selected node data.
Returns:
str: The selected node data.
"""
if data_list is None:
return "Click on a node in the diagram."
text = ""
for data in data_list:
text += f'## {data["execution_id"]} / {data["name"]}' + "\n\n"
if data["type"] == "artifact":
for item in [
"artifact_data_type",
"is_cached",
"producer_step_id",
"parent_step_id",
"uri",
]:
text += f"**{item}**: {data[item]}" + "\n\n"
elif data["type"] == "step":
text += "### Inputs:" + "\n\n"
for k, v in data["inputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Outputs:" + "\n\n"
for k, v in data["outputs"].items():
text += f"**{k}**: {v}" + "\n\n"
text += "### Params:"
for k, v in data["parameters"].items():
text += f"**{k}**: {v}" + "\n\n"
return text
@app.callback( # type: ignore[misc]
[Output("cytoscape", "zoom"), Output("cytoscape", "elements")],
[Input("bt-reset", "n_clicks")],
)
def reset_layout(
n_clicks: int,
) -> List[Union[int, List[Dict[str, Collection[str]]]]]:
"""Resets the layout.
Args:
n_clicks: The number of clicks on the reset button.
Returns:
The zoom and the elements.
"""
logger.debug(n_clicks, "clicked in reset button.")
return [1, edges + nodes]
if mode is not None:
app.run_server(mode=mode)
app.run_server()
return app
deepchecks
special
Deepchecks integration for ZenML.
The Deepchecks integration provides a way to validate your data in your pipelines. It includes a way to detect data anomalies and define checks to ensure quality of data.
The integration includes custom materializers to store Deepchecks SuiteResults
and
a visualizer to visualize the results in an easy way on a notebook and in your
browser.
DeepchecksIntegration (Integration)
Definition of Deepchecks integration for ZenML.
Source code in zenml/integrations/deepchecks/__init__.py
class DeepchecksIntegration(Integration):
"""Definition of [Deepchecks](https://github.com/deepchecks/deepchecks) integration for ZenML."""
NAME = DEEPCHECKS
REQUIREMENTS = ["deepchecks[vision]==0.8.0", "torchvision==0.11.2"]
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
activate()
staticmethod
Activate the Deepchecks integration.
Source code in zenml/integrations/deepchecks/__init__.py
@staticmethod
def activate() -> None:
"""Activate the Deepchecks integration."""
from zenml.integrations.deepchecks import materializers # noqa
from zenml.integrations.deepchecks import visualizers # noqa
flavors()
classmethod
Declare the stack component flavors for the Deepchecks integration.
Returns:
Type | Description |
---|---|
List[zenml.zen_stores.models.flavor_wrapper.FlavorWrapper] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/deepchecks/__init__.py
@classmethod
def flavors(cls) -> List[FlavorWrapper]:
"""Declare the stack component flavors for the Deepchecks integration.
Returns:
List of stack component flavors for this integration.
"""
return [
FlavorWrapper(
name=DEEPCHECKS_DATA_VALIDATOR_FLAVOR,
source="zenml.integrations.deepchecks.data_validators.DeepchecksDataValidator",
type=StackComponentType.DATA_VALIDATOR,
integration=cls.NAME,
),
]
data_validators
special
Initialization of the Deepchecks data validator for ZenML.
deepchecks_data_validator
Implementation of the Deepchecks data validator.
DeepchecksDataValidator (BaseDataValidator)
pydantic-model
Deepchecks data validator stack component.
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
class DeepchecksDataValidator(BaseDataValidator):
"""Deepchecks data validator stack component."""
# Class Configuration
FLAVOR: ClassVar[str] = DEEPCHECKS_DATA_VALIDATOR_FLAVOR
NAME: ClassVar[str] = "Deepchecks"
@staticmethod
def _split_checks(
check_list: Sequence[str],
) -> Tuple[Sequence[str], Sequence[str]]:
"""Split a list of check identifiers in two lists, one for tabular and one for computer vision checks.
Args:
check_list: A list of check identifiers.
Returns:
List of tabular check identifiers and list of computer vision
check identifiers.
"""
tabular_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_tabular_check(check),
check_list,
)
)
vision_checks = list(
filter(
lambda check: DeepchecksValidationCheck.is_vision_check(check),
check_list,
)
)
return tabular_checks, vision_checks
# flake8: noqa: C901
@classmethod
def _create_and_run_check_suite(
cls,
check_enum: Type[DeepchecksValidationCheck],
reference_dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[
Union[pd.DataFrame, DataLoader[Any]]
] = None,
model: Optional[Union[ClassifierMixin, Module]] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
) -> SuiteResult:
"""Create and run a Deepchecks check suite corresponding to the input parameters.
This method contains generic logic common to all Deepchecks data
validator methods that validates the input arguments and uses them to
generate and run a Deepchecks check suite.
Args:
check_enum: ZenML enum type grouping together Deepchecks checks with
the same characteristics. This is used to generate a default
list of checks, if a custom list isn't provided via the
`check_list` argument.
reference_dataset: Primary (reference) dataset argument used during
validation.
comparison_dataset: Optional secondary (comparison) dataset argument
used during comparison checks.
model: Optional model argument used during validation.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the list of Deepchecks checks to be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
Returns:
Deepchecks SuiteResult object with the Suite run results.
Raises:
TypeError: If the datasets, model and check list arguments combine
data types and/or checks from different categories (tabular and
computer vision).
"""
# Detect what type of check to perform (tabular or computer vision) from
# the dataset/model datatypes and the check list. At the same time,
# validate the combination of data types used for dataset and model
# arguments and the check list.
is_tabular = False
is_vision = False
for dataset in [reference_dataset, comparison_dataset]:
if dataset is None:
continue
if isinstance(dataset, pd.DataFrame):
is_tabular = True
elif isinstance(dataset, DataLoader):
is_vision = True
else:
raise TypeError(
f"Unsupported dataset data type found: {type(dataset)}. "
f"Supported data types are {str(pd.DataFrame)} for tabular "
f"data and {str(DataLoader)} for computer vision data."
)
if model:
if isinstance(model, ClassifierMixin):
is_tabular = True
elif isinstance(model, Module):
is_vision = True
else:
raise TypeError(
f"Unsupported model data type found: {type(model)}. "
f"Supported data types are {str(ClassifierMixin)} for "
f"tabular data and {str(Module)} for computer vision "
f"data."
)
if is_tabular and is_vision:
raise TypeError(
f"Tabular and computer vision data types used for datasets and "
f"models cannot be mixed. They must all belong to the same "
f"category. Supported data types for tabular data are "
f"{str(pd.DataFrame)} for datasets and {str(ClassifierMixin)} "
f"for models. Supported data types for computer vision data "
f"are {str(pd.DataFrame)} for datasets and and {str(Module)} "
f"for models."
)
if not check_list:
# default to executing all the checks listed in the supplied
# checks enum type if a custom check list is not supplied
tabular_checks, vision_checks = cls._split_checks(
check_enum.values()
)
if is_tabular:
check_list = tabular_checks
vision_checks = []
else:
check_list = vision_checks
tabular_checks = []
else:
tabular_checks, vision_checks = cls._split_checks(check_list)
if tabular_checks and vision_checks:
raise TypeError(
f"The check list cannot mix tabular checks "
f"({tabular_checks}) and computer vision checks ("
f"{vision_checks})."
)
if is_tabular and vision_checks:
raise TypeError(
f"Tabular data types used for datasets and models can only "
f"be used with tabular validation checks. The following "
f"computer vision checks included in the check list are "
f"not valid: {vision_checks}."
)
if is_vision and tabular_checks:
raise TypeError(
f"Computer vision data types used for datasets and models "
f"can only be used with computer vision validation checks. "
f"The following tabular checks included in the check list "
f"are not valid: {tabular_checks}."
)
check_classes = map(
lambda check: (
check,
check_enum.get_check_class(check),
),
check_list,
)
# use the pipeline name and the step name to generate a unique suite
# name
try:
# get pipeline name and step name
step_env = cast(
StepEnvironment, Environment()[STEP_ENVIRONMENT_NAME]
)
suite_name = f"{step_env.pipeline_name}_{step_env.step_name}"
except KeyError:
# if not running inside a pipeline step, use random values
suite_name = f"suite_{random_str(5)}"
if is_tabular:
dataset_class = TabularData
suite_class = TabularSuite
full_suite = full_tabular_suite()
else:
dataset_class = VisionData
suite_class = VisionSuite
full_suite = full_vision_suite()
train_dataset = dataset_class(reference_dataset, **dataset_kwargs)
test_dataset = None
if comparison_dataset is not None:
test_dataset = dataset_class(comparison_dataset, **dataset_kwargs)
suite = suite_class(name=suite_name)
# Some Deepchecks checks require a minimum configuration such as
# conditions to be configured (see https://docs.deepchecks.com/stable/user-guide/general/customizations/examples/plot_configure_check_conditions.html#sphx-glr-user-guide-general-customizations-examples-plot-configure-check-conditions-py)
# for their execution to have meaning. For checks that don't have
# custom configuration attributes explicitly specified in the
# `check_kwargs` input parameter, we use the default check
# instances extracted from the full suite shipped with Deepchecks.
default_checks = {
check.__class__: check for check in full_suite.checks.values()
}
for check_name, check_class in check_classes:
extra_kwargs = check_kwargs.get(check_name, {})
default_check = default_checks.get(check_class)
check: BaseCheck
if extra_kwargs or not default_check:
check = check_class(**check_kwargs)
else:
check = default_check
# extract the condition kwargs from the check kwargs
for arg_name, condition_kwargs in extra_kwargs.items():
if not arg_name.startswith("condition_") or not isinstance(
condition_kwargs, dict
):
continue
condition_method = getattr(check, f"add_{arg_name}", None)
if not condition_method or not callable(condition_method):
logger.warning(
f"Deepchecks check type {check.__class__} has no "
f"condition named {arg_name}. Ignoring the check "
f"argument."
)
continue
condition_method(**condition_kwargs)
suite.add(check)
return suite.run(
train_dataset=train_dataset,
test_dataset=test_dataset,
model=model,
**run_kwargs,
)
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
data_validation(self, dataset, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the comparison_dataset
argument.
The check_list
argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
DeepchecksDataIntegrityCheck
and DeepchecksDataDriftCheck
enum
values. If omitted:
-
if the
comparison_dataset
is omitted, a suite with all available data integrity checks will be performed on the input data. SeeDeepchecksDataIntegrityCheck
for a list of Deepchecks builtin checks that are compatible with this method. -
if the
comparison_dataset
is supplied, a suite with all available data drift checks will be performed on the input data. SeeDeepchecksDataDriftCheck
for a list of Deepchecks builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for data comparison checks (e.g data drift checks). |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def data_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> SuiteResult:
"""Run one or more Deepchecks data validation checks on a dataset.
Call this method to analyze and identify potential integrity problems
with a single dataset (e.g. missing values, conflicting labels, mixed
data types etc.) and dataset comparison checks (e.g. data drift
checks). Dataset comparison checks require that a second dataset be
supplied via the `comparison_dataset` argument.
The `check_list` argument may be used to specify a custom set of
Deepchecks data integrity checks to perform, identified by
`DeepchecksDataIntegrityCheck` and `DeepchecksDataDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
data integrity checks will be performed on the input data. See
`DeepchecksDataIntegrityCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available data drift checks will be performed on the input
data. See `DeepchecksDataDriftCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
comparison_dataset: Optional second dataset to be used for data
comparison checks (e.g data drift checks).
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the data validation checks to be performed.
`DeepchecksDataIntegrityCheck` enum values should be used for
single data validation checks and `DeepchecksDataDriftCheck`
enum values for data comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksDataIntegrityCheck
else:
check_enum = DeepchecksDataDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
model_validation(self, dataset, model, comparison_dataset=None, check_list=None, dataset_kwargs={}, check_kwargs={}, run_kwargs={}, **kwargs)
Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion matrix validation, performance reports, model error analyses, etc). A second dataset is required for model performance comparison tests (i.e. tests that identify changes in a model behavior by comparing how it performs on two different datasets).
The check_list
argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
DeepchecksModelValidationCheck
and DeepchecksModelDriftCheck
enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[pandas.core.frame.DataFrame, torch.utils.data.dataloader.DataLoader[Any]] |
Target dataset to be validated. |
required |
model |
Union[sklearn.base.ClassifierMixin, torch.nn.modules.module.Module] |
Target model to be validated. |
required |
comparison_dataset |
Optional[Any] |
Optional second dataset to be used for model comparison checks. |
None |
check_list |
Optional[Sequence[str]] |
Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
|
None |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the Deepchecks tabular.Dataset or vision.VisionData constructor. |
{} |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
{} |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
{} |
kwargs |
Any |
Additional keyword arguments (unused). |
{} |
Returns:
Type | Description |
---|---|
Any |
A Deepchecks SuiteResult with the results of the validation. |
Source code in zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py
def model_validation(
self,
dataset: Union[pd.DataFrame, DataLoader[Any]],
model: Union[ClassifierMixin, Module],
comparison_dataset: Optional[Any] = None,
check_list: Optional[Sequence[str]] = None,
dataset_kwargs: Dict[str, Any] = {},
check_kwargs: Dict[str, Dict[str, Any]] = {},
run_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> Any:
"""Run one or more Deepchecks model validation checks.
Call this method to perform model validation checks (e.g. confusion
matrix validation, performance reports, model error analyses, etc).
A second dataset is required for model performance comparison tests
(i.e. tests that identify changes in a model behavior by comparing how
it performs on two different datasets).
The `check_list` argument may be used to specify a custom set of
Deepchecks model validation checks to perform, identified by
`DeepchecksModelValidationCheck` and `DeepchecksModelDriftCheck` enum
values. If omitted:
* if the `comparison_dataset` is omitted, a suite with all available
model validation checks will be performed on the input data. See
`DeepchecksModelValidationCheck` for a list of Deepchecks builtin
checks that are compatible with this method.
* if the `comparison_dataset` is supplied, a suite with all
available model comparison checks will be performed on the input
data. See `DeepchecksModelValidationCheck` for a list of Deepchecks
builtin checks that are compatible with this method.
Args:
dataset: Target dataset to be validated.
model: Target model to be validated.
comparison_dataset: Optional second dataset to be used for model
comparison checks.
check_list: Optional list of ZenML Deepchecks check identifiers
specifying the model validation checks to be performed.
`DeepchecksModelValidationCheck` enum values should be used for
model validation checks and `DeepchecksModelDriftCheck` enum
values for model comparison checks. If not supplied, the
entire set of checks applicable to the input dataset(s)
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks tabular.Dataset or vision.VisionData constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
kwargs: Additional keyword arguments (unused).
Returns:
A Deepchecks SuiteResult with the results of the validation.
"""
check_enum: Type[DeepchecksValidationCheck]
if comparison_dataset is None:
check_enum = DeepchecksModelValidationCheck
else:
check_enum = DeepchecksModelDriftCheck
return self._create_and_run_check_suite(
check_enum=check_enum,
reference_dataset=dataset,
comparison_dataset=comparison_dataset,
model=model,
check_list=check_list,
dataset_kwargs=dataset_kwargs,
check_kwargs=check_kwargs,
run_kwargs=run_kwargs,
)
materializers
special
Deepchecks materializers.
deepchecks_dataset_materializer
Implementation of Deepchecks dataset materializer.
DeepchecksDatasetMaterializer (BaseMaterializer)
Materializer to read data to and from Deepchecks dataset.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
class DeepchecksDatasetMaterializer(BaseMaterializer):
"""Materializer to read data to and from Deepchecks dataset."""
ASSOCIATED_TYPES = (Dataset,)
ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
handle_input(self, data_type)
Reads pandas dataframes and creates deepchecks.Dataset from it.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Dataset |
A Deepchecks Dataset. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_input(self, data_type: Type[Any]) -> Dataset:
"""Reads pandas dataframes and creates deepchecks.Dataset from it.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks Dataset.
"""
super().handle_input(data_type)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
df = pandas_materializer.handle_input(data_type)
# Recreate from pandas dataframe
return Dataset(df)
handle_return(self, df)
Serializes pandas dataframe within a Dataset object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
df |
Dataset |
A deepchecks.Dataset object. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py
def handle_return(self, df: Dataset) -> None:
"""Serializes pandas dataframe within a Dataset object.
Args:
df: A deepchecks.Dataset object.
"""
super().handle_return(df)
# Outsource to pandas
pandas_materializer = PandasMaterializer(self.artifact)
pandas_materializer.handle_return(df.data)
deepchecks_results_materializer
Implementation of Deepchecks suite results materializer.
DeepchecksResultMaterializer (BaseMaterializer)
Materializer to read data to and from CheckResult and SuiteResult objects.
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
class DeepchecksResultMaterializer(BaseMaterializer):
"""Materializer to read data to and from CheckResult and SuiteResult objects."""
ASSOCIATED_TYPES = (
CheckResult,
SuiteResult,
)
ASSOCIATED_ARTIFACT_TYPES = (DataAnalysisArtifact,)
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
handle_input(self, data_type)
Reads a Deepchecks check or suite result from a serialized JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_type |
Type[Any] |
The type of the data to read. |
required |
Returns:
Type | Description |
---|---|
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the input data type is not supported. |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_input(
self, data_type: Type[Any]
) -> Union[CheckResult, SuiteResult]:
"""Reads a Deepchecks check or suite result from a serialized JSON file.
Args:
data_type: The type of the data to read.
Returns:
A Deepchecks CheckResult or SuiteResult.
Raises:
RuntimeError: if the input data type is not supported.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
json_res = io_utils.read_file_contents_as_string(filepath)
if data_type == SuiteResult:
res = SuiteResult.from_json(json_res)
elif data_type == CheckResult:
res = CheckResult.from_json(json_res)
else:
raise RuntimeError(f"Unknown data type: {data_type}")
return res
handle_return(self, result)
Creates a JSON serialization for a CheckResult or SuiteResult.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
result |
Union[deepchecks.core.check_result.CheckResult, deepchecks.core.suite.SuiteResult] |
A Deepchecks CheckResult or SuiteResult. |
required |
Source code in zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py
def handle_return(self, result: Union[CheckResult, SuiteResult]) -> None:
"""Creates a JSON serialization for a CheckResult or SuiteResult.
Args:
result: A Deepchecks CheckResult or SuiteResult.
"""
super().handle_return(result)
filepath = os.path.join(self.artifact.uri, RESULTS_FILENAME)
serialized_json = result.to_json(True)
io_utils.write_file_contents_as_string(filepath, serialized_json)
steps
special
Initialization of the Deepchecks Standard Steps.
deepchecks_data_drift
Implementation of the Deepchecks data drift validation step.
DeepchecksDataDriftCheckStep (BaseStep)
Deepchecks data drift validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStep(BaseStep):
"""Deepchecks data drift validator step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, config)
Main entrypoint for the Deepchecks data drift validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the data drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the data drift check. |
required |
config |
DeepchecksDataDriftCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
config: DeepchecksDataDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data drift validator step.
Args:
reference_dataset: Reference dataset for the data drift check.
target_dataset: Target dataset to be used for the data drift check.
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksDataDriftCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataDriftCheck]] |
Optional list of DeepchecksDataDriftCheck identifiers specifying the subset of Deepchecks data drift checks to be performed. If not supplied, the entire set of data drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
class DeepchecksDataDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data drift validator step.
Attributes:
check_list: Optional list of DeepchecksDataDriftCheck identifiers
specifying the subset of Deepchecks data drift checks to be
performed. If not supplied, the entire set of data drift checks will
be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_drift_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to run data drift checks on two input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksDataDriftCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_drift.py
def deepchecks_data_drift_check_step(
step_name: str,
config: DeepchecksDataDriftCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataDriftCheckStep step.
The returned DeepchecksDataDriftCheckStep can be used in a pipeline to
run data drift checks on two input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksDataDriftCheckStep step instance
"""
return clone_step(DeepchecksDataDriftCheckStep, step_name)(config=config)
deepchecks_data_integrity
Implementation of the Deepchecks data integrity validation step.
DeepchecksDataIntegrityCheckStep (BaseStep)
Deepchecks data integrity validator step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStep(BaseStep):
"""Deepchecks data integrity validator step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, config)
Main entrypoint for the Deepchecks data integrity validator step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to validate |
required |
config |
DeepchecksDataIntegrityCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks data integrity validator step.
Args:
dataset: a Pandas DataFrame to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.data_validation(
dataset=dataset,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksDataIntegrityCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks data integrity validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksDataIntegrityCheck]] |
Optional list of DeepchecksDataIntegrityCheck identifiers specifying the subset of Deepchecks data integrity checks to be performed. If not supplied, the entire set of data integrity checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
class DeepchecksDataIntegrityCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks data integrity validator step.
Attributes:
check_list: Optional list of DeepchecksDataIntegrityCheck identifiers
specifying the subset of Deepchecks data integrity checks to be
performed. If not supplied, the entire set of data integrity checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_data_integrity_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to run data integrity checks on an input pd.DataFrame and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksDataIntegrityCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksDataIntegrityCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py
def deepchecks_data_integrity_check_step(
step_name: str,
config: DeepchecksDataIntegrityCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksDataIntegrityCheckStep step.
The returned DeepchecksDataIntegrityCheckStep can be used in a pipeline to
run data integrity checks on an input pd.DataFrame and return the results
as a Deepchecks SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksDataIntegrityCheckStep step instance
"""
return clone_step(DeepchecksDataIntegrityCheckStep, step_name)(
config=config
)
deepchecks_model_drift
Implementation of the Deepchecks model drift validation step.
DeepchecksModelDriftCheckStep (BaseStep)
Deepchecks model drift step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStep(BaseStep):
"""Deepchecks model drift step."""
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, reference_dataset, target_dataset, model, config)
Main entrypoint for the Deepchecks model drift step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference_dataset |
DataFrame |
Reference dataset for the model drift check. |
required |
target_dataset |
DataFrame |
Target dataset to be used for the model drift check. |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
config |
DeepchecksModelDriftCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def entrypoint( # type: ignore[override]
self,
reference_dataset: pd.DataFrame,
target_dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelDriftCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model drift step.
Args:
reference_dataset: Reference dataset for the model drift check.
target_dataset: Target dataset to be used for the model drift check.
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=reference_dataset,
comparison_dataset=target_dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksModelDriftCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model drift validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelDriftCheck]] |
Optional list of DeepchecksModelDriftCheck identifiers specifying the subset of Deepchecks model drift checks to be performed. If not supplied, the entire set of model drift checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
class DeepchecksModelDriftCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model drift validator step.
Attributes:
check_list: Optional list of DeepchecksModelDriftCheck identifiers
specifying the subset of Deepchecks model drift checks to be
performed. If not supplied, the entire set of model drift checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_drift_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to run model drift checks on two input pd.DataFrame datasets and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksModelDriftCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelDriftCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_drift.py
def deepchecks_model_drift_check_step(
step_name: str,
config: DeepchecksModelDriftCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelDriftCheckStep step.
The returned DeepchecksModelDriftCheckStep can be used in a pipeline to
run model drift checks on two input pd.DataFrame datasets and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksModelDriftCheckStep step instance
"""
return clone_step(DeepchecksModelDriftCheckStep, step_name)(config=config)
deepchecks_model_validation
Implementation of the Deepchecks model validation validation step.
DeepchecksModelValidationCheckStep (BaseStep)
Deepchecks model validation step.
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStep(BaseStep):
"""Deepchecks model validation step."""
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
CONFIG_CLASS (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
entrypoint(self, dataset, model, config)
Main entrypoint for the Deepchecks model validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataFrame |
a Pandas DataFrame to use for the validation |
required |
model |
ClassifierMixin |
a scikit-learn model to validate |
required |
config |
DeepchecksModelValidationCheckStepConfig |
the configuration for the step |
required |
Returns:
Type | Description |
---|---|
SuiteResult |
A Deepchecks suite result with the validation results. |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def entrypoint( # type: ignore[override]
self,
dataset: pd.DataFrame,
model: ClassifierMixin,
config: DeepchecksModelValidationCheckStepConfig,
) -> SuiteResult:
"""Main entrypoint for the Deepchecks model validation step.
Args:
dataset: a Pandas DataFrame to use for the validation
model: a scikit-learn model to validate
config: the configuration for the step
Returns:
A Deepchecks suite result with the validation results.
"""
data_validator = cast(
DeepchecksDataValidator,
DeepchecksDataValidator.get_active_data_validator(),
)
return data_validator.model_validation(
dataset=dataset,
model=model,
check_list=cast(Optional[Sequence[str]], config.check_list),
dataset_kwargs=config.dataset_kwargs,
check_kwargs=config.check_kwargs,
run_kwargs=config.run_kwargs,
)
DeepchecksModelValidationCheckStepConfig (BaseStepConfig)
pydantic-model
Config class for the Deepchecks model validation validator step.
Attributes:
Name | Type | Description |
---|---|---|
check_list |
Optional[Sequence[zenml.integrations.deepchecks.validation_checks.DeepchecksModelValidationCheck]] |
Optional list of DeepchecksModelValidationCheck identifiers specifying the subset of Deepchecks model validation checks to be performed. If not supplied, the entire set of model validation checks will be performed. |
dataset_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks |
check_kwargs |
Dict[str, Dict[str, Any]] |
Additional keyword arguments to be passed to the Deepchecks check object constructors. Arguments are grouped for each check and indexed using the full check class name or check enum value as dictionary keys. |
run_kwargs |
Dict[str, Any] |
Additional keyword arguments to be passed to the
Deepchecks Suite |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
class DeepchecksModelValidationCheckStepConfig(BaseStepConfig):
"""Config class for the Deepchecks model validation validator step.
Attributes:
check_list: Optional list of DeepchecksModelValidationCheck identifiers
specifying the subset of Deepchecks model validation checks to be
performed. If not supplied, the entire set of model validation checks
will be performed.
dataset_kwargs: Additional keyword arguments to be passed to the
Deepchecks `tabular.Dataset` or `vision.VisionData` constructor.
check_kwargs: Additional keyword arguments to be passed to the
Deepchecks check object constructors. Arguments are grouped for
each check and indexed using the full check class name or
check enum value as dictionary keys.
run_kwargs: Additional keyword arguments to be passed to the
Deepchecks Suite `run` method.
"""
check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None
dataset_kwargs: Dict[str, Any] = Field(default_factory=dict)
check_kwargs: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
run_kwargs: Dict[str, Any] = Field(default_factory=dict)
deepchecks_model_validation_check_step(step_name, config)
Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to run model validation checks on an input pd.DataFrame dataset and an input scikit-learn ClassifierMixin model and return the results as a Deepchecks SuiteResult object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
step_name |
str |
The name of the step |
required |
config |
DeepchecksModelValidationCheckStepConfig |
The configuration for the step |
required |
Returns:
Type | Description |
---|---|
BaseStep |
a DeepchecksModelValidationCheckStep step instance |
Source code in zenml/integrations/deepchecks/steps/deepchecks_model_validation.py
def deepchecks_model_validation_check_step(
step_name: str,
config: DeepchecksModelValidationCheckStepConfig,
) -> BaseStep:
"""Shortcut function to create a new instance of the DeepchecksModelValidationCheckStep step.
The returned DeepchecksModelValidationCheckStep can be used in a pipeline to
run model validation checks on an input pd.DataFrame dataset and an input
scikit-learn ClassifierMixin model and return the results as a Deepchecks
SuiteResult object.
Args:
step_name: The name of the step
config: The configuration for the step
Returns:
a DeepchecksModelValidationCheckStep step instance
"""
return clone_step(DeepchecksModelValidationCheckStep, step_name)(
config=config
)
validation_checks
Definition of the Deepchecks validation check types.
DeepchecksDataDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data drift checks.
This list reflects the set of train-test validation checks provided by
Deepchecks:
* [for tabular data](https://docs.deepchecks.com/stable/checks_gallery/tabular.html#train-test-validation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#train-test-validation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets as input.
"""
TABULAR_CATEGORY_MISMATCH_TRAIN_TEST = resolve_class(
tabular_checks.CategoryMismatchTrainTest
)
TABULAR_DATASET_SIZE_COMPARISON = resolve_class(
tabular_checks.DatasetsSizeComparison
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_DUPLICATES = resolve_class(
tabular_checks.DateTrainTestLeakageDuplicates
)
TABULAR_DATE_TRAIN_TEST_LEAKAGE_OVERLAP = resolve_class(
tabular_checks.DateTrainTestLeakageOverlap
)
TABULAR_DOMINANT_FREQUENCY_CHANGE = resolve_class(
tabular_checks.DominantFrequencyChange
)
TABULAR_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
tabular_checks.FeatureLabelCorrelationChange
)
TABULAR_INDEX_LEAKAGE = resolve_class(tabular_checks.IndexTrainTestLeakage)
TABULAR_NEW_LABEL_TRAIN_TEST = resolve_class(
tabular_checks.NewLabelTrainTest
)
TABULAR_STRING_MISMATCH_COMPARISON = resolve_class(
tabular_checks.StringMismatchComparison
)
TABULAR_TRAIN_TEST_FEATURE_DRIFT = resolve_class(
tabular_checks.TrainTestFeatureDrift
)
TABULAR_TRAIN_TEST_LABEL_DRIFT = resolve_class(
tabular_checks.TrainTestLabelDrift
)
TABULAR_TRAIN_TEST_SAMPLES_MIX = resolve_class(
tabular_checks.TrainTestSamplesMix
)
TABULAR_WHOLE_DATASET_DRIFT = resolve_class(
tabular_checks.WholeDatasetDrift
)
VISION_FEATURE_LABEL_CORRELATION_CHANGE = resolve_class(
vision_checks.FeatureLabelCorrelationChange
)
VISION_HEATMAP_COMPARISON = resolve_class(vision_checks.HeatmapComparison)
VISION_IMAGE_DATASET_DRIFT = resolve_class(vision_checks.ImageDatasetDrift)
VISION_IMAGE_PROPERTY_DRIFT = resolve_class(
vision_checks.ImagePropertyDrift
)
VISION_NEW_LABELS = resolve_class(vision_checks.NewLabels)
VISION_SIMILAR_IMAGE_LEAKAGE = resolve_class(
vision_checks.SimilarImageLeakage
)
VISION_TRAIN_TEST_LABEL_DRIFT = resolve_class(
vision_checks.TrainTestLabelDrift
)
DeepchecksDataIntegrityCheck (DeepchecksValidationCheck)
Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
All these checks inherit from deepchecks.tabular.SingleDatasetCheck
or
deepchecks.vision.SingleDatasetCheck
and require a single dataset as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksDataIntegrityCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks data integrity checks.
This list reflects the set of data integrity checks provided by Deepchecks:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#data-integrity)
* [for computer vision](https://docs.deepchecks.com/en/stable/checks_gallery/vision.html#data-integrity)
All these checks inherit from `deepchecks.tabular.SingleDatasetCheck` or
`deepchecks.vision.SingleDatasetCheck` and require a single dataset as input.
"""
TABULAR_COLUMNS_INFO = resolve_class(tabular_checks.ColumnsInfo)
TABULAR_CONFLICTING_LABELS = resolve_class(tabular_checks.ConflictingLabels)
TABULAR_DATA_DUPLICATES = resolve_class(tabular_checks.DataDuplicates)
TABULAR_FEATURE_FEATURE_CORRELATION = resolve_class(
FeatureFeatureCorrelation
)
TABULAR_FEATURE_LABEL_CORRELATION = resolve_class(
tabular_checks.FeatureLabelCorrelation
)
TABULAR_IDENTIFIER_LEAKAGE = resolve_class(tabular_checks.IdentifierLeakage)
TABULAR_IS_SINGLE_VALUE = resolve_class(tabular_checks.IsSingleValue)
TABULAR_MIXED_DATA_TYPES = resolve_class(tabular_checks.MixedDataTypes)
TABULAR_MIXED_NULLS = resolve_class(tabular_checks.MixedNulls)
TABULAR_OUTLIER_SAMPLE_DETECTION = resolve_class(
tabular_checks.OutlierSampleDetection
)
TABULAR_SPECIAL_CHARS = resolve_class(tabular_checks.SpecialCharacters)
TABULAR_STRING_LENGTH_OUT_OF_BOUNDS = resolve_class(
tabular_checks.StringLengthOutOfBounds
)
TABULAR_STRING_MISMATCH = resolve_class(tabular_checks.StringMismatch)
VISION_IMAGE_PROPERTY_OUTLIERS = resolve_class(
vision_checks.ImagePropertyOutliers
)
VISION_LABEL_PROPERTY_OUTLIERS = resolve_class(
vision_checks.LabelPropertyOutliers
)
DeepchecksModelDriftCheck (DeepchecksValidationCheck)
Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by Deepchecks that require two datasets and a mandatory model as input:
All these checks inherit from deepchecks.tabular.TrainTestCheck
or
deepchecks.vision.TrainTestCheck
and require two datasets and a mandatory
model as input.
Source code in zenml/integrations/deepchecks/validation_checks.py
class DeepchecksModelDriftCheck(DeepchecksValidationCheck):
"""Categories of Deepchecks model drift checks.
This list includes a subset of the model evaluation checks provided by
Deepchecks that require two datasets and a mandatory model as input:
* [for tabular data](https://docs.deepchecks.com/en/stable/checks_gallery/tabular.html#model-evaluation)
* [for computer vision](https://docs.deepchecks.com/stable/checks_gallery/vision.html#model-evaluation)
All these checks inherit from `deepchecks.tabular.TrainTestCheck` or
`deepchecks.vision.TrainTestCheck` and require two datasets and a mandatory
model as input.
"""
TABULAR_BOOSTING_OVERFIT = resolve_class(tabular_checks.BoostingOverfit)
TABULAR_MODEL_ERROR_ANALYSIS = resolve_class(
tabular_checks.ModelErrorAnalysis
)
TABULAR_PERFORMANCE_REPORT = resolve_class(tabular_checks.PerformanceReport)
TABULAR_SIMPLE_MODEL_COMPARISON = resolve_class(
tabular_checks.SimpleModelComparison
)
TABULAR_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
tabular_checks.TrainTestPredictionDrift
)
TABULAR_UNUSED_FEATURES = resolve_class(tabular_checks.UnusedFeatures)
VISION_CLASS_PERFORMANCE = resolve_class(vision_checks.ClassPerformance)
VISION_MODEL_ERROR_ANALYSIS = resolve_class(
vision_checks.ModelErrorAnalysis
)
VISION_SIMPLE_MODEL_COMPARISON = resolve_class(
vision_checks.SimpleModelComparison
)
VISION_TRAIN_TEST_PREDICTION_DRIFT = resolve_class(
vision_checks.TrainTestPredictionDrift
)