Skip to content

Label Studio

zenml.integrations.label_studio special

Initialization of the Label Studio integration.

LabelStudioIntegration (Integration)

Definition of Label Studio integration for ZenML.

Source code in zenml/integrations/label_studio/__init__.py
class LabelStudioIntegration(Integration):
    """Definition of Label Studio integration for ZenML."""

    NAME = LABEL_STUDIO
    REQUIREMENTS = ["label-studio==1.6.0", "label-studio-sdk==0.0.15"]

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the Label Studio integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.label_studio.flavors import (
            LabelStudioAnnotatorFlavor,
        )

        return [LabelStudioAnnotatorFlavor]

flavors() classmethod

Declare the stack component flavors for the Label Studio integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/label_studio/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the Label Studio integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.label_studio.flavors import (
        LabelStudioAnnotatorFlavor,
    )

    return [LabelStudioAnnotatorFlavor]

annotators special

Initialization of the Label Studio annotators submodule.

label_studio_annotator

Implementation of the Label Studio annotation integration.

LabelStudioAnnotator (BaseAnnotator, AuthenticationMixin)

Class to interact with the Label Studio annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
class LabelStudioAnnotator(BaseAnnotator, AuthenticationMixin):
    """Class to interact with the Label Studio annotation interface."""

    @property
    def config(self) -> LabelStudioAnnotatorConfig:
        """Returns the `LabelStudioAnnotatorConfig` config.

        Returns:
            The configuration.
        """
        return cast(LabelStudioAnnotatorConfig, self._config)

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Validates that the stack contains a cloud artifact store.

        Returns:
            StackValidator: Validator for the stack.
        """

        def _ensure_cloud_artifact_stores(stack: Stack) -> Tuple[bool, str]:
            # For now this only works on cloud artifact stores.
            return (
                stack.artifact_store.flavor
                in [
                    AZURE_ARTIFACT_STORE_FLAVOR,
                    GCP_ARTIFACT_STORE_FLAVOR,
                    S3_ARTIFACT_STORE_FLAVOR,
                ],
                "Only cloud artifact stores are currently supported",
            )

        return StackValidator(
            required_components={StackComponentType.SECRETS_MANAGER},
            custom_validation_function=_ensure_cloud_artifact_stores,
        )

    def get_url(self) -> str:
        """Gets the top-level URL of the annotation interface.

        Returns:
            The URL of the annotation interface.
        """
        return f"{self.config.instance_url}:{self.config.port}"

    def get_url_for_dataset(self, dataset_name: str) -> str:
        """Gets the URL of the annotation interface for the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The URL of the annotation interface.
        """
        project_id = self.get_id_from_name(dataset_name)
        return f"{self.get_url()}/projects/{project_id}/"

    def get_id_from_name(self, dataset_name: str) -> Optional[int]:
        """Gets the ID of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            The ID of the dataset.
        """
        projects = self.get_datasets()
        for project in projects:
            if project.get_params()["title"] == dataset_name:
                return cast(int, project.get_params()["id"])
        return None

    def get_datasets(self) -> List[Any]:
        """Gets the datasets currently available for annotation.

        Returns:
            A list of datasets.
        """
        datasets = self._get_client().get_projects()
        return cast(List[Any], datasets)

    def get_dataset_names(self) -> List[str]:
        """Gets the names of the datasets.

        Returns:
            A list of dataset names.
        """
        return [
            dataset.get_params()["title"] for dataset in self.get_datasets()
        ]

    def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
        """Gets the statistics of the given dataset.

        Args:
            dataset_name: The name of the dataset.

        Returns:
            A tuple containing (labeled_task_count, unlabeled_task_count) for
                the dataset.

        Raises:
            IndexError: If the dataset does not exist.
        """
        for project in self.get_datasets():
            if dataset_name in project.get_params()["title"]:
                labeled_task_count = len(project.get_labeled_tasks())
                unlabeled_task_count = len(project.get_unlabeled_tasks())
                return (labeled_task_count, unlabeled_task_count)
        raise IndexError(
            f"Dataset {dataset_name} not found. Please use "
            f"`zenml annotator dataset list` to list all available datasets."
        )

    def launch(self, url: Optional[str]) -> None:
        """Launches the annotation interface.

        Args:
            url: The URL of the annotation interface.
        """
        if not url:
            url = self.get_url()
        if self._connection_available():
            webbrowser.open(url, new=1, autoraise=True)
        else:
            logger.warning(
                "Could not launch annotation interface"
                "because the connection could not be established."
            )

    def _get_client(self) -> Client:
        """Gets Label Studio client.

        Returns:
            Label Studio client.

        Raises:
            ValueError: when unable to access the Label Studio API key.
        """
        secret = self.get_authentication_secret(ArbitrarySecretSchema)
        if not secret:
            raise ValueError(
                f"Unable to access predefined secret '{secret}' to access Label Studio API key."
            )
        api_key = secret.content["api_key"]
        return Client(url=self.get_url(), api_key=api_key)

    def _connection_available(self) -> bool:
        """Checks if the connection to the annotation server is available.

        Returns:
            True if the connection is available, False otherwise.
        """
        try:
            result = self._get_client().check_connection()
            return result.get("status") == "UP"  # type: ignore[no-any-return]
        # TODO: [HIGH] refactor to use a more specific exception
        except Exception:
            logger.error(
                "Connection error: No connection was able to be established to the Label Studio backend."
            )
            return False

    def add_dataset(self, **kwargs: Any) -> Any:
        """Registers a dataset for annotation.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            A Label Studio Project object.

        Raises:
            ValueError: if 'dataset_name' and 'label_config' aren't provided.
        """
        dataset_name = kwargs.get("dataset_name")
        label_config = kwargs.get("label_config")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")
        elif not label_config:
            raise ValueError("`label_config` keyword argument is required.")

        return self._get_client().start_project(
            title=dataset_name,
            label_config=label_config,
        )

    def delete_dataset(self, **kwargs: Any) -> None:
        """Deletes a dataset from the annotation interface.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio
                client.

        Raises:
            NotImplementedError: If the deletion of a dataset is not supported.
        """
        raise NotImplementedError("Awaiting Label Studio release.")
        # TODO: Awaiting a new Label Studio version to be released with this method
        # ls = self._get_client()
        # dataset_name = kwargs.get("dataset_name")
        # if not dataset_name:
        #     raise ValueError("`dataset_name` keyword argument is required.")

        # dataset_id = self.get_id_from_name(dataset_name)
        # if not dataset_id:
        #     raise ValueError(
        #         f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        #     )
        # ls.delete_project(dataset_id)

    def get_dataset(self, **kwargs: Any) -> Any:
        """Gets the dataset with the given name.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The LabelStudio Dataset object (a 'Project') for the given name.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        # TODO: check for and raise error if client unavailable
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id)

    def get_converted_dataset(
        self, dataset_name: str, output_format: str
    ) -> Dict[Any, Any]:
        """Extract annotated tasks in a specific converted format.

        Args:
            dataset_name: Id of the dataset.
            output_format: Output format.

        Returns:
            A dictionary containing the converted dataset.
        """
        project = self.get_dataset(dataset_name=dataset_name)
        return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]

    def get_labeled_data(self, **kwargs: Any) -> Any:
        """Gets the labeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The labeled data.

        Raises:
            ValueError: If the dataset name is not provided or if the dataset
                does not exist.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_labeled_tasks()

    def get_unlabeled_data(self, **kwargs: str) -> Any:
        """Gets the unlabeled data for the given dataset.

        Args:
            **kwargs: Additional keyword arguments to pass to the Label Studio client.

        Returns:
            The unlabeled data.

        Raises:
            ValueError: If the dataset name is not provided.
        """
        dataset_name = kwargs.get("dataset_name")
        if not dataset_name:
            raise ValueError("`dataset_name` keyword argument is required.")

        dataset_id = self.get_id_from_name(dataset_name)
        if not dataset_id:
            raise ValueError(
                f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
            )
        return self._get_client().get_project(dataset_id).get_unlabeled_tasks()

    def register_dataset_for_annotation(
        self,
        params: LabelStudioDatasetRegistrationParameters,
    ) -> Any:
        """Registers a dataset for annotation.

        Args:
            params: Parameters for the dataset.

        Returns:
            A Label Studio Project object.
        """
        project_id = self.get_id_from_name(params.dataset_name)
        if project_id:
            dataset = self._get_client().get_project(project_id)
        else:
            dataset = self.add_dataset(
                dataset_name=params.dataset_name,
                label_config=params.label_config,
            )

        return dataset

    def _get_azure_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Azure import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Azure import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/azure?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_gcs_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all Google Cloud Storage import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of Google Cloud Storage import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/gcs?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _get_s3_import_storage_sources(
        self, dataset_id: int
    ) -> List[Dict[str, Any]]:
        """Gets a list of all AWS S3 import storage sources.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A list of AWS S3 import storage sources.

        Raises:
            ConnectionError: If the connection to the Label Studio backend is unavailable.
        """
        # TODO: check if client actually is connected etc
        query_url = f"/api/storages/s3?project={dataset_id}"
        response = self._get_client().make_request(method="GET", url=query_url)
        if response.status_code == 200:
            return cast(List[Dict[str, Any]], response.json())
        else:
            raise ConnectionError(
                f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}."
            )

    def _storage_source_already_exists(
        self,
        uri: str,
        params: LabelStudioDatasetSyncParameters,
        dataset: Project,
    ) -> bool:
        """Returns whether a storage source already exists.

        Args:
            uri: URI of the storage source.
            params: Parameters for the dataset.
            dataset: Label Studio dataset.

        Returns:
            True if the storage source already exists, False otherwise.

        Raises:
            NotImplementedError: If the storage source type is not supported.
        """
        # TODO: check we are already connected
        dataset_id = int(dataset.get_params()["id"])
        if params.storage_type == "azure":
            storage_sources = self._get_azure_import_storage_sources(dataset_id)
        elif params.storage_type == "gcs":
            storage_sources = self._get_gcs_import_storage_sources(dataset_id)
        elif params.storage_type == "s3":
            storage_sources = self._get_s3_import_storage_sources(dataset_id)
        else:
            raise NotImplementedError(
                f"Storage type '{params.storage_type}' not implemented."
            )
        return any(
            (
                source.get("presign") == params.presign
                and source.get("bucket") == uri
                and source.get("regex_filter") == params.regex_filter
                and source.get("use_blob_urls") == params.use_blob_urls
                and source.get("title") == dataset.get_params()["title"]
                and source.get("description") == params.description
                and source.get("presign_ttl") == params.presign_ttl
                and source.get("project") == dataset_id
            )
            for source in storage_sources
        )

    def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
        """Returns the parsed Label Studio label config for a dataset.

        Args:
            dataset_id: Id of the dataset.

        Returns:
            A dictionary containing the parsed label config.

        Raises:
            ValueError: If no dataset is found for the given id.
        """
        # TODO: check if client actually is connected etc
        dataset = self._get_client().get_project(dataset_id)
        if dataset:
            return cast(Dict[str, Any], dataset.parsed_label_config)
        raise ValueError("No dataset found for the given id.")

    def connect_and_sync_external_storage(
        self,
        uri: str,
        params: LabelStudioDatasetSyncParameters,
        dataset: Project,
    ) -> Optional[Dict[str, Any]]:
        """Syncs the external storage for the given project.

        Args:
            uri: URI of the storage source.
            params: Parameters for the dataset.
            dataset: Label Studio dataset.

        Returns:
            A dictionary containing the sync result.

        Raises:
            ValueError: If the storage type is not supported.
        """
        # TODO: check if proposed storage source has differing / new data
        # if self._storage_source_already_exists(uri, config, dataset):
        #     return None

        storage_connection_args = {
            "prefix": params.prefix,
            "regex_filter": params.regex_filter,
            "use_blob_urls": params.use_blob_urls,
            "presign": params.presign,
            "presign_ttl": params.presign_ttl,
            "title": dataset.get_params()["title"],
            "description": params.description,
        }
        if params.storage_type == "azure":
            if not params.azure_account_name or not params.azure_account_key:
                logger.warning(
                    "Authentication credentials for Azure aren't fully "
                    "provided. Please update the storage synchronization "
                    "settings in the Label Studio web UI as per your needs."
                )
            storage = dataset.connect_azure_import_storage(
                container=uri,
                account_name=params.azure_account_name,
                account_key=params.azure_account_key,
                **storage_connection_args,
            )
        elif params.storage_type == "gcs":
            if not params.google_application_credentials:
                logger.warning(
                    "Authentication credentials for Google Cloud Storage "
                    "aren't fully provided. Please update the storage "
                    "synchronization settings in the Label Studio web UI as "
                    "per your needs."
                )
            storage = dataset.connect_google_import_storage(
                bucket=uri,
                google_application_credentials=params.google_application_credentials,
                **storage_connection_args,
            )
        elif params.storage_type == "s3":
            if not params.aws_access_key_id or not params.aws_secret_access_key:
                logger.warning(
                    "Authentication credentials for S3 aren't fully provided."
                    "Please update the storage synchronization settings in the "
                    " Label Studio web UI as per your needs."
                )
            storage = dataset.connect_s3_import_storage(
                bucket=uri,
                aws_access_key_id=params.aws_access_key_id,
                aws_secret_access_key=params.aws_secret_access_key,
                aws_session_token=params.aws_session_token,
                region_name=params.s3_region_name,
                s3_endpoint=params.s3_endpoint,
                **storage_connection_args,
            )
        else:
            raise ValueError(
                f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
            )

        synced_storage = self._get_client().sync_storage(
            storage_id=storage["id"], storage_type=storage["type"]
        )
        return cast(Dict[str, Any], synced_storage)

    @property
    def root_directory(self) -> str:
        """Returns path to the root directory.

        Returns:
            Path to the root directory.
        """
        return os.path.join(
            io_utils.get_global_config_directory(),
            "annotators",
            str(self.id),
        )

    @property
    def _pid_file_path(self) -> str:
        """Returns path to the daemon PID file.

        Returns:
            Path to the daemon PID file.
        """
        return os.path.join(self.root_directory, "label_studio_daemon.pid")

    @property
    def _log_file(self) -> str:
        """Path of the daemon log file.

        Returns:
            Path to the daemon log file.
        """
        return os.path.join(self.root_directory, "label_studio_daemon.log")

    @property
    def is_provisioned(self) -> bool:
        """If the component provisioned resources to run locally.

        Returns:
            True if the component provisioned resources to run locally.
        """
        return fileio.exists(self.root_directory)

    @property
    def is_running(self) -> bool:
        """If the component is running locally.

        Returns:
            True if the component is running locally, False otherwise.
        """
        if not self.is_local_instance:
            return True

        if sys.platform != "win32":
            from zenml.utils.daemon import check_if_daemon_is_running

            if not check_if_daemon_is_running(self._pid_file_path):
                return False
        else:
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass

        return True

    @property
    def is_local_instance(self) -> bool:
        """Determines if the Label Studio instance is running locally.

        Returns:
            True if the component is running locally, False otherwise.
        """
        return self.config.instance_url == DEFAULT_LOCAL_INSTANCE_URL

    def provision(self) -> None:
        """Spins up the annotation server backend."""
        fileio.makedirs(self.root_directory)

    def deprovision(self) -> None:
        """Spins down the annotation server backend."""
        if fileio.exists(self._log_file):
            fileio.remove(self._log_file)

    def resume(self) -> None:
        """Resumes the annotation interface."""
        if self.is_running:
            logger.info("Local annotation deployment already running.")
            return

        if self.is_local_instance:
            self.start_annotator_daemon()

    def suspend(self) -> None:
        """Suspends the annotation interface."""
        if not self.is_running:
            logger.info("Local annotation server is not running.")
            return

        if self.is_local_instance:
            self.stop_annotator_daemon()

    def start_annotator_daemon(self) -> None:
        """Starts the annotation server backend.

        Raises:
            ProvisioningError: If the annotation server backend is already
                running or the port is already occupied.
        """
        command = [
            "label-studio",
            "start",
            "--no-browser",
            "--port",
            f"{self.config.port}",
        ]

        if sys.platform == "win32":
            logger.warning(
                "Daemon functionality not supported on Windows. "
                "In order to access the Label Studio server locally, "
                "please run '%s' in a separate command line shell.",
                self.config.port,
                " ".join(command),
            )
        elif not networking_utils.port_available(self.config.port):
            raise ProvisioningError(
                f"Unable to port-forward Label Studio to local "
                f"port {self.config.port} because the port is occupied. In order to "
                f"access Label Studio locally, please "
                f"change the configuration to use an available "
                f"port or stop the other process currently using the port."
            )
        else:
            from zenml.utils import daemon

            def _daemon_function() -> None:
                """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
                subprocess.check_call(command)

            daemon.run_as_daemon(
                _daemon_function,
                pid_file=self._pid_file_path,
                log_file=self._log_file,
            )
            logger.info(
                "Started Label Studio daemon (check the daemon"
                "logs at `%s` in case you're not able to access the annotation "
                f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
                self._log_file,
            )

    def stop_annotator_daemon(self) -> None:
        """Stops the annotation server backend."""
        if fileio.exists(self._pid_file_path):
            if sys.platform == "win32":
                # Daemon functionality is not supported on Windows, so the PID
                # file won't exist. This if clause exists just for mypy to not
                # complain about missing functions
                pass
            else:
                from zenml.utils import daemon

                daemon.stop_daemon(self._pid_file_path)
                fileio.remove(self._pid_file_path)
