Coverage for subcell_pipeline/simulation/readdy/post_processor.py: 32%

226 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Class for post processing ReaDDy trajectories.""" 

2 

3import math 

4from typing import Optional 

5 

6import numpy as np 

7from tqdm import tqdm 

8 

9from subcell_pipeline.analysis.compression_metrics.polymer_trace import ( 

10 get_contour_length_from_trace, 

11) 

12from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fiber 

13from subcell_pipeline.simulation.readdy.data_structures import FrameData 

14 

15ACTIN_START_PARTICLE_PHRASE: list[str] = ["pointed"] 

16"""Phrases indicating actin start particle.""" 

17 

18ACTIN_PARTICLE_TYPES: list[str] = [ 

19 "actin#", 

20 "actin#ATP_", 

21 "actin#mid_", 

22 "actin#mid_ATP_", 

23 "actin#fixed_", 

24 "actin#fixed_ATP_", 

25 "actin#mid_fixed_", 

26 "actin#mid_fixed_ATP_", 

27 "actin#barbed_", 

28 "actin#barbed_ATP_", 

29 "actin#fixed_barbed_", 

30 "actin#fixed_barbed_ATP_", 

31] 

32"""Actin particle types from simularium/readdy-models.""" 

33 

34IDEAL_ACTIN_POSITIONS: np.ndarray = np.array( 

35 [ 

36 [24.738, 20.881, 26.671], 

37 [27.609, 24.061, 27.598], 

38 [30.382, 21.190, 25.725], 

39 ] 

40) 

41"""Ideal actin positions measured from crystal structure.""" 

42 

43IDEAL_ACTIN_VECTOR_TO_AXIS: np.ndarray = np.array( 

44 [-0.01056751, -1.47785105, -0.65833209] 

45) 

46"""Ideal actin vector to axis.""" 

47 

48 

49class ReaddyPostProcessor: 

50 """Get views of ReaDDy trajectory for different analysis purposes.""" 

51 

52 trajectory: list[FrameData] 

53 """ReaDDy data trajectory from ReaddyLoader(h5_file_path).trajectory().""" 

54 

55 box_size: np.ndarray 

56 """The size of the x,y,z dimensions of the simulation volume (shape = 3).""" 

57 

58 periodic_boundary: bool 

59 """True if simulation had periodic boundary, False otherwise.""" 

60 

61 def __init__( 

62 self, 

63 trajectory: list[FrameData], 

64 box_size: np.ndarray, 

65 periodic_boundary: bool = False, 

66 ): 

67 self.trajectory = trajectory 

68 self.box_size = box_size 

69 self.periodic_boundary = periodic_boundary 

70 

71 def times(self) -> np.ndarray: 

72 """ 

73 Get simulation time at each timestep. 

74 

75 Returns 

76 ------- 

77 times 

78 Array of time stamps in simulation time for each timestep (shape = 

79 n_timesteps). 

80 """ 

81 result = [trajectory.time for trajectory in self.trajectory] 

82 return np.array(result) 

83 

84 def _id_for_neighbor_of_types( 

85 self, 

86 time_ix: int, 

87 particle_id: int, 

88 neighbor_types: list[str], 

89 exclude_ids: Optional[list[int]] = None, 

90 ) -> int: 

91 """ 

92 Get the id for the first neighbor with a type_name in neighbor_types at 

93 the given time index. 

94 """ 

95 particles = self.trajectory[time_ix].particles 

96 for neighbor_id in particles[particle_id].neighbor_ids: 

97 if exclude_ids is not None and neighbor_id in exclude_ids: 

98 continue 

99 for neighbor_type in neighbor_types: 

100 if neighbor_type == particles[neighbor_id].type_name: 

101 return neighbor_id 

102 return -1 

103 

104 def _ids_for_chain_of_types( 

105 self, 

106 time_ix: int, 

107 start_particle_id: int, 

108 chain_particle_types: list[list[str]], 

109 current_polymer_number: int, 

110 chain_length: int = 0, 

111 last_particle_id: Optional[int] = None, 

112 result: Optional[list[int]] = None, 

113 ) -> list[int]: 

114 """ 

