Skip to content

Pytorch

zenml.integrations.pytorch special

Initialization of the PyTorch integration.

PytorchIntegration (Integration)

Definition of PyTorch integration for ZenML.

Source code in zenml/integrations/pytorch/__init__.py
class PytorchIntegration(Integration):
    """Definition of PyTorch integration for ZenML."""

    NAME = PYTORCH
    REQUIREMENTS = ["torch"]

    @classmethod
    def activate(cls) -> None:
        """Activates the integration."""
        from zenml.integrations.pytorch import materializers  # noqa

activate() classmethod

Activates the integration.

Source code in zenml/integrations/pytorch/__init__.py
@classmethod
def activate(cls) -> None:
    """Activates the integration."""
    from zenml.integrations.pytorch import materializers  # noqa

materializers special

Initialization of the PyTorch Materializer.

pytorch_dataloader_materializer

Implementation of the PyTorch DataLoader materializer.

PyTorchDataLoaderMaterializer (BaseMaterializer)

Materializer to read/write PyTorch dataloaders.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
class PyTorchDataLoaderMaterializer(BaseMaterializer):
    """Materializer to read/write PyTorch dataloaders."""

    ASSOCIATED_TYPES = (DataLoader,)
    ASSOCIATED_ARTIFACT_TYPES = (DataArtifact,)

    def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
        """Reads and returns a PyTorch dataloader.

        Args:
            data_type: The type of the dataloader to load.

        Returns:
            A loaded PyTorch dataloader.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, dataloader: DataLoader[Any]) -> None:
        """Writes a PyTorch dataloader.

        Args:
            dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
        """
        super().handle_return(dataloader)

        # Save entire dataloader to artifact directory
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(dataloader, f)
handle_input(self, data_type)

Reads and returns a PyTorch dataloader.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the dataloader to load.

required

Returns:

Type Description
torch.utils.data.dataloader.DataLoader[Any]

A loaded PyTorch dataloader.

Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_input(self, data_type: Type[Any]) -> DataLoader[Any]:
    """Reads and returns a PyTorch dataloader.

    Args:
        data_type: The type of the dataloader to load.

    Returns:
        A loaded PyTorch dataloader.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return cast(DataLoader[Any], torch.load(f))  # type: ignore[no-untyped-call]  # noqa
handle_return(self, dataloader)

Writes a PyTorch dataloader.

Parameters:

Name Type Description Default
dataloader torch.utils.data.dataloader.DataLoader[Any]

A torch.utils.DataLoader or a dict to pass into dataloader.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py
def handle_return(self, dataloader: DataLoader[Any]) -> None:
    """Writes a PyTorch dataloader.

    Args:
        dataloader: A torch.utils.DataLoader or a dict to pass into dataloader.save
    """
    super().handle_return(dataloader)

    # Save entire dataloader to artifact directory
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(dataloader, f)

pytorch_module_materializer

Implementation of the PyTorch Module materializer.

PyTorchModuleMaterializer (BaseMaterializer)

Materializer to read/write Pytorch models.

Inspired by the guide: https://pytorch.org/tutorials/beginner/saving_loading_models.html

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
class PyTorchModuleMaterializer(BaseMaterializer):
    """Materializer to read/write Pytorch models.

    Inspired by the guide:
    https://pytorch.org/tutorials/beginner/saving_loading_models.html
    """

    ASSOCIATED_TYPES = (Module,)
    ASSOCIATED_ARTIFACT_TYPES = (ModelArtifact,)

    def handle_input(self, data_type: Type[Any]) -> Module:
        """Reads and returns a PyTorch model.

        Only loads the model, not the checkpoint.

        Args:
            data_type: The type of the model to load.

        Returns:
            A loaded pytorch model.
        """
        super().handle_input(data_type)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
        ) as f:
            return torch.load(f)  # type: ignore[no-untyped-call]  # noqa

    def handle_return(self, model: Module) -> None:
        """Writes a PyTorch model, as a model and a checkpoint.

        Args:
            model: A torch.nn.Module or a dict to pass into model.save
        """
        super().handle_return(model)

        # Save entire model to artifact directory, This is the default behavior
        # for loading model in development phase (training, evaluation)
        with fileio.open(
            os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
        ) as f:
            torch.save(model, f)

        # Also save model checkpoint to artifact directory,
        # This is the default behavior for loading model in production phase (inference)
        if isinstance(model, Module):
            with fileio.open(
                os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
            ) as f:
                torch.save(model.state_dict(), f)
handle_input(self, data_type)

Reads and returns a PyTorch model.

Only loads the model, not the checkpoint.

Parameters:

Name Type Description Default
data_type Type[Any]

The type of the model to load.

required

Returns:

Type Description
Module

A loaded pytorch model.

Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_input(self, data_type: Type[Any]) -> Module:
    """Reads and returns a PyTorch model.

    Only loads the model, not the checkpoint.

    Args:
        data_type: The type of the model to load.

    Returns:
        A loaded pytorch model.
    """
    super().handle_input(data_type)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "rb"
    ) as f:
        return torch.load(f)  # type: ignore[no-untyped-call]  # noqa
handle_return(self, model)

Writes a PyTorch model, as a model and a checkpoint.

Parameters:

Name Type Description Default
model Module

A torch.nn.Module or a dict to pass into model.save

required
Source code in zenml/integrations/pytorch/materializers/pytorch_module_materializer.py
def handle_return(self, model: Module) -> None:
    """Writes a PyTorch model, as a model and a checkpoint.

    Args:
        model: A torch.nn.Module or a dict to pass into model.save
    """
    super().handle_return(model)

    # Save entire model to artifact directory, This is the default behavior
    # for loading model in development phase (training, evaluation)
    with fileio.open(
        os.path.join(self.artifact.uri, DEFAULT_FILENAME), "wb"
    ) as f:
        torch.save(model, f)

    # Also save model checkpoint to artifact directory,
    # This is the default behavior for loading model in production phase (inference)
    if isinstance(model, Module):
        with fileio.open(
            os.path.join(self.artifact.uri, CHECKPOINT_FILENAME), "wb"
        ) as f:
            torch.save(model.state_dict(), f)