config: LabelStudioAnnotatorConfig property readonly

Returns the LabelStudioAnnotatorConfig config.

Returns:

Type Description
LabelStudioAnnotatorConfig

The configuration.

is_local_instance: bool property readonly

Determines if the Label Studio instance is running locally.

Returns:

Type Description
bool

True if the component is running locally, False otherwise.

is_provisioned: bool property readonly

If the component provisioned resources to run locally.

Returns:

Type Description
bool

True if the component provisioned resources to run locally.

is_running: bool property readonly

If the component is running locally.

Returns:

Type Description
bool

True if the component is running locally, False otherwise.

root_directory: str property readonly

Returns path to the root directory.

Returns:

Type Description
str

Path to the root directory.

validator: Optional[StackValidator] property readonly

Validates that the stack contains a cloud artifact store.

Returns:

Type Description
StackValidator

Validator for the stack.

add_dataset(self, **kwargs)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

A Label Studio Project object.

Exceptions:

Type Description
ValueError

if 'dataset_name' and 'label_config' aren't provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def add_dataset(self, **kwargs: Any) -> Any:
    """Registers a dataset for annotation.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        A Label Studio Project object.

    Raises:
        ValueError: if 'dataset_name' and 'label_config' aren't provided.
    """
    dataset_name = kwargs.get("dataset_name")
    label_config = kwargs.get("label_config")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")
    elif not label_config:
        raise ValueError("`label_config` keyword argument is required.")

    return self._get_client().start_project(
        title=dataset_name,
        label_config=label_config,
    )
