Coverage for subcell_pipeline/visualization/individual_trajectory.py: 0%

145 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Visualization methods for individual simulators.""" 

2 

3from typing import Optional 

4 

5import numpy as np 

6import pandas as pd 

7from io_collection.keys.check_key import check_key 

8from io_collection.load.load_buffer import load_buffer 

9from io_collection.load.load_text import load_text 

10from io_collection.save.save_buffer import save_buffer 

11from pint import UnitRegistry 

12from simulariumio import ( 

13 DISPLAY_TYPE, 

14 CameraData, 

15 DisplayData, 

16 InputFileData, 

17 MetaData, 

18 TrajectoryConverter, 

19 UnitData, 

20) 

21from simulariumio.cytosim import CytosimConverter, CytosimData, CytosimObjectInfo 

22from simulariumio.filters import EveryNthTimestepFilter 

23from simulariumio.readdy import ReaddyConverter, ReaddyData 

24 

25from subcell_pipeline.analysis.compression_metrics.compression_analysis import ( 

26 get_compression_metric_data, 

27) 

28from subcell_pipeline.analysis.compression_metrics.compression_metric import ( 

29 CompressionMetric, 

30) 

31from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fiber 

32from subcell_pipeline.simulation.cytosim.post_processing import CYTOSIM_SCALE_FACTOR 

33from subcell_pipeline.simulation.readdy.loader import ReaddyLoader 

34from subcell_pipeline.simulation.readdy.parser import BOX_SIZE as READDY_BOX_SIZE 

35from subcell_pipeline.simulation.readdy.parser import ( 

36 READDY_TIMESTEP, 

37 download_readdy_hdf5, 

38) 

39from subcell_pipeline.simulation.readdy.post_processor import ReaddyPostProcessor 

40from subcell_pipeline.visualization.display_data import get_readdy_display_data 

41from subcell_pipeline.visualization.scatter_plots import make_empty_scatter_plots 

42from subcell_pipeline.visualization.spatial_annotator import ( 

43 add_fiber_annotation_agents, 

44 add_sphere_annotation_agents, 

45) 

46 

47READDY_SAVED_FRAMES: int = 1000 

48"""Number of saved frames for ReaDDy simulations.""" 

49 

50BOX_SIZE: np.ndarray = np.array(3 * [600.0]) 

51"""Bounding box size for individual simulator trajectory.""" 

52 

53 

54def _add_individual_plots( 

55 converter: TrajectoryConverter, 

56 metrics: list[CompressionMetric], 

57 metrics_data: pd.DataFrame, 

58 times: np.ndarray, 

59 time_units: UnitData, 

60) -> None: 

61 """Add plots to individual trajectory with calculated metrics.""" 

62 scatter_plots = make_empty_scatter_plots( 

63 metrics, times=times, time_units=time_units 

64 ) 

65 for metric, plot in scatter_plots.items(): 

66 plot.ytraces["filament"] = np.array(metrics_data[metric.value]) 

67 converter.add_plot(plot, "scatter") 

68 

69 

70def _add_readdy_spatial_annotations( 

71 converter: TrajectoryConverter, 

72 post_processor: ReaddyPostProcessor, 

73 n_monomer_points: int, 

74) -> None: 

75 """ 

76 Add visualizations of edges, normals, and control points to the ReaDDy 

77 Simularium data. 

