Skip to content

Label Studio

zenml.integrations.label_studio special

Initialization of the Label Studio integration.

LabelStudioIntegration (Integration)

Definition of Label Studio integration for ZenML.

Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
    """Definition of Label Studio integration for ZenML."""

    NAME = LABEL_STUDIO
    REQUIREMENTS = [
        "label-studio>=1.6.0,<=1.7.3",
        "label-studio-sdk>=0.0.17,<=0.0.24",
    ]

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the Label Studio integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.label_studio.flavors import (
            LabelStudioAnnotatorFlavor,
        )

        return [LabelStudioAnnotatorFlavor]

flavors() classmethod

Declare the stack component flavors for the Label Studio integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the Label Studio integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.label_studio.flavors import (
        LabelStudioAnnotatorFlavor,
    )

    return [LabelStudioAnnotatorFlavor]

annotators special

Initialization of the Label Studio annotators submodule.

label_studio_annotator

Implementation of the Label Studio annotation integration.

LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)

Class to interact with the Label Studio annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
    """Class to interact with the Label Studio annotation interface."""

    @property
    def config(self) -> LabelStudioAnnotatorConfig:
        """Returns the `LabelStudioAnnotatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(LabelStudioAnnotatorConfig, self._config)

    def get_url(self) -> str:
        """Gets the top-level URL of the annotation interface.

        Returns:
            The URL of the annotation interface.
        """
        return f"{self.config.instance_url}:{self.config.port}"

    def get_url_for_dataset(self, dataset_name: str) -> str:
        """Gets the URL of the annotation interface for the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The URL of the annotation interface.
        """
        project_id = self.get_id_from_name(dataset_name)
        return f"{self.get_url()}/projects/{project_id}/"

    def get_id_from_name(self, dataset_name: str) -> Optional[int]:
        """Gets the ID of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The ID of the dataset.
        """
        projects = self.get_datasets()
        for project in projects:
            if project.get_params()["title"] == dataset_name:
                return cast(int, project.get_params()["id"])
        return None

    def get_datasets(self) -> List[Any]:
        """Gets the datasets currently available for annotation.

        Returns:
            A list of datasets.
        """
        datasets = self._get_client().get_projects()
        return cast(List[Any], datasets)

    def get_dataset_names(self) -> List[str]:
        """Gets the names of the datasets.

        Returns:
            A list of dataset names.
        """
        return [
            dataset.get_params()["title"] for dataset in self.get_datasets()
        ]

    def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
        """Gets the statistics of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            A tuple containing (labeled_task_count, unlabeled_task_count) for
                the dataset.

        Raises:
            IndexError: If the dataset does not exist.
        """
        for project in self.get_datasets():
            if dataset_name in project.get_params()["title"]:
                labeled_task_count = len(project.get_labeled_tasks())
                unlabeled_task_count = len(project.get_unlabeled_tasks())
                return (labeled_task_count, unlabeled_task_count)
        raise IndexError(
            f"Dataset {dataset_name} not found. Please use "
            f"`zenml annotator dataset list` to list all available datasets."
        )

    def launch(self, url: Optional[str]) -> None:
        """Launches the annotation interface.

        Args:
            url: The URL of the annotation interface.
        """
        if not url:
            url = self.get_url()
        if self._connection_available():
            webbrowser.open(url, new=1, autoraise=True)
        else:
            logger.warning(
                "Could not launch annotation interface"
                "because the connection could not be established."
            )

    def _get_client(self) -> Client:
        """Gets Label Studio client.

        Returns:
            Label Studio client.

        Raises:
            ValueError: when unable to access the Label Studio API key.
        """
        secret = self.get_authentication_secret(ArbitrarySecretSchema)
        if not secret:
            raise ValueError(
                "Unable to access predefined secret to access Label Studio API key."
            )
        api_key = secret.content.get("api_key")
        if not api_key:
            raise ValueError(
                "Unable to access Label Studio API key from secret."
            )
        return Client(url=self.get_url(), api_key=api_key)

    def _connection_available(self) -> bool:
        """Checks if the connection to the annotation server is available.

        Returns:
            True if the connection is available, False otherwise.
        """
        try:
            result = self._get_client().check_connection()
            return result.get("status") == "UP"  # type: ignore[no-any-return]
        # TODO: [HIGH] refactor to use a more specific exception
        except Exception:
            logger.error(
                "Connection error: No connection was able to be established to the Label Studio backend."
            )
            return False

    def add_dataset(self, **kwargs: Any) -> Any:
        """Registers a dataset for annotation.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            A Label Studio Project object.

        Raises:
            ValueError: if 'dataset_name' and 'label_config' aren't provided.
        """
        dataset_name = kwargs.get("dataset_name")
        label_config = kwargs.get("label_config")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")
        elif not label_config:
            raise ValueError("`label_config` keyword argument is required.")

        return self._get_client().start_project(
            title=dataset_name,
            label_config=label_config,
        )

    def delete_dataset(self, **kwargs: Any) -> None:
        """Deletes a dataset from the annotation interface.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio
                client.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        ls = self._get_client()
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        ls.delete_project(dataset_id)

    def get_dataset(self, **kwargs: Any) -> Any:
        """Gets the dataset with the given name.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The LabelStudio Dataset object (a 'Project') for the given name.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        # TODO: check for and raise error if client unavailable
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id)

    def get_converted_dataset(
        self, dataset_name: str, output_format: str
    ) -> Dict[Any, Any]:
        """Extract annotated tasks in a specific converted format.

        Args:
            dataset_name: Id of the dataset.
            output_format: Output format.

        Returns:
            A dictionary containing the converted dataset.
        """
        project = self.get_dataset(dataset_name=dataset_name)
        return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]

    def get_labeled_data(self, **kwargs: Any) -> Any:
        """Gets the labeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The labeled data.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_labeled_tasks()

    def get_unlabeled_data(self, **kwargs: str) -> Any:
        """Gets the unlabeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The unlabeled data.

        Raises:
            ValueError: If the dataset name is not provided.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_unlabeled_tasks()

    def register_dataset_for_annotation(
        self,
        params: LabelStudioDatasetRegistrationParameters,
    ) -> Any:
        """Registers a dataset for annotation.

        Args:
            params: Parameters for the dataset.

        Returns:
            A Label Studio Project object.
        """
        project_id = self.get_id_from_name(params.dataset_name)
        if project_id:
            dataset = self._get_client().get_project(project_id)
        else:
            dataset = self.add_dataset(
                dataset_name=params.dataset_name,
                label_config=params.label_config,
            )

        return dataset

    def _get_azure_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Azure import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Azure import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/azure?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_gcs_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Google Cloud Storage import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Google Cloud Storage import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/gcs?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_s3_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all AWS S3 import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of AWS S3 import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/s3?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _storage_source_already_exists(
        self,
        uri: str,
        params: LabelStudioDatasetSyncParameters,
        dataset: Project,
    ) -> bool:
        """Returns whether a storage source already exists.

        Args:
            uri: URI of the storage source.
            params: Parameters for the dataset.
            dataset: Label Studio dataset.

        Returns:
            True if the storage source already exists, False otherwise.

        Raises:
            NotImplementedError: If the storage source type is not supported.
        """
        # TODO: check we are already connected
        dataset_id = int(dataset.get_params()["id"])
        if params.storage_type == "azure":
            storage_sources = self._get_azure_import_storage_sources(
                dataset_id
            )
        elif params.storage_type == "gcs":
            storage_sources = self._get_gcs_import_storage_sources(dataset_id)
        elif params.storage_type == "s3":
            storage_sources = self._get_s3_import_storage_sources(dataset_id)
        else:
            raise NotImplementedError(
                f"Storage type '{params.storage_type}' not implemented."
            )
        return any(
            (
                source.get("presign") == params.presign
                and source.get("bucket") == uri
                and source.get("regex_filter") == params.regex_filter
                and source.get("use_blob_urls") == params.use_blob_urls
                and source.get("title") == dataset.get_params()["title"]
                and source.get("description") == params.description
                and source.get("presign_ttl") == params.presign_ttl
                and source.get("project") == dataset_id
            )
            for source in storage_sources
        )

    def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
        """Returns the parsed Label Studio label config for a dataset.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A dictionary containing the parsed label config.

        Raises:
            ValueError: If no dataset is found for the given id.
        """
        # TODO: check if client actually is connected etc
        dataset = self._get_client().get_project(dataset_id)
        if dataset:
            return cast(Dict[str, Any], dataset.parsed_label_config)
        raise ValueError("No dataset found for the given id.")

    def populate_artifact_store_parameters(
        self,
        params: LabelStudioDatasetSyncParameters,
        artifact_store: BaseArtifactStore,
    ) -> None:
        """Populate the dataset sync parameters with the artifact store credentials.

        Args:
            params: The dataset sync parameters.
            artifact_store: The active artifact store.

        Raises:
            RuntimeError: if the artifact store credentials cannot be fetched.
        """
        if artifact_store.flavor == "s3":
            from zenml.integrations.s3.artifact_stores import S3ArtifactStore

            assert isinstance(artifact_store, S3ArtifactStore)

            params.storage_type = "s3"

            (
                aws_access_key_id,
                aws_secret_access_key,
                aws_session_token,
            ) = artifact_store.get_credentials()

            if aws_access_key_id and aws_secret_access_key:
                # Convert the credentials into the format expected by Label
                # Studio
                params.aws_access_key_id = aws_access_key_id
                params.aws_secret_access_key = aws_secret_access_key
                params.aws_session_token = aws_session_token

                if artifact_store.config.client_kwargs:
                    if "endpoint_url" in artifact_store.config.client_kwargs:
                        params.s3_endpoint = (
                            artifact_store.config.client_kwargs["endpoint_url"]
                        )
                    if "region_name" in artifact_store.config.client_kwargs:
                        params.s3_region_name = str(
                            artifact_store.config.client_kwargs["region_name"]
                        )

                return

            raise RuntimeError(
                "No credentials are configured for the active S3 artifact "
                "store. The Label Studio annotator needs explicit credentials "
                "to be configured for your artifact store to sync data "
                "artifacts."
            )

        elif artifact_store.flavor == "gcp":
            from zenml.integrations.gcp.artifact_stores import GCPArtifactStore

            assert isinstance(artifact_store, GCPArtifactStore)

            params.storage_type = "gcs"

            gcp_credentials = artifact_store.get_credentials()

            if gcp_credentials:
                # Save the credentials to a file in secure location, because
                # Label Studio will need to read it from a file
                secret_folder = Path(
                    GlobalConfiguration().config_directory,
                    "label-studio",
                    str(self.id),
                )
                fileio.makedirs(str(secret_folder))
                file_path = Path(
                    secret_folder, "google_application_credentials.json"
                )
                with open(file_path, "w") as f:
                    f.write(json.dumps(gcp_credentials))
                    file_path.chmod(0o600)

                params.google_application_credentials = str(file_path)

                return

            raise RuntimeError(
                "No credentials are configured for the active GCS artifact "
                "store. The Label Studio annotator needs explicit credentials "
                "to be configured for your artifact store to sync data "
                "artifacts."
            )

        elif artifact_store.flavor == "azure":
            from zenml.integrations.azure.artifact_stores import (
                AzureArtifactStore,
            )

            assert isinstance(artifact_store, AzureArtifactStore)

            params.storage_type = "azure"

            azure_credentials = artifact_store.get_credentials()

            if azure_credentials:
                # Convert the credentials into the format expected by Label
                # Studio
                if azure_credentials.connection_string is not None:
                    try:
                        # We need to extract the account name and key from the
                        # connection string
                        tokens = azure_credentials.connection_string.split(";")
                        token_dict = dict(
                            [token.split("=", maxsplit=1) for token in tokens]
                        )
                        params.azure_account_name = token_dict["AccountName"]
                        params.azure_account_key = token_dict["AccountKey"]
                    except (KeyError, ValueError) as e:
                        raise RuntimeError(
                            "The Azure connection string configured for the "
                            "artifact store expected format."
                        ) from e

                    return

                if (
                    azure_credentials.account_name is not None
                    and azure_credentials.account_key is not None
                ):
                    params.azure_account_name = azure_credentials.account_name
                    params.azure_account_key = azure_credentials.account_key

                    return

                raise RuntimeError(
                    "The Label Studio annotator could not use the "
                    "credentials currently configured in the active Azure "
                    "artifact store because it only supports Azure storage "
                    "account credentials. "
                    "Please use Azure storage account credentials for your "
                    "artifact store."
                )

            raise RuntimeError(
                "No credentials are configured for the active Azure artifact "
                "store. The Label Studio annotator needs explicit credentials "
                "to be configured for your artifact store to sync data "
                "artifacts."
            )

        elif artifact_store.flavor == "local":
            from zenml.artifact_stores.local_artifact_store import (
                LocalArtifactStore,
            )

            assert isinstance(artifact_store, LocalArtifactStore)

            params.storage_type = "local"
            if params.prefix is None:
                params.prefix = artifact_store.path
            elif not params.prefix.startswith(artifact_store.path.lstrip("/")):
                raise RuntimeError(
                    "The prefix for the local storage must be a subdirectory "
                    "of the local artifact store path."
                )
            return

        raise RuntimeError(
            f"The active artifact store type '{artifact_store.flavor}' is not "
            "supported by ZenML's Label Studio integration. "
            "Please use one of the supported artifact stores (S3, GCP, "
            "Azure or local)."
        )

    def connect_and_sync_external_storage(
        self,
        uri: str,
        params: LabelStudioDatasetSyncParameters,
        dataset: Project,
    ) -> Optional[Dict[str, Any]]:
        """Syncs the external storage for the given project.

        Args:
            uri: URI of the storage source.
            params: Parameters for the dataset.
            dataset: Label Studio dataset.

        Returns:
            A dictionary containing the sync result.

        Raises:
            ValueError: If the storage type is not supported.
        """
        # TODO: check if proposed storage source has differing / new data
        # if self._storage_source_already_exists(uri, config, dataset):
        #     return None

        storage_connection_args = {
            "prefix": params.prefix,
            "regex_filter": params.regex_filter,
            "use_blob_urls": params.use_blob_urls,
            "presign": params.presign,
            "presign_ttl": params.presign_ttl,
            "title": dataset.get_params()["title"],
            "description": params.description,
        }
        if params.storage_type == "azure":
            if not params.azure_account_name or not params.azure_account_key:
                logger.warning(
                    "Authentication credentials for Azure aren't fully "
                    "provided. Please update the storage synchronization "
                    "settings in the Label Studio web UI as per your needs."
                )
            storage = dataset.connect_azure_import_storage(
                container=uri,
                account_name=params.azure_account_name,
                account_key=params.azure_account_key,
                **storage_connection_args,
            )
        elif params.storage_type == "gcs":
            if not params.google_application_credentials:
                logger.warning(
                    "Authentication credentials for Google Cloud Storage "
                    "aren't fully provided. Please update the storage "
                    "synchronization settings in the Label Studio web UI as "
                    "per your needs."
                )
            storage = dataset.connect_google_import_storage(
                bucket=uri,
                google_application_credentials=params.google_application_credentials,
                **storage_connection_args,
            )
        elif params.storage_type == "s3":
            if (
                not params.aws_access_key_id
                or not params.aws_secret_access_key
            ):
                logger.warning(
                    "Authentication credentials for S3 aren't fully provided."
                    "Please update the storage synchronization settings in the "
                    "Label Studio web UI as per your needs."
                )

            # temporary fix using client method until LS supports
            # recursive_scan in their SDK
            # (https://github.com/heartexlabs/label-studio-sdk/pull/130)
            ls_client = self._get_client()
            payload = {
                "bucket": uri,
                "prefix": params.prefix,
                "regex_filter": params.regex_filter,
                "use_blob_urls": params.use_blob_urls,
                "aws_access_key_id": params.aws_access_key_id,
                "aws_secret_access_key": params.aws_secret_access_key,
                "aws_session_token": params.aws_session_token,
                "region_name": params.s3_region_name,
                "s3_endpoint": params.s3_endpoint,
                "presign": params.presign,
                "presign_ttl": params.presign_ttl,
                "title": dataset.get_params()["title"],
                "description": params.description,
                "project": dataset.id,
                "recursive_scan": True,
            }
            response = ls_client.make_request(
                "POST", "/api/storages/s3", json=payload
            )
            storage = response.json()

        elif params.storage_type == "local":
            if not params.prefix:
                raise ValueError(
                    "The 'prefix' parameter is required for local storage "
                    "synchronization."
                )

            # Drop arguments that are not used by the local storage
            storage_connection_args.pop("presign")
            storage_connection_args.pop("presign_ttl")
            storage_connection_args.pop("prefix")

            prefix = params.prefix
            if not prefix.startswith("/"):
                prefix = f"/{prefix}"
            root_path = Path(prefix).parent

            # Set the environment variables required by Label Studio
            # to allow local file serving (see https://labelstud.io/guide/storage.html#Prerequisites-2)
            os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"] = "true"
            os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"] = str(
                root_path
            )

            storage = dataset.connect_local_import_storage(
                local_store_path=prefix,
                **storage_connection_args,
            )

            del os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"]
            del os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"]
        else:
            raise ValueError(
                f"Invalid storage type. '{params.storage_type}' is not "
                "supported by ZenML's Label Studio integration. Please choose "
                "between 'azure', 'gcs', 'aws' or 'local'."
            )

        synced_storage = self._get_client().sync_storage(
            storage_id=storage["id"], storage_type=storage["type"]
        )
        return cast(Dict[str, Any], synced_storage)
config: LabelStudioAnnotatorConfig property readonly

Returns the LabelStudioAnnotatorConfig config.

Returns:

Type Description
LabelStudioAnnotatorConfig

The configuration.

add_dataset(self, **kwargs)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

A Label Studio Project object.

Exceptions:

Type Description
ValueError

if 'dataset_name' and 'label_config' aren't provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
    """Registers a dataset for annotation.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        A Label Studio Project object.

    Raises:
        ValueError: if 'dataset_name' and 'label_config' aren't provided.
    """
    dataset_name = kwargs.get("dataset_name")
    label_config = kwargs.get("label_config")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")
    elif not label_config:
        raise ValueError("`label_config` keyword argument is required.")

    return self._get_client().start_project(
        title=dataset_name,
        label_config=label_config,
    )
connect_and_sync_external_storage(self, uri, params, dataset)

Syncs the external storage for the given project.

Parameters:

Name Type Description Default
uri str

URI of the storage source.

required
params LabelStudioDatasetSyncParameters

Parameters for the dataset.

required
dataset Project

Label Studio dataset.

required

Returns:

Type Description
Optional[Dict[str, Any]]

A dictionary containing the sync result.

Exceptions:

Type Description
ValueError

If the storage type is not supported.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
    self,
    uri: str,
    params: LabelStudioDatasetSyncParameters,
    dataset: Project,
) -> Optional[Dict[str, Any]]:
    """Syncs the external storage for the given project.

    Args:
        uri: URI of the storage source.
        params: Parameters for the dataset.
        dataset: Label Studio dataset.

    Returns:
        A dictionary containing the sync result.

    Raises:
        ValueError: If the storage type is not supported.
    """
    # TODO: check if proposed storage source has differing / new data
    # if self._storage_source_already_exists(uri, config, dataset):
    #     return None

    storage_connection_args = {
        "prefix": params.prefix,
        "regex_filter": params.regex_filter,
        "use_blob_urls": params.use_blob_urls,
        "presign": params.presign,
        "presign_ttl": params.presign_ttl,
        "title": dataset.get_params()["title"],
        "description": params.description,
    }
    if params.storage_type == "azure":
        if not params.azure_account_name or not params.azure_account_key:
            logger.warning(
                "Authentication credentials for Azure aren't fully "
                "provided. Please update the storage synchronization "
                "settings in the Label Studio web UI as per your needs."
            )
        storage = dataset.connect_azure_import_storage(
            container=uri,
            account_name=params.azure_account_name,
            account_key=params.azure_account_key,
            **storage_connection_args,
        )
    elif params.storage_type == "gcs":
        if not params.google_application_credentials:
            logger.warning(
                "Authentication credentials for Google Cloud Storage "
                "aren't fully provided. Please update the storage "
                "synchronization settings in the Label Studio web UI as "
                "per your needs."
            )
        storage = dataset.connect_google_import_storage(
            bucket=uri,
            google_application_credentials=params.google_application_credentials,
            **storage_connection_args,
        )
    elif params.storage_type == "s3":
        if (
            not params.aws_access_key_id
            or not params.aws_secret_access_key
        ):
            logger.warning(
                "Authentication credentials for S3 aren't fully provided."
                "Please update the storage synchronization settings in the "
                "Label Studio web UI as per your needs."
            )

        # temporary fix using client method until LS supports
        # recursive_scan in their SDK
        # (https://github.com/heartexlabs/label-studio-sdk/pull/130)
        ls_client = self._get_client()
        payload = {
            "bucket": uri,
            "prefix": params.prefix,
            "regex_filter": params.regex_filter,
            "use_blob_urls": params.use_blob_urls,
            "aws_access_key_id": params.aws_access_key_id,
            "aws_secret_access_key": params.aws_secret_access_key,
            "aws_session_token": params.aws_session_token,
            "region_name": params.s3_region_name,
            "s3_endpoint": params.s3_endpoint,
            "presign": params.presign,
            "presign_ttl": params.presign_ttl,
            "title": dataset.get_params()["title"],
            "description": params.description,
            "project": dataset.id,
            "recursive_scan": True,
        }
        response = ls_client.make_request(
            "POST", "/api/storages/s3", json=payload
        )
        storage = response.json()

    elif params.storage_type == "local":
        if not params.prefix:
            raise ValueError(
                "The 'prefix' parameter is required for local storage "
                "synchronization."
            )

        # Drop arguments that are not used by the local storage
        storage_connection_args.pop("presign")
        storage_connection_args.pop("presign_ttl")
        storage_connection_args.pop("prefix")

        prefix = params.prefix
        if not prefix.startswith("/"):
            prefix = f"/{prefix}"
        root_path = Path(prefix).parent

        # Set the environment variables required by Label Studio
        # to allow local file serving (see https://labelstud.io/guide/storage.html#Prerequisites-2)
        os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"] = "true"
        os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"] = str(
            root_path
        )

        storage = dataset.connect_local_import_storage(
            local_store_path=prefix,
            **storage_connection_args,
        )

        del os.environ["LABEL_STUDIO_LOCAL_FILES_SERVING_ENABLED"]
        del os.environ["LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT"]
    else:
        raise ValueError(
            f"Invalid storage type. '{params.storage_type}' is not "
            "supported by ZenML's Label Studio integration. Please choose "
            "between 'azure', 'gcs', 'aws' or 'local'."
        )

    synced_storage = self._get_client().sync_storage(
        storage_id=storage["id"], storage_type=storage["type"]
    )
    return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)

Deletes a dataset from the annotation interface.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
    """Deletes a dataset from the annotation interface.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio
            client.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    ls = self._get_client()
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    ls.delete_project(dataset_id)
get_converted_dataset(self, dataset_name, output_format)