connect_and_sync_external_storage(self, uri, params, dataset)

Syncs the external storage for the given project.

Parameters:

Name Type Description Default
uri str

URI of the storage source.

required
params LabelStudioDatasetSyncParameters

Parameters for the dataset.

required
dataset Project

Label Studio dataset.

required

Returns:

Type Description
Optional[Dict[str, Any]]

A dictionary containing the sync result.

Exceptions:

Type Description
ValueError

If the storage type is not supported.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def connect_and_sync_external_storage(
    self,
    uri: str,
    params: LabelStudioDatasetSyncParameters,
    dataset: Project,
) -> Optional[Dict[str, Any]]:
    """Syncs the external storage for the given project.

    Args:
        uri: URI of the storage source.
        params: Parameters for the dataset.
        dataset: Label Studio dataset.

    Returns:
        A dictionary containing the sync result.

    Raises:
        ValueError: If the storage type is not supported.
    """
    # TODO: check if proposed storage source has differing / new data
    # if self._storage_source_already_exists(uri, config, dataset):
    #     return None

    storage_connection_args = {
        "prefix": params.prefix,
        "regex_filter": params.regex_filter,
        "use_blob_urls": params.use_blob_urls,
        "presign": params.presign,
        "presign_ttl": params.presign_ttl,
        "title": dataset.get_params()["title"],
        "description": params.description,
    }
    if params.storage_type == "azure":
        if not params.azure_account_name or not params.azure_account_key:
            logger.warning(
                "Authentication credentials for Azure aren't fully "
                "provided. Please update the storage synchronization "
                "settings in the Label Studio web UI as per your needs."
            )
        storage = dataset.connect_azure_import_storage(
            container=uri,
            account_name=params.azure_account_name,
            account_key=params.azure_account_key,
            **storage_connection_args,
        )
    elif params.storage_type == "gcs":
        if not params.google_application_credentials:
            logger.warning(
                "Authentication credentials for Google Cloud Storage "
                "aren't fully provided. Please update the storage "
                "synchronization settings in the Label Studio web UI as "
                "per your needs."
            )
        storage = dataset.connect_google_import_storage(
            bucket=uri,
            google_application_credentials=params.google_application_credentials,
            **storage_connection_args,
        )
    elif params.storage_type == "s3":
        if not params.aws_access_key_id or not params.aws_secret_access_key:
            logger.warning(
                "Authentication credentials for S3 aren't fully provided."
                "Please update the storage synchronization settings in the "
                " Label Studio web UI as per your needs."
            )
        storage = dataset.connect_s3_import_storage(
            bucket=uri,
            aws_access_key_id=params.aws_access_key_id,
            aws_secret_access_key=params.aws_secret_access_key,
            aws_session_token=params.aws_session_token,
            region_name=params.s3_region_name,
            s3_endpoint=params.s3_endpoint,
            **storage_connection_args,
        )
    else:
        raise ValueError(
            f"Invalid storage type. '{params.storage_type}' is not supported by ZenML's Label Studio integration. Please choose between 'azure', 'gcs' and 'aws'."
        )

    synced_storage = self._get_client().sync_storage(
        storage_id=storage["id"], storage_type=storage["type"]
    )
    return cast(Dict[str, Any], synced_storage)
