Source code for braindynamics_starprotocol.trialbraindata

import numpy as np
import os
import json

# CONSTANTS
SAMPLE_DTYPE = np.float32 # data type of samples
EVENT_DTYPE = np.int32 # data type of event timestamps & markers/codes

[docs] class TrialBrainData: """Abstract class for representing brain actvity recordings from multiple experimental trials. Attributes: trial_nr (int): number of trials. chan_nr (int): number of channels. trials_len (int or list[int]): length of trials. samp_freq (float): sampling frequency in Hz. samp_mat_list (list[np.ndarray]): list of 2D sample matrices (of shape chan_nr, trials_len). info_dict (dict): dictionary for storing meta data (e.g. channel labels). """ def __init__(self): self.trial_nr = 0 # number of trials self.chan_nr = 0 # number of channels (variables, e.g. electrodes, ROIs) self.trials_len = None # length of the trials (in samples) self.samp_freq = 0.0 # sampling frequency (in Hz!!) self.samp_mat_list = [] # list of trial sample matris self.info_dict = dict() # dictionary of other infos (metadata) def __str__(self)->str: return f"Brain activity recording from {self.trial_nr} trials of length {self.trials_len}; {self.chan_nr} channels and a {self.samp_freq}Hz sampling frequency."
[docs] def load(self, samp_mat_list:list[np.ndarray], samp_freq:float, info_dict:dict=None)->None: """Function that loads object from input parameters. Args: samp_mat_list (list[np.ndarray]): list of 2D sample matrices samp_freq (float): sampling frequency in Hz info_dict (dict, optional): dictionary for storing meta data. Defaults to None. Raises: ValueError: uneven number of channels in different trials ValueError: negative sampling frequency """ self.trial_nr = len(samp_mat_list) self.chan_nr = samp_mat_list[0].shape[0] self.trials_len = [samp_mat_list[0].shape[1]] for t in range(1, self.trial_nr): chan_nr, trial_len = samp_mat_list[t].shape if chan_nr != self.chan_nr: raise ValueError(f"Number of channels across the trials must be constant. Expected {self.chan_nr} number of channels for trial {t}, got {chan_nr} instead.") self.trials_len.append(trial_len) if len(np.unique(self.trials_len)) == 1: trial_len = self.trials_len[0] self.trials_len = trial_len self.samp_mat_list = samp_mat_list if samp_freq <= 0.0: raise ValueError(f"Sampling frequency must be > 0! (value provided {samp_freq})") self.samp_freq = samp_freq if info_dict is not None: self.info_dict = info_dict
[docs] def load_from_files(self, info_json_path:str, samp_bin_path:str)->None: """Load brain activity recording data from two files. Args: info_json_path (str): path to json file containing information about the recording (i.e. number of channels, trials, trial lengths, channel names, etc.) samp_bin_path (str): path to binary file storing 3D array of shape (number of trials, number of channels, trial length) of 32-bit floating point values """ with open(info_json_path, "r") as file: info_dict = json.load(file) self.trial_nr = info_dict["trial_nr"] self.chan_nr = info_dict["chan_nr"] self.trials_len = info_dict["trials_len"] self.samp_freq = info_dict["samp_freq_Hz"] self.info_dict["chan_name_list"] = info_dict["chan_name_list"] if type(self.trials_len) == list: # case of uneven trial lengths if len(self.trials_len) != self.trial_nr: raise ValueError("") max_trials_len = np.max(np.array(self.trials_len)) trials_samp_mat = np.fromfile(samp_bin_path, dtype=SAMPLE_DTYPE).reshape(self.trial_nr, self.chan_nr, max_trials_len) for t in range(self.trial_nr): trial_len = self.trials_len[t] self.samp_mat_list.append(trials_samp_mat[t][:, :trial_len]) # remove zero padding elif type(self.trials_len) == int: # all trials have the same length trials_samp_mat = np.fromfile(samp_bin_path, dtype=SAMPLE_DTYPE).reshape(self.trial_nr, self.chan_nr, self.trials_len) for t in range(self.trial_nr): self.samp_mat_list.append(trials_samp_mat[t])
[docs] def clear(self): """Function that clears all data. """ self.trial_nr = 0 self.chan_nr = 0 self.trial_len = 0 self.samp_freq = 0.0 self.samp_mat_list = np.empty((self.trial_nr, self.chan_nr, self.trial_len), dtype=SAMPLE_DTYPE) self.info_dict.clear()