Extract annotated tasks in a specific converted format.

Parameters:

Name Type Description Default
dataset_name str

Id of the dataset.

required
output_format str

Output format.

required

Returns:

Type Description
Dict[Any, Any]

A dictionary containing the converted dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
    self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
    """Extract annotated tasks in a specific converted format.

    Args:
        dataset_name: Id of the dataset.
        output_format: Output format.

    Returns:
        A dictionary containing the converted dataset.
    """
    project = self.get_dataset(dataset_name=dataset_name)
    return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]
get_dataset(self, **kwargs)

Gets the dataset with the given name.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The LabelStudio Dataset object (a 'Project') for the given name.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
    """Gets the dataset with the given name.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The LabelStudio Dataset object (a 'Project') for the given name.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    # TODO: check for and raise error if client unavailable
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id)
get_dataset_names(self)

Gets the names of the datasets.

Returns:

Type Description
List[str]

A list of dataset names.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
    """Gets the names of the datasets.

    Returns:
        A list of dataset names.
    """
    return [
        dataset.get_params()["title"] for dataset in self.get_datasets()
    ]
get_dataset_stats(self, dataset_name)

Gets the statistics of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Tuple[int, int]

A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset.

Exceptions:

Type Description
IndexError

If the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
    """Gets the statistics of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        A tuple containing (labeled_task_count, unlabeled_task_count) for
            the dataset.

    Raises:
        IndexError: If the dataset does not exist.
    """
    for project in self.get_datasets():
        if dataset_name in project.get_params()["title"]:
            labeled_task_count = len(project.get_labeled_tasks())
            unlabeled_task_count = len(project.get_unlabeled_tasks())
            return (labeled_task_count, unlabeled_task_count)
    raise IndexError(
        f"Dataset {dataset_name} not found. Please use "
        f"`zenml annotator dataset list` to list all available datasets."
    )
get_datasets(self)

Gets the datasets currently available for annotation.

Returns:

Type Description
List[Any]

A list of datasets.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
    """Gets the datasets currently available for annotation.

    Returns:
        A list of datasets.
    """
    datasets = self._get_client().get_projects()
    return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)

Gets the ID of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Optional[int]

The ID of the dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
    """Gets the ID of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The ID of the dataset.
    """
    projects = self.get_datasets()
    for project in projects:
        if project.get_params()["title"] == dataset_name:
            return cast(int, project.get_params()["id"])
    return None
get_labeled_data(self, **kwargs)

Gets the labeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The labeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
    """Gets the labeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The labeled data.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)