delete_dataset(self, **kwargs)

Deletes a dataset from the annotation interface.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Exceptions:

Type Description
NotImplementedError

If the deletion of a dataset is not supported.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def delete_dataset(self, **kwargs: Any) -> None:
    """Deletes a dataset from the annotation interface.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio
            client.

    Raises:
        NotImplementedError: If the deletion of a dataset is not supported.
    """
    raise NotImplementedError("Awaiting Label Studio release.")
    # TODO: Awaiting a new Label Studio version to be released with this method
    # ls = self._get_client()
    # dataset_name = kwargs.get("dataset_name")
    # if not dataset_name:
    #     raise ValueError("`dataset_name` keyword argument is required.")

    # dataset_id = self.get_id_from_name(dataset_name)
    # if not dataset_id:
    #     raise ValueError(
    #         f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
    #     )
    # ls.delete_project(dataset_id)
deprovision(self)

Spins down the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def deprovision(self) -> None:
    """Spins down the annotation server backend."""
    if fileio.exists(self._log_file):
        fileio.remove(self._log_file)
get_converted_dataset(self, dataset_name, output_format)

Extract annotated tasks in a specific converted format.

Parameters:

Name Type Description Default
dataset_name str

Id of the dataset.

required
output_format str

Output format.

required

Returns:

Type Description
Dict[Any, Any]

A dictionary containing the converted dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_converted_dataset(
    self, dataset_name: str, output_format: str
) -> Dict[Any, Any]:
    """Extract annotated tasks in a specific converted format.

    Args:
        dataset_name: Id of the dataset.
        output_format: Output format.

    Returns:
        A dictionary containing the converted dataset.
    """
    project = self.get_dataset(dataset_name=dataset_name)
    return project.export_tasks(export_type=output_format)  # type: ignore[no-any-return]
get_dataset(self, **kwargs)

Gets the dataset with the given name.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The LabelStudio Dataset object (a 'Project') for the given name.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset(self, **kwargs: Any) -> Any:
    """Gets the dataset with the given name.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The LabelStudio Dataset object (a 'Project') for the given name.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    # TODO: check for and raise error if client unavailable
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id)
get_dataset_names(self)

Gets the names of the datasets.

Returns:

Type Description
List[str]

A list of dataset names.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_names(self) -> List[str]:
    """Gets the names of the datasets.

    Returns:
        A list of dataset names.
    """
    return [
        dataset.get_params()["title"] for dataset in self.get_datasets()
    ]
get_dataset_stats(self, dataset_name)

Gets the statistics of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Tuple[int, int]

A tuple containing (labeled_task_count, unlabeled_task_count) for the dataset.

Exceptions:

Type Description
IndexError

If the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]:
    """Gets the statistics of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        A tuple containing (labeled_task_count, unlabeled_task_count) for
            the dataset.

    Raises:
        IndexError: If the dataset does not exist.
    """
    for project in self.get_datasets():
        if dataset_name in project.get_params()["title"]:
            labeled_task_count = len(project.get_labeled_tasks())
            unlabeled_task_count = len(project.get_unlabeled_tasks())
            return (labeled_task_count, unlabeled_task_count)
    raise IndexError(
        f"Dataset {dataset_name} not found. Please use "
        f"`zenml annotator dataset list` to list all available datasets."
    )
get_datasets(self)

Gets the datasets currently available for annotation.

Returns:

Type Description
List[Any]

A list of datasets.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_datasets(self) -> List[Any]:
    """Gets the datasets currently available for annotation.

    Returns:
        A list of datasets.
    """
    datasets = self._get_client().get_projects()
    return cast(List[Any], datasets)
get_id_from_name(self, dataset_name)

Gets the ID of the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
Optional[int]

The ID of the dataset.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_id_from_name(self, dataset_name: str) -> Optional[int]:
    """Gets the ID of the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The ID of the dataset.
    """
    projects = self.get_datasets()
    for project in projects:
        if project.get_params()["title"] == dataset_name:
            return cast(int, project.get_params()["id"])
    return None
get_labeled_data(self, **kwargs)

Gets the labeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The labeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided or if the dataset does not exist.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_labeled_data(self, **kwargs: Any) -> Any:
    """Gets the labeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The labeled data.

    Raises:
        ValueError: If the dataset name is not provided or if the dataset
            does not exist.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_labeled_tasks()
get_parsed_label_config(self, dataset_id)

Returns the parsed Label Studio label config for a dataset.

Parameters:

Name Type Description Default
dataset_id int

Id of the dataset.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing the parsed label config.

Exceptions:

Type Description
ValueError

If no dataset is found for the given id.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]:
    """Returns the parsed Label Studio label config for a dataset.

    Args:
        dataset_id: Id of the dataset.

    Returns:
        A dictionary containing the parsed label config.

    Raises:
        ValueError: If no dataset is found for the given id.
    """
    # TODO: check if client actually is connected etc
    dataset = self._get_client().get_project(dataset_id)
    if dataset:
        return cast(Dict[str, Any], dataset.parsed_label_config)
    raise ValueError("No dataset found for the given id.")
get_unlabeled_data(self, **kwargs)

Gets the unlabeled data for the given dataset.

Parameters:

Name Type Description Default
**kwargs str

Additional keyword arguments to pass to the Label Studio client.

{}

Returns:

Type Description
Any

The unlabeled data.

Exceptions:

Type Description
ValueError

If the dataset name is not provided.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_unlabeled_data(self, **kwargs: str) -> Any:
    """Gets the unlabeled data for the given dataset.

    Args:
        **kwargs: Additional keyword arguments to pass to the Label Studio client.

    Returns:
        The unlabeled data.

    Raises:
        ValueError: If the dataset name is not provided.
    """
    dataset_name = kwargs.get("dataset_name")
    if not dataset_name:
        raise ValueError("`dataset_name` keyword argument is required.")

    dataset_id = self.get_id_from_name(dataset_name)
    if not dataset_id:
        raise ValueError(
            f"Dataset name '{dataset_name}' has no corresponding `dataset_id` in Label Studio."
        )
    return self._get_client().get_project(dataset_id).get_unlabeled_tasks()
get_url(self)