78 """ 

79 fiber_chain_ids = post_processor.linear_fiber_chain_ids(polymer_number_range=5) 

80 axis_positions, _ = post_processor.linear_fiber_axis_positions(fiber_chain_ids) 

81 fiber_points = post_processor.linear_fiber_control_points( 

82 axis_positions=axis_positions, 

83 n_points=n_monomer_points, 

84 ) 

85 converter._data.agent_data.positions, fiber_points = ( 

86 post_processor.align_trajectory(fiber_points) 

87 ) 

88 axis_positions, _ = post_processor.linear_fiber_axis_positions(fiber_chain_ids) 

89 edges = post_processor.edge_positions() 

90 

91 # edges 

92 converter._data = add_fiber_annotation_agents( 

93 converter._data, 

94 fiber_points=edges, 

95 type_name="edge", 

96 fiber_width=0.5, 

97 color="#eaeaea", 

98 ) 

99 

100 # normals 

101 normals = post_processor.linear_fiber_normals( 

102 fiber_chain_ids=fiber_chain_ids, 

103 axis_positions=axis_positions, 

104 normal_length=10.0, 

105 ) 

106 converter._data = add_fiber_annotation_agents( 

107 converter._data, 

108 fiber_points=normals, 

109 type_name="normal", 

110 fiber_width=0.5, 

111 color="#685bf3", 

112 ) 

113 

114 # control points 

115 sphere_positions = [] 

116 for time_ix in range(len(fiber_points)): 

117 sphere_positions.append(fiber_points[time_ix][0]) 

118 converter._data = add_sphere_annotation_agents( 

119 converter._data, 

120 sphere_positions, 

121 type_name="fiber point", 

122 radius=0.8, 

123 rainbow_colors=True, 

124 ) 

125 

126 

127def _get_readdy_simularium_converter( 

128 path_to_readdy_h5: str, 

129 total_steps: int, 

130 n_timepoints: int, 

131) -> TrajectoryConverter: 

132 """ 

133 Load from ReaDDy outputs and generate a TrajectoryConverter to visualize an 

134 actin trajectory in Simularium. 

135 """ 

136 converter = ReaddyConverter( 

137 ReaddyData( 

138 timestep=1e-6 * (READDY_TIMESTEP * total_steps / READDY_SAVED_FRAMES), 

139 path_to_readdy_h5=path_to_readdy_h5, 

140 meta_data=MetaData( 

141 box_size=READDY_BOX_SIZE, 

142 camera_defaults=CameraData( 

143 position=np.array([70.0, 70.0, 300.0]), 

144 look_at_position=np.array([70.0, 70.0, 0.0]), 

145 fov_degrees=60.0, 

146 ), 

147 scale_factor=1.0, 

148 ), 

149 display_data=get_readdy_display_data(), 

150 time_units=UnitData("ms"), 

151 spatial_units=UnitData("nm"), 

152 ) 

153 ) 

154 return _filter_time(converter, n_timepoints) 

155 

156 

157def visualize_individual_readdy_trajectory( 

158 bucket: str, 

159 series_name: str, 

160 series_key: str, 

161 rep_ix: int, 

162 n_timepoints: int, 

163 n_monomer_points: int, 

164 total_steps: int, 

165 temp_path: str, 

166 metrics: list[CompressionMetric], 

167 metrics_data: pd.DataFrame, 

168) -> None: 

169 """ 

170 Save a Simularium file for a single ReaDDy trajectory with plots and spatial 

171 annotations. 

172 

173 Parameters 

174 ---------- 

175 bucket 

176 Name of S3 bucket for input and output files. 

177 series_name 

178 Name of simulation series. 

179 series_key 

180 Combination of series and condition names. 

181 rep_ix 

182 Replicate index. 

183 n_timepoints 

184 Number of equally spaced timepoints to visualize. 

185 n_monomer_points 

186 Number of equally spaced monomer points to visualize. 

187 total_steps 

188 Total number of steps for each simulation key. 

189 temp_path 

190 Local path for saving visualization output files. 

191 metrics 

192 List of metrics to include in visualization plots. 

193 metrics_data 

194 Calculated compression metrics data. 

