Coverage for subcell_pipeline/visualization/dimensionality_reduction.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Visualization methods for dimensionality reduction analysis.""" 

2 

3import os 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7from io_collection.load.load_buffer import load_buffer 

8from io_collection.load.load_dataframe import load_dataframe 

9from io_collection.load.load_pickle import load_pickle 

10from io_collection.save.save_buffer import save_buffer 

11from matplotlib.colors import Colormap 

12from simulariumio import DISPLAY_TYPE, CameraData, DisplayData, MetaData, UnitData 

13from sklearn.decomposition import PCA 

14 

15from subcell_pipeline.visualization.fiber_points import ( 

16 generate_trajectory_converter_for_fiber_points, 

17) 

18 

19BOX_SIZE: np.ndarray = np.array(3 * [600.0]) 

20"""Bounding box size for dimensionality reduction trajectory.""" 

21 

22 

23def _rgb_to_hex_color(color: tuple[float, float, float]) -> str: 

24 """ 

25 Convert RGB color to hexadecimal format. 

26 

27 Parameters 

28 ---------- 

29 color 

30 Red, green, and blue colors (between 0.0 and 1.0). 

31 

32 Returns 

33 ------- 

34 : 

35 Color in hexadecimal format. 

36 """ 

37 rgb = (int(255 * color[0]), int(255 * color[1]), int(255 * color[2])) 

38 return f"#{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}" 

39 

40 

41def _pca_fiber_points_over_time( 

42 samples: list[np.ndarray], 

43 pca: PCA, 

44 pc_ix: int, 

45 simulator_name: str = "Combined", 

46 color: str = "#eaeaea", 

47) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]: 

48 """ 

49 Get fiber_points for samples of the PC distributions in order to visualize 

50 the samples over time. 

51 """ 

52 if simulator_name == "Combined": 

53 simulator_name = "" 

54 if simulator_name: 

55 simulator_name += "#" 

56 fiber_points: list[np.ndarray] = [] 

57 display_data: dict[str, DisplayData] = {} 

58 for sample_ix in range(len(samples[0])): 

59 if pc_ix < 1: 

60 data = [samples[0][sample_ix], 0] 

61 else: 

62 data = [0, samples[1][sample_ix]] 

63 fiber_points.append(pca.inverse_transform(data).reshape(-1, 3)) 

64 fiber_points_arr: np.ndarray = np.array(fiber_points) 

65 type_name: str = f"{simulator_name}PC{pc_ix + 1}" 

66 display_data[type_name] = DisplayData( 

67 name=type_name, 

68 display_type=DISPLAY_TYPE.FIBER, 

69 color=color, 

70 ) 

71 return [fiber_points_arr], [type_name], display_data 

72 

73 

74def _pca_fiber_points_one_timestep( 

75 samples: list[np.ndarray], 

76 pca: PCA, 

77 color_maps: dict[str, Colormap], 

78 pc_ix: int, 

79 simulator_name: str = "Combined", 

80) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]: 

81 """ 

82 Get fiber_points for samples of the PC distributions in order to visualize 

83 the samples together in one timestep. 

84 """ 

85 color_map = color_maps[simulator_name] 

86 if simulator_name == "Combined": 

87 simulator_name = "" 

88 if simulator_name: 

89 simulator_name += "_" 

90 

91 fiber_points = [] 

92 type_names = [] 

93 display_data = {} 

94 for sample_ix in range(len(samples[0])): 

95 data = [ 

96 [samples[0][sample_ix], 0], 

97 [0, samples[1][sample_ix]], 

98 ] 

99 fiber_points.append(pca.inverse_transform(data[pc_ix]).reshape(1, -1, 3)) 

100 sample = samples[pc_ix][sample_ix] 

101 sample_name = str(round(sample)) 

102 type_name = f"{simulator_name}PC{pc_ix + 1}#{sample_name}" 

103 type_names.append(type_name) 

104 if type_name not in display_data: 

105 color_range = -samples[pc_ix][0] 

106 display_data[type_name] = DisplayData( 

107 name=type_name, 

108 display_type=DISPLAY_TYPE.FIBER, 

109 color=_rgb_to_hex_color(color_map(abs(sample) / color_range)), 

110 ) 

111 return fiber_points, type_names, display_data 

112 

113 

114def _generate_simularium_and_save( 

115 name: str, 

116 fiber_points: list[np.ndarray], 

117 type_names: list[str], 

118 display_data: dict[str, DisplayData], 

119 distribution_over_time: bool, 

120 simulator_detail: bool, 

121 bucket: str, 

122 temp_path: str, 

123 pc: str, 

124) -> None: 

125 """Generate a simulariumio object for the fiber points and save it.""" 

126 meta_data = MetaData( 

127 box_size=BOX_SIZE, 

128 camera_defaults=CameraData( 

129 position=np.array([70.0, 70.0, 300.0]), 

130 look_at_position=np.array([70.0, 70.0, 0.0]), 

131 fov_degrees=60.0, 

132 ), 

133 trajectory_title="Actin Compression Dimensionality Reduction", 

134 ) 

135 time_units = UnitData("count") # frames 

136 spatial_units = UnitData("nm") # nanometers 

137 converter = generate_trajectory_converter_for_fiber_points( 

138 fiber_points, 

139 type_names, 

140 meta_data, 

141 display_data, 

142 time_units, 

143 spatial_units, 

144 fiber_radius=1.0, 

145 ) 

146 

147 # Save locally and copy to bucket. 

148 output_key = name 

149 output_key += "_time" if distribution_over_time else "" 

150 output_key += "_simulators" if simulator_detail else "" 

151 output_key += f"_pc{pc}" if pc else "" 

152 local_file_path = os.path.join(temp_path, output_key) 

153 converter.save(output_path=local_file_path) 

154 output_key = f"{output_key}.simularium" 

155 save_buffer(bucket, f"{name}/{output_key}", load_buffer(temp_path, output_key)) 

156 

157 

158def visualize_dimensionality_reduction( 

159 bucket: str, 

160 pca_results_key: str, 

161 pca_pickle_key: str, 

162 distribution_over_time: bool, 

163 simulator_detail: bool, 

164 sample_ranges: dict[str, list[list[float]]], 

165 separate_pcs: bool, 

166 sample_resolution: int, 

167 temp_path: str, 

168) -> None: 

169 """ 

