Skip to content

Dataset Class

Import Dataset from the dataset module:

from radburst.utils.dataset import Dataset

radburst.utils.dataset

Dataset

Dataset(data_dir, labels, preprocess=None, binary=True)

Bases: Dataset

Dataset class to manage loading, storing and processing data.

Intialize the dataset.

PARAMETER DESCRIPTION
data_dir

The root directory containing the FITS data files.

TYPE: str

labels

Path to csv file containing labels (paths and burst types) or labels dataframe

TYPE: str or DataFrame

preprocess

Function that takes a spectrogram array and returns a preprocessed array. Defaults to None.

TYPE: callable DEFAULT: None

binary

True for binary labels: 0 and 1 for no burst and burst False for type labels: burst number for burst, 0 for no burst

TYPE: bool DEFAULT: True

ATTRIBUTE DESCRIPTION
data_dir

The directory path for the dataset.

TYPE: str

data

List that stores the loaded data arrays from FITS files.

TYPE: list

Source code in radburst/utils/dataset.py
def __init__(self, data_dir, labels, preprocess=None, binary=True):
    """Intialize the dataset.

    Args:
        data_dir (str): The root directory containing the FITS data files.
        labels (str or pd.DataFrame): Path to csv file containing labels (paths and burst types) or labels dataframe
        preprocess (callable, optional): Function that takes a spectrogram array and returns a preprocessed array.
                                         Defaults to None.
        binary (bool, optional): True for binary labels: 0 and 1 for no burst and burst
                                 False for type labels: burst number for burst, 0 for no burst

    Attributes:
        data_dir (str): The directory path for the dataset.
        data (list): List that stores the loaded data arrays from FITS files.
    """
    self.data_dir = data_dir
    self.binary = binary
    self.preprocess = preprocess

    # Load labels data
    if isinstance(labels, str):
        self.labels_df = pd.read_csv(labels)
    elif isinstance(labels, pd.DataFrame):
        self.labels_df = labels
    else:
        raise TypeError('labels must be a str path or a pd.DataFrame')

    self.paths = self.labels_df['path']

data_dir instance-attribute

data_dir = data_dir

binary instance-attribute

binary = binary

preprocess instance-attribute

preprocess = preprocess

labels_df instance-attribute

labels_df = read_csv(labels)

paths instance-attribute

paths = labels_df['path']

__getitem__

__getitem__(idx)
Source code in radburst/utils/dataset.py
def __getitem__(self, idx):

    # Load file
    file_path = os.path.join(self.data_dir, self.labels_df['path'].iloc[idx])
    spectrogram_arr = utils.load_fits_file(file_path)

    # Get label for file
    if self.binary:
        label = self.labels_df['burst'].iloc[idx]
    else:
        label = self.labels_df['type'].iloc[idx]

    # Preprocss
    if self.preprocess:
        spectrogram_arr = self.preprocess(spectrogram_arr)

    # Convert to tensorsso 
    spect = np.expand_dims(spectrogram_arr, axis=0)
    spect_tensor = torch.FloatTensor(spect)
    label_tensor = torch.FloatTensor([label])

    return spect_tensor, label_tensor

__len__

__len__()
Source code in radburst/utils/dataset.py
def __len__(self):
    return len(self.labels_df)

get_filtered_dataset

get_filtered_dataset(condition)
Source code in radburst/utils/dataset.py
def get_filtered_dataset(self, condition):
    new_labels = self.labels_df.query(condition).reset_index(drop=True)
    new_dataset = Dataset(data_dir=self.data_dir,
                          labels=new_labels,
                          preprocess=self.preprocess,
                          binary=self.preprocess)
    return new_dataset

only_bursts

only_bursts()
Source code in radburst/utils/dataset.py
def only_bursts(self):
    return self.get_filtered_dataset(condition='burst == 1')

only_nonbursts

only_nonbursts()
Source code in radburst/utils/dataset.py
def only_nonbursts(self):
    return self.get_filtered_dataset(condition='burst == 0')

Resize

Resize(new_size)
Source code in radburst/utils/dataset.py
def __init__(self, new_size):
    self.new_size = new_size

new_size instance-attribute

new_size = new_size

__call__

__call__(array)
Source code in radburst/utils/dataset.py
def __call__(self, array):
    return skimage.transform.resize(array, self.new_size)

MinMaxNormalize

MinMaxNormalize(eps=1e-08)
Source code in radburst/utils/dataset.py
def __init__(self, eps=1e-8):
    self.eps = eps

eps instance-attribute

eps = eps

__call__

__call__(array)

Add small epsilon to prevent division by zero

Source code in radburst/utils/dataset.py
def __call__(self, array):
    """Add small epsilon to prevent division by zero"""
    min_val = np.min(array)
    max_val = np.max(array)
    return (array - min_val) / (max_val - min_val + self.eps)