Returns the parsed Label Studio label config for a dataset.

Parameters:

Name Type Description Default
dataset_id int

Id of the dataset.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing the parsed label config.

Exceptions:

Type Description
ValueError

If no dataset is found for the given id.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
    """Returns the parsed Label Studio label config for a dataset.

    Args:
        dataset_id: Id of the dataset.

    Returns:
        A dictionary containing the parsed label config.

    Raises:
        ValueError: If no dataset is found for the given id.
    """
    # TODO: check if client actually is connected etc
    dataset = self._get_client().get_project(dataset_id)
    if dataset:
        return cast(Dict[str, Any], dataset.parsed_label_config)
    raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)

Gets the unlabeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs str

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The unlabeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
    """Gets the unlabeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The unlabeled data.

    Raises:
        ValueError: If the dataset name is not provided.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)

Gets the top-level URL of the annotation interface.

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
    """Gets the top-level URL of the annotation interface.

    Returns:
        The URL of the annotation interface.
    """
    return f"{self.config.instance_url}:{self.config.port}"
get_url_for_dataset(self, dataset_name)

Gets the URL of the annotation interface for the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
    """Gets the URL of the annotation interface for the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The URL of the annotation interface.
    """
    project_id = self.get_id_from_name(dataset_name)
    return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)

Launches the annotation interface.

Parameters:

Name Type Description Default
url Optional[str]

The URL of the annotation interface.

required
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
    """Launches the annotation interface.

    Args:
        url: The URL of the annotation interface.
    """
    if not url:
        url = self.get_url()
    if self._connection_available():
        webbrowser.open(url, new=1, autoraise=True)
    else:
        logger.warning(
            "Could not launch annotation interface"
            "because the connection could not be established."
        )
populate_artifact_store_parameters(self, params, artifact_store)

Populate the dataset sync parameters with the artifact store credentials.

Parameters:

Name Type Description Default
params LabelStudioDatasetSyncParameters

The dataset sync parameters.

required
artifact_store BaseArtifactStore

The active artifact store.

required

Exceptions:

Type Description
RuntimeError

if the artifact store credentials cannot be fetched.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def populate_artifact_store_parameters(
    self,
    params: LabelStudioDatasetSyncParameters,
    artifact_store: BaseArtifactStore,
) -> None:
    """Populate the dataset sync parameters with the artifact store credentials.

    Args:
        params: The dataset sync parameters.
        artifact_store: The active artifact store.

    Raises:
        RuntimeError: if the artifact store credentials cannot be fetched.
    """
    if artifact_store.flavor == "s3":
        from zenml.integrations.s3.artifact_stores import S3ArtifactStore

        assert isinstance(artifact_store, S3ArtifactStore)

        params.storage_type = "s3"

        (
            aws_access_key_id,
            aws_secret_access_key,
            aws_session_token,
        ) = artifact_store.get_credentials()

        if aws_access_key_id and aws_secret_access_key:
            # Convert the credentials into the format expected by Label
            # Studio
            params.aws_access_key_id = aws_access_key_id
            params.aws_secret_access_key = aws_secret_access_key
            params.aws_session_token = aws_session_token

            if artifact_store.config.client_kwargs:
                if "endpoint_url" in artifact_store.config.client_kwargs:
                    params.s3_endpoint = (
                        artifact_store.config.client_kwargs["endpoint_url"]
                    )
                if "region_name" in artifact_store.config.client_kwargs:
                    params.s3_region_name = str(
                        artifact_store.config.client_kwargs["region_name"]
                    )

            return

        raise RuntimeError(
            "No credentials are configured for the active S3 artifact "
            "store. The Label Studio annotator needs explicit credentials "
            "to be configured for your artifact store to sync data "
            "artifacts."
        )

    elif artifact_store.flavor == "gcp":
        from zenml.integrations.gcp.artifact_stores import GCPArtifactStore

        assert isinstance(artifact_store, GCPArtifactStore)

        params.storage_type = "gcs"

        gcp_credentials = artifact_store.get_credentials()

        if gcp_credentials:
            # Save the credentials to a file in secure location, because
            # Label Studio will need to read it from a file
            secret_folder = Path(
                GlobalConfiguration().config_directory,
                "label-studio",
                str(self.id),
            )
            fileio.makedirs(str(secret_folder))
            file_path = Path(
                secret_folder, "google_application_credentials.json"
            )
            with open(file_path, "w") as f:
                f.write(json.dumps(gcp_credentials))
                file_path.chmod(0o600)

            params.google_application_credentials = str(file_path)

            return

        raise RuntimeError(
            "No credentials are configured for the active GCS artifact "
            "store. The Label Studio annotator needs explicit credentials "
            "to be configured for your artifact store to sync data "
            "artifacts."
        )

    elif artifact_store.flavor == "azure":
        from zenml.integrations.azure.artifact_stores import (
            AzureArtifactStore,
        )

        assert isinstance(artifact_store, AzureArtifactStore)

        params.storage_type = "azure"

        azure_credentials = artifact_store.get_credentials()

        if azure_credentials:
            # Convert the credentials into the format expected by Label
            # Studio
            if azure_credentials.connection_string is not None:
                try:
                    # We need to extract the account name and key from the
                    # connection string
                    tokens = azure_credentials.connection_string.split(";")
                    token_dict = dict(
                        [token.split("=", maxsplit=1) for token in tokens]
                    )
                    params.azure_account_name = token_dict["AccountName"]
                    params.azure_account_key = token_dict["AccountKey"]
                except (KeyError, ValueError) as e:
                    raise RuntimeError(
                        "The Azure connection string configured for the "
                        "artifact store expected format."
                    ) from e

                return

            if (
                azure_credentials.account_name is not None
                and azure_credentials.account_key is not None
            ):
                params.azure_account_name = azure_credentials.account_name
                params.azure_account_key = azure_credentials.account_key

                return

            raise RuntimeError(
                "The Label Studio annotator could not use the "
                "credentials currently configured in the active Azure "
                "artifact store because it only supports Azure storage "
                "account credentials. "
                "Please use Azure storage account credentials for your "
                "artifact store."
            )

        raise RuntimeError(
            "No credentials are configured for the active Azure artifact "
            "store. The Label Studio annotator needs explicit credentials "
            "to be configured for your artifact store to sync data "
            "artifacts."
        )

    elif artifact_store.flavor == "local":
        from zenml.artifact_stores.local_artifact_store import (
            LocalArtifactStore,
        )

        assert isinstance(artifact_store, LocalArtifactStore)

        params.storage_type = "local"
        if params.prefix is None:
            params.prefix = artifact_store.path
        elif not params.prefix.startswith(artifact_store.path.lstrip("/")):
            raise RuntimeError(
                "The prefix for the local storage must be a subdirectory "
                "of the local artifact store path."
            )
        return

    raise RuntimeError(
        f"The active artifact store type '{artifact_store.flavor}' is not "
        "supported by ZenML's Label Studio integration. "
        "Please use one of the supported artifact stores (S3, GCP, "
        "Azure or local)."
    )
register_dataset_for_annotation(self, params)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
params LabelStudioDatasetRegistrationParameters

Parameters for the dataset.

required

Returns:

Type Description
Any

A Label Studio Project object.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
    self,
    params: LabelStudioDatasetRegistrationParameters,
) -> Any:
    """Registers a dataset for annotation.

    Args:
        params: Parameters for the dataset.

    Returns:
        A Label Studio Project object.
    """
    project_id = self.get_id_from_name(params.dataset_name)
    if project_id:
        dataset = self._get_client().get_project(project_id)
    else:
        dataset = self.add_dataset(
            dataset_name=params.dataset_name,
            label_config=params.label_config,
        )

    return dataset

flavors special

Label Studio integration flavors.

label_studio_annotator_flavor

Label Studio annotator flavor.

LabelStudioAnnotatorConfig (BaseAnnotatorConfig, AuthenticationConfigMixin) pydantic-model

Config for the Label Studio annotator.

Attributes:

Name Type Description
instance_url str

URL of the Label Studio instance.

port int

The port to use for the annotation interface.

Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorConfig(
    BaseAnnotatorConfig, AuthenticationConfigMixin
):
    """Config for the Label Studio annotator.

    Attributes:
        instance_url: URL of the Label Studio instance.
        port: The port to use for the annotation interface.
    """

    instance_url: str = DEFAULT_LOCAL_INSTANCE_URL
    port: int = DEFAULT_LOCAL_LABEL_STUDIO_PORT
LabelStudioAnnotatorFlavor (BaseAnnotatorFlavor)

Label Studio annotator flavor.

Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorFlavor(BaseAnnotatorFlavor):
    """Label Studio annotator flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return LABEL_STUDIO_ANNOTATOR_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/label_studio.png"

    @property
    def config_class(self) -> Type[LabelStudioAnnotatorConfig]:
        """Returns `LabelStudioAnnotatorConfig` config class.

        Returns:
                The config class.
        """
        return LabelStudioAnnotatorConfig

    @property
    def implementation_class(self) -> Type["LabelStudioAnnotator"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.label_studio.annotators import (
            LabelStudioAnnotator,
        )

        return LabelStudioAnnotator
config_class: Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig] property readonly