170 Visualize PCA space for actin fibers. 

171 

172 Parameters 

173 ---------- 

174 bucket 

175 Name of S3 bucket for input and output files. 

176 pca_results_key 

177 File key for PCA results dataframe. 

178 pca_pickle_key 

179 File key for PCA object pickle. 

180 distribution_over_time 

181 True to scroll through the PC distributions over time, False otherwise. 

182 simulator_detail 

183 True to show individual simulator ranges, False otherwise. 

184 sample_ranges 

185 Min and max values to visualize for each PC (and each simulator if 

186 simulator_detail). 

187 separate_pcs 

188 True to Visualize PCs in separate files, False otherwise. 

189 sample_resolution 

190 Number of samples for each PC distribution. 

191 temp_path 

192 Local path for saving visualization output files. 

193 """ 

194 pca_results = load_dataframe(bucket, pca_results_key) 

195 pca = load_pickle(bucket, pca_pickle_key) 

196 

197 fiber_points: list[list[np.ndarray]] = [[], []] 

198 type_names: list[list[str]] = [[], []] 

199 display_data: list[dict[str, DisplayData]] = [{}, {}] 

200 pca_results_simulators = { 

201 "Combined": pca_results, 

202 } 

203 if simulator_detail: 

204 pca_results_simulators["ReaDDy"] = pca_results.loc[ 

205 pca_results["SIMULATOR"] == "READDY" 

206 ] 

207 pca_results_simulators["Cytosim"] = pca_results.loc[ 

208 pca_results["SIMULATOR"] == "CYTOSIM" 

209 ] 

210 color_maps = { 

211 "Combined": plt.colormaps.get_cmap("RdPu"), 

212 "ReaDDy": plt.colormaps.get_cmap("YlOrRd"), 

213 "Cytosim": plt.colormaps.get_cmap("GnBu"), 

214 } 

215 over_time_colors = { 

216 "Combined": "#ffffff", 

217 "ReaDDy": "#ff8f52", 

218 "Cytosim": "#1cbfaa", 

219 } 

220 dataset_name = os.path.splitext(pca_pickle_key)[0] 

221 pc_ixs = list(range(2)) 

222 for simulator in pca_results_simulators: 

223 samples = [ 

224 np.arange( 

225 sample_ranges[simulator][0][0], 

226 sample_ranges[simulator][0][1], 

227 (sample_ranges[simulator][0][1] - sample_ranges[simulator][0][0]) 

228 / float(sample_resolution), 

229 ), 

230 np.arange( 

231 sample_ranges[simulator][1][0], 

232 sample_ranges[simulator][1][1], 

233 (sample_ranges[simulator][1][1] - sample_ranges[simulator][1][0]) 

234 / float(sample_resolution), 

235 ), 

236 ] 

237 for pc_ix in pc_ixs: 

238 if distribution_over_time: 

239 _fiber_points, _type_names, _display_data = _pca_fiber_points_over_time( 

240 samples, pca, pc_ix, simulator, over_time_colors[simulator] 

241 ) 

242 else: 

243 _fiber_points, _type_names, _display_data = ( 

244 _pca_fiber_points_one_timestep( 

245 samples, pca, color_maps, pc_ix, simulator 

246 ) 

247 ) 

248 if separate_pcs: 

249 fiber_points[pc_ix] += _fiber_points 

250 type_names[pc_ix] += _type_names 

251 display_data[pc_ix] = {**display_data[pc_ix], **_display_data} 

252 else: 

253 fiber_points[0] += _fiber_points 

254 type_names[0] += _type_names 

255 display_data[0] = {**display_data[0], **_display_data} 

256 if separate_pcs: 

257 for pc_ix in pc_ixs: 

258 _generate_simularium_and_save( 

259 dataset_name, 

260 fiber_points[pc_ix], 

261 type_names[pc_ix], 

262 display_data[pc_ix], 

263 distribution_over_time, 

264 simulator_detail, 

265 bucket, 

266 temp_path, 

267 str(pc_ix + 1), 

268 ) 

269 else: 

270 _generate_simularium_and_save( 

271 dataset_name, 

272 fiber_points[0], 

273 type_names[0], 

274 display_data[0], 

275 distribution_over_time, 

276 simulator_detail, 

277 bucket, 

278 temp_path, 

279 "", 

280 )