Source code for subcell_pipeline.visualization.dimensionality_reduction

"""Visualization methods for dimensionality reduction analysis."""

import os

import matplotlib.pyplot as plt
import numpy as np
from io_collection.load.load_buffer import load_buffer
from io_collection.load.load_dataframe import load_dataframe
from io_collection.load.load_pickle import load_pickle
from io_collection.save.save_buffer import save_buffer
from matplotlib.colors import Colormap
from simulariumio import DISPLAY_TYPE, CameraData, DisplayData, MetaData, UnitData
from sklearn.decomposition import PCA

from subcell_pipeline.visualization.fiber_points import (
    generate_trajectory_converter_for_fiber_points,
)

BOX_SIZE: np.ndarray = np.array(3 * [600.0])
"""Bounding box size for dimensionality reduction trajectory."""


def _rgb_to_hex_color(color: tuple[float, float, float]) -> str:
    """
    Convert RGB color to hexadecimal format.

    Parameters
    ----------
    color
        Red, green, and blue colors (between 0.0 and 1.0).

    Returns
    -------
    :
        Color in hexadecimal format.
    """
    rgb = (int(255 * color[0]), int(255 * color[1]), int(255 * color[2]))
    return f"#{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}"


def _pca_fiber_points_over_time(
    samples: list[np.ndarray],
    pca: PCA,
    pc_ix: int,
    simulator_name: str = "Combined",
    color: str = "#eaeaea",
) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]:
    """
    Get fiber_points for samples of the PC distributions in order to visualize
    the samples over time.
    """
    if simulator_name == "Combined":
        simulator_name = ""
    if simulator_name:
        simulator_name += "#"
    fiber_points: list[np.ndarray] = []
    display_data: dict[str, DisplayData] = {}
    for sample_ix in range(len(samples[0])):
        if pc_ix < 1:
            data = [samples[0][sample_ix], 0]
        else:
            data = [0, samples[1][sample_ix]]
        fiber_points.append(pca.inverse_transform(data).reshape(-1, 3))
    fiber_points_arr: np.ndarray = np.array(fiber_points)
    type_name: str = f"{simulator_name}PC{pc_ix + 1}"
    display_data[type_name] = DisplayData(
        name=type_name,
        display_type=DISPLAY_TYPE.FIBER,
        color=color,
    )
    return [fiber_points_arr], [type_name], display_data


def _pca_fiber_points_one_timestep(
    samples: list[np.ndarray],
    pca: PCA,
    color_maps: dict[str, Colormap],
    pc_ix: int,
    simulator_name: str = "Combined",
) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]:
    """
    Get fiber_points for samples of the PC distributions in order to visualize
    the samples together in one timestep.
    """
    color_map = color_maps[simulator_name]
    if simulator_name == "Combined":
        simulator_name = ""
    if simulator_name:
        simulator_name += "_"

    fiber_points = []
    type_names = []
    display_data = {}
    for sample_ix in range(len(samples[0])):
        data = [
            [samples[0][sample_ix], 0],
            [0, samples[1][sample_ix]],
        ]
        fiber_points.append(pca.inverse_transform(data[pc_ix]).reshape(1, -1, 3))
        sample = samples[pc_ix][sample_ix]
        sample_name = str(round(sample))
        type_name = f"{simulator_name}PC{pc_ix + 1}#{sample_name}"
        type_names.append(type_name)
        if type_name not in display_data:
            color_range = -samples[pc_ix][0]
            display_data[type_name] = DisplayData(
                name=type_name,
                display_type=DISPLAY_TYPE.FIBER,
                color=_rgb_to_hex_color(color_map(abs(sample) / color_range)),
            )
    return fiber_points, type_names, display_data


