Coverage for subcell_pipeline/visualization/combined_trajectory.py: 0%

75 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Visualization methods for combined simulators.""" 

2 

3import os 

4from typing import Optional 

5 

6import numpy as np 

7import pandas as pd 

8from io_collection.keys.check_key import check_key 

9from io_collection.load.load_buffer import load_buffer 

10from io_collection.load.load_dataframe import load_dataframe 

11from io_collection.save.save_buffer import save_buffer 

12from simulariumio import ( 

13 DISPLAY_TYPE, 

14 AgentData, 

15 CameraData, 

16 DisplayData, 

17 MetaData, 

18 TrajectoryConverter, 

19 TrajectoryData, 

20 UnitData, 

21) 

22 

23from subcell_pipeline.analysis.compression_metrics.compression_analysis import ( 

24 get_compression_metric_data, 

25) 

26from subcell_pipeline.analysis.compression_metrics.compression_metric import ( 

27 CompressionMetric, 

28) 

29from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fibers 

30from subcell_pipeline.visualization.scatter_plots import make_empty_scatter_plots 

31 

32BOX_SIZE: np.ndarray = np.array(3 * [600.0]) 

33"""Bounding box size for combined simulator trajectories.""" 

34 

35 

36def _load_fiber_points_from_dataframe( 

37 dataframe: pd.DataFrame, n_timepoints: int 

38) -> np.ndarray: 

39 """ 

40 Load and reshape fiber points from sampled dataframe. 

41 

42 Sampled dataframe is in the shape (n_timepoints x n_fiber_points, 3); method 

43 returns the dataframe reshaped to (n_timepoints, n_fiber_points x 3). If the 

44 sampled dataframe does not have the expected number of timepoints, method 

45 will raise an exception. 

46 """ 

47 

48 dataframe.sort_values(by=["time", "fiber_point"]) 

49 total_steps = dataframe.time.unique().shape[0] 

50 

51 if total_steps != n_timepoints: 

52 raise Exception( 

53 f"Requested number of timesteps [ {n_timepoints} ] does not match " 

54 f"number of timesteps in dataset [ {total_steps} ]." 

55 ) 

56 

57 align_fibers(dataframe) 

58 

59 fiber_points = [] 

60 for _, group in dataframe.groupby("time"): 

61 fiber_points.append(group[["xpos", "ypos", "zpos"]].values.flatten()) 

62 

63 return np.array(fiber_points) 

64 

65 

66def _get_combined_trajectory_converter( 

67 fiber_points: list[np.ndarray], 

68 type_names: list[str], 

69 display_data: dict[str, DisplayData], 

70) -> TrajectoryConverter: 

71 """ 

72 Generate a TrajectoryConverter to visualize simulations from ReaDDy and 

73 Cytosim together. 

74 """ 

75 

76 total_conditions = len(fiber_points) 

77 total_steps = fiber_points[0].shape[0] 

78 total_subpoints = fiber_points[0].shape[1] 

79 

80 traj_data = TrajectoryData( 

81 meta_data=MetaData( 

82 box_size=BOX_SIZE, 

83 camera_defaults=CameraData( 

84 position=np.array([75.0, 220.0, 15.0]), 

85 look_at_position=np.array([75.0, 75.0, 0.0]), 

86 fov_degrees=60.0, 

87 ), 

88 trajectory_title="Actin compression in Cytosim and Readdy", 

89 ), 

90 agent_data=AgentData( 

91 times=np.arange(total_steps), 

92 n_agents=total_conditions * np.ones(total_steps), 

93 viz_types=1001 

94 * np.ones((total_steps, total_conditions)), # fiber viz type = 1001 

95 unique_ids=np.array(total_steps * [list(range(total_conditions))]), 

96 types=total_steps * [type_names], 

97 positions=np.zeros((total_steps, total_conditions, 3)), 

98 radii=np.ones((total_steps, total_conditions)), 

99 n_subpoints=total_subpoints * np.ones((total_steps, total_conditions)), 

100 subpoints=np.moveaxis(np.array(fiber_points), [0, 1], [1, 0]), 

101 display_data=display_data, 

102 ), 

103 time_units=UnitData("count"), # frames 

104 spatial_units=UnitData("nm"), # nanometer 

105 ) 

106 return TrajectoryConverter(traj_data) 

107 

108 

109def _add_combined_plots( 

110 converter: TrajectoryConverter, 

111 metrics: list[CompressionMetric], 

112 metrics_data: pd.DataFrame, 

113 n_timepoints: int, 

114 plot_names: list[tuple[str, str, int]], 

115 type_names: list[str], 

116) -> None: 

117 """Add plots for combined trajectories with calculated metrics.""" 

118 scatter_plots = make_empty_scatter_plots(metrics, total_steps=n_timepoints) 

119 

120 for metric, plot in scatter_plots.items(): 

121 for plot_name, type_name in zip(plot_names, type_names): 

122 simulator, key, seed = plot_name 

123 simulator_data = metrics_data[simulator] 

124 data = simulator_data[ 

125 (simulator_data["key"] == key) & (simulator_data["seed"] == seed) 

126 ] 

127 plot.ytraces[type_name] = np.array(data[metric.value]) 

128 converter.add_plot(plot, "scatter") 

129 

130 

131def visualize_combined_trajectories( 

132 buckets: dict[str, str], 

133 series_name: str, 

134 condition_keys: list[str], 

135 replicates: list[int], 

136 n_timepoints: int, 

137 simulator_colors: dict[str, str], 

138 temp_path: str, 

139 metrics: Optional[list[CompressionMetric]] = None, 

140 recalculate: bool = False, 

141) -> None: 

142 """ 