115 Get IDs for a chain of particles with chain_particle_types in the given 

116 frame of data, starting from the particle with start_particle_id and 

117 avoiding the particle with last_particle_id. 

118 

119 If chain_length = 0, return entire chain. 

120 """ 

121 if result is None: 

122 result = [start_particle_id] 

123 if chain_length == 1: 

124 return result 

125 neighbor_id = self._id_for_neighbor_of_types( 

126 time_ix=time_ix, 

127 particle_id=start_particle_id, 

128 neighbor_types=chain_particle_types[current_polymer_number], 

129 exclude_ids=[last_particle_id] if last_particle_id is not None else [], 

130 ) 

131 if neighbor_id < 0: 

132 return result 

133 result.append(neighbor_id) 

134 return self._ids_for_chain_of_types( 

135 time_ix=time_ix, 

136 start_particle_id=neighbor_id, 

137 chain_particle_types=chain_particle_types, 

138 current_polymer_number=( 

139 (current_polymer_number + 1) % len(chain_particle_types) 

140 ), 

141 chain_length=chain_length - 1 if chain_length > 0 else 0, 

142 last_particle_id=start_particle_id, 

143 result=result, 

144 ) 

145 

146 def _non_periodic_position( 

147 self, position1: np.ndarray, position2: np.ndarray 

148 ) -> np.ndarray: 

149 """ 

150 If the distance between two positions is greater than box_size, move the 

151 second position across the box. 

152 """ 

153 if not self.periodic_boundary: 

154 return position2 

155 result = np.copy(position2) 

156 for dim in range(3): 

157 if abs(position2[dim] - position1[dim]) > self.box_size[dim] / 2.0: 

158 result[dim] -= position2[dim] / abs(position2[dim]) * self.box_size[dim] 

159 return result 

160 

161 @staticmethod 

162 def _vector_is_invalid(vector: np.ndarray) -> bool: 

163 """Check if any of a 3D vector's components are NaN.""" 

164 return math.isnan(vector[0]) or math.isnan(vector[1]) or math.isnan(vector[2]) 

165 

166 @staticmethod 

167 def _normalize(vector: np.ndarray) -> np.ndarray: 

168 """Normalize a vector.""" 

169 if vector[0] == 0 and vector[1] == 0 and vector[2] == 0: 

170 return vector 

171 return vector / np.linalg.norm(vector) 

172 

173 @staticmethod 

174 def _orientation_from_positions(positions: np.ndarray) -> np.ndarray: 

175 """ 

176 Orthonormalize and cross the vectors from a particle position to prev 

177 and next particle positions to get a basis local to the particle. 

178 

179 The positions array is structured as: 

180 [ 

181 prev particle's position, 

182 this particle's position, 

183 next particle's position, 

184 ] 

185 """ 

186 v1 = ReaddyPostProcessor._normalize(positions[0] - positions[1]) 

187 v2 = ReaddyPostProcessor._normalize(positions[2] - positions[1]) 

188 v2 = ReaddyPostProcessor._normalize(v2 - (np.dot(v1, v2) / np.dot(v1, v1)) * v1) 

189 v3 = np.cross(v2, v1) 

190 return np.array( 

191 [[v1[0], v2[0], v3[0]], [v1[1], v2[1], v3[1]], [v1[2], v2[2], v3[2]]] 

192 ) 

193 

194 def _rotation( 

195 self, positions: np.ndarray, ideal_positions: np.ndarray 

196 ) -> np.ndarray: 

197 """ 

198 Get the difference in the particles's current orientation compared to 

199 the initial orientation as a rotation matrix. 

200 

201 The positions array is structured as: 

202 [ 

203 prev particle's position, 

204 this particle's position, 

205 next particle's position, 

206 ] 

207 """ 

