Label Studio
zenml.integrations.label_studio
special
Initialization of the Label Studio integration.
LabelStudioIntegration (Integration)
Definition of Label Studio integration for ZenML.
Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
"""Definition of Label Studio integration for ZenML."""
NAME = LABEL_STUDIO
REQUIREMENTS = [
"label-studio>=1.6.0,<=1.7.3",
"label-studio-sdk>=0.0.17,<=0.0.24",
]
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
flavors()
classmethod
Declare the stack component flavors for the Label Studio integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the Label Studio integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.label_studio.flavors import (
LabelStudioAnnotatorFlavor,
)
return [LabelStudioAnnotatorFlavor]
annotators
special
Initialization of the Label Studio annotators submodule.
label_studio_annotator
Implementation of the Label Studio annotation integration.
LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)
Class to interact with the Label Studio annotation interface.
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
"""Class to interact with the Label Studio annotation interface."""
@property
def config(self) -> LabelStudioAnnotatorConfig:
"""Returns the `LabelStudioAnnotatorConfig` config.
Returns:
The configuration.
"""
return cast(LabelStudioAnnotatorConfig, self._config)
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"{self.config.instance_url}:{self.config.port}"
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
def _get_client(self) -> Client:
"""Gets Label Studio client.
Returns:
Label Studio client.
Raises:
ValueError: when unable to access the Label Studio API key.
"""
secret = self.get_authentication_secret(ArbitrarySecretSchema)
if not secret:
raise ValueError(
"Unable to access predefined secret to access Label Studio API key."
)
api_key = secret.content.get("api_key")
if not api_key:
raise ValueError(
"Unable to access Label Studio API key from secret."
)
return Client(url=self.get_url(), api_key=api_key)
def _connection_available(self) -> bool:
"""Checks if the connection to the annotation server is available.
Returns:
True if the connection is available, False otherwise.
"""
try:
result = self._get_client().check_connection()
return result.get("status") == "UP" # type: ignore[no-any-return]
# TODO: [HIGH] refactor to use a more specific exception
except Exception:
logger.error(
"Connection error: No connection was able to be established to the Label Studio backend."
)
return False
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
ls = self._get_client()
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
ls.delete_project(dataset_id)
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
def _get_azure_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Azure import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Azure import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/azure?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_gcs_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all Google Cloud Storage import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of Google Cloud Storage import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/gcs?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _get_s3_import_storage_sources(
self, dataset_id: int
) -> List[Dict[str, Any]]:
"""Gets a list of all AWS S3 import storage sources.
Args:
dataset_id: Id of the dataset.
Returns:
A list of AWS S3 import storage sources.
Raises:
ConnectionError: If the connection to the Label Studio backend is unavailable.
"""
# TODO: check if client actually is connected etc
query_url = f"/api/storages/s3?project={dataset_id}"
response = self._get_client().make_request(method="GET", url=query_url)
if response.status_code == 200:
return cast(List[Dict[str, Any]], response.json())
else:
raise ConnectionError(
f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
)
def _storage_source_already_exists(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> bool:
"""Returns whether a storage source already exists.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
True if the storage source already exists, False otherwise.
Raises:
NotImplementedError: If the storage source type is not supported.
"""
# TODO: check we are already connected
dataset_id = int(dataset.get_params()["id"])
if params.storage_type == "azure":
storage_sources = self._get_azure_import_storage_sources(
dataset_id
)
elif params.storage_type == "gcs":
storage_sources = self._get_gcs_import_storage_sources(dataset_id)
elif params.storage_type == "s3":
storage_sources = self._get_s3_import_storage_sources(dataset_id)
else:
raise NotImplementedError(
f"Storage type '{params.storage_type}' not implemented."
)
return any(
(
source.get("presign") == params.presign
and source.get("bucket") == uri
and source.get("regex_filter") == params.regex_filter
and source.get("use_blob_urls") == params.use_blob_urls
and source.get("title") == dataset.get_params()["title"]
and source.get("description") == params.description
and source.get("presign_ttl") == params.presign_ttl
and source.get("project") == dataset_id
)
for source in storage_sources
)
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
def populate_artifact_store_parameters(
self,
params: LabelStudioDatasetSyncParameters,
artifact_store: BaseArtifactStore,
) -> None:
"""Populate the dataset sync parameters with the artifact store credentials.
Args:
params: The dataset sync parameters.
artifact_store: The active artifact store.
Raises:
RuntimeError: if the artifact store credentials cannot be fetched.
"""
if artifact_store.flavor == "s3":
from zenml.integrations.s3.artifact_stores import S3ArtifactStore
assert isinstance(artifact_store, S3ArtifactStore)
params.storage_type = "s3"
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
) = artifact_store.get_credentials()
if aws_access_key_id and aws_secret_access_key:
# Convert the credentials into the format expected by Label
# Studio
params.aws_access_key_id = aws_access_key_id
params.aws_secret_access_key = aws_secret_access_key
params.aws_session_token = aws_session_token
if artifact_store.config.client_kwargs:
if "endpoint_url" in artifact_store.config.client_kwargs:
params.s3_endpoint = (
artifact_store.config.client_kwargs["endpoint_url"]
)
if "region_name" in artifact_store.config.client_kwargs:
params.s3_region_name = str(
artifact_store.config.client_kwargs["region_name"]
)
return
raise RuntimeError(
"No credentials are configured for the active S3 artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "gcp":
from zenml.integrations.gcp.artifact_stores import GCPArtifactStore
assert isinstance(artifact_store, GCPArtifactStore)
params.storage_type = "gcs"
gcp_credentials = artifact_store.get_credentials()
if gcp_credentials:
# Save the credentials to a file in secure location, because
# Label Studio will need to read it from a file
secret_folder = Path(
GlobalConfiguration().config_directory,
"label-studio",
str(self.id),
)
fileio.makedirs(str(secret_folder))
file_path = Path(
secret_folder, "google_application_credentials.json"
)
with open(file_path, "w") as f:
f.write(json.dumps(gcp_credentials))
file_path.chmod(0o600)
params.google_application_credentials = str(file_path)
return
raise RuntimeError(
"No credentials are configured for the active GCS artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "azure":
from zenml.integrations.azure.artifact_stores import (
AzureArtifactStore,
)
assert isinstance(artifact_store, AzureArtifactStore)
params.storage_type = "azure"
azure_credentials = artifact_store.get_credentials()
if azure_credentials:
# Convert the credentials into the format expected by Label
# Studio
if azure_credentials.connection_string is not None:
try:
# We need to extract the account name and key from the
# connection string
tokens = azure_credentials.connection_string.split(";")
token_dict = dict(
[token.split("=", maxsplit=1) for token in tokens]
)
params.azure_account_name = token_dict["AccountName"]
params.azure_account_key = token_dict["AccountKey"]
except (KeyError, ValueError) as e:
raise RuntimeError(
"The Azure connection string configured for the "
"artifact store expected format."
) from e
return
if (
azure_credentials.account_name is not None
and azure_credentials.account_key is not None
):
params.azure_account_name = azure_credentials.account_name
params.azure_account_key = azure_credentials.account_key
return
raise RuntimeError(
"The Label Studio annotator could not use the "
"credentials currently configured in the active Azure "
"artifact store because it only supports Azure storage "
"account credentials. "
"Please use Azure storage account credentials for your "
"artifact store."
)
raise RuntimeError(
"No credentials are configured for the active Azure artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "local":
from zenml.artifact_stores.local_artifact_store import (
LocalArtifactStore,
)
assert isinstance(artifact_store, LocalArtifactStore)
params.storage_type = "local"
if params.prefix is None:
params.prefix = artifact_store.path
elif not params.prefix.startswith(artifact_store.path.lstrip("/")):
raise RuntimeError(
"The prefix for the local storage must be a subdirectory "
"of the local artifact store path."
)
return
raise RuntimeError(
f"The active artifact store type '{artifact_store.flavor}' is not "
"supported by ZenML's Label Studio integration. "
"Please use one of the supported artifact stores (S3, GCP, "
"Azure or local)."
)
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if (
not params.aws_access_key_id
or not params.aws_secret_access_key
):
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
"Label Studio web UI as per your needs."
)
# temporary fix using client method until LS supports
# recursive_scan in their SDK
# (https://github.com/heartexlabs/label-studio-sdk/pull/130)
ls_client = self._get_client()
payload = {
"bucket": uri,
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"aws_access_key_id": params.aws_access_key_id,
"aws_secret_access_key": params.aws_secret_access_key,
"aws_session_token": params.aws_session_token,
"region_name": params.s3_region_name,
"s3_endpoint": params.s3_endpoint,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
"project": dataset.id,
"recursive_scan": True,
}
response = ls_client.make_request(
"POST", "/api/storages/s3", json=payload
)
storage = response.json()
elif params.storage_type == "local":
if not params.prefix:
raise ValueError(
"The 'prefix' parameter is required for local storage "
"synchronization."
)
# Drop arguments that are not used by the local storage
storage_connection_args.pop("presign")
storage_connection_args.pop("presign_ttl")
storage_connection_args.pop("prefix")
prefix = params.prefix
if not prefix.startswith("/"):
prefix = f"/{prefix}"
root_path = Path(prefix).parent
# Set the environment variables required by Label Studio
# to allow local file serving (see https://labelstud.io/guide/storage.html#Prerequisites-2)
os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"] = "true"
os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"] = str(
root_path
)
storage = dataset.connect_local_import_storage(
local_store_path=prefix,
**storage_connection_args,
)
del os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"]
del os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"]
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not "
"supported by ZenML's Label Studio integration. Please choose "
"between 'azure', 'gcs', 'aws' or 'local'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
config: LabelStudioAnnotatorConfig
property
readonly
Returns the LabelStudioAnnotatorConfig
config.
Returns:
Type | Description |
---|---|
LabelStudioAnnotatorConfig |
The configuration. |
add_dataset(self, **kwargs)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Exceptions:
Type | Description |
---|---|
ValueError |
if 'dataset_name' and 'label_config' aren't provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
"""Registers a dataset for annotation.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
A Label Studio Project object.
Raises:
ValueError: if 'dataset_name' and 'label_config' aren't provided.
"""
dataset_name = kwargs.get("dataset_name")
label_config = kwargs.get("label_config")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
elif not label_config:
raise ValueError("`label_config` keyword argument is required.")
return self._get_client().start_project(
title=dataset_name,
label_config=label_config,
)
connect_and_sync_external_storage(self, uri, params, dataset)
Syncs the external storage for the given project.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
URI of the storage source. |
required |
params |
LabelStudioDatasetSyncParameters |
Parameters for the dataset. |
required |
dataset |
Project |
Label Studio dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[Dict[str, Any]] |
A dictionary containing the sync result. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the storage type is not supported. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
self,
uri: str,
params: LabelStudioDatasetSyncParameters,
dataset: Project,
) -> Optional[Dict[str, Any]]:
"""Syncs the external storage for the given project.
Args:
uri: URI of the storage source.
params: Parameters for the dataset.
dataset: Label Studio dataset.
Returns:
A dictionary containing the sync result.
Raises:
ValueError: If the storage type is not supported.
"""
# TODO: check if proposed storage source has differing / new data
# if self._storage_source_already_exists(uri, config, dataset):
# return None
storage_connection_args = {
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
}
if params.storage_type == "azure":
if not params.azure_account_name or not params.azure_account_key:
logger.warning(
"Authentication credentials for Azure aren't fully "
"provided. Please update the storage synchronization "
"settings in the Label Studio web UI as per your needs."
)
storage = dataset.connect_azure_import_storage(
container=uri,
account_name=params.azure_account_name,
account_key=params.azure_account_key,
**storage_connection_args,
)
elif params.storage_type == "gcs":
if not params.google_application_credentials:
logger.warning(
"Authentication credentials for Google Cloud Storage "
"aren't fully provided. Please update the storage "
"synchronization settings in the Label Studio web UI as "
"per your needs."
)
storage = dataset.connect_google_import_storage(
bucket=uri,
google_application_credentials=params.google_application_credentials,
**storage_connection_args,
)
elif params.storage_type == "s3":
if (
not params.aws_access_key_id
or not params.aws_secret_access_key
):
logger.warning(
"Authentication credentials for S3 aren't fully provided."
"Please update the storage synchronization settings in the "
"Label Studio web UI as per your needs."
)
# temporary fix using client method until LS supports
# recursive_scan in their SDK
# (https://github.com/heartexlabs/label-studio-sdk/pull/130)
ls_client = self._get_client()
payload = {
"bucket": uri,
"prefix": params.prefix,
"regex_filter": params.regex_filter,
"use_blob_urls": params.use_blob_urls,
"aws_access_key_id": params.aws_access_key_id,
"aws_secret_access_key": params.aws_secret_access_key,
"aws_session_token": params.aws_session_token,
"region_name": params.s3_region_name,
"s3_endpoint": params.s3_endpoint,
"presign": params.presign,
"presign_ttl": params.presign_ttl,
"title": dataset.get_params()["title"],
"description": params.description,
"project": dataset.id,
"recursive_scan": True,
}
response = ls_client.make_request(
"POST", "/api/storages/s3", json=payload
)
storage = response.json()
elif params.storage_type == "local":
if not params.prefix:
raise ValueError(
"The 'prefix' parameter is required for local storage "
"synchronization."
)
# Drop arguments that are not used by the local storage
storage_connection_args.pop("presign")
storage_connection_args.pop("presign_ttl")
storage_connection_args.pop("prefix")
prefix = params.prefix
if not prefix.startswith("/"):
prefix = f"/{prefix}"
root_path = Path(prefix).parent
# Set the environment variables required by Label Studio
# to allow local file serving (see https://labelstud.io/guide/storage.html#Prerequisites-2)
os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"] = "true"
os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"] = str(
root_path
)
storage = dataset.connect_local_import_storage(
local_store_path=prefix,
**storage_connection_args,
)
del os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"]
del os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"]
else:
raise ValueError(
f"Invalid storage type. '{params.storage_type}' is not "
"supported by ZenML's Label Studio integration. Please choose "
"between 'azure', 'gcs', 'aws' or 'local'."
)
synced_storage = self._get_client().sync_storage(
storage_id=storage["id"], storage_type=storage["type"]
)
return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)
Deletes a dataset from the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
"""Deletes a dataset from the annotation interface.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio
client.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
ls = self._get_client()
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
ls.delete_project(dataset_id)
get_converted_dataset(self, dataset_name, output_format)
Extract annotated tasks in a specific converted format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Id of the dataset. |
required |
output_format |
str |
Output format. |
required |
Returns:
Type | Description |
---|---|
Dict[Any, Any] |
A dictionary containing the converted dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
"""Extract annotated tasks in a specific converted format.
Args:
dataset_name: Id of the dataset.
output_format: Output format.
Returns:
A dictionary containing the converted dataset.
"""
project = self.get_dataset(dataset_name=dataset_name)
return project.export_tasks(export_type=output_format) # type: ignore[no-any-return]
get_dataset(self, **kwargs)
Gets the dataset with the given name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The LabelStudio Dataset object (a 'Project') for the given name. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
"""Gets the dataset with the given name.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The LabelStudio Dataset object (a 'Project') for the given name.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
# TODO: check for and raise error if client unavailable
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id)
get_dataset_names(self)
Gets the names of the datasets.
Returns:
Type | Description |
---|---|
List[str] |
A list of dataset names. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
"""Gets the names of the datasets.
Returns:
A list of dataset names.
"""
return [
dataset.get_params()["title"] for dataset in self.get_datasets()
]
get_dataset_stats(self, dataset_name)
Gets the statistics of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Tuple[int, int] |
A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset. |
Exceptions:
Type | Description |
---|---|
IndexError |
If the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
"""Gets the statistics of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
A tuple containing (labeled_task_count, unlabeled_task_count) for
the dataset.
Raises:
IndexError: If the dataset does not exist.
"""
for project in self.get_datasets():
if dataset_name in project.get_params()["title"]:
labeled_task_count = len(project.get_labeled_tasks())
unlabeled_task_count = len(project.get_unlabeled_tasks())
return (labeled_task_count, unlabeled_task_count)
raise IndexError(
f"Dataset {dataset_name} not found. Please use "
f"`zenml annotator dataset list` to list all available datasets."
)
get_datasets(self)
Gets the datasets currently available for annotation.
Returns:
Type | Description |
---|---|
List[Any] |
A list of datasets. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
"""Gets the datasets currently available for annotation.
Returns:
A list of datasets.
"""
datasets = self._get_client().get_projects()
return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)
Gets the ID of the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
Optional[int] |
The ID of the dataset. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
"""Gets the ID of the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The ID of the dataset.
"""
projects = self.get_datasets()
for project in projects:
if project.get_params()["title"] == dataset_name:
return cast(int, project.get_params()["id"])
return None
get_labeled_data(self, **kwargs)
Gets the labeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Any |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The labeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided or if the dataset does not exist. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
"""Gets the labeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The labeled data.
Raises:
ValueError: If the dataset name is not provided or if the dataset
does not exist.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)
Returns the parsed Label Studio label config for a dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_id |
int |
Id of the dataset. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any] |
A dictionary containing the parsed label config. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no dataset is found for the given id. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
"""Returns the parsed Label Studio label config for a dataset.
Args:
dataset_id: Id of the dataset.
Returns:
A dictionary containing the parsed label config.
Raises:
ValueError: If no dataset is found for the given id.
"""
# TODO: check if client actually is connected etc
dataset = self._get_client().get_project(dataset_id)
if dataset:
return cast(Dict[str, Any], dataset.parsed_label_config)
raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)
Gets the unlabeled data for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
str |
Additional keyword arguments to pass to the Label Studio client. |
{} |
Returns:
Type | Description |
---|---|
Any |
The unlabeled data. |
Exceptions:
Type | Description |
---|---|
ValueError |
If the dataset name is not provided. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
"""Gets the unlabeled data for the given dataset.
Args:
**kwargs: Additional keyword arguments to pass to the Label Studio client.
Returns:
The unlabeled data.
Raises:
ValueError: If the dataset name is not provided.
"""
dataset_name = kwargs.get("dataset_name")
if not dataset_name:
raise ValueError("`dataset_name` keyword argument is required.")
dataset_id = self.get_id_from_name(dataset_name)
if not dataset_id:
raise ValueError(
f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
)
return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)
Gets the top-level URL of the annotation interface.
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
"""Gets the top-level URL of the annotation interface.
Returns:
The URL of the annotation interface.
"""
return f"{self.config.instance_url}:{self.config.port}"
get_url_for_dataset(self, dataset_name)
Gets the URL of the annotation interface for the given dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
The name of the dataset. |
required |
Returns:
Type | Description |
---|---|
str |
The URL of the annotation interface. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
"""Gets the URL of the annotation interface for the given dataset.
Args:
dataset_name: The name of the dataset.
Returns:
The URL of the annotation interface.
"""
project_id = self.get_id_from_name(dataset_name)
return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)
Launches the annotation interface.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
Optional[str] |
The URL of the annotation interface. |
required |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
"""Launches the annotation interface.
Args:
url: The URL of the annotation interface.
"""
if not url:
url = self.get_url()
if self._connection_available():
webbrowser.open(url, new=1, autoraise=True)
else:
logger.warning(
"Could not launch annotation interface"
"because the connection could not be established."
)
populate_artifact_store_parameters(self, params, artifact_store)
Populate the dataset sync parameters with the artifact store credentials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetSyncParameters |
The dataset sync parameters. |
required |
artifact_store |
BaseArtifactStore |
The active artifact store. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
if the artifact store credentials cannot be fetched. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def populate_artifact_store_parameters(
self,
params: LabelStudioDatasetSyncParameters,
artifact_store: BaseArtifactStore,
) -> None:
"""Populate the dataset sync parameters with the artifact store credentials.
Args:
params: The dataset sync parameters.
artifact_store: The active artifact store.
Raises:
RuntimeError: if the artifact store credentials cannot be fetched.
"""
if artifact_store.flavor == "s3":
from zenml.integrations.s3.artifact_stores import S3ArtifactStore
assert isinstance(artifact_store, S3ArtifactStore)
params.storage_type = "s3"
(
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
) = artifact_store.get_credentials()
if aws_access_key_id and aws_secret_access_key:
# Convert the credentials into the format expected by Label
# Studio
params.aws_access_key_id = aws_access_key_id
params.aws_secret_access_key = aws_secret_access_key
params.aws_session_token = aws_session_token
if artifact_store.config.client_kwargs:
if "endpoint_url" in artifact_store.config.client_kwargs:
params.s3_endpoint = (
artifact_store.config.client_kwargs["endpoint_url"]
)
if "region_name" in artifact_store.config.client_kwargs:
params.s3_region_name = str(
artifact_store.config.client_kwargs["region_name"]
)
return
raise RuntimeError(
"No credentials are configured for the active S3 artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "gcp":
from zenml.integrations.gcp.artifact_stores import GCPArtifactStore
assert isinstance(artifact_store, GCPArtifactStore)
params.storage_type = "gcs"
gcp_credentials = artifact_store.get_credentials()
if gcp_credentials:
# Save the credentials to a file in secure location, because
# Label Studio will need to read it from a file
secret_folder = Path(
GlobalConfiguration().config_directory,
"label-studio",
str(self.id),
)
fileio.makedirs(str(secret_folder))
file_path = Path(
secret_folder, "google_application_credentials.json"
)
with open(file_path, "w") as f:
f.write(json.dumps(gcp_credentials))
file_path.chmod(0o600)
params.google_application_credentials = str(file_path)
return
raise RuntimeError(
"No credentials are configured for the active GCS artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "azure":
from zenml.integrations.azure.artifact_stores import (
AzureArtifactStore,
)
assert isinstance(artifact_store, AzureArtifactStore)
params.storage_type = "azure"
azure_credentials = artifact_store.get_credentials()
if azure_credentials:
# Convert the credentials into the format expected by Label
# Studio
if azure_credentials.connection_string is not None:
try:
# We need to extract the account name and key from the
# connection string
tokens = azure_credentials.connection_string.split(";")
token_dict = dict(
[token.split("=", maxsplit=1) for token in tokens]
)
params.azure_account_name = token_dict["AccountName"]
params.azure_account_key = token_dict["AccountKey"]
except (KeyError, ValueError) as e:
raise RuntimeError(
"The Azure connection string configured for the "
"artifact store expected format."
) from e
return
if (
azure_credentials.account_name is not None
and azure_credentials.account_key is not None
):
params.azure_account_name = azure_credentials.account_name
params.azure_account_key = azure_credentials.account_key
return
raise RuntimeError(
"The Label Studio annotator could not use the "
"credentials currently configured in the active Azure "
"artifact store because it only supports Azure storage "
"account credentials. "
"Please use Azure storage account credentials for your "
"artifact store."
)
raise RuntimeError(
"No credentials are configured for the active Azure artifact "
"store. The Label Studio annotator needs explicit credentials "
"to be configured for your artifact store to sync data "
"artifacts."
)
elif artifact_store.flavor == "local":
from zenml.artifact_stores.local_artifact_store import (
LocalArtifactStore,
)
assert isinstance(artifact_store, LocalArtifactStore)
params.storage_type = "local"
if params.prefix is None:
params.prefix = artifact_store.path
elif not params.prefix.startswith(artifact_store.path.lstrip("/")):
raise RuntimeError(
"The prefix for the local storage must be a subdirectory "
"of the local artifact store path."
)
return
raise RuntimeError(
f"The active artifact store type '{artifact_store.flavor}' is not "
"supported by ZenML's Label Studio integration. "
"Please use one of the supported artifact stores (S3, GCP, "
"Azure or local)."
)
register_dataset_for_annotation(self, params)
Registers a dataset for annotation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Parameters for the dataset. |
required |
Returns:
Type | Description |
---|---|
Any |
A Label Studio Project object. |
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
self,
params: LabelStudioDatasetRegistrationParameters,
) -> Any:
"""Registers a dataset for annotation.
Args:
params: Parameters for the dataset.
Returns:
A Label Studio Project object.
"""
project_id = self.get_id_from_name(params.dataset_name)
if project_id:
dataset = self._get_client().get_project(project_id)
else:
dataset = self.add_dataset(
dataset_name=params.dataset_name,
label_config=params.label_config,
)
return dataset
flavors
special
Label Studio integration flavors.
label_studio_annotator_flavor
Label Studio annotator flavor.
LabelStudioAnnotatorConfig (BaseAnnotatorConfig, AuthenticationConfigMixin)
pydantic-model
Config for the Label Studio annotator.
Attributes:
Name | Type | Description |
---|---|---|
instance_url |
str |
URL of the Label Studio instance. |
port |
int |
The port to use for the annotation interface. |
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorConfig(
BaseAnnotatorConfig, AuthenticationConfigMixin
):
"""Config for the Label Studio annotator.
Attributes:
instance_url: URL of the Label Studio instance.
port: The port to use for the annotation interface.
"""
instance_url: str = DEFAULT_LOCAL_INSTANCE_URL
port: int = DEFAULT_LOCAL_LABEL_STUDIO_PORT
LabelStudioAnnotatorFlavor (BaseAnnotatorFlavor)
Label Studio annotator flavor.
Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorFlavor(BaseAnnotatorFlavor):
"""Label Studio annotator flavor."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return LABEL_STUDIO_ANNOTATOR_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/label_studio.png"
@property
def config_class(self) -> Type[LabelStudioAnnotatorConfig]:
"""Returns `LabelStudioAnnotatorConfig` config class.
Returns:
The config class.
"""
return LabelStudioAnnotatorConfig
@property
def implementation_class(self) -> Type["LabelStudioAnnotator"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.label_studio.annotators import (
LabelStudioAnnotator,
)
return LabelStudioAnnotator
config_class: Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig]
property
readonly
Returns LabelStudioAnnotatorConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[LabelStudioAnnotator]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[LabelStudioAnnotator] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
label_config_generators
special
Initialization of the Label Studio config generators submodule.
label_config_generators
Implementation of label config generators for Label Studio.
generate_basic_object_detection_bounding_boxes_label_config(labels)
Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for object detection with bounding boxes.
This is based on the basic config example shown at
https://labelstud.io/templates/image_bbox.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES
label_config_start = """<View>
<Image name="image" value="$image"/>
<RectangleLabels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = "</RectangleLabels>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_basic_ocr_label_config(labels)
Generates a Label Studio config for optical character recognition (OCR) labeling task.
This is based on the basic config example shown at https://labelstud.io/templates/optical_character_recognition.html
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_ocr_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio config for optical character recognition (OCR) labeling task.
This is based on the basic config example shown at
https://labelstud.io/templates/optical_character_recognition.html
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.OCR
label_config_start = """
<View>
<Image name="image" value="$ocr" zoom="true" zoomControl="true" rotateControl="true"/>
<View>
<Filter toName="label" minlength="0" name="filter"/>
<Labels name="label" toName="image">
"""
label_config_choices = "".join(
f"<Label value='{label}' />\n" for label in labels
)
label_config_end = """
</Labels>
</View>
<Rectangle name="bbox" toName="image" strokeWidth="3"/>
<Polygon name="poly" toName="image" strokeWidth="3"/>
<TextArea name="transcription" toName="image" editable="true" perRegion="true" required="true" maxSubmissions="1" rows="5" placeholder="Recognized Text" displayMode="region-list"/>
</View>
"""
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
generate_image_classification_label_config(labels)
Generates a Label Studio label config for image classification.
This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels |
List[str] |
A list of labels to be used in the label config. |
required |
Returns:
Type | Description |
---|---|
Tuple[str, str] |
A tuple of the generated label config and the label config type. |
Exceptions:
Type | Description |
---|---|
ValueError |
If no labels are provided. |
Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
labels: List[str],
) -> Tuple[str, str]:
"""Generates a Label Studio label config for image classification.
This is based on the basic config example shown at
https://labelstud.io/templates/image_classification.html.
Args:
labels: A list of labels to be used in the label config.
Returns:
A tuple of the generated label config and the label config type.
Raises:
ValueError: If no labels are provided.
"""
if not labels:
raise ValueError("No labels provided")
label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION
label_config_start = """<View>
<Image name="image" value="$image"/>
<Choices name="choice" toName="image">
"""
label_config_choices = "".join(
f"<Choice value='{label}' />\n" for label in labels
)
label_config_end = "</Choices>\n</View>"
label_config = label_config_start + label_config_choices + label_config_end
return (
label_config,
label_config_type,
)
label_studio_utils
Utility functions for the Label Studio annotator integration.
clean_url(url)
Remove extraneous parts of the URL prior to mapping.
Removes the query and netloc parts of the URL, and strips the leading slash
from the path. For example, a string like
'gs%3A//label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG'
would become
label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
A URL string. |
required |
Returns:
Type | Description |
---|---|
str |
A cleaned URL string. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def clean_url(url: str) -> str:
"""Remove extraneous parts of the URL prior to mapping.
Removes the query and netloc parts of the URL, and strips the leading slash
from the path. For example, a string like
`'gs%3A//label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG'`
would become
`label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG`.
Args:
url: A URL string.
Returns:
A cleaned URL string.
"""
parsed = urlparse(url)
parsed = parsed._replace(netloc="", query="")
return parsed.path.lstrip("/")
convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)
Converts a list of predictions from local file references to task id.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
List[Dict[str, Any]] |
List of predictions. |
required |
tasks |
List[Dict[str, Any]] |
List of tasks. |
required |
filename_reference |
str |
Name of the file reference in the predictions. |
required |
storage_type |
str |
Storage type of the predictions. |
required |
Returns:
Type | Description |
---|---|
List[Dict[str, Any]] |
List of predictions using task ids as reference. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
preds: List[Dict[str, Any]],
tasks: List[Dict[str, Any]],
filename_reference: str,
storage_type: str,
) -> List[Dict[str, Any]]:
"""Converts a list of predictions from local file references to task id.
Args:
preds: List of predictions.
tasks: List of tasks.
filename_reference: Name of the file reference in the predictions.
storage_type: Storage type of the predictions.
Returns:
List of predictions using task ids as reference.
"""
filename_id_mapping = {
clean_url(task["data"][filename_reference]): task["id"]
for task in tasks
}
# GCS and S3 URL encodes filenames containing spaces, requiring this
# separate encoding step
if storage_type == "gcs":
# we remove the scheme from the URL to match the pred to the Label
# Studio task
preds = [
{
"filename": quote(pred["filename"]).split("//")[1],
"result": pred["result"],
}
for pred in preds
]
elif storage_type == "s3":
# S3 URLs are of the form s3://bucket-name/path/to/file so we need to
# make sure we only encode the path so we can match the pred to the
# Label Studio task
preds = [
{
"filename": "/".join(
quote(pred["filename"]).split("//")[1].split("/")[1:]
),
"result": pred["result"],
}
for pred in preds
]
return [
{
"task": int(
filename_id_mapping[
urlparse(pred["filename"]).netloc
+ urlparse(pred["filename"]).path
]
),
"result": pred["result"],
}
for pred in preds
]
get_file_extension(path_str)
Return the file extension of the given filename.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_str |
str |
Path to the file. |
required |
Returns:
Type | Description |
---|---|
str |
File extension. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
"""Return the file extension of the given filename.
Args:
path_str: Path to the file.
Returns:
File extension.
"""
return os.path.splitext(urlparse(path_str).path)[1]
is_azure_url(url)
Return whether the given URL is an Azure URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an Azure URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
"""Return whether the given URL is an Azure URL.
Args:
url: URL to check.
Returns:
True if the URL is an Azure URL, False otherwise.
"""
return "blob.core.windows.net" in urlparse(url).netloc
is_gcs_url(url)
Return whether the given URL is an GCS URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an GCS URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
"""Return whether the given URL is an GCS URL.
Args:
url: URL to check.
Returns:
True if the URL is an GCS URL, False otherwise.
"""
return "storage.googleapis.com" in urlparse(url).netloc
is_s3_url(url)
Return whether the given URL is an S3 URL.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
url |
str |
URL to check. |
required |
Returns:
Type | Description |
---|---|
bool |
True if the URL is an S3 URL, False otherwise. |
Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
"""Return whether the given URL is an S3 URL.
Args:
url: URL to check.
Returns:
True if the URL is an S3 URL, False otherwise.
"""
return "s3.amazonaws" in urlparse(url).netloc
steps
special
Standard steps to be used with the Label Studio annotator integration.
label_studio_standard_steps
Implementation of standard steps for the Label Studio annotator integration.
LabelStudioDatasetRegistrationParameters (BaseParameters)
pydantic-model
Step parameters when registering a dataset with Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
label_config |
str |
The label config to use for the annotation interface. |
dataset_name |
str |
Name of the dataset to register. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
"""Step parameters when registering a dataset with Label Studio.
Attributes:
label_config: The label config to use for the annotation interface.
dataset_name: Name of the dataset to register.
"""
label_config: str
dataset_name: str
LabelStudioDatasetSyncParameters (BaseParameters)
pydantic-model
Step parameters when syncing data to Label Studio.
Attributes:
Name | Type | Description |
---|---|---|
storage_type |
str |
The type of storage to sync to. Can be one of ["gcs", "s3", "azure", "local"]. Defaults to "local". |
label_config_type |
str |
The type of label config to use. |
prefix |
Optional[str] |
Specify the prefix within the cloud store to import your data from. For local storage, this is the full absolute path to the directory containing your data. |
regex_filter |
Optional[str] |
Specify a regex filter to filter the files to import. |
use_blob_urls |
Optional[bool] |
Specify whether your data is raw image or video data, or JSON tasks. |
presign |
Optional[bool] |
Specify whether or not to create presigned URLs. |
presign_ttl |
Optional[int] |
Specify how long to keep presigned URLs active. |
description |
Optional[str] |
Specify a description for the dataset. |
azure_account_name |
Optional[str] |
Specify the Azure account name to use for the storage. |
azure_account_key |
Optional[str] |
Specify the Azure account key to use for the storage. |
google_application_credentials |
Optional[str] |
Specify the file with Google application credentials to use for the storage. |
aws_access_key_id |
Optional[str] |
Specify the AWS access key ID to use for the storage. |
aws_secret_access_key |
Optional[str] |
Specify the AWS secret access key to use for the storage. |
aws_session_token |
Optional[str] |
Specify the AWS session token to use for the storage. |
s3_region_name |
Optional[str] |
Specify the S3 region name to use for the storage. |
s3_endpoint |
Optional[str] |
Specify the S3 endpoint to use for the storage. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
"""Step parameters when syncing data to Label Studio.
Attributes:
storage_type: The type of storage to sync to. Can be one of
["gcs", "s3", "azure", "local"]. Defaults to "local".
label_config_type: The type of label config to use.
prefix: Specify the prefix within the cloud store to import your data
from. For local storage, this is the full absolute path to the
directory containing your data.
regex_filter: Specify a regex filter to filter the files to import.
use_blob_urls: Specify whether your data is raw image or video data, or
JSON tasks.
presign: Specify whether or not to create presigned URLs.
presign_ttl: Specify how long to keep presigned URLs active.
description: Specify a description for the dataset.
azure_account_name: Specify the Azure account name to use for the
storage.
azure_account_key: Specify the Azure account key to use for the
storage.
google_application_credentials: Specify the file with Google application
credentials to use for the storage.
aws_access_key_id: Specify the AWS access key ID to use for the
storage.
aws_secret_access_key: Specify the AWS secret access key to use for the
storage.
aws_session_token: Specify the AWS session token to use for the
storage.
s3_region_name: Specify the S3 region name to use for the storage.
s3_endpoint: Specify the S3 endpoint to use for the storage.
"""
storage_type: str = "local"
label_config_type: str
prefix: Optional[str] = None
regex_filter: Optional[str] = ".*"
use_blob_urls: Optional[bool] = True
presign: Optional[bool] = True
presign_ttl: Optional[int] = 1
description: Optional[str] = ""
# credentials specific to the main cloud providers
azure_account_name: Optional[str]
azure_account_key: Optional[str]
google_application_credentials: Optional[str]
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
aws_session_token: Optional[str]
s3_region_name: Optional[str]
s3_endpoint: Optional[str]
get_labeled_data (_DecoratedStep)
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
Name of the dataset. |
required |
Returns:
Type | Description |
---|---|
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(dataset_name)
staticmethod
Gets labeled data from the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset_name |
str |
Name of the dataset. |
required |
Returns:
Type | Description |
---|---|
List |
List of labeled data. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str) -> List: # type: ignore[type-arg]
"""Gets labeled data from the dataset.
Args:
dataset_name: Name of the dataset.
Returns:
List of labeled data.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
# TODO [MEDIUM]: have this check for new data *since the last time this step ran*
annotator = Client().active_stack.annotator
if not annotator:
raise StackComponentInterfaceError("No active annotator.")
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator._connection_available():
dataset = annotator.get_dataset(dataset_name=dataset_name)
return dataset.get_labeled_tasks() # type: ignore[no-any-return]
raise StackComponentInterfaceError(
"Unable to connect to annotator stack component."
)
get_or_create_dataset (_DecoratedStep)
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Step parameters. |
required |
Returns:
Type | Description |
---|---|
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(params)
staticmethod
Gets preexisting dataset or creates a new one.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
LabelStudioDatasetRegistrationParameters |
Step parameters. |
required |
Returns:
Type | Description |
---|---|
str |
The dataset name. |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
params: LabelStudioDatasetRegistrationParameters,
) -> str:
"""Gets preexisting dataset or creates a new one.
Args:
params: Step parameters.
Returns:
The dataset name.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
StackComponentInterfaceError: If no active annotator could be found.
"""
annotator = Client().active_stack.annotator
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
if annotator and annotator._connection_available():
for dataset in annotator.get_datasets():
if dataset.get_params()["title"] == params.dataset_name:
return cast(str, dataset.get_params()["title"])
dataset = annotator.register_dataset_for_annotation(params)
return cast(str, dataset.get_params()["title"])
raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (_DecoratedStep)
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
The URI of the data to sync. |
required | |
dataset_name |
The name of the dataset to sync to. |
required | |
predictions |
The predictions to sync. |
required | |
params |
The parameters for the sync. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
entrypoint(uri, dataset_name, predictions, params)
staticmethod
Syncs new data to Label Studio.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
The URI of the data to sync. |
required |
dataset_name |
str |
The name of the dataset to sync to. |
required |
predictions |
List[Dict[str, Any]] |
The predictions to sync. |
required |
params |
LabelStudioDatasetSyncParameters |
The parameters for the sync. |
required |
Exceptions:
Type | Description |
---|---|
TypeError |
If you are trying to use it with an annotator that is not Label Studio. |
ValueError |
if you are trying to sync from outside ZenML. |
StackComponentInterfaceError |
If no active annotator could be found. |
Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
uri: str,
dataset_name: str,
predictions: List[Dict[str, Any]],
params: LabelStudioDatasetSyncParameters,
) -> None:
"""Syncs new data to Label Studio.
Args:
uri: The URI of the data to sync.
dataset_name: The name of the dataset to sync to.
predictions: The predictions to sync.
params: The parameters for the sync.
Raises:
TypeError: If you are trying to use it with an annotator that is not
Label Studio.
ValueError: if you are trying to sync from outside ZenML.
StackComponentInterfaceError: If no active annotator could be found.
"""
stack = Client().active_stack
annotator = stack.annotator
artifact_store = stack.artifact_store
if not annotator or not artifact_store:
raise StackComponentInterfaceError(
"An active annotator and artifact store are required to run this step."
)
from zenml.integrations.label_studio.annotators.label_studio_annotator import (
LabelStudioAnnotator,
)
if not isinstance(annotator, LabelStudioAnnotator):
raise TypeError(
"This step can only be used with the Label Studio annotator."
)
# TODO: check that annotator is connected before querying it
dataset = annotator.get_dataset(dataset_name=dataset_name)
if not uri.startswith(artifact_store.path):
raise ValueError(
"ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
)
# removes the initial forward slash from the prefix attribute by slicing
params.prefix = urlparse(uri).path.lstrip("/")
base_uri = urlparse(uri).netloc
annotator.populate_artifact_store_parameters(
params=params, artifact_store=artifact_store
)
if annotator and annotator._connection_available():
# TODO: get existing (CHECK!) or create the sync connection
annotator.connect_and_sync_external_storage(
uri=base_uri,
params=params,
dataset=dataset,
)
if predictions:
filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
params.label_config_type
]
preds_with_task_ids = convert_pred_filenames_to_task_ids(
predictions,
dataset.tasks,
filename_reference,
params.storage_type,
)
# TODO: filter out any predictions that exist + have already been
# made (maybe?). Only pass in preds for tasks without pre-annotations.
dataset.create_predictions(preds_with_task_ids)
else:
raise StackComponentInterfaceError("No active annotator.")