Returns LabelStudioAnnotatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[LabelStudioAnnotator] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[LabelStudioAnnotator]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

label_config_generators special

Initialization of the Label Studio config generators submodule.

label_config_generators

Implementation of label config generators for Label Studio.

generate_basic_object_detection_bounding_boxes_label_config(labels)

Generates a Label Studio config for object detection with bounding boxes.

This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio config for object detection with bounding boxes.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_bbox.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <RectangleLabels name="label" toName="image">
    """
    label_config_choices = "".join(
        f"<Label value='{label}' />\n" for label in labels
    )
    label_config_end = "</RectangleLabels>\n</View>"
    label_config = label_config_start + label_config_choices + label_config_end

    return (
        label_config,
        label_config_type,
    )
generate_basic_ocr_label_config(labels)

Generates a Label Studio config for optical character recognition (OCR) labeling task.

This is based on the basic config example shown at https://labelstud.io/templates/optical_character_recognition.html

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_ocr_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio config for optical character recognition (OCR) labeling task.

    This is based on the basic config example shown at
    https://labelstud.io/templates/optical_character_recognition.html

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.OCR

    label_config_start = """
    <View>
    <Image name="image" value="$ocr" zoom="true" zoomControl="true" rotateControl="true"/>
    <View>
    <Filter toName="label" minlength="0" name="filter"/>
    <Labels name="label" toName="image">
    """
    label_config_choices = "".join(
        f"<Label value='{label}' />\n" for label in labels
    )

    label_config_end = """
    </Labels>
    </View>
    <Rectangle name="bbox" toName="image" strokeWidth="3"/>
    <Polygon name="poly" toName="image" strokeWidth="3"/>
    <TextArea name="transcription" toName="image" editable="true" perRegion="true" required="true" maxSubmissions="1" rows="5" placeholder="Recognized Text" displayMode="region-list"/>
    </View>
    """
    label_config = label_config_start + label_config_choices + label_config_end

    return (
        label_config,
        label_config_type,
    )
generate_image_classification_label_config(labels)

Generates a Label Studio label config for image classification.

This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio label config for image classification.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_classification.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <Choices name="choice" toName="image">
    """
    label_config_choices = "".join(
        f"<Choice value='{label}' />\n" for label in labels
    )
    label_config_end = "</Choices>\n</View>"

    label_config = label_config_start + label_config_choices + label_config_end
    return (
        label_config,
        label_config_type,
    )

