Skip to content

DeviceMixin

Methods and Attributes

A mixin for consistent device and dtype management across all modules.

This should be inherited by all classes that are sensitive to device or dtype.

Source code in torchebm/core/device_mixin.py
class DeviceMixin:
    """
    A mixin for consistent device and dtype management across all modules.

    This should be inherited by all classes that are sensitive to device or dtype.
    """

    def __init__(self, device: Union[str, torch.device, None] = None, dtype: Optional[torch.dtype] = None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._device = normalize_device(device)
        self._dtype: Optional[torch.dtype] = dtype

    @property
    def device(self) -> torch.device:
        if self._device is not None:
            return normalize_device(self._device)
        if self._device is None:
            if hasattr(self, "parameters") and callable(getattr(self, "parameters")):
                try:
                    param_device = next(self.parameters()).device
                    return normalize_device(param_device)
                except StopIteration:
                    pass

            if hasattr(self, "buffers") and callable(getattr(self, "buffers")):
                try:
                    buffer_device = next(self.buffers()).device
                    return normalize_device(buffer_device)
                except StopIteration:
                    pass

        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @property
    def dtype(self) -> torch.dtype:
        if self._dtype is not None:
            return self._dtype
        # Try infer from parameters/buffers if available
        if hasattr(self, "parameters") and callable(getattr(self, "parameters")):
            try:
                param_dtype = next(self.parameters()).dtype
                return param_dtype
            except StopIteration:
                pass
        if hasattr(self, "buffers") and callable(getattr(self, "buffers")):
            try:
                buffer_dtype = next(self.buffers()).dtype
                return buffer_dtype
            except StopIteration:
                pass
        return torch.float32

    @dtype.setter
    def dtype(self, value: torch.dtype):
        self._dtype = value

    def to(self, *args, **kwargs):
        """Override to() to update internal device tracking."""
        # Call parent's to() if it exists (e.g., nn.Module); otherwise, operate in-place
        parent_to = getattr(super(), "to", None)
        result = self
        if callable(parent_to):
            result = parent_to(*args, **kwargs)

        # Update internal device tracking based on provided args/kwargs
        target_device = None
        target_dtype = None
        if args and isinstance(args[0], (str, torch.device)):
            target_device = normalize_device(args[0])
        elif args and isinstance(args[0], torch.dtype):
            target_dtype = args[0]
        if "device" in kwargs:
            target_device = normalize_device(kwargs["device"])
        if "dtype" in kwargs and isinstance(kwargs["dtype"], torch.dtype):
            target_dtype = kwargs["dtype"]
        if target_device is not None:
            self._device = target_device
        if target_dtype is not None:
            self._dtype = target_dtype

        return result

    @staticmethod
    def safe_to(obj, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
        """
        Safely moves an object to a device and/or dtype, if it supports the `.to()` method.
        """
        if not hasattr(obj, "to") or not callable(getattr(obj, "to")):
            return obj
        try:
            if device is not None or dtype is not None:
                return obj.to(device=device, dtype=dtype)
            return obj
        except TypeError:
            # Fallbacks for custom signatures
            if device is not None:
                try:
                    return obj.to(device)
                except Exception:
                    pass
            if dtype is not None:
                try:
                    return obj.to(dtype)
                except Exception:
                    pass
            return obj

    # Mixed precision helpers
    def setup_mixed_precision(self, use_mixed_precision: bool) -> None:
        """Configures mixed precision settings."""
        self.use_mixed_precision = bool(use_mixed_precision)
        if self.use_mixed_precision:
            try:
                # Import lazily to avoid hard dependency when not used
                from torch.cuda.amp import autocast as _autocast  # noqa: F401
                self.autocast_available = True
                if not self.device.type.startswith("cuda"):
                    warnings.warn(
                        f"Mixed precision requested but device is {self.device}. Mixed precision requires CUDA. Falling back to full precision.",
                        UserWarning,
                    )
                    self.use_mixed_precision = False
                    self.autocast_available = False
            except ImportError:
                warnings.warn(
                    "Mixed precision requested but torch.cuda.amp not available. Falling back to full precision. Requires PyTorch 1.6+.",
                    UserWarning,
                )
                self.use_mixed_precision = False
                self.autocast_available = False
        else:
            self.autocast_available = False

    def autocast_context(self):
        """
        Returns a `torch.cuda.amp.autocast` context manager if mixed precision is enabled,
        otherwise a `nullcontext`.
        """
        if getattr(self, "use_mixed_precision", False) and getattr(self, "autocast_available", False):
            from torch.cuda.amp import autocast
            return autocast()
        return nullcontext()

_device instance-attribute

_device = normalize_device(device)

_dtype instance-attribute

_dtype: Optional[dtype] = dtype

device property

device: device

dtype property writable

dtype: dtype

to

to(*args, **kwargs)

Override to() to update internal device tracking.

Source code in torchebm/core/device_mixin.py
def to(self, *args, **kwargs):
    """Override to() to update internal device tracking."""
    # Call parent's to() if it exists (e.g., nn.Module); otherwise, operate in-place
    parent_to = getattr(super(), "to", None)
    result = self
    if callable(parent_to):
        result = parent_to(*args, **kwargs)

    # Update internal device tracking based on provided args/kwargs
    target_device = None
    target_dtype = None
    if args and isinstance(args[0], (str, torch.device)):
        target_device = normalize_device(args[0])
    elif args and isinstance(args[0], torch.dtype):
        target_dtype = args[0]
    if "device" in kwargs:
        target_device = normalize_device(kwargs["device"])
    if "dtype" in kwargs and isinstance(kwargs["dtype"], torch.dtype):
        target_dtype = kwargs["dtype"]
    if target_device is not None:
        self._device = target_device
    if target_dtype is not None:
        self._dtype = target_dtype

    return result

safe_to staticmethod

safe_to(obj, device: Optional[device] = None, dtype: Optional[dtype] = None)

Safely moves an object to a device and/or dtype, if it supports the .to() method.

Source code in torchebm/core/device_mixin.py
@staticmethod
def safe_to(obj, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
    """
    Safely moves an object to a device and/or dtype, if it supports the `.to()` method.
    """
    if not hasattr(obj, "to") or not callable(getattr(obj, "to")):
        return obj
    try:
        if device is not None or dtype is not None:
            return obj.to(device=device, dtype=dtype)
        return obj
    except TypeError:
        # Fallbacks for custom signatures
        if device is not None:
            try:
                return obj.to(device)
            except Exception:
                pass
        if dtype is not None:
            try:
                return obj.to(dtype)
            except Exception:
                pass
        return obj

setup_mixed_precision

setup_mixed_precision(use_mixed_precision: bool) -> None

Configures mixed precision settings.

Source code in torchebm/core/device_mixin.py
def setup_mixed_precision(self, use_mixed_precision: bool) -> None:
    """Configures mixed precision settings."""
    self.use_mixed_precision = bool(use_mixed_precision)
    if self.use_mixed_precision:
        try:
            # Import lazily to avoid hard dependency when not used
            from torch.cuda.amp import autocast as _autocast  # noqa: F401
            self.autocast_available = True
            if not self.device.type.startswith("cuda"):
                warnings.warn(
                    f"Mixed precision requested but device is {self.device}. Mixed precision requires CUDA. Falling back to full precision.",
                    UserWarning,
                )
                self.use_mixed_precision = False
                self.autocast_available = False
        except ImportError:
            warnings.warn(
                "Mixed precision requested but torch.cuda.amp not available. Falling back to full precision. Requires PyTorch 1.6+.",
                UserWarning,
            )
            self.use_mixed_precision = False
            self.autocast_available = False
    else:
        self.autocast_available = False

autocast_context

autocast_context()

Returns a torch.cuda.amp.autocast context manager if mixed precision is enabled, otherwise a nullcontext.

Source code in torchebm/core/device_mixin.py
def autocast_context(self):
    """
    Returns a `torch.cuda.amp.autocast` context manager if mixed precision is enabled,
    otherwise a `nullcontext`.
    """
    if getattr(self, "use_mixed_precision", False) and getattr(self, "autocast_available", False):
        from torch.cuda.amp import autocast
        return autocast()
    return nullcontext()