from typing import List, Optional, Tuple, Union
import numpy as np
from scipy.ndimage.filters import convolve1d
from scipy.signal import get_window
from .base import _Augmenter, _default_seed
[docs]class Convolve(_Augmenter):
"""
Convolve time series with a kernel window.
Parameters
----------
window : str, tuple, or list, optional
The type of kernal window used for the convolution.
- If str or tuple, it is a window type that can be passed to
`scipy.signal.get_window`. See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html
for more details.
- If list, it is a list of such object. The type of a kernel window
convolved with a time series is randomly sampled from this list.
Default: "hann".
size : int, list, tuple, optional
Length of kernel windows.
- If int, all series are convolved with windows of the same length.
- If list, each series is convolved with a window with a size sampled
from the list randomly.
- If 2-tuple, each series is convolved with a window with a size sampled
from the interval randomly.
Default: 7.
per_channel : bool, optional
Whether to sample a kernel window for each channel in a time series or
to use the same window for all channels in a time series. Only used if
the kernel window is not deterministic. Default: False.
repeats : int, optional
The number of times a series is augmented. If greater than one, a series
will be augmented so many times independently. This parameter can also
be set by operator `*`. Default: 1.
prob : float, optional
The probability of a series is augmented. It must be in (0.0, 1.0]. This
parameter can also be set by operator `@`. Default: 1.0.
seed : int, optional
The random seed. Default: None.
"""
def __init__(
self,
window: Union[str, Tuple, List[Union[str, Tuple]]] = "hann",
size: Union[int, Tuple[int, int], List[int]] = 7,
per_channel: bool = False,
repeats: int = 1,
prob: float = 1.0,
seed: Optional[int] = _default_seed,
):
self.window = window
self.size = size
self.per_channel = per_channel
super().__init__(repeats=repeats, prob=prob, seed=seed)
@classmethod
def _get_param_name(cls) -> Tuple[str, ...]:
return ("window", "size", "per_channel")
@property
def window(self) -> Union[str, Tuple, List[Union[str, Tuple]]]:
return self._window
@window.setter
def window(self, w: Union[str, Tuple, List[Union[str, Tuple]]]) -> None:
WINDOW_ERROR_MSG = (
"Parameter `window` must be a str or a tuple that can pass to "
"`scipy.signal.get_window`, or a list of such objects. See "
"https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html "
"for more details."
)
if not isinstance(w, list):
try:
get_window(w, 7)
except TypeError:
raise TypeError(WINDOW_ERROR_MSG)
except ValueError:
raise ValueError(WINDOW_ERROR_MSG)
except:
raise RuntimeError(WINDOW_ERROR_MSG)
else:
for ww in w:
try:
get_window(ww, 7)
except TypeError:
raise TypeError(WINDOW_ERROR_MSG)
except ValueError:
raise ValueError(WINDOW_ERROR_MSG)
except:
raise RuntimeError(WINDOW_ERROR_MSG)
self._window = w
@property
def size(self) -> Union[int, Tuple[int, int], List[int]]:
return self._size
@size.setter
def size(self, n: Union[int, Tuple[int, int], List[int]]) -> None:
SIZE_ERROR_MSG = (
"Parameter `size` must be a positive integer, "
"a 2-tuple of positive integers representing an interval, "
"or a list of positive integers."
)
if not isinstance(n, int):
if isinstance(n, list):
if len(n) == 0:
raise ValueError(SIZE_ERROR_MSG)
if not all([isinstance(nn, int) for nn in n]):
raise TypeError(SIZE_ERROR_MSG)
if not all([nn > 0 for nn in n]):
raise ValueError(SIZE_ERROR_MSG)
elif isinstance(n, tuple):
if len(n) != 2:
raise ValueError(SIZE_ERROR_MSG)
if (not isinstance(n[0], int)) or (not isinstance(n[1], int)):
raise TypeError(SIZE_ERROR_MSG)
if n[0] >= n[1]:
raise ValueError(SIZE_ERROR_MSG)
if (n[0] <= 0) or (n[1] <= 0):
raise ValueError(SIZE_ERROR_MSG)
else:
raise TypeError(SIZE_ERROR_MSG)
elif n <= 0:
raise ValueError(SIZE_ERROR_MSG)
self._size = n
@property
def per_channel(self) -> bool:
return self._per_channel
@per_channel.setter
def per_channel(self, p: bool) -> None:
if not isinstance(p, bool):
raise TypeError("Paremeter `per_channel` must be boolean.")
self._per_channel = p
def _augment_core(
self, X: np.ndarray, Y: Optional[np.ndarray]
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
N, T, C = X.shape
rand = np.random.RandomState(self.seed)
if isinstance(self.window, (str, tuple)):
window_type = [self.window for _ in range(N * C)]
else:
if self.per_channel:
window_type = [
self.window[i]
for i in rand.choice(len(self.window), N * C)
]
else:
window_type = [
self.window[i]
for i in rand.choice(len(self.window), N)
for _ in range(C)
]
if isinstance(self.size, int):
window_size = np.array([self.size for _ in range(N * C)])
elif isinstance(self.size, tuple):
if self.per_channel:
window_size = rand.choice(
range(self.size[0], self.size[1]), N * C
)
else:
window_size = rand.choice(range(self.size[0], self.size[1]), N)
window_size = np.repeat(window_size, C)
else:
if self.per_channel:
window_size = rand.choice(self.size, N * C)
else:
window_size = rand.choice(self.size, N)
window_size = np.repeat(window_size, C)
window_size = window_size.astype(int)
X_aug = X.copy()
X_aug = X_aug.swapaxes(1, 2).reshape(N * C, T)
for ws in np.unique(window_size):
for wt in set(window_type):
window = get_window(window=wt, Nx=ws, fftbins=False)
X_aug[
(window_size == ws) & [w == wt for w in window_type], :
] = (
convolve1d(
X_aug[
(window_size == ws)
& [w == wt for w in window_type],
:,
],
window,
axis=1,
)
/ window.sum()
)
X_aug = X_aug.reshape(N, C, T).swapaxes(1, 2)
if Y is not None:
Y_aug = Y.copy()
else:
Y_aug = None
return X_aug, Y_aug