label_studio_utils

Utility functions for the Label Studio annotator integration.

clean_url(url)

Remove extraneous parts of the URL prior to mapping.

Removes the query and netloc parts of the URL, and strips the leading slash from the path. For example, a string like 'gs%3A//label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG' would become label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG.

Parameters:

Name Type Description Default
url str

A URL string.

required

Returns:

Type Description
str

A cleaned URL string.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def clean_url(url: str) -> str:
    """Remove extraneous parts of the URL prior to mapping.

    Removes the query and netloc parts of the URL, and strips the leading slash
    from the path. For example, a string like
    `'gs%3A//label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG'`
    would become
    `label-studio/load_image_data/images/fdbcd451-0c80-495c-a9c5-6b51776f5019/1/0/image_file.JPEG`.

    Args:
        url: A URL string.

    Returns:
        A cleaned URL string.
    """
    parsed = urlparse(url)
    parsed = parsed._replace(netloc="", query="")
    return parsed.path.lstrip("/")

convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)

Converts a list of predictions from local file references to task id.

Parameters:

Name Type Description Default
preds List[Dict[str, Any]]

List of predictions.

required
tasks List[Dict[str, Any]]

List of tasks.

required
filename_reference str

Name of the file reference in the predictions.

required
storage_type str

Storage type of the predictions.