Gets the top-level URL of the annotation interface.

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url(self) -> str:
    """Gets the top-level URL of the annotation interface.

    Returns:
        The URL of the annotation interface.
    """
    return f"{self.config.instance_url}:{self.config.port}"
get_url_for_dataset(self, dataset_name)

Gets the URL of the annotation interface for the given dataset.

Parameters:

Name Type Description Default
dataset_name str

The name of the dataset.

required

Returns:

Type Description
str

The URL of the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def get_url_for_dataset(self, dataset_name: str) -> str:
    """Gets the URL of the annotation interface for the given dataset.

    Args:
        dataset_name: The name of the dataset.

    Returns:
        The URL of the annotation interface.
    """
    project_id = self.get_id_from_name(dataset_name)
    return f"{self.get_url()}/projects/{project_id}/"
launch(self, url)

Launches the annotation interface.

Parameters:

Name Type Description Default
url Optional[str]

The URL of the annotation interface.

required
Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def launch(self, url: Optional[str]) -> None:
    """Launches the annotation interface.

    Args:
        url: The URL of the annotation interface.
    """
    if not url:
        url = self.get_url()
    if self._connection_available():
        webbrowser.open(url, new=1, autoraise=True)
    else:
        logger.warning(
            "Could not launch annotation interface"
            "because the connection could not be established."
        )
provision(self)

Spins up the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def provision(self) -> None:
    """Spins up the annotation server backend."""
    fileio.makedirs(self.root_directory)
register_dataset_for_annotation(self, params)

Registers a dataset for annotation.

Parameters:

Name Type Description Default
params LabelStudioDatasetRegistrationParameters

Parameters for the dataset.

required

Returns:

Type Description
Any

A Label Studio Project object.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def register_dataset_for_annotation(
    self,
    params: LabelStudioDatasetRegistrationParameters,
) -> Any:
    """Registers a dataset for annotation.

    Args:
        params: Parameters for the dataset.

    Returns:
        A Label Studio Project object.
    """
    project_id = self.get_id_from_name(params.dataset_name)
    if project_id:
        dataset = self._get_client().get_project(project_id)
    else:
        dataset = self.add_dataset(
            dataset_name=params.dataset_name,
            label_config=params.label_config,
        )

    return dataset
resume(self)

Resumes the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def resume(self) -> None:
    """Resumes the annotation interface."""
    if self.is_running:
        logger.info("Local annotation deployment already running.")
        return

    if self.is_local_instance:
        self.start_annotator_daemon()
start_annotator_daemon(self)

Starts the annotation server backend.

Exceptions:

Type Description
ProvisioningError

If the annotation server backend is already running or the port is already occupied.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def start_annotator_daemon(self) -> None:
    """Starts the annotation server backend.

    Raises:
        ProvisioningError: If the annotation server backend is already
            running or the port is already occupied.
    """
    command = [
        "label-studio",
        "start",
        "--no-browser",
        "--port",
        f"{self.config.port}",
    ]

    if sys.platform == "win32":
        logger.warning(
            "Daemon functionality not supported on Windows. "
            "In order to access the Label Studio server locally, "
            "please run '%s' in a separate command line shell.",
            self.config.port,
            " ".join(command),
        )
    elif not networking_utils.port_available(self.config.port):
        raise ProvisioningError(
            f"Unable to port-forward Label Studio to local "
            f"port {self.config.port} because the port is occupied. In order to "
            f"access Label Studio locally, please "
            f"change the configuration to use an available "
            f"port or stop the other process currently using the port."
        )
    else:
        from zenml.utils import daemon

        def _daemon_function() -> None:
            """Forwards the port of the Kubeflow Pipelines Metadata pod ."""
            subprocess.check_call(command)

        daemon.run_as_daemon(
            _daemon_function,
            pid_file=self._pid_file_path,
            log_file=self._log_file,
        )
        logger.info(
            "Started Label Studio daemon (check the daemon"
            "logs at `%s` in case you're not able to access the annotation "
            f"interface). Please visit `{self.get_url()}/` to use the Label Studio interface.",
            self._log_file,
        )
stop_annotator_daemon(self)

Stops the annotation server backend.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def stop_annotator_daemon(self) -> None:
    """Stops the annotation server backend."""
    if fileio.exists(self._pid_file_path):
        if sys.platform == "win32":
            # Daemon functionality is not supported on Windows, so the PID
            # file won't exist. This if clause exists just for mypy to not
            # complain about missing functions
            pass
        else:
            from zenml.utils import daemon

            daemon.stop_daemon(self._pid_file_path)
            fileio.remove(self._pid_file_path)
suspend(self)

Suspends the annotation interface.

Source code in zenml/integrations/label_studio/annotators/label_studio_annotator.py
def suspend(self) -> None:
    """Suspends the annotation interface."""
    if not self.is_running:
        logger.info("Local annotation server is not running.")
        return

    if self.is_local_instance:
        self.stop_annotator_daemon()

flavors special

Label Studio integration flavors.

label_studio_annotator_flavor

Label Studio annotator flavor.

LabelStudioAnnotatorConfig (BaseAnnotatorConfig, AuthenticationConfigMixin) pydantic-model

Config for the Label Studio annotator.

Attributes:

Name Type Description
instance_url str

URL of the Label Studio instance.

port int

The port to use for the annotation interface.

Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorConfig(
    BaseAnnotatorConfig, AuthenticationConfigMixin
):
    """Config for the Label Studio annotator.

    Attributes:
        instance_url: URL of the Label Studio instance.
        port: The port to use for the annotation interface.
    """

    instance_url: str = DEFAULT_LOCAL_INSTANCE_URL
    port: int = DEFAULT_LOCAL_LABEL_STUDIO_PORT
LabelStudioAnnotatorFlavor (BaseAnnotatorFlavor)

Label Studio annotator flavor.

Source code in zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py
class LabelStudioAnnotatorFlavor(BaseAnnotatorFlavor):
    """Label Studio annotator flavor."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return LABEL_STUDIO_ANNOTATOR_FLAVOR

    @property
    def config_class(self) -> Type[LabelStudioAnnotatorConfig]:
        """Returns `LabelStudioAnnotatorConfig` config class.

        Returns:
                The config class.
        """
        return LabelStudioAnnotatorConfig

    @property
    def implementation_class(self) -> Type["LabelStudioAnnotator"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.label_studio.annotators import (
            LabelStudioAnnotator,
        )

        return LabelStudioAnnotator
config_class: Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig] property readonly

Returns LabelStudioAnnotatorConfig config class.

Returns:

Type Description
Type[zenml.integrations.label_studio.flavors.label_studio_annotator_flavor.LabelStudioAnnotatorConfig]

The config class.

implementation_class: Type[LabelStudioAnnotator] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[LabelStudioAnnotator]

The implementation class.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

label_config_generators special

Initialization of the Label Studio config generators submodule.

label_config_generators

Implementation of label config generators for Label Studio.

generate_basic_object_detection_bounding_boxes_label_config(labels)

Generates a Label Studio config for object detection with bounding boxes.

This is based on the basic config example shown at https://labelstud.io/templates/image_bbox.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_object_detection_bounding_boxes_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio config for object detection with bounding boxes.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_bbox.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.OBJECT_DETECTION_BOUNDING_BOXES

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <RectangleLabels name="label" toName="image">
    """
    label_config_choices = "".join(
        f"<Label value='{label}' />\n" for label in labels
    )
    label_config_end = "</RectangleLabels>\n</View>"
    label_config = label_config_start + label_config_choices + label_config_end

    return (
        label_config,
        label_config_type,
    )