208 positions[0] = self._non_periodic_position(positions[1], positions[0]) 

209 positions[2] = self._non_periodic_position(positions[1], positions[2]) 

210 return np.matmul( 

211 self._orientation_from_positions(positions), 

212 np.linalg.inv(self._orientation_from_positions(ideal_positions)), 

213 ) 

214 

215 def rotate_positions( 

216 self, positions: np.ndarray, rotation: np.ndarray 

217 ) -> np.ndarray: 

218 """ 

219 Rotate an x,y,z position (or an array of them) around the x-axis with 

220 the given rotation matrix. 

221 """ 

222 if len(positions.shape) > 1: 

223 result = np.dot(positions[:, 1:], rotation) 

224 return np.concatenate((positions[:, 0:1], result), axis=1) 

225 else: 

226 result = np.dot(positions[1:], rotation) 

227 return np.concatenate((positions[0:1], result), axis=0) 

228 

229 def align_trajectory( 

230 self, 

231 fiber_points: list[list[np.ndarray]], 

232 ) -> tuple[np.ndarray, list[list[np.ndarray]]]: 

233 """ 

234 Align the positions of particles in the trajectory so that the furthest 

235 point from the x-axis is aligned with the positive y-axis at the last 

236 time point. 

237 

238 Parameters 

239 ---------- 

240 fiber_points 

241 How many numbers are used to represent the relative identity of 

242 particles in the chain? 

243 start_particle_phrases 

244 List of phrases in particle type names for the first particles in 

245 the linear chain. 

246 other_particle_types 

247 List of particle type names (without polymer numbers at the end) for 

248 the particles other than the start particles. 

249 

250 Returns 

251 ------- 

252 positions 

253 Array (shape = timesteps x 1 x n x 3) containing the x,y,z positions 

254 of actin monomer particles at each timestep. 

255 fiber_points 

256 List of lists of arrays (shape = n x 3) containing the x,y,z 

257 positions of control points for each fiber at each time. 

258 """ 

259 result: list[list[np.ndarray]] = [] 

260 _, rotation = align_fiber(fiber_points[-1][0]) 

261 for time_ix in range(len(self.trajectory)): 

262 result.append([]) 

263 for _, particle in self.trajectory[time_ix].particles.items(): 

264 particle.position = self.rotate_positions(particle.position, rotation) 

265 result[time_ix].append(particle.position) 

266 fiber_points[time_ix][0] = self.rotate_positions( 

267 fiber_points[time_ix][0], rotation 

268 ) 

269 return np.array(result), fiber_points 

270 

271 def linear_fiber_chain_ids( 

272 self, 

273 polymer_number_range: int, 

274 start_particle_phrases: list[str] = ACTIN_START_PARTICLE_PHRASE, 

275 other_particle_types: list[str] = ACTIN_PARTICLE_TYPES, 

276 ) -> list[list[list[int]]]: 

277 """ 

278 Get particle IDs for particles in each linear fiber at each timestep. 

279 

280 Parameters 

281 ---------- 

282 polymer_number_range 

283 How many numbers are used to represent the relative identity of 

284 particles in the chain? 

285 start_particle_phrases 

286 List of phrases in particle type names for the first particles in 

287 the linear chain. 

288 other_particle_types 

289 List of particle type names (without polymer numbers at the end) for 

290 the particles other than the start particles. 

291 

292 Returns 

293 ------- 

294 : 

295 List of lists of lists of the particle IDs for each particle for 

296 each fiber at each time. 

297 """ 

298 result: list[list[list[int]]] = [] 

299 chain_particle_types = [] 

300 for i in range(polymer_number_range): 

301 chain_particle_types.append( 

302 [f"{type_name}{i + 1}" for type_name in other_particle_types] 

303 ) 

304 for time_ix in range(len(self.trajectory)): 