195 """ 

196 

197 h5_file_path = download_readdy_hdf5( 

198 bucket, series_name, series_key, rep_ix, temp_path 

199 ) 

200 

201 assert isinstance(h5_file_path, str) 

202 

203 converter = _get_readdy_simularium_converter( 

204 h5_file_path, total_steps, n_timepoints 

205 ) 

206 

207 if metrics: 

208 times = 2 * metrics_data["time"].values # "time" seems to range (0, 0.5) 

209 times *= 1e-6 * (READDY_TIMESTEP * total_steps / n_timepoints) 

210 _add_individual_plots( 

211 converter, metrics, metrics_data, times, converter._data.time_units 

212 ) 

213 

214 assert isinstance(h5_file_path, str) 

215 

216 rep_id = rep_ix + 1 

217 pickle_key = f"{series_name}/data/{series_key}_{rep_id:06d}.pkl" 

218 time_inc = total_steps // n_timepoints 

219 

220 readdy_loader = ReaddyLoader( 

221 h5_file_path=h5_file_path, 

222 time_inc=time_inc, 

223 timestep=READDY_TIMESTEP, 

224 pickle_location=bucket, 

225 pickle_key=pickle_key, 

226 ) 

227 

228 post_processor = ReaddyPostProcessor( 

229 readdy_loader.trajectory(), box_size=READDY_BOX_SIZE 

230 ) 

231 

232 _add_readdy_spatial_annotations(converter, post_processor, n_monomer_points) 

233 

234 # Save simularium file. Turn off validate IDs for performance. 

235 converter.save(output_path=h5_file_path, validate_ids=False) 

236 

237 

238def visualize_individual_readdy_trajectories( 

239 bucket: str, 

240 series_name: str, 

241 condition_keys: list[str], 

242 n_replicates: int, 

243 n_timepoints: int, 

244 n_monomer_points: int, 

245 total_steps: dict[str, int], 

246 temp_path: str, 

247 metrics: Optional[list[CompressionMetric]] = None, 

248 recalculate: bool = True, 

249) -> None: 

250 """ 

251 Visualize individual ReaDDy simulations for select conditions and 

252 replicates. 

253 

254 Parameters 

255 ---------- 

256 bucket 

257 Name of S3 bucket for input and output files. 

258 series_name 

259 Name of simulation series. 

260 condition_keys 

261 List of condition keys. 

262 n_replicates 

263 Number of simulation replicates. 

264 n_timepoints 

265 Number of equally spaced timepoints to visualize. 

266 n_monomer_points 

267 Number of equally spaced monomer points to visualize. 

268 total_steps 

269 Total number of steps for each simulation key. 

270 temp_path 

271 Local path for saving visualization output files. 

272 metrics 

273 List of metrics to include in visualization plots. 

274 recalculate 

275 True to recalculate visualization files, False otherwise. 

276 """ 

277 

278 if metrics is not None: 

279 all_metrics_data = get_compression_metric_data( 

280 bucket, 

281 series_name, 

282 condition_keys, 

283 list(range(1, n_replicates + 1)), 

284 metrics, 

285 recalculate=False, 

286 ) 

287 else: 

288 metrics = [] 

289 all_metrics_data = pd.DataFrame(columns=["key", "seed"]) 

290 

291 for condition_key in condition_keys: 

292 series_key = f"{series_name}_{condition_key}" if condition_key else series_name 

293 

294 for rep_ix in range(n_replicates): 

295 rep_id = rep_ix + 1 

296 output_key = f"{series_name}/viz/{series_key}_{rep_id:06d}.simularium" 

297 

298 # Skip if output file already exists. 

299 if not recalculate and check_key(bucket, output_key): 

300 print( 

301 f"Simularium file for [ { output_key } ] already exists. Skipping." 

302 ) 

303 continue 

304 

305 print(f"Visualizing data for [ {condition_key} ] replicate [ {rep_ix} ]") 

306 

307 # Filter metrics data for specific conditon and replicate. 

308 if condition_key: 

309 metrics_data = all_metrics_data[ 

310 (all_metrics_data["key"] == condition_key) 

311 & (all_metrics_data["seed"] == rep_id) 

312 ] 

313 else: 

314 metrics_data = all_metrics_data[(all_metrics_data["seed"] == rep_id)] 

315 

316 visualize_individual_readdy_trajectory( 

317 bucket, 

318 series_name, 

319 series_key, 

320 rep_ix, 

321 n_timepoints, 

322 n_monomer_points, 

323 total_steps[condition_key], 

324 temp_path, 

325 metrics, 

326 metrics_data, 

327 ) 

328 

329 # Upload saved file to S3. 

330 temp_key = f"{series_key}_{rep_ix}.h5.simularium" 

331 save_buffer(bucket, output_key, load_buffer(temp_path, temp_key)) 

332 

333 

334def _find_time_units(raw_time: float, units: str = "s") -> tuple[str, float]: 

335 """Get compact time units and a multiplier to put the times in those units.""" 

336 time = UnitRegistry().Quantity(raw_time, units) 

337 time_compact = time.to_compact() 

338 return f"{time_compact.units:~}", time_compact.magnitude / raw_time 

339 

340 

341def _filter_time( 

342 converter: TrajectoryConverter, n_timepoints: int 

343) -> TrajectoryConverter: 

344 """Filter times using simulariumio time filter.""" 

345 time_inc = int(converter._data.agent_data.times.shape[0] / n_timepoints) 

346 if time_inc < 2: 

347 return converter 

348 converter._data = converter.filter_data([EveryNthTimestepFilter(n=time_inc)]) 

349 return converter 

350 

351 

352def _align_cytosim_fiber(converter: TrajectoryConverter) -> None: 

353 """ 