def _generate_simularium_and_save(
    name: str,
    fiber_points: list[np.ndarray],
    type_names: list[str],
    display_data: dict[str, DisplayData],
    distribution_over_time: bool,
    simulator_detail: bool,
    bucket: str,
    temp_path: str,
    pc: str,
) -> None:
    """Generate a simulariumio object for the fiber points and save it."""
    meta_data = MetaData(
        box_size=BOX_SIZE,
        camera_defaults=CameraData(
            position=np.array([70.0, 70.0, 300.0]),
            look_at_position=np.array([70.0, 70.0, 0.0]),
            fov_degrees=60.0,
        ),
        trajectory_title="Actin Compression Dimensionality Reduction",
    )
    time_units = UnitData("count")  # frames
    spatial_units = UnitData("nm")  # nanometers
    converter = generate_trajectory_converter_for_fiber_points(
        fiber_points,
        type_names,
        meta_data,
        display_data,
        time_units,
        spatial_units,
        fiber_radius=1.0,
    )

    # Save locally and copy to bucket.
    output_key = name
    output_key += "_time" if distribution_over_time else ""
    output_key += "_simulators" if simulator_detail else ""
    output_key += f"_pc{pc}" if pc else ""
    local_file_path = os.path.join(temp_path, output_key)
    converter.save(output_path=local_file_path)
    output_key = f"{output_key}.simularium"
    save_buffer(bucket, f"{name}/{output_key}", load_buffer(temp_path, output_key))


[docs] def visualize_dimensionality_reduction( bucket: str, pca_results_key: str, pca_pickle_key: str, distribution_over_time: bool, simulator_detail: bool, sample_ranges: dict[str, list[list[float]]], separate_pcs: bool, sample_resolution: int, temp_path: str, ) -> None: """ Visualize PCA space for actin fibers. Parameters ---------- bucket Name of S3 bucket for input and output files. pca_results_key File key for PCA results dataframe. pca_pickle_key File key for PCA object pickle. distribution_over_time True to scroll through the PC distributions over time, False otherwise. simulator_detail True to show individual simulator ranges, False otherwise. sample_ranges Min and max values to visualize for each PC (and each simulator if simulator_detail). separate_pcs True to Visualize PCs in separate files, False otherwise. sample_resolution Number of samples for each PC distribution. temp_path Local path for saving visualization output files. """ pca_results = load_dataframe(bucket, pca_results_key) pca = load_pickle(bucket, pca_pickle_key) fiber_points: list[list[np.ndarray]] = [[], []] type_names: list[list[str]] = [[], []] display_data: list[dict[str, DisplayData]] = [{}, {}] pca_results_simulators = { "Combined": pca_results, } if simulator_detail: pca_results_simulators["ReaDDy"] = pca_results.loc[ pca_results["SIMULATOR"] == "READDY" ] pca_results_simulators["Cytosim"] = pca_results.loc[ pca_results["SIMULATOR"] == "CYTOSIM" ] color_maps = { "Combined": plt.colormaps.get_cmap("RdPu"), "ReaDDy": plt.colormaps.get_cmap("YlOrRd"), "Cytosim": plt.colormaps.get_cmap("GnBu"), } over_time_colors = { "Combined": "#ffffff", "ReaDDy": "#ff8f52", "Cytosim": "#1cbfaa", } dataset_name = os.path.splitext(pca_pickle_key)[0] pc_ixs = list(range(2)) for simulator in pca_results_simulators: samples = [ np.arange( sample_ranges[simulator][0][0], sample_ranges[simulator][0][1], (sample_ranges[simulator][0][1] - sample_ranges[simulator][0][0]) / float(sample_resolution), ), np.arange( sample_ranges[simulator][1][0], sample_ranges[simulator][1][1], (sample_ranges[simulator][1][1] - sample_ranges[simulator][1][0]) / float(sample_resolution), ), ] for pc_ix in pc_ixs: if distribution_over_time: _fiber_points, _type_names, _display_data = _pca_fiber_points_over_time( samples, pca, pc_ix, simulator, over_time_colors[simulator] ) else: _fiber_points, _type_names, _display_data = ( _pca_fiber_points_one_timestep( samples, pca, color_maps, pc_ix, simulator ) ) if separate_pcs: fiber_points[pc_ix] += _fiber_points type_names[pc_ix] += _type_names display_data[pc_ix] = {**display_data[pc_ix], **_display_data} else: fiber_points[0] += _fiber_points type_names[0] += _type_names display_data[0] = {**display_data[0], **_display_data} if separate_pcs: for pc_ix in pc_ixs: _generate_simularium_and_save( dataset_name, fiber_points[pc_ix], type_names[pc_ix], display_data[pc_ix], distribution_over_time, simulator_detail, bucket, temp_path, str(pc_ix + 1), ) else: _generate_simularium_and_save( dataset_name, fiber_points[0], type_names[0], display_data[0], distribution_over_time, simulator_detail, bucket, temp_path, "", )