305 particles = self.trajectory[time_ix].particles 

306 result.append([]) 

307 for particle_id in particles: 

308 # check if this particle is the start of a chain 

309 is_start_particle = False 

310 for phrase in start_particle_phrases: 

311 if phrase in particles[particle_id].type_name: 

312 is_start_particle = True 

313 break 

314 if not is_start_particle: 

315 continue 

316 # get ids for particles in the chain 

317 chain_ids = self._ids_for_chain_of_types( 

318 time_ix=time_ix, 

319 start_particle_id=particle_id, 

320 chain_particle_types=chain_particle_types, 

321 current_polymer_number=int(particles[particle_id].type_name[-1]), 

322 ) 

323 if len(chain_ids) < 2: 

324 continue 

325 result[time_ix].append(chain_ids) 

326 return result 

327 

328 def linear_fiber_axis_positions( 

329 self, 

330 fiber_chain_ids: list[list[list[int]]], 

331 ideal_positions: np.ndarray = IDEAL_ACTIN_POSITIONS, 

332 ideal_vector_to_axis: np.ndarray = IDEAL_ACTIN_VECTOR_TO_AXIS, 

333 ) -> tuple[list[list[np.ndarray]], list[list[list[int]]]]: 

334 """ 

335 Get x,y,z axis positions for each particle in each linear fiber at each 

336 timestep. 

337 

338 Parameters 

339 ---------- 

340 fiber_chain_ids 

341 List of list of lists of particle IDs for each particle in each 

342 fiber at each time. 

343 ideal_positions 

344 The x,y,z positions for 3 particles in an ideal chain (shape = 3 x 

345 3). 

346 ideal_vector_to_axis 

347 Vector from the second ideal position to the axis of the fiber 

348 (shape = 3). 

349 

350 Returns 

351 ------- 

352 axis_positions 

353 Lists of lists of arrays (shape = n x 3) containing the x,y,z 

354 positions of the closest point on the fiber axis to the position of 

355 each particle in each fiber at each time. 

356 new_chain_ids 

357 List of lists of lists of particle IDs matching the axis_positions 

358 for each particle in each fiber at each time. 

359 """ 

360 result: list[list[np.ndarray]] = [] 

361 ids: list[list[list[int]]] = [] 

362 for time_ix in range(len(fiber_chain_ids)): 

363 result.append([]) 

364 ids.append([]) 

365 for fiber_ix in range(len(fiber_chain_ids[time_ix])): 

366 axis_positions = [] 

367 new_ids = [] 

368 particles = self.trajectory[time_ix].particles 

369 chain_ids = fiber_chain_ids[time_ix][fiber_ix] 

370 for particle_ix in range(1, len(chain_ids) - 1): 

371 positions = [ 

372 particles[chain_ids[particle_ix - 1]].position, 

373 particles[chain_ids[particle_ix]].position, 

374 particles[chain_ids[particle_ix + 1]].position, 

375 ] 

376 pos_invalid = False 

377 for pos in positions: 

378 if self._vector_is_invalid(pos): 

379 pos_invalid = True 

380 break 

381 if pos_invalid: 

382 break 

383 rotation = self._rotation(np.array(positions), ideal_positions) 

384 if rotation is None: 

385 break 

386 vector_to_axis_local = np.squeeze( 

387 np.array(np.dot(rotation, ideal_vector_to_axis)) 

388 ) 

389 axis_pos = positions[1] + vector_to_axis_local 

390 if self._vector_is_invalid(axis_pos): 

391 break 

392 axis_positions.append(axis_pos) 

393 new_ids.append(particle_ix) 

394 if len(axis_positions) < 2: 

395 continue 

396 result[time_ix].append(np.array(axis_positions)) 

397 ids[time_ix].append(new_ids) 

398 return result, ids 

399 