354 Align the fiber subpoints so that the furthest point from the x-axis 

355 is aligned with the positive y-axis at the last time point. 

356 """ 

357 fiber_points = converter._data.agent_data.subpoints[:, 0, :] 

358 n_timesteps = fiber_points.shape[0] 

359 n_points = int(fiber_points.shape[1] / 3) 

360 fiber_points = fiber_points.reshape((n_timesteps, n_points, 3)) 

361 _, rotation = align_fiber(fiber_points[-1]) 

362 for time_ix in range(n_timesteps): 

363 rotated = np.dot(fiber_points[time_ix][:, 1:], rotation) 

364 converter._data.agent_data.subpoints[time_ix, 0, :] = np.concatenate( 

365 (fiber_points[time_ix][:, 0:1], rotated), axis=1 

366 ).reshape(n_points * 3) 

367 

368 

369def _get_cytosim_simularium_converter( 

370 fiber_points_data: str, 

371 singles_data: str, 

372 n_timepoints: int, 

373) -> TrajectoryConverter: 

374 """ 

375 Load from Cytosim outputs and generate a TrajectoryConverter to visualize an 

376 actin trajectory in Simularium. 

377 """ 

378 singles_display_data = DisplayData( 

379 name="linker", 

380 radius=0.004, 

381 display_type=DISPLAY_TYPE.SPHERE, 

382 color="#eaeaea", 

383 ) 

384 converter = CytosimConverter( 

385 CytosimData( 

386 meta_data=MetaData( 

387 box_size=BOX_SIZE, 

388 camera_defaults=CameraData( 

389 position=np.array([70.0, 70.0, 300.0]), 

390 look_at_position=np.array([70.0, 70.0, 0.0]), 

391 fov_degrees=60.0, 

392 ), 

393 scale_factor=1, 

394 ), 

395 object_info={ 

396 "fibers": CytosimObjectInfo( 

397 cytosim_file=InputFileData( 

398 file_contents=fiber_points_data, 

399 ), 

400 display_data={ 

401 1: DisplayData( 

402 name="actin", 

403 radius=0.002, 

404 display_type=DISPLAY_TYPE.FIBER, 

405 ) 

406 }, 

407 ), 

408 "singles": CytosimObjectInfo( 

409 cytosim_file=InputFileData( 

410 file_contents=singles_data, 

411 ), 

412 display_data={ 

413 1: singles_display_data, 

414 2: singles_display_data, 

415 3: singles_display_data, 

416 4: singles_display_data, 

417 }, 

418 ), 

419 }, 

420 ) 

421 ) 

422 _align_cytosim_fiber(converter) 

423 converter._data.agent_data.radii *= CYTOSIM_SCALE_FACTOR 

424 converter._data.agent_data.positions *= CYTOSIM_SCALE_FACTOR 

425 converter._data.agent_data.subpoints *= CYTOSIM_SCALE_FACTOR 

426 converter = _filter_time(converter, n_timepoints) 

427 time_units, time_multiplier = _find_time_units(converter._data.agent_data.times[-1]) 

428 converter._data.agent_data.times *= time_multiplier 

429 converter._data.time_units = UnitData(time_units) 

430 return converter 

431 

432 

433def visualize_individual_cytosim_trajectory( 

434 bucket: str, 

435 series_name: str, 

436 series_key: str, 

437 index: int, 

438 n_timepoints: int, 

439 temp_path: str, 

440 metrics: list[CompressionMetric], 

441 metrics_data: pd.DataFrame, 

442) -> None: 

443 """ 