generate_basic_ocr_label_config(labels)

Generates a Label Studio config for optical character recognition (OCR) labeling task.

This is based on the basic config example shown at https://labelstud.io/templates/optical_character_recognition.html

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_basic_ocr_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio config for optical character recognition (OCR) labeling task.

    This is based on the basic config example shown at
    https://labelstud.io/templates/optical_character_recognition.html

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.OCR

    label_config_start = """
    <View>
    <Image name="image" value="$ocr" zoom="true" zoomControl="true" rotateControl="true"/>
    <View>
    <Filter toName="label" minlength="0" name="filter"/>
    <Labels name="label" toName="image">
    """
    label_config_choices = "".join(
        f"<Label value='{label}' />\n" for label in labels
    )

    label_config_end = """
    </Labels>
    </View>
    <Rectangle name="bbox" toName="image" strokeWidth="3"/>
    <Polygon name="poly" toName="image" strokeWidth="3"/>
    <TextArea name="transcription" toName="image" editable="true" perRegion="true" required="true" maxSubmissions="1" rows="5" placeholder="Recognized Text" displayMode="region-list"/>
    </View>
    """
    label_config = label_config_start + label_config_choices + label_config_end

    return (
        label_config,
        label_config_type,
    )
generate_image_classification_label_config(labels)

Generates a Label Studio label config for image classification.

This is based on the basic config example shown at https://labelstud.io/templates/image_classification.html.

Parameters:

Name Type Description Default
labels List[str]

A list of labels to be used in the label config.

required

Returns:

Type Description
Tuple[str, str]

A tuple of the generated label config and the label config type.

Exceptions:

Type Description
ValueError

If no labels are provided.

Source code in zenml/integrations/label_studio/label_config_generators/label_config_generators.py
def generate_image_classification_label_config(
    labels: List[str],
) -> Tuple[str, str]:
    """Generates a Label Studio label config for image classification.

    This is based on the basic config example shown at
    https://labelstud.io/templates/image_classification.html.

    Args:
        labels: A list of labels to be used in the label config.

    Returns:
        A tuple of the generated label config and the label config type.

    Raises:
        ValueError: If no labels are provided.
    """
    if not labels:
        raise ValueError("No labels provided")

    label_config_type = AnnotationTasks.IMAGE_CLASSIFICATION

    label_config_start = """<View>
    <Image name="image" value="$image"/>
    <Choices name="choice" toName="image">
    """
    label_config_choices = "".join(
        f"<Choice value='{label}' />\n" for label in labels
    )
    label_config_end = "</Choices>\n</View>"

    label_config = label_config_start + label_config_choices + label_config_end
    return (
        label_config,
        label_config_type,
    )

label_studio_utils

Utility functions for the Label Studio annotator integration.

convert_pred_filenames_to_task_ids(preds, tasks, filename_reference, storage_type)

Converts a list of predictions from local file references to task id.

Parameters:

Name Type Description Default
preds List[Dict[str, Any]]

List of predictions.

required
tasks List[Dict[str, Any]]

List of tasks.

required
filename_reference str

Name of the file reference in the predictions.

required
storage_type str

Storage type of the predictions.

required

Returns:

Type Description
List[Dict[str, Any]]

List of predictions using task ids as reference.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def convert_pred_filenames_to_task_ids(
    preds: List[Dict[str, Any]],
    tasks: List[Dict[str, Any]],
    filename_reference: str,
    storage_type: str,
) -> List[Dict[str, Any]]:
    """Converts a list of predictions from local file references to task id.

    Args:
        preds: List of predictions.
        tasks: List of tasks.
        filename_reference: Name of the file reference in the predictions.
        storage_type: Storage type of the predictions.

    Returns:
        List of predictions using task ids as reference.
    """
    filename_id_mapping = {
        os.path.basename(urlparse(task["data"][filename_reference]).path): task[
            "id"
        ]
        for task in tasks
    }
    # GCS and S3 URL encodes filenames containing spaces, requiring this
    # separate encoding step
    if storage_type in {"gcs", "s3"}:
        preds = [
            {"filename": quote(pred["filename"]), "result": pred["result"]}
            for pred in preds
        ]
    return [
        {
            "task": int(
                filename_id_mapping[os.path.basename(pred["filename"])]
            ),
            "result": pred["result"],
        }
        for pred in preds
    ]

get_file_extension(path_str)

Return the file extension of the given filename.

Parameters:

Name Type Description Default
path_str str

Path to the file.

required

Returns:

Type Description
str

File extension.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def get_file_extension(path_str: str) -> str:
    """Return the file extension of the given filename.

    Args:
        path_str: Path to the file.

    Returns:
        File extension.
    """
    return os.path.splitext(urlparse(path_str).path)[1]

is_azure_url(url)

Return whether the given URL is an Azure URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an Azure URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_azure_url(url: str) -> bool:
    """Return whether the given URL is an Azure URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an Azure URL, False otherwise.
    """
    return "blob.core.windows.net" in urlparse(url).netloc

is_gcs_url(url)

Return whether the given URL is an GCS URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an GCS URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_gcs_url(url: str) -> bool:
    """Return whether the given URL is an GCS URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an GCS URL, False otherwise.
    """
    return "storage.googleapis.com" in urlparse(url).netloc

is_s3_url(url)

Return whether the given URL is an S3 URL.

Parameters:

Name Type Description Default
url str

URL to check.

required

Returns:

Type Description
bool

True if the URL is an S3 URL, False otherwise.

Source code in zenml/integrations/label_studio/label_studio_utils.py
def is_s3_url(url: str) -> bool:
    """Return whether the given URL is an S3 URL.

    Args:
        url: URL to check.

    Returns:
        True if the URL is an S3 URL, False otherwise.
    """
    return "s3.amazonaws" in urlparse(url).netloc

steps special

Standard steps to be used with the Label Studio annotator integration.

label_studio_standard_steps

Implementation of standard steps for the Label Studio annotator integration.

LabelStudioDatasetRegistrationParameters (BaseParameters) pydantic-model

Step parameters when registering a dataset with Label Studio.

Attributes:

Name Type Description
label_config str

The label config to use for the annotation interface.

dataset_name str

Name of the dataset to register.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
    """Step parameters when registering a dataset with Label Studio.

    Attributes:
        label_config: The label config to use for the annotation interface.
        dataset_name: Name of the dataset to register.
    """

    label_config: str
    dataset_name: str
LabelStudioDatasetSyncParameters (BaseParameters) pydantic-model

Step parameters when syncing data to Label Studio.

Attributes:

Name Type Description
storage_type str

The type of storage to sync to.

label_config_type str

The type of label config to use.

prefix Optional[str]

Specify the prefix within the cloud store to import your data from.

regex_filter Optional[str]

Specify a regex filter to filter the files to import.

use_blob_urls Optional[bool]

Specify whether your data is raw image or video data, or JSON tasks.