400 def linear_fiber_normals( 

401 self, 

402 fiber_chain_ids: list[list[list[int]]], 

403 axis_positions: list[list[np.ndarray]], 

404 normal_length: float = 5, 

405 ) -> list[list[np.ndarray]]: 

406 """ 

407 Get x,y,z positions defining start and end points for normals for each 

408 particle in each fiber at each timestep. 

409 

410 Parameters 

411 ---------- 

412 fiber_chain_ids 

413 List of lists of lists of particle IDs for particles in each fiber 

414 at each time. 

415 axis_positions 

416 List of lists of arrays (shape = n x 3) containing the x,y,z 

417 positions of the closest point on the fiber axis to the position of 

418 each particle in each fiber at each time. 

419 normal_length 

420 Length of the resulting normal vectors in the trajectory's spatial 

421 units. 

422 

423 Returns 

424 ------- 

425 : 

426 List of lists of arrays (shape = 2 x 3) containing the x,y,z normals 

427 of each particle in each fiber at each time. 

428 """ 

429 result: list[list[np.ndarray]] = [] 

430 for time_ix in range(len(fiber_chain_ids)): 

431 result.append([]) 

432 particles = self.trajectory[time_ix].particles 

433 for chain_ix in range(len(fiber_chain_ids[time_ix])): 

434 n_particles = len(fiber_chain_ids[time_ix][chain_ix]) 

435 for particle_ix, particle_id in enumerate( 

436 fiber_chain_ids[time_ix][chain_ix] 

437 ): 

438 # Skip first and last particle 

439 if particle_ix == 0 or particle_ix == n_particles - 1: 

440 continue 

441 position = particles[particle_id].position 

442 axis_position = axis_positions[time_ix][chain_ix][particle_ix - 1] 

443 direction = ReaddyPostProcessor._normalize(position - axis_position) 

444 result[time_ix].append( 

445 np.array( 

446 [axis_position, axis_position + normal_length * direction] 

447 ) 

448 ) 

449 return result 

450 

451 @staticmethod 

452 def linear_fiber_control_points( 

453 axis_positions: list[list[np.ndarray]], 

454 n_points: int, 

455 ) -> list[list[np.ndarray]]: 

456 """ 

457 Resample the fiber line defined by each array of axis positions to get 

458 the requested number of points between x,y,z control points for each 

459 linear fiber at each timestep. 

460 

461 Parameters 

462 ---------- 

463 axis_positions 

464 List of lists of arrays (shape = n x 3) containing the x,y,z 

465 positions of the closest point on the fiber axis to the position of 

466 each particle in each fiber at each time. 

467 n_points 

468 Number of control points (spaced evenly) on resulting fibers. 

469 

470 Returns 

471 ------- 

472 : 

473 List of lists of arrays (shape = n x 3) containing the x,y,z 

474 positions of control points for each fiber at each time. 

475 """ 

476 if n_points < 2: 

477 raise ValueError("n_points must be > 1 to define a fiber.") 

478 result: list[list[np.ndarray]] = [] 

479 for time_ix in tqdm(range(len(axis_positions))): 

480 result.append([]) 

481 contour_length = get_contour_length_from_trace(axis_positions[time_ix][0]) 

482 segment_length = contour_length / float(n_points - 1) 

483 pt_ix = 1 

484 for positions in axis_positions[time_ix]: 

485 control_points = np.zeros((n_points, 3)) 

486 control_points[0] = positions[0] 

487 current_position = np.copy(positions[0]) 

488 leftover_length: float = 0 

489 for pos_ix in range(1, len(positions)): 

490 v_segment = positions[pos_ix] - positions[pos_ix - 1] 

491 direction = ReaddyPostProcessor._normalize(v_segment) 

492 remaining_length = ( 

493 np.linalg.norm(v_segment).item() + leftover_length 

494 ) 

495 # Rounding to 9 decimal places to avoid floating point error 

496 # where the remaining length is very close to the segment 

497 # length, causeing the final control point to be skipped. 