required

Returns:

Type Description
List[Dict[str, Any]]

List of predictions using task ids as reference.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
    preds: List[Dict[str, Any]],
    tasks: List[Dict[str, Any]],
    filename_reference: str,
    storage_type: str,
) -> List[Dict[str, Any]]:
    """Converts a list of predictions from local file references to task id.

    Args:
        preds: List of predictions.
        tasks: List of tasks.
        filename_reference: Name of the file reference in the predictions.
        storage_type: Storage type of the predictions.

    Returns:
        List of predictions using task ids as reference.
    """
    filename_id_mapping = {
        clean_url(task["data"][filename_reference]): task["id"]
        for task in tasks
    }

    # GCS and S3 URL encodes filenames containing spaces, requiring this
    # separate encoding step
    if storage_type == "gcs":
        # we remove the scheme from the URL to match the pred to the Label
        # Studio task
        preds = [
            {
                "filename": quote(pred["filename"]).split("//")[1],
                "result": pred["result"],
            }
            for pred in preds
        ]
    elif storage_type == "s3":
        # S3 URLs are of the form s3://bucket-name/path/to/file so we need to
        # make sure we only encode the path so we can match the pred to the
        # Label Studio task
        preds = [
            {
                "filename": "/".join(
                    quote(pred["filename"]).split("//")[1].split("/")[1:]
                ),
                "result": pred["result"],
            }
            for pred in preds
        ]

    return [
        {
            "task": int(
                filename_id_mapping[
                    urlparse(pred["filename"]).netloc
                    + urlparse(pred["filename"]).path
                ]
            ),
            "result": pred["result"],
        }
        for pred in preds
    ]