presign Optional[bool]

Specify whether or not to create presigned URLs.

presign_ttl Optional[int]

Specify how long to keep presigned URLs active.

description Optional[str]

Specify a description for the dataset.

azure_account_name Optional[str]

Specify the Azure account name to use for the storage.

azure_account_key Optional[str]

Specify the Azure account key to use for the storage.

google_application_credentials Optional[str]

Specify the Google application credentials to use for the storage.

aws_access_key_id Optional[str]

Specify the AWS access key ID to use for the storage.

aws_secret_access_key Optional[str]

Specify the AWS secret access key to use for the storage.

aws_session_token Optional[str]

Specify the AWS session token to use for the storage.

s3_region_name Optional[str]

Specify the S3 region name to use for the storage.

s3_endpoint Optional[str]

Specify the S3 endpoint to use for the storage.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
    """Step parameters when syncing data to Label Studio.

    Attributes:
        storage_type: The type of storage to sync to.
        label_config_type: The type of label config to use.

        prefix: Specify the prefix within the cloud store to import your data
            from.
        regex_filter: Specify a regex filter to filter the files to import.
        use_blob_urls: Specify whether your data is raw image or video data, or
            JSON tasks.
        presign: Specify whether or not to create presigned URLs.
        presign_ttl: Specify how long to keep presigned URLs active.
        description: Specify a description for the dataset.

        azure_account_name: Specify the Azure account name to use for the
            storage.
        azure_account_key: Specify the Azure account key to use for the
            storage.
        google_application_credentials: Specify the Google application
            credentials to use for the storage.
        aws_access_key_id: Specify the AWS access key ID to use for the
            storage.
        aws_secret_access_key: Specify the AWS secret access key to use for the
            storage.
        aws_session_token: Specify the AWS session token to use for the
            storage.
        s3_region_name: Specify the S3 region name to use for the storage.
        s3_endpoint: Specify the S3 endpoint to use for the storage.
    """

    storage_type: str
    label_config_type: str

    prefix: Optional[str] = None
    regex_filter: Optional[str] = ".*"
    use_blob_urls: Optional[bool] = True
    presign: Optional[bool] = True
    presign_ttl: Optional[int] = 1
    description: Optional[str] = ""

    # credentials specific to the main cloud providers
    azure_account_name: Optional[str]
    azure_account_key: Optional[str]
    google_application_credentials: Optional[str]
    aws_access_key_id: Optional[str]
    aws_secret_access_key: Optional[str]
    aws_session_token: Optional[str]
    s3_region_name: Optional[str]
    s3_endpoint: Optional[str]
get_labeled_data (BaseStep)

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name

Name of the dataset.

required
context

The StepContext.

required

Returns:

Type Description

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

entrypoint(dataset_name, context) staticmethod

Gets labeled data from the dataset.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset.

required
context StepContext

The StepContext.

required

Returns:

Type Description
List

List of labeled data.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_labeled_data(dataset_name: str, context: StepContext) -> List:  # type: ignore[type-arg]
    """Gets labeled data from the dataset.

    Args:
        dataset_name: Name of the dataset.
        context: The StepContext.

    Returns:
        List of labeled data.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    # TODO [MEDIUM]: have this check for new data *since the last time this step ran*
    annotator = context.stack.annotator  # type: ignore[union-attr]
    if not annotator:
        raise StackComponentInterfaceError("No active annotator.")
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )
    if annotator._connection_available():
        dataset = annotator.get_dataset(dataset_name=dataset_name)
        return dataset.get_labeled_tasks()  # type: ignore[no-any-return]

    raise StackComponentInterfaceError(
        "Unable to connect to annotator stack component."
    )
get_or_create_dataset (BaseStep)

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
params

Step parameters.

required
context

Step context.

required

Returns:

Type Description

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

PARAMETERS_CLASS (BaseParameters) pydantic-model

Step parameters when registering a dataset with Label Studio.

Attributes:

Name Type Description
label_config str

The label config to use for the annotation interface.

dataset_name str

Name of the dataset to register.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetRegistrationParameters(BaseParameters):
    """Step parameters when registering a dataset with Label Studio.

    Attributes:
        label_config: The label config to use for the annotation interface.
        dataset_name: Name of the dataset to register.
    """

    label_config: str
    dataset_name: str
entrypoint(params, context) staticmethod

Gets preexisting dataset or creates a new one.

Parameters:

Name Type Description Default
params LabelStudioDatasetRegistrationParameters

Step parameters.

required
context StepContext

Step context.

required

Returns:

Type Description
str

The dataset name.

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def get_or_create_dataset(
    params: LabelStudioDatasetRegistrationParameters,
    context: StepContext,
) -> str:
    """Gets preexisting dataset or creates a new one.

    Args:
        params: Step parameters.
        context: Step context.

    Returns:
        The dataset name.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    annotator = context.stack.annotator  # type: ignore[union-attr]
    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    if annotator and annotator._connection_available():
        for dataset in annotator.get_datasets():
            if dataset.get_params()["title"] == params.dataset_name:
                return cast(str, dataset.get_params()["title"])

        dataset = annotator.register_dataset_for_annotation(params)
        return cast(str, dataset.get_params()["title"])

    raise StackComponentInterfaceError("No active annotator.")
    # if annotator and annotator._connection_available():
    #     preexisting_dataset_list = [
    #         dataset
    #         for dataset in annotator.get_datasets()
    #         if dataset.get_params()["title"] == config.dataset_name
    #     ]
    #     if (
    #         not preexisting_dataset_list
    #         and annotator
    #         and annotator._connection_available()
    #     ):
    #         registered_dataset = annotator.register_dataset_for_annotation(
    #             config
    #         )
    #     elif preexisting_dataset_list:
    #         return cast(str, preexisting_dataset_list[0].get_params()["title"])
    #     else:
    #         raise StackComponentInterfaceError("No active annotator.")

    #     return cast(str, registered_dataset.get_params()["title"])
    # else:
    #     raise StackComponentInterfaceError("No active annotator.")
sync_new_data_to_label_studio (BaseStep)

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri

The URI of the data to sync.

required
dataset_name

The name of the dataset to sync to.

required
predictions

The predictions to sync.

required
params

The parameters for the sync.

required
context

The StepContext.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

PARAMETERS_CLASS (BaseParameters) pydantic-model

Step parameters when syncing data to Label Studio.

Attributes:

Name Type Description
storage_type str

The type of storage to sync to.

label_config_type str

The type of label config to use.

prefix Optional[str]

Specify the prefix within the cloud store to import your data from.

regex_filter Optional[str]

Specify a regex filter to filter the files to import.

use_blob_urls Optional[bool]

Specify whether your data is raw image or video data, or JSON tasks.

presign Optional[bool]

Specify whether or not to create presigned URLs.

presign_ttl Optional[int]

Specify how long to keep presigned URLs active.

description Optional[str]

Specify a description for the dataset.

azure_account_name Optional[str]

Specify the Azure account name to use for the storage.

azure_account_key Optional[str]

Specify the Azure account key to use for the storage.

google_application_credentials Optional[str]

Specify the Google application credentials to use for the storage.

aws_access_key_id Optional[str]