143 Visualize combined simulations from ReaDDy and Cytosim for select conditions 

144 and number of replicates. 

145 

146 Parameters 

147 ---------- 

148 buckets 

149 Names of S3 buckets for input and output files for each simulator and 

150 visualization. 

151 series_name 

152 Name of simulation series. 

153 condition_keys 

154 List of condition keys. 

155 replicates 

156 Simulation replicates ids. 

157 n_timepoints 

158 Number of equally spaced timepoints to visualize. 

159 simulator_colors 

160 Map of simulator name to color. 

161 temp_path 

162 Local path for saving visualization output files. 

163 metrics 

164 List of metrics to include in visualization plots. 

165 recalculate 

166 True to recalculate visualization files, False otherwise. 

167 """ 

168 

169 fiber_points = [] 

170 type_names = [] 

171 plot_names = [] 

172 display_data = {} 

173 all_metrics_data = {} 

174 

175 for simulator, color in simulator_colors.items(): 

176 bucket = buckets[simulator] 

177 

178 # Load calculated compression metric data. 

179 if metrics is not None: 

180 all_metrics_data[simulator] = get_compression_metric_data( 

181 bucket, 

182 series_name, 

183 condition_keys, 

184 replicates, 

185 metrics, 

186 recalculate=recalculate, 

187 ) 

188 else: 

189 metrics = [] 

190 all_metrics_data[simulator] = pd.DataFrame(columns=["key", "seed"]) 

191 

192 for condition_key in condition_keys: 

193 series_key = ( 

194 f"{series_name}_{condition_key}" if condition_key else series_name 

195 ) 

196 

197 for replicate in replicates: 

198 dataframe_key = ( 

199 f"{series_name}/samples/{series_key}_{replicate:06d}.csv" 

200 ) 

201 

202 # Skip if input dataframe does not exist. 

203 if not check_key(bucket, dataframe_key): 

204 print( 

205 f"Dataframe not available for {simulator} " 

206 f"[ { dataframe_key } ]. Skipping." 

207 ) 

208 continue 

209 

210 print( 

211 f"Loading data for [ {simulator} ] " 

212 f"condition [ { dataframe_key } ] " 

213 f"replicate [ {replicate} ]" 

214 ) 

215 

216 dataframe = load_dataframe(bucket, dataframe_key) 

217 fiber_points.append( 

218 _load_fiber_points_from_dataframe(dataframe, n_timepoints) 

219 ) 

220 

221 condition = int(condition_key) / 10 

222 condition = round(condition) if condition_key[-1] == "0" else condition 

223 

224 type_names.append(f"{simulator}#{condition} um/s {replicate}") 

225 plot_names.append((simulator, condition_key, replicate)) 

226 display_data[type_names[-1]] = DisplayData( 

227 name=type_names[-1], 

228 display_type=DISPLAY_TYPE.FIBER, 

229 color=color, 

230 ) 

231 

232 converter = _get_combined_trajectory_converter( 

233 fiber_points, type_names, display_data 

234 ) 

235 

236 if metrics: 

237 _add_combined_plots( 

238 converter, metrics, all_metrics_data, n_timepoints, plot_names, type_names 

239 ) 

240 

241 output_key = "actin_compression_cytosim_readdy.simularium" 

242 local_file_path = os.path.join(temp_path, output_key) 

243 converter.save(output_path=local_file_path.replace(".simularium", "")) 

244 output_bucket = buckets["combined"] 

245 save_buffer(output_bucket, output_key, load_buffer(temp_path, output_key))