Coverage for subcell_pipeline/visualization/tomography.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Visualization methods for tomography data analysis.""" 

2 

3import os 

4from typing import Optional 

5 

6import numpy as np 

7import pandas as pd 

8from io_collection.load.load_buffer import load_buffer 

9from io_collection.load.load_dataframe import load_dataframe 

10from io_collection.save.save_buffer import save_buffer 

11from simulariumio import CameraData, MetaData, TrajectoryConverter, UnitData 

12 

13from subcell_pipeline.analysis.compression_metrics.compression_metric import ( 

14 CompressionMetric, 

15) 

16from subcell_pipeline.analysis.tomography_data.tomography_data import ( 

17 TOMOGRAPHY_SAMPLE_COLUMNS, 

18) 

19from subcell_pipeline.visualization.fiber_points import ( 

20 generate_trajectory_converter_for_fiber_points, 

21) 

22from subcell_pipeline.visualization.histogram_plots import make_empty_histogram_plots 

23 

24TOMOGRAPHY_VIZ_SCALE: float = 100.0 

25"""Spatial scaling factor for tomography visualization.""" 

26 

27 

28def _add_tomography_plots( 

29 converter: TrajectoryConverter, 

30 metrics: list[CompressionMetric], 

31 fiber_points: list[np.ndarray], 

32) -> None: 

33 """Add plots to tomography data with calculated metrics.""" 

34 

35 histogram_plots = make_empty_histogram_plots(metrics) 

36 

37 for metric, plot in histogram_plots.items(): 

38 values = [ 

39 metric.calculate_metric(polymer_trace=fiber[0, :, :]) 

40 for fiber in fiber_points 

41 ] 

42 

43 if metric == CompressionMetric.COMPRESSION_RATIO: 

44 plot.traces["actin"] = np.array(values) * 100 

45 else: 

46 plot.traces["actin"] = np.array(values) 

47 

48 converter.add_plot(plot, "histogram") 

49 

50 

51def _get_tomography_spatial_center_and_size( 

52 tomo_df: pd.DataFrame, 

53) -> tuple[np.ndarray, np.ndarray]: 

54 """Get the center and size of the tomography dataset in 3D space.""" 

55 

56 all_mins = [] 

57 all_maxs = [] 

58 

59 for column in TOMOGRAPHY_SAMPLE_COLUMNS: 

60 all_mins.append(tomo_df[column].min()) 

61 all_maxs.append(tomo_df[column].max()) 

62 

63 mins = np.array(all_mins) 

64 maxs = np.array(all_maxs) 

65 

66 return mins + 0.5 * (maxs - mins), maxs - mins 

67 

68 

69def visualize_tomography( 

70 bucket: str, 

71 name: str, 

72 temp_path: str, 

73 metrics: Optional[list[CompressionMetric]] = None, 

74) -> None: 

75 """ 

76 Visualize segmented tomography data for actin fibers. 

77 

78 Parameters 

79 ---------- 

80 bucket 

81 Name of S3 bucket for input and output files. 

82 name 

83 Name of tomography dataset. 

84 temp_path 

85 Local path for saving visualization output files. 

86 metrics 

87 List of metrics to include in visualization plots. 

88 """ 

89 

90 tomo_key = f"{name}/{name}_coordinates_sampled.csv" 

91 tomo_df = load_dataframe(bucket, tomo_key) 

92 tomo_df = tomo_df.sort_values(by=["id", "monomer_ids"]) 

93 tomo_df = tomo_df.reset_index(drop=True) 

94 

95 time_units = UnitData("count") 

96 spatial_units = UnitData("um", 0.003) 

97 

98 center, box_size = _get_tomography_spatial_center_and_size(tomo_df) 

99 

100 all_fiber_points = [] 

101 all_type_names = [] 

102 

103 for fiber_id, fiber_df in tomo_df.groupby("id"): 

104 fiber_index, dataset = fiber_id.split("_", 1) 

105 fiber_points = TOMOGRAPHY_VIZ_SCALE * ( 

106 np.array([fiber_df[TOMOGRAPHY_SAMPLE_COLUMNS]]) - center 

107 ) 

108 all_fiber_points.append(fiber_points) 

109 all_type_names.append(f"{dataset}#{fiber_index}") 

110 

111 converter = generate_trajectory_converter_for_fiber_points( 

112 all_fiber_points, 

113 all_type_names, 

114 MetaData( 

115 box_size=TOMOGRAPHY_VIZ_SCALE * box_size, 

116 camera_defaults=CameraData(position=np.array([0.0, 0.0, 70.0])), 

117 ), 

118 {}, 

119 time_units, 

120 spatial_units, 

121 ) 

122 

123 if metrics: 

124 _add_tomography_plots(converter, metrics, all_fiber_points) 

125 

126 # Save locally and copy to bucket. 

127 local_file_path = os.path.join(temp_path, name) 

128 converter.save(output_path=local_file_path) 

129 output_key = f"{name}/{name}.simularium" 

130 save_buffer(bucket, output_key, load_buffer(temp_path, f"{name}.simularium"))