get_file_extension(path_str)

Return the file extension of the given filename.

Parameters:

Name Type Description Default
path_str str

Path to the file.

required

Returns:

Type Description
str

File extension.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
    """Return the file extension of the given filename.

    Args:
        path_str: Path to the file.

    Returns:
        File extension.
    """
    return os.path.splitext(urlparse(path_str).path)[1]

is_azure_url(url)

Return whether the given URL is an Azure URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an Azure URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
    """Return whether the given URL is an Azure URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an Azure URL, False otherwise.
    """
    return "blob.core.windows.net" in urlparse(url).netloc

is_gcs_url(url)

Return whether the given URL is an GCS URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an GCS URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
    """Return whether the given URL is an GCS URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an GCS URL, False otherwise.
    """
    return "storage.googleapis.com" in urlparse(url).netloc

is_s3_url(url)

Return whether the given URL is an S3 URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an S3 URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
    """Return whether the given URL is an S3 URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an S3 URL, False otherwise.
    """
    return "s3.amazonaws" in urlparse(url).netloc

steps special

Standard steps to be used with the Label Studio annotator integration.

label_studio_standard_steps

Implementation of standard steps for the Label Studio annotator integration.

LabelStudioDatasetRegistrationParameters (BaseParameters) pydantic-model

Step parameters when registering a dataset with Label Studio.

Attributes:

Name Type Description
label_config str

The label config to use for the annotation interface.

dataset_name str

Name of the dataset to register.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
    """Step parameters when registering a dataset with Label Studio.

    Attributes:
        label_config: The label config to use for the annotation interface.
        dataset_name: Name of the dataset to register.
    """

    label_config: str
    dataset_name: str
LabelStudioDatasetSyncParameters (BaseParameters) pydantic-model

Step parameters when syncing data to Label Studio.

Attributes:

Name Type Description
storage_type str

The type of storage to sync to. Can be one of ["gcs", "s3", "azure", "local"]. Defaults to "local".

label_config_type str

The type of label config to use.

prefix Optional[str]

Specify the prefix within the cloud store to import your data from. For local storage, this is the full absolute path to the directory containing your data.

regex_filter Optional[str]

Specify a regex filter to filter the files to import.

use_blob_urls Optional[bool]

Specify whether your data is raw image or video data, or JSON tasks.

presign Optional[bool]

Specify whether or not to create presigned URLs.

presign_ttl Optional[int]

Specify how long to keep presigned URLs active.

description Optional[str]

Specify a description for the dataset.

azure_account_name Optional[str]

Specify the Azure account name to use for the storage.

azure_account_key Optional[str]

Specify the Azure account key to use for the storage.

google_application_credentials Optional[str]

Specify the file with Google application credentials to use for the storage.

aws_access_key_id Optional[str]

Specify the AWS access key ID to use for the storage.