Specify the AWS access key ID to use for the storage.

aws_secret_access_key Optional[str]

Specify the AWS secret access key to use for the storage.

aws_session_token Optional[str]

Specify the AWS session token to use for the storage.

s3_region_name Optional[str]

Specify the S3 region name to use for the storage.

s3_endpoint Optional[str]

Specify the S3 endpoint to use for the storage.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
class LabelStudioDatasetSyncParameters(BaseParameters):
    """Step parameters when syncing data to Label Studio.

    Attributes:
        storage_type: The type of storage to sync to.
        label_config_type: The type of label config to use.

        prefix: Specify the prefix within the cloud store to import your data
            from.
        regex_filter: Specify a regex filter to filter the files to import.
        use_blob_urls: Specify whether your data is raw image or video data, or
            JSON tasks.
        presign: Specify whether or not to create presigned URLs.
        presign_ttl: Specify how long to keep presigned URLs active.
        description: Specify a description for the dataset.

        azure_account_name: Specify the Azure account name to use for the
            storage.
        azure_account_key: Specify the Azure account key to use for the
            storage.
        google_application_credentials: Specify the Google application
            credentials to use for the storage.
        aws_access_key_id: Specify the AWS access key ID to use for the
            storage.
        aws_secret_access_key: Specify the AWS secret access key to use for the
            storage.
        aws_session_token: Specify the AWS session token to use for the
            storage.
        s3_region_name: Specify the S3 region name to use for the storage.
        s3_endpoint: Specify the S3 endpoint to use for the storage.
    """

    storage_type: str
    label_config_type: str

    prefix: Optional[str] = None
    regex_filter: Optional[str] = ".*"
    use_blob_urls: Optional[bool] = True
    presign: Optional[bool] = True
    presign_ttl: Optional[int] = 1
    description: Optional[str] = ""

    # credentials specific to the main cloud providers
    azure_account_name: Optional[str]
    azure_account_key: Optional[str]
    google_application_credentials: Optional[str]
    aws_access_key_id: Optional[str]
    aws_secret_access_key: Optional[str]
    aws_session_token: Optional[str]
    s3_region_name: Optional[str]
    s3_endpoint: Optional[str]
entrypoint(uri, dataset_name, predictions, params, context) staticmethod

Syncs new data to Label Studio.

Parameters:

Name Type Description Default
uri str

The URI of the data to sync.

required
dataset_name str

The name of the dataset to sync to.

required
predictions List[Dict[str, Any]]

The predictions to sync.

required
params LabelStudioDatasetSyncParameters

The parameters for the sync.

required
context StepContext

The StepContext.

required

Exceptions:

Type Description
TypeError

If you are trying to use it with an annotator that is not Label Studio.

ValueError

if you are trying to sync from outside ZenML.

StackComponentInterfaceError

If no active annotator could be found.

Source code in zenml/integrations/label_studio/steps/label_studio_standard_steps.py
@step(enable_cache=False)
def sync_new_data_to_label_studio(
    uri: str,
    dataset_name: str,
    predictions: List[Dict[str, Any]],
    params: LabelStudioDatasetSyncParameters,
    context: StepContext,
) -> None:
    """Syncs new data to Label Studio.

    Args:
        uri: The URI of the data to sync.
        dataset_name: The name of the dataset to sync to.
        predictions: The predictions to sync.
        params: The parameters for the sync.
        context: The StepContext.

    Raises:
        TypeError: If you are trying to use it with an annotator that is not
            Label Studio.
        ValueError: if you are trying to sync from outside ZenML.
        StackComponentInterfaceError: If no active annotator could be found.
    """
    annotator = context.stack.annotator  # type: ignore[union-attr]
    artifact_store = context.stack.artifact_store  # type: ignore[union-attr]
    secrets_manager = context.stack.secrets_manager  # type: ignore[union-attr]
    if not annotator or not artifact_store or not secrets_manager:
        raise StackComponentInterfaceError(
            "An active annotator, artifact store and secrets manager are required to run this step."
        )

    from zenml.integrations.label_studio.annotators.label_studio_annotator import (
        LabelStudioAnnotator,
    )

    if not isinstance(annotator, LabelStudioAnnotator):
        raise TypeError(
            "This step can only be used with the Label Studio annotator."
        )

    # TODO: check that annotator is connected before querying it
    dataset = annotator.get_dataset(dataset_name=dataset_name)
    if not uri.startswith(artifact_store.path):
        raise ValueError(
            "ZenML only currently supports syncing data passed from other ZenML steps and via the Artifact Store."
        )

    # removes the initial forward slash from the prefix attribute by slicing
    params.prefix = urlparse(uri).path.lstrip("/")
    base_uri = urlparse(uri).netloc

    # gets the secret used for authentication
    if params.storage_type == "azure":
        if not isinstance(artifact_store, AuthenticationMixin):
            raise TypeError(
                "The artifact store must inherit from "
                f"{AuthenticationMixin.__name__} to work with a Label Studio "
                f"`{params.storage_type}` storage."
            )

        azure_secret = artifact_store.get_authentication_secret(
            expected_schema_type=AzureSecretSchema
        )

        if not azure_secret:
            raise ValueError(
                "Missing secret to authenticate cloud storage for Label Studio."
            )

        params.azure_account_name = azure_secret.account_name
        params.azure_account_key = azure_secret.account_key
    elif params.storage_type == "gcs":
        if not isinstance(artifact_store, AuthenticationMixin):
            raise TypeError(
                "The artifact store must inherit from "
                f"{AuthenticationMixin.__name__} to work with a Label Studio "
                f"`{params.storage_type}` storage."
            )

        gcp_secret = artifact_store.get_authentication_secret(
            expected_schema_type=GCPSecretSchema
        )
        if not gcp_secret:
            raise ValueError(
                "Missing secret to authenticate cloud storage for Label Studio."
            )

        params.google_application_credentials = gcp_secret.token
    elif params.storage_type == "s3":
        aws_secret = secrets_manager.get_secret(LABEL_STUDIO_AWS_SECRET_NAME)
        if not isinstance(aws_secret, AWSSecretSchema):
            raise TypeError(
                f"The secret `{LABEL_STUDIO_AWS_SECRET_NAME}` needs to be "
                f"an `aws` schema secret."
            )

        params.aws_access_key_id = aws_secret.aws_access_key_id
        params.aws_secret_access_key = aws_secret.aws_secret_access_key
        params.aws_session_token = aws_secret.aws_session_token

    if annotator and annotator._connection_available():
        # TODO: get existing (CHECK!) or create the sync connection
        annotator.connect_and_sync_external_storage(
            uri=base_uri,
            params=params,
            dataset=dataset,
        )
        if predictions:
            filename_reference = TASK_TO_FILENAME_REFERENCE_MAPPING[
                params.label_config_type
            ]
            preds_with_task_ids = convert_pred_filenames_to_task_ids(
                predictions,
                dataset.tasks,
                filename_reference,
                params.storage_type,
            )
            # TODO: filter out any predictions that exist + have already been
            # made (maybe?). Only pass in preds for tasks without pre-annotations.
            dataset.create_predictions(preds_with_task_ids)
    else:
        raise StackComponentInterfaceError("No active annotator.")