Source code for tsaug._augmenter.base

from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, overload

import numpy as np

from . import _default_seed


class _Augmenter(ABC):
    def __init__(
        self,
        repeats: int = 1,
        prob: float = 1.0,
        seed: Optional[int] = _default_seed,
    ) -> None:
        self.repeats = repeats
        self.prob = prob
        self.seed = seed

    @classmethod
    @abstractmethod
    def _get_param_name(cls) -> Tuple[str, ...]:
        return tuple()

    def _get_params(self) -> Dict[str, Any]:
        return {
            param_name: getattr(self, param_name)
            for param_name in self._get_param_name()
        }

    @property
    def repeats(self) -> int:
        return self._repeat

    @repeats.setter
    def repeats(self, M: int) -> None:
        if not isinstance(M, int):
            raise TypeError("Parameter `repeats` must be a positive integer.")
        if M <= 0:
            raise ValueError("Parameter `repeats` must be a positive integer.")
        self._repeat = M

    @property
    def prob(self) -> float:
        return self._prob

    @prob.setter
    def prob(self, p: float) -> None:
        if not isinstance(p, (int, float)):
            raise TypeError(
                "Parameter `prob` must be a positive number between 0 and 1."
            )
        if p > 1.0 or p <= 0.0:
            raise TypeError(
                "Parameter `prob` must be a positive number between 0 and 1."
            )
        self._prob = p

    @property
    def seed(self) -> Optional[int]:
        return self._seed

    @seed.setter
    def seed(self, s: Optional[int]) -> None:
        np.random.RandomState(s)  # try setting up seed
        self._seed = s

    def _augmented_series_length(self, T: int) -> int:
        """
        Return the length (2nd dimension) of augmented series.

        For most augmenters, the length of series will not be changed. If an
        augmenter does change the length of series, this method should be
        overwritten in the augmenter subclass.

        """
        return T

    @overload
    def augment(self, X: np.ndarray) -> np.ndarray:
        ...

    @overload
    def augment(self, X: np.ndarray, Y: None) -> np.ndarray:
        ...

    @overload
    def augment(
        self, X: np.ndarray, Y: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        ...

    def augment(
        self, X: np.ndarray, Y: Optional[np.ndarray] = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """
        Augment time series.

        Parameters
        ----------
        X : numpy array
            Time series to be augmented. It must be a numpy array with shape
            (T,), (N, T), or (N, T, C), where T is the length of a series, N is
            the number of series, and C is the number of a channels in a series.

        Y: numpy array, optional
            Segmentation mask of the original time series. It must be a binary
            numpy array with shape (T,), (N, T), or (N, T, L), where T is the
            length of a series, N is the number of series, and L is the number
            of a segmentation classes. Default: None.

        Returns
        -------
        2-tuple (numpy array, numpy array), or numpy array
            The augmented time series, and (optionally) the augmented
            segmentation mask.

        """
        X_ERROR_MSG = (
            "Input X must be a numpy array with shape (T,), (N, T), or (N, T, "
            "C), where T is the length of a series, N is the number of series, "
            "and C is the number of a channels in a series."
        )

        Y_ERROR_MSG = (
            "Input Y must be a numpy array with shape (T,), (N, T), or (N, T, "
            "L), where T is the length of a series, N is the number of series, "
            "and L is the number of a segmentation classes."
        )

        if not isinstance(X, np.ndarray):
            raise TypeError(X_ERROR_MSG)
        ndim_x = X.ndim
        if ndim_x == 1:  # (T, )
            X = X.reshape(1, -1, 1)
        elif ndim_x == 2:  # (N, T)
            X = np.expand_dims(X, 2)
        elif ndim_x == 3:  # (N, T, C)
            pass
        else:
            raise ValueError(X_ERROR_MSG)

        if Y is not None:
            if not isinstance(Y, np.ndarray):
                raise TypeError(Y_ERROR_MSG)
            ndim_y = Y.ndim
            if ndim_y == 1:  # (T, )
                Y = Y.reshape(1, -1, 1)
            elif ndim_y == 2:  # (N, T)
                Y = np.expand_dims(Y, 2)
            elif ndim_y == 3:  # (N, T, L)
                pass
            else:
                raise ValueError(Y_ERROR_MSG)

        N, T, _ = X.shape

        if Y is not None:
            Ny, Ty, _ = Y.shape
            # check consistency between X and Y
            if N != Ny:
                raise ValueError(
                    "The numbers of series in X and Y are different."
                )
            if T != Ty:
                raise ValueError(
                    "The length of series in X and Y are different."
                )

        # if this augmenter changes series length AND multiple outputs are
        # expected, it must has prob equal to 1, otherwise the outputs may have
        # different length
        if (
            self._augmented_series_length(T) != T
            and ((self.repeats > 1) or (N > 1))
            and (self.prob != 1.0)
        ):
            raise RuntimeError(
                "This augmenter changes series length. "
                "When augmenting multiple series or multiple times, parameter "
                "`prob` must be 1.0, otherwise the output series may have "
                "different length."
            )

        # augment
        X_aug, Y_aug = self._augment(X, Y)

        if ndim_x == 1:
            if self.repeats == 1:
                X_aug = X_aug.flatten()
            else:
                X_aug = X_aug.reshape(self.repeats, -1)
        elif ndim_x == 2:
            X_aug = X_aug.reshape(N * self.repeats, -1)

        if Y_aug is not None:
            if ndim_y == 1:
                if self.repeats == 1:
                    Y_aug = Y_aug.flatten()
                else:
                    Y_aug = Y_aug.reshape(self.repeats, -1)
            elif ndim_y == 2:
                Y_aug = Y_aug.reshape(N * self.repeats, -1)

        if Y_aug is None:
            return X_aug
        else:
            return X_aug, Y_aug

    def _augment(
        self, X: np.ndarray, Y: Optional[np.ndarray]
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """
        The main part of augmentation, without pre- and post-processing.

        This method calls _augment_core which is the core algorithmic part. The
        process in this method handles `repeats` and `prob`.

        1. If `repeats` > 1, we first concatenate `repeats` copies of input
           into a 'super' input..
        2. Select series from the (super) input to be augmented.
        3. Apply _augment_core to the selected series.

        The problem of this strategy includes:

        1. The memory burden may be unnecessarily high for 'long-to-short'
           augmentation like cropping (say if cropping a 100-window from a
           100M-series, this strategy copies 100M-series `repeats` times).
        2. Some time-consuming computation may be duplicated, for example
           quantization with kmeans. Each series should only train a model once
           instead of `repeats` times.

        In those cases, the subclass of the augmenter should overwrite this
        method.

        """
        rand = np.random.RandomState(self.seed)
        N = X.shape[0]
        ind = (
            rand.uniform(size=self.repeats * N) <= self.prob
        )  # indice of series to be augmented
        if Y is None:
            if self.repeats > 1:
                X_aug = np.repeat(X.copy(), self.repeats, axis=0)
            else:
                X_aug = X.copy()
            Y_aug = None
            if ind.any():
                X_aug[ind, :], Y_aug = self._augment_core(X_aug[ind, :], None)
        else:
            if self.repeats > 1:
                X_aug = np.repeat(X.copy(), self.repeats, axis=0)
                Y_aug = np.repeat(Y.copy(), self.repeats, axis=0)
            else:
                X_aug = X.copy()
                Y_aug = Y.copy()
            if ind.any():
                X_aug[ind, :], Y_aug[ind, :] = self._augment_core(
                    X_aug[ind, :], Y_aug[ind, :]
                )
        return X_aug, Y_aug

    @abstractmethod
    def _augment_core(
        self, X: np.ndarray, Y: Optional[np.ndarray]
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """
        The core of augmentation.
        """
        pass

    def _copy(self) -> "_Augmenter":
        "Return a copy of this augmenter."
        return deepcopy(self)

    def __mul__(self, m: int) -> "_Augmenter":
        """
        Operator * creates an augmenter that is equivalent to running this
        augmenter for m times independently.

        Parameters
        ----------
        m : int
            The returned augmenter is equivalent to running this augmenter for
            m times independently.

        Returns
        -------
        augmenter
            An augmenter that is equivalent to running this augmenter for m
            times independently.

        """
        copy = self._copy()
        copy.repeats = copy.repeats * m
        return copy

    def __matmul__(self, p: float) -> "_Augmenter":
        """
        Operator @ creates an augmenter that is equivalent to running this
        augmenter with probability p.

        Parameters
        ----------
        p : float
            The returned augmenter is equivalent to running this augmenter with
            probability p.

        Returns
        -------
        augmenter
            An augmenter that is equivalent to running this augmenter with
            probability p.

        """
        copy = self._copy()
        copy.prob = copy.prob * p
        return copy

    def __add__(
        self, a: Union["_Augmenter", "_AugmenterPipe"]
    ) -> "_AugmenterPipe":
        """
        Operator + connects this augmenter with another augmenter or an
        augmenter pipe to form a (new) augmenter pipe.

        Parameters
        ----------
        a : augmenter or augmenter pipe
            The augmenter or augmenter pipe to be connected with this augmenter.

        Returns
        -------
        augmenter pipe
            The output augmenter pipe.

        """
        if isinstance(a, _Augmenter):
            return _AugmenterPipe([self._copy(), a._copy()])
        elif isinstance(a, _AugmenterPipe):
            return _AugmenterPipe(
                [self._copy()] + [augmenter._copy() for augmenter in a]
            )
        else:
            raise TypeError(
                "An augmenter can only be connected by another augmenter or an "
                "augmenter pipeline."
            )

    def __len__(self) -> int:
        return 1


class _AugmenterPipe:
    def __init__(self, pipe: List[_Augmenter]):
        self._pipe = pipe

    def __getitem__(self, ind: int) -> _Augmenter:
        if isinstance(self._pipe.__getitem__(ind), _Augmenter):
            return self._pipe.__getitem__(ind)
        else:
            raise NotImplementedError(
                "Getting multiple augmenters in an augmenter pipe is not "
                "supported yet."
            )

    def __setitem__(self, ind: int, value: _Augmenter) -> None:
        if isinstance(self._pipe.__getitem__(ind), _Augmenter) and isinstance(
            value, _Augmenter
        ):
            self._pipe.__setitem__(ind, value)
        else:
            raise NotImplementedError(
                "Setting multiple augmenters in an augmenter pipe is not "
                "supported yet."
            )

    def __iter__(self) -> Iterator[_Augmenter]:
        return self._pipe.__iter__()

    def __len__(self) -> int:
        return len(self._pipe)

    def summary(self, show_params: bool = False) -> None:
        """
        Print summary of this augmenter pipe.

        Parameters
        ----------
        show_params : bool, optional
            Whether show parameters of each augmenter in the summary table. If
            True, the table may be too wide to be readable. Default: False.

        """

        print(
            "{ind}\t{name}\t{repeats}\t{prob}\t{param}".format(
                ind="idx",
                name="augmenter",
                repeats="repeats".rjust(8),
                prob="prob".rjust(5),
                param="params" if show_params else "",
            )
        )
        print("=" * (120 if show_params else 45))
        for i, augmenter in enumerate(self):
            print(
                "{ind:3.0g}\t{name}\t{repeats:8.3g}\t{prob:5.3g}\t{param}".format(
                    ind=i,
                    name=augmenter.__class__.__name__.ljust(8),
                    repeats=augmenter.repeats,
                    prob=augmenter.prob,
                    param=augmenter._get_params() if show_params else "",
                )
            )

    @overload
    def augment(self, X: np.ndarray) -> np.ndarray:
        ...

    @overload
    def augment(self, X: np.ndarray, Y: None) -> np.ndarray:
        ...

    @overload
    def augment(
        self, X: np.ndarray, Y: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        ...

    def augment(
        self, X: np.ndarray, Y: Optional[np.ndarray] = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """
        Augment time series.

        Parameters
        ----------
        X : numpy array
            Time series to be augmented. It must be a numpy array with shape
            (T,), (N, T), or (N, T, C), where T is the length of a series, N is
            the number of series, and C is the number of a channels in a series.

        Y: numpy array, optional
            Segmentation mask of the original time series. It must be a binary
            numpy array with shape (T,), (N, T), or (N, T, L), where T is the
            length of a series, N is the number of series, and L is the number
            of a segmentation classes. Default: None.

        Returns
        -------
        2-tuple (numpy array, numpy array), or numpy array
            The augmented time series, and (optionally) the augmented
            segmentation mask.

        """
        X_aug = X
        Y_aug = Y
        for augmenter in self._pipe:
            if Y_aug is None:
                X_aug = augmenter.augment(X_aug)
            else:
                X_aug, Y_aug = augmenter.augment(X_aug, Y_aug)
        if Y_aug is None:
            return X_aug
        else:
            return X_aug, Y_aug

    def __add__(
        self, a: Union["_Augmenter", "_AugmenterPipe"]
    ) -> "_AugmenterPipe":
        """
        Operator + connects this augmenter pipe with another augmenter or an
        augmenter pipe to form a new augmenter pipe.

        Parameters
        ----------
        a : augmenter or augmenter pipe
            The augmenter or augmenter pipe to be connected with this augmenter
            pipe.

        Returns
        -------
        augmenter pipe
            The output augmenter pipe.

        """
        if isinstance(a, _Augmenter):
            return _AugmenterPipe(
                [augmenter._copy() for augmenter in self] + [a._copy()]
            )
        elif isinstance(a, _AugmenterPipe):
            return _AugmenterPipe(
                [augmenter._copy() for augmenter in self]
                + [augmenter._copy() for augmenter in a]
            )
        else:
            raise TypeError(
                "An augmenter can only be connected by another augmenter or an "
                "augmenter pipeline."
            )