aws_secret_access_key Optional[str]

Specify the AWS secret access key to use for the storage.

aws_session_token Optional[str]

Specify the AWS session token to use for the storage.

s3_region_name Optional[str]

Specify the S3 region name to use for the storage.

s3_endpoint Optional[str]

Specify the S3 endpoint to use for the storage.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
    """Step parameters when syncing data to Label Studio.

    Attributes:
        storage_type: The type of storage to sync to. Can be one of
            ["gcs", "s3", "azure", "local"]. Defaults to "local".
        label_config_type: The type of label config to use.
        prefix: Specify the prefix within the cloud store to import your data
            from. For local storage, this is the full absolute path to the
            directory containing your data.
        regex_filter: Specify a regex filter to filter the files to import.
        use_blob_urls: Specify whether your data is raw image or video data, or
            JSON tasks.
        presign: Specify whether or not to create presigned URLs.
        presign_ttl: Specify how long to keep presigned URLs active.
        description: Specify a description for the dataset.

        azure_account_name: Specify the Azure account name to use for the
            storage.
        azure_account_key: Specify the Azure account key to use for the
            storage.
        google_application_credentials: Specify the file with Google application
            credentials to use for the storage.
        aws_access_key_id: Specify the AWS access key ID to use for the
            storage.
        aws_secret_access_key: Specify the AWS secret access key to use for the
            storage.
        aws_session_token: Specify the AWS session token to use for the
            storage.
        s3_region_name: Specify the S3 region name to use for the storage.
        s3_endpoint: Specify the S3 endpoint to use for the storage.
    """

    storage_type: str = "local"
    label_config_type: str

    prefix: Optional[str] = None
    regex_filter: Optional[str] = ".*"
    use_blob_urls: Optional[bool] = True
    presign: Optional[bool] = True
    presign_ttl: Optional[int] = 1
    description: Optional[str] = ""

    # credentials specific to the main cloud providers
    azure_account_name: Optional[str]
    azure_account_key: Optional[str]
    google_application_credentials: Optional[str]
    aws_access_key_id: Optional[str]
    aws_secret_access_key: Optional[str]
    aws_session_token: Optional[str]
    s3_region_name: Optional[str]
    s3_endpoint: Optional[str]
get_labeled_data (_DecoratedStep)

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name

Name of the dataset.

required

Returns:

Type Description

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

entrypoint(dataset_name) staticmethod

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset.

required

Returns:

Type Description
List

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str) -> List:  # type: ignore[type-arg]
    """Gets labeled data from the dataset.

    Args:
        dataset_name: Name of the dataset.

    Returns:
        List of labeled data.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    # TODO [MEDIUM]: have this check for new data *since the last time this step ran*
    annotator = Client().active_stack.annotator
    if not annotator:
        raise StackComponentInterfaceError("No active annotator.")
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )
    if annotator._connection_available():
        dataset = annotator.get_dataset(dataset_name=dataset_name)
        return dataset.get_labeled_tasks()  # type: ignore[no-any-return]

    raise StackComponentInterfaceError(
        "Unable to connect to annotator stack component."
    )
get_or_create_dataset (_DecoratedStep)

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
params

Step parameters.

required

Returns:

Type Description

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

entrypoint(params) staticmethod

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
params LabelStudioDatasetRegistrationParameters

Step parameters.

required

Returns:

Type Description
str

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
    params: LabelStudioDatasetRegistrationParameters,
) -> str:
    """Gets preexisting dataset or creates a new one.

    Args:
        params: Step parameters.

    Returns:
        The dataset name.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    annotator = Client().active_stack.annotator
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    if annotator and annotator._connection_available():
        for dataset in annotator.get_datasets():
            if dataset.get_params()["title"] == params.dataset_name:
                return cast(str, dataset.get_params()["title"])

        dataset = annotator.register_dataset_for_annotation(params)
        return cast(str, dataset.get_params()["title"])

    raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (_DecoratedStep)

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri

The URI of the data to sync.

required
dataset_name

The name of the dataset to sync to.

required
predictions

The predictions to sync.

required
params

The parameters for the sync.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

entrypoint(uri, dataset_name, predictions, params) staticmethod

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri str

The URI of the data to sync.

required
dataset_name str

The name of the dataset to sync to.

required
predictions List[Dict[str, Any]]

The predictions to sync.

required
params LabelStudioDatasetSyncParameters

The parameters for the sync.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
    uri: str,
    dataset_name: str,
    predictions: List[Dict[str, Any]],
    params: LabelStudioDatasetSyncParameters,
) -> None:
    """Syncs new data to Label Studio.

    Args:
        uri: The URI of the data to sync.
        dataset_name: The name of the dataset to sync to.
        predictions: The predictions to sync.
        params: The parameters for the sync.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        ValueError: if you are trying to sync from outside ZenML.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    stack = Client().active_stack
    annotator = stack.annotator
    artifact_store = stack.artifact_store
    if not annotator or not artifact_store:
        raise StackComponentInterfaceError(
            "An active annotator and artifact store are required to run this step."
        )

    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    # TODO: check that annotator is connected before querying it
    dataset = annotator.get_dataset(dataset_name=dataset_name)
    if not uri.startswith(artifact_store.path):
        raise ValueError(
            "ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
        )
    # removes the initial forward slash from the prefix attribute by slicing
    params.prefix = urlparse(uri).path.lstrip("/")
    base_uri = urlparse(uri).netloc

    annotator.populate_artifact_store_parameters(
        params=params, artifact_store=artifact_store
    )

    if annotator and annotator._connection_available():
        # TODO: get existing (CHECK!) or create the sync connection
        annotator.connect_and_sync_external_storage(
            uri=base_uri,
            params=params,
            dataset=dataset,
        )
        if predictions:
            filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
                params.label_config_type
            ]

            preds_with_task_ids = convert_pred_filenames_to_task_ids(
                predictions,
                dataset.tasks,
                filename_reference,
                params.storage_type,
            )
            # TODO: filter out any predictions that exist + have already been
            # made (maybe?). Only pass in preds for tasks without pre-annotations.
            dataset.create_predictions(preds_with_task_ids)
    else:
        raise StackComponentInterfaceError("No active annotator.")