498 while round(remaining_length, 9) >= round(segment_length, 9): 

499 current_position += ( 

500 segment_length - leftover_length 

501 ) * direction 

502 control_points[pt_ix, :] = current_position[:] 

503 pt_ix += 1 

504 leftover_length = 0 

505 remaining_length -= segment_length 

506 current_position += (remaining_length - leftover_length) * direction 

507 leftover_length = remaining_length 

508 result[time_ix].append(control_points) 

509 return result 

510 

511 def fiber_bond_energies( 

512 self, 

513 fiber_chain_ids: list[list[list[int]]], 

514 ideal_lengths: dict[int, float], 

515 ks: dict[int, float], 

516 stride: int = 1, 

517 ) -> tuple[dict[int, np.ndarray], np.ndarray]: 

518 """ 

519 Get the strain energy using the harmonic spring equation and the bond 

520 distance between particles with a given polymer number offset. 

521 

522 Parameters 

523 ---------- 

524 fiber_chain_ids 

525 List of lists of lists of particle IDs for particles in each fiber 

526 at each time. 

527 ideal_lengths 

528 Ideal bond length for each of the polymer number offsets. 

529 ks 

530 Bond energy constant for each of the polymer number offsets. 

531 stride 

532 Calculate bond energy every stride timesteps. 

533 

534 Returns 

535 ------- 

536 bond_energies 

537 Map of polymer number offset to array (shape = time x bonds) of bond 

538 energy for each bond at each time. 

539 filament_positions 

540 Array (shape = time x bonds) of position in the filament from the 

541 starting end for the first particle in each bond at each time. 

542 """ 

543 energies: dict[int, list[list[float]]] = {} 

544 for offset in ideal_lengths: 

545 energies[offset] = [] 

546 filament_positions: list[list[int]] = [] 

547 for time_ix in range(0, len(self.trajectory), stride): 

548 for offset in ideal_lengths: 

549 energies[offset].append([]) 

550 filament_positions.append([]) 

551 particles = self.trajectory[time_ix].particles 

552 new_time = math.floor(time_ix / stride) 

553 for fiber_ix in range(len(fiber_chain_ids[time_ix])): 

554 fiber_ids = fiber_chain_ids[time_ix][fiber_ix] 

555 for ix in range(len(fiber_ids) - 2): 

556 particle = particles[fiber_ids[ix]] 

557 if "fixed" in particle.type_name: 

558 continue 

559 for offset in ideal_lengths: 

560 offset_particle = particles[fiber_ids[ix + offset]] 

561 if "fixed" in offset_particle.type_name: 

562 continue 

563 offset_pos = self._non_periodic_position( 

564 particle.position, offset_particle.position 

565 ) 

566 bond_stretch = ( 

567 np.linalg.norm(offset_pos - particle.position).item() 

568 - ideal_lengths[offset] 

569 ) 

570 energy = 0.5 * ks[offset] * bond_stretch * bond_stretch 

571 if math.isnan(energy): 

572 energy = 0.0 

573 energies[offset][new_time].append(energy) 

574 filament_positions[new_time].append(ix) 

575 return ( 

576 {offset: np.array(energy) for offset, energy in energies.items()}, 

577 np.array(filament_positions), 

578 ) 

579 

580 def edge_positions(self) -> list[list[np.ndarray]]: 

581 """ 

582 Get the edges between particles as start and end positions. 

583 

584 Returns 

585 ------- 

586 : 

587 List of list of edges as position of each of the two connected 

588 particles for each edge at each time. 

589 """ 

590 edges: list[list[np.ndarray]] = [] 

591 for frame in self.trajectory: 

592 edges.append([]) 

593 for edge in frame.edge_ids: 

594 edges[-1].append( 

595 np.array( 

596 [ 

597 frame.particles[edge[0]].position, 

598 frame.particles[edge[1]].position, 

599 ] 

600 ) 

601 ) 

602 return edges