444 Save a Simularium file for a single Cytosim trajectory with plots and 

445 spatial annotations. 

446 

447 Parameters 

448 ---------- 

449 bucket 

450 Name of S3 bucket for input and output files. 

451 series_name 

452 Name of simulation series. 

453 series_key 

454 Combination of series and condition names. 

455 index 

456 Simulation replicate index. 

457 n_timepoints 

458 Number of equally spaced timepoints to visualize. 

459 temp_path 

460 Local path for saving visualization output files. 

461 metrics 

462 List of metrics to include in visualization plots. 

463 metrics_data 

464 Calculated compression metrics data. 

465 """ 

466 

467 output_key_template = f"{series_name}/outputs/{series_key}_{index}/%s" 

468 fiber_points_data = load_text(bucket, output_key_template % "fiber_points.txt") 

469 singles_data = load_text(bucket, output_key_template % "singles.txt") 

470 

471 converter = _get_cytosim_simularium_converter( 

472 fiber_points_data, singles_data, n_timepoints 

473 ) 

474 

475 if metrics: 

476 times = 1e3 * metrics_data["time"].values # s --> ms 

477 _add_individual_plots( 

478 converter, metrics, metrics_data, times, converter._data.time_units 

479 ) 

480 

481 # Save simularium file. Turn off validate IDs for performance. 

482 local_file_path = f"{temp_path}/{series_key}_{index}" 

483 converter.save(output_path=local_file_path, validate_ids=False) 

484 

485 

486def visualize_individual_cytosim_trajectories( 

487 bucket: str, 

488 series_name: str, 

489 condition_keys: list[str], 

490 random_seeds: list[int], 

491 n_timepoints: int, 

492 temp_path: str, 

493 metrics: Optional[list[CompressionMetric]] = None, 

494 recalculate: bool = True, 

495) -> None: 

496 """ 

497 Visualize individual Cytosim simulations for select conditions and 

498 replicates. 

499 

500 Parameters 

501 ---------- 

502 bucket 

503 Name of S3 bucket for input and output files. 

504 series_name 

505 Name of simulation series. 

506 condition_keys 

507 List of condition keys. 

508 random_seeds 

509 Random seeds for simulations. 

510 n_timepoints 

511 Number of equally spaced timepoints to visualize. 

512 temp_path 

513 Local path for saving visualization output files. 

514 metrics 

515 List of metrics to include in visualization plots. 

516 recalculate 

517 True to recalculate visualization files, False otherwise. 

518 """ 

519 

520 if metrics is not None: 

521 all_metrics_data = get_compression_metric_data( 

522 bucket, 

523 series_name, 

524 condition_keys, 

525 random_seeds, 

526 metrics, 

527 recalculate=False, 

528 ) 

529 else: 

530 metrics = [] 

531 all_metrics_data = pd.DataFrame(columns=["key", "seed"]) 

532 

533 for condition_key in condition_keys: 

534 series_key = f"{series_name}_{condition_key}" if condition_key else series_name 

535 

536 for index, seed in enumerate(random_seeds): 

537 output_key = f"{series_name}/viz/{series_key}_{seed:06d}.simularium" 

538 

539 # Skip if output file already exists. 

540 if not recalculate and check_key(bucket, output_key): 

541 print( 

542 f"Simularium file for [ { output_key } ] already exists. Skipping." 

543 ) 

544 continue 

545 

546 print(f"Visualizing data for [ {condition_key} ] seed [ {seed} ]") 

547 

548 # Filter metrics data for specific conditon and replicate. 

549 if condition_key: 

550 metrics_data = all_metrics_data[ 

551 (all_metrics_data["key"] == condition_key) 

552 & (all_metrics_data["seed"] == seed) 

553 ] 

554 else: 

555 metrics_data = all_metrics_data[(all_metrics_data["seed"] == seed)] 

556 

557 visualize_individual_cytosim_trajectory( 

558 bucket, 

559 series_name, 

560 series_key, 

561 index, 

562 n_timepoints, 

563 temp_path, 

564 metrics, 

565 metrics_data, 

566 ) 

567 

568 temp_key = f"{series_key}_{index}.simularium" 

569 save_buffer(bucket, output_key, load_buffer(temp_path, temp_key))