Source code for subcell_pipeline.analysis.dimensionality_reduction.fiber_data

"""Methods for fiber data merging and alignment."""

from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from io_collection.keys.check_key import check_key
from io_collection.load.load_dataframe import load_dataframe
from io_collection.save.save_dataframe import save_dataframe
from io_collection.save.save_figure import save_figure
from io_collection.save.save_json import save_json


[docs] def get_merged_data( bucket: str, series_name: str, condition_keys: list[str], random_seeds: list[int], align: bool = True, ) -> pd.DataFrame: """ Load or create merged data for given conditions and random seeds. If merged data (aligned or unaligned) already exists, load the data. Otherwise, iterate through the conditions and seeds to merge the data. Parameters ---------- bucket Name of S3 bucket for input and output files. series_name Name of simulation series. condition_keys List of condition keys. random_seeds Random seeds for simulations. align True if data should be aligned, False otherwise. Returns ------- : Merged data. """ align_key = "all_samples_aligned" if align else "all_samples_unaligned" data_key = f"{series_name}/analysis/{series_name}_{align_key}.csv" # Return data, if merged data already exists. if check_key(bucket, data_key): print( f"Dataframe [ { data_key } ] already exists. Loading existing merged data." ) return load_dataframe(bucket, data_key, dtype={"key": "str"}) all_samples: list[pd.DataFrame] = [] for condition_key in condition_keys: series_key = f"{series_name}_{condition_key}" if condition_key else series_name for seed in random_seeds: print(f"Loading samples for [ {condition_key} ] seed [ {seed} ]") sample_key = f"{series_name}/samples/{series_key}_{seed:06d}.csv" samples = load_dataframe(bucket, sample_key) samples["seed"] = seed samples["key"] = condition_key if align: align_fibers(samples) all_samples.append(samples) samples_dataframe = pd.concat(all_samples) save_dataframe(bucket, data_key, samples_dataframe, index=False) return samples_dataframe
[docs] def align_fibers(data: pd.DataFrame) -> None: """ Align fibers for each time point in the data. Parameters ---------- data Simulated fiber data. """ aligned_fibers = [] for time, group in data.groupby("time", sort=False): coords = group[["xpos", "ypos", "zpos"]].values if time == 0: fiber = coords else: fiber, _ = align_fiber(coords) aligned_fibers.append(fiber) all_aligned_fibers = np.vstack(aligned_fibers) data["xpos"] = all_aligned_fibers[:, 0] data["ypos"] = all_aligned_fibers[:, 1] data["zpos"] = all_aligned_fibers[:, 2]
[docs] def align_fiber(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Align an array of x, y, z coordinates along the positive y axis. The function identifies the furthest point in the yz-plane and computes the angle needed to rotate this point to lie on the positive y axis. This rotation angle is applied to all y and z coordinates; x coordinates are not changed. For example, if the furthest point is (0.5, 0, 1), it is rotated to (0.5, 1, 0) with an angle of pi / 2. Parameters ---------- coords Array of x, y, and z positions. """ # Identify rotation angle based on distance to point furthest from (0,0) distances = np.sqrt(np.sum(coords[:, 1:] ** 2, axis=1)) max_index = np.argmax(distances) angle = np.arctan2(coords[max_index, 2], coords[max_index, 1]) # Create rotation matrix c, s = np.cos(angle), np.sin(angle) rot = np.array(((c, -s), (s, c))) # Rotate y and z rotated = np.dot(coords[:, 1:], rot) return np.concatenate((coords[:, 0:1], rotated), axis=1), rot
[docs] def reshape_fibers(data: pd.DataFrame) -> tuple[np.ndarray, pd.DataFrame]: """ Reshape data from tidy data format to array of fibers and fiber features. Parameters ---------- data Simulated fiber data. Returns ------- : Array of fibers and dataframe of fiber features. """ all_features = [] all_fibers = [] for (time, velocity, repeat, simulator), group in data.groupby( ["time", "velocity", "repeat", "simulator"] ): fiber = group[["xpos", "ypos", "zpos"]].values.reshape(-1, 1) all_fibers.append(fiber) all_features.append( { "TIME": time, "VELOCITY": velocity, "REPEAT": repeat, "SIMULATOR": simulator.upper(), } ) return np.array(all_fibers).squeeze(), pd.DataFrame(all_features)
[docs] def save_aligned_fibers( data: pd.DataFrame, time_map: dict, save_location: str, save_key: str ) -> None: """ Save aligned fiber data. Parameters ---------- data Aligned fiber data. time_map Map of selected aligned time for each simulator and condition. save_location Location for output file (local path or S3 bucket). save_key Name key for output file. """ output = [] for (simulator, repeat, key, time), group in data.groupby( ["simulator", "repeat", "key", "time"] ): if time != time_map[(simulator, key)]: continue fiber = group[["xpos", "ypos", "zpos"]].values output.append( { "simulator": simulator.upper(), "repeat": int(repeat), "key": key, "x": fiber[:, 0].tolist(), "y": fiber[:, 1].tolist(), "z": fiber[:, 2].tolist(), } ) save_json(save_location, save_key, output)
[docs] def plot_fibers_by_key_and_seed( data: pd.DataFrame, save_location: Optional[str] = None, save_key: str = "aligned_fibers_by_key_and_seed.png", ) -> None: """ Plot simulated fiber data for each condition key and random seed. Parameters ---------- data Simulated fiber data. save_location Location for output file (local path or S3 bucket). save_key Name key for output file. """ rows = data["key"].unique() cols = data["seed"].unique() figure, ax = plt.subplots( len(rows), len(cols), figsize=(10, 6), sharey=True, sharex=True ) for row_index, row in enumerate(rows): for col_index, col in enumerate(cols): if row_index == 0: ax[row_index, col_index].set_title(f"REPEAT = {col}") if col_index == 0: ax[row_index, col_index].set_ylabel(f"KEY = {row}") subset = data[(data["key"] == row) & (data["seed"] == col)] for (_, simulator), group in subset.groupby(["time", "simulator"]): color = "red" if simulator == "readdy" else "blue" coords = group[["xpos", "ypos", "zpos"]].values ax[row_index, col_index].plot( coords[:, 1], coords[:, 2], lw=0.5, color=color, alpha=0.5 ) plt.tight_layout() plt.show() if save_location is not None: save_figure(save_location, save_key, figure)