Coverage for subcell_pipeline/simulation/readdy/post_processor.py: 32%
226 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Class for post processing ReaDDy trajectories."""
3import math
4from typing import Optional
6import numpy as np
7from tqdm import tqdm
9from subcell_pipeline.analysis.compression_metrics.polymer_trace import (
10 get_contour_length_from_trace,
11)
12from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fiber
13from subcell_pipeline.simulation.readdy.data_structures import FrameData
15ACTIN_START_PARTICLE_PHRASE: list[str] = ["pointed"]
16"""Phrases indicating actin start particle."""
18ACTIN_PARTICLE_TYPES: list[str] = [
19 "actin#",
20 "actin#ATP_",
21 "actin#mid_",
22 "actin#mid_ATP_",
23 "actin#fixed_",
24 "actin#fixed_ATP_",
25 "actin#mid_fixed_",
26 "actin#mid_fixed_ATP_",
27 "actin#barbed_",
28 "actin#barbed_ATP_",
29 "actin#fixed_barbed_",
30 "actin#fixed_barbed_ATP_",
31]
32"""Actin particle types from simularium/readdy-models."""
34IDEAL_ACTIN_POSITIONS: np.ndarray = np.array(
35 [
36 [24.738, 20.881, 26.671],
37 [27.609, 24.061, 27.598],
38 [30.382, 21.190, 25.725],
39 ]
40)
41"""Ideal actin positions measured from crystal structure."""
43IDEAL_ACTIN_VECTOR_TO_AXIS: np.ndarray = np.array(
44 [-0.01056751, -1.47785105, -0.65833209]
45)
46"""Ideal actin vector to axis."""
49class ReaddyPostProcessor:
50 """Get views of ReaDDy trajectory for different analysis purposes."""
52 trajectory: list[FrameData]
53 """ReaDDy data trajectory from ReaddyLoader(h5_file_path).trajectory()."""
55 box_size: np.ndarray
56 """The size of the x,y,z dimensions of the simulation volume (shape = 3)."""
58 periodic_boundary: bool
59 """True if simulation had periodic boundary, False otherwise."""
61 def __init__(
62 self,
63 trajectory: list[FrameData],
64 box_size: np.ndarray,
65 periodic_boundary: bool = False,
66 ):
67 self.trajectory = trajectory
68 self.box_size = box_size
69 self.periodic_boundary = periodic_boundary
71 def times(self) -> np.ndarray:
72 """
73 Get simulation time at each timestep.
75 Returns
76 -------
77 times
78 Array of time stamps in simulation time for each timestep (shape =
79 n_timesteps).
80 """
81 result = [trajectory.time for trajectory in self.trajectory]
82 return np.array(result)
84 def _id_for_neighbor_of_types(
85 self,
86 time_ix: int,
87 particle_id: int,
88 neighbor_types: list[str],
89 exclude_ids: Optional[list[int]] = None,
90 ) -> int:
91 """
92 Get the id for the first neighbor with a type_name in neighbor_types at
93 the given time index.
94 """
95 particles = self.trajectory[time_ix].particles
96 for neighbor_id in particles[particle_id].neighbor_ids:
97 if exclude_ids is not None and neighbor_id in exclude_ids:
98 continue
99 for neighbor_type in neighbor_types:
100 if neighbor_type == particles[neighbor_id].type_name:
101 return neighbor_id
102 return -1
104 def _ids_for_chain_of_types(
105 self,
106 time_ix: int,
107 start_particle_id: int,
108 chain_particle_types: list[list[str]],
109 current_polymer_number: int,
110 chain_length: int = 0,
111 last_particle_id: Optional[int] = None,
112 result: Optional[list[int]] = None,
113 ) -> list[int]:
114 """
115 Get IDs for a chain of particles with chain_particle_types in the given
116 frame of data, starting from the particle with start_particle_id and
117 avoiding the particle with last_particle_id.
119 If chain_length = 0, return entire chain.
120 """
121 if result is None:
122 result = [start_particle_id]
123 if chain_length == 1:
124 return result
125 neighbor_id = self._id_for_neighbor_of_types(
126 time_ix=time_ix,
127 particle_id=start_particle_id,
128 neighbor_types=chain_particle_types[current_polymer_number],
129 exclude_ids=[last_particle_id] if last_particle_id is not None else [],
130 )
131 if neighbor_id < 0:
132 return result
133 result.append(neighbor_id)
134 return self._ids_for_chain_of_types(
135 time_ix=time_ix,
136 start_particle_id=neighbor_id,
137 chain_particle_types=chain_particle_types,
138 current_polymer_number=(
139 (current_polymer_number + 1) % len(chain_particle_types)
140 ),
141 chain_length=chain_length - 1 if chain_length > 0 else 0,
142 last_particle_id=start_particle_id,
143 result=result,
144 )
146 def _non_periodic_position(
147 self, position1: np.ndarray, position2: np.ndarray
148 ) -> np.ndarray:
149 """
150 If the distance between two positions is greater than box_size, move the
151 second position across the box.
152 """
153 if not self.periodic_boundary:
154 return position2
155 result = np.copy(position2)
156 for dim in range(3):
157 if abs(position2[dim] - position1[dim]) > self.box_size[dim] / 2.0:
158 result[dim] -= position2[dim] / abs(position2[dim]) * self.box_size[dim]
159 return result
161 @staticmethod
162 def _vector_is_invalid(vector: np.ndarray) -> bool:
163 """Check if any of a 3D vector's components are NaN."""
164 return math.isnan(vector[0]) or math.isnan(vector[1]) or math.isnan(vector[2])
166 @staticmethod
167 def _normalize(vector: np.ndarray) -> np.ndarray:
168 """Normalize a vector."""
169 if vector[0] == 0 and vector[1] == 0 and vector[2] == 0:
170 return vector
171 return vector / np.linalg.norm(vector)
173 @staticmethod
174 def _orientation_from_positions(positions: np.ndarray) -> np.ndarray:
175 """
176 Orthonormalize and cross the vectors from a particle position to prev
177 and next particle positions to get a basis local to the particle.
179 The positions array is structured as:
180 [
181 prev particle's position,
182 this particle's position,
183 next particle's position,
184 ]
185 """
186 v1 = ReaddyPostProcessor._normalize(positions[0] - positions[1])
187 v2 = ReaddyPostProcessor._normalize(positions[2] - positions[1])
188 v2 = ReaddyPostProcessor._normalize(v2 - (np.dot(v1, v2) / np.dot(v1, v1)) * v1)
189 v3 = np.cross(v2, v1)
190 return np.array(
191 [[v1[0], v2[0], v3[0]], [v1[1], v2[1], v3[1]], [v1[2], v2[2], v3[2]]]
192 )
194 def _rotation(
195 self, positions: np.ndarray, ideal_positions: np.ndarray
196 ) -> np.ndarray:
197 """
198 Get the difference in the particles's current orientation compared to
199 the initial orientation as a rotation matrix.
201 The positions array is structured as:
202 [
203 prev particle's position,
204 this particle's position,
205 next particle's position,
206 ]
207 """
208 positions[0] = self._non_periodic_position(positions[1], positions[0])
209 positions[2] = self._non_periodic_position(positions[1], positions[2])
210 return np.matmul(
211 self._orientation_from_positions(positions),
212 np.linalg.inv(self._orientation_from_positions(ideal_positions)),
213 )
215 def rotate_positions(
216 self, positions: np.ndarray, rotation: np.ndarray
217 ) -> np.ndarray:
218 """
219 Rotate an x,y,z position (or an array of them) around the x-axis with
220 the given rotation matrix.
221 """
222 if len(positions.shape) > 1:
223 result = np.dot(positions[:, 1:], rotation)
224 return np.concatenate((positions[:, 0:1], result), axis=1)
225 else:
226 result = np.dot(positions[1:], rotation)
227 return np.concatenate((positions[0:1], result), axis=0)
229 def align_trajectory(
230 self,
231 fiber_points: list[list[np.ndarray]],
232 ) -> tuple[np.ndarray, list[list[np.ndarray]]]:
233 """
234 Align the positions of particles in the trajectory so that the furthest
235 point from the x-axis is aligned with the positive y-axis at the last
236 time point.
238 Parameters
239 ----------
240 fiber_points
241 How many numbers are used to represent the relative identity of
242 particles in the chain?
243 start_particle_phrases
244 List of phrases in particle type names for the first particles in
245 the linear chain.
246 other_particle_types
247 List of particle type names (without polymer numbers at the end) for
248 the particles other than the start particles.
250 Returns
251 -------
252 positions
253 Array (shape = timesteps x 1 x n x 3) containing the x,y,z positions
254 of actin monomer particles at each timestep.
255 fiber_points
256 List of lists of arrays (shape = n x 3) containing the x,y,z
257 positions of control points for each fiber at each time.
258 """
259 result: list[list[np.ndarray]] = []
260 _, rotation = align_fiber(fiber_points[-1][0])
261 for time_ix in range(len(self.trajectory)):
262 result.append([])
263 for _, particle in self.trajectory[time_ix].particles.items():
264 particle.position = self.rotate_positions(particle.position, rotation)
265 result[time_ix].append(particle.position)
266 fiber_points[time_ix][0] = self.rotate_positions(
267 fiber_points[time_ix][0], rotation
268 )
269 return np.array(result), fiber_points
271 def linear_fiber_chain_ids(
272 self,
273 polymer_number_range: int,
274 start_particle_phrases: list[str] = ACTIN_START_PARTICLE_PHRASE,
275 other_particle_types: list[str] = ACTIN_PARTICLE_TYPES,
276 ) -> list[list[list[int]]]:
277 """
278 Get particle IDs for particles in each linear fiber at each timestep.
280 Parameters
281 ----------
282 polymer_number_range
283 How many numbers are used to represent the relative identity of
284 particles in the chain?
285 start_particle_phrases
286 List of phrases in particle type names for the first particles in
287 the linear chain.
288 other_particle_types
289 List of particle type names (without polymer numbers at the end) for
290 the particles other than the start particles.
292 Returns
293 -------
294 :
295 List of lists of lists of the particle IDs for each particle for
296 each fiber at each time.
297 """
298 result: list[list[list[int]]] = []
299 chain_particle_types = []
300 for i in range(polymer_number_range):
301 chain_particle_types.append(
302 [f"{type_name}{i + 1}" for type_name in other_particle_types]
303 )
304 for time_ix in range(len(self.trajectory)):
305 particles = self.trajectory[time_ix].particles
306 result.append([])
307 for particle_id in particles:
308 # check if this particle is the start of a chain
309 is_start_particle = False
310 for phrase in start_particle_phrases:
311 if phrase in particles[particle_id].type_name:
312 is_start_particle = True
313 break
314 if not is_start_particle:
315 continue
316 # get ids for particles in the chain
317 chain_ids = self._ids_for_chain_of_types(
318 time_ix=time_ix,
319 start_particle_id=particle_id,
320 chain_particle_types=chain_particle_types,
321 current_polymer_number=int(particles[particle_id].type_name[-1]),
322 )
323 if len(chain_ids) < 2:
324 continue
325 result[time_ix].append(chain_ids)
326 return result
328 def linear_fiber_axis_positions(
329 self,
330 fiber_chain_ids: list[list[list[int]]],
331 ideal_positions: np.ndarray = IDEAL_ACTIN_POSITIONS,
332 ideal_vector_to_axis: np.ndarray = IDEAL_ACTIN_VECTOR_TO_AXIS,
333 ) -> tuple[list[list[np.ndarray]], list[list[list[int]]]]:
334 """
335 Get x,y,z axis positions for each particle in each linear fiber at each
336 timestep.
338 Parameters
339 ----------
340 fiber_chain_ids
341 List of list of lists of particle IDs for each particle in each
342 fiber at each time.
343 ideal_positions
344 The x,y,z positions for 3 particles in an ideal chain (shape = 3 x
345 3).
346 ideal_vector_to_axis
347 Vector from the second ideal position to the axis of the fiber
348 (shape = 3).
350 Returns
351 -------
352 axis_positions
353 Lists of lists of arrays (shape = n x 3) containing the x,y,z
354 positions of the closest point on the fiber axis to the position of
355 each particle in each fiber at each time.
356 new_chain_ids
357 List of lists of lists of particle IDs matching the axis_positions
358 for each particle in each fiber at each time.
359 """
360 result: list[list[np.ndarray]] = []
361 ids: list[list[list[int]]] = []
362 for time_ix in range(len(fiber_chain_ids)):
363 result.append([])
364 ids.append([])
365 for fiber_ix in range(len(fiber_chain_ids[time_ix])):
366 axis_positions = []
367 new_ids = []
368 particles = self.trajectory[time_ix].particles
369 chain_ids = fiber_chain_ids[time_ix][fiber_ix]
370 for particle_ix in range(1, len(chain_ids) - 1):
371 positions = [
372 particles[chain_ids[particle_ix - 1]].position,
373 particles[chain_ids[particle_ix]].position,
374 particles[chain_ids[particle_ix + 1]].position,
375 ]
376 pos_invalid = False
377 for pos in positions:
378 if self._vector_is_invalid(pos):
379 pos_invalid = True
380 break
381 if pos_invalid:
382 break
383 rotation = self._rotation(np.array(positions), ideal_positions)
384 if rotation is None:
385 break
386 vector_to_axis_local = np.squeeze(
387 np.array(np.dot(rotation, ideal_vector_to_axis))
388 )
389 axis_pos = positions[1] + vector_to_axis_local
390 if self._vector_is_invalid(axis_pos):
391 break
392 axis_positions.append(axis_pos)
393 new_ids.append(particle_ix)
394 if len(axis_positions) < 2:
395 continue
396 result[time_ix].append(np.array(axis_positions))
397 ids[time_ix].append(new_ids)
398 return result, ids
400 def linear_fiber_normals(
401 self,
402 fiber_chain_ids: list[list[list[int]]],
403 axis_positions: list[list[np.ndarray]],
404 normal_length: float = 5,
405 ) -> list[list[np.ndarray]]:
406 """
407 Get x,y,z positions defining start and end points for normals for each
408 particle in each fiber at each timestep.
410 Parameters
411 ----------
412 fiber_chain_ids
413 List of lists of lists of particle IDs for particles in each fiber
414 at each time.
415 axis_positions
416 List of lists of arrays (shape = n x 3) containing the x,y,z
417 positions of the closest point on the fiber axis to the position of
418 each particle in each fiber at each time.
419 normal_length
420 Length of the resulting normal vectors in the trajectory's spatial
421 units.
423 Returns
424 -------
425 :
426 List of lists of arrays (shape = 2 x 3) containing the x,y,z normals
427 of each particle in each fiber at each time.
428 """
429 result: list[list[np.ndarray]] = []
430 for time_ix in range(len(fiber_chain_ids)):
431 result.append([])
432 particles = self.trajectory[time_ix].particles
433 for chain_ix in range(len(fiber_chain_ids[time_ix])):
434 n_particles = len(fiber_chain_ids[time_ix][chain_ix])
435 for particle_ix, particle_id in enumerate(
436 fiber_chain_ids[time_ix][chain_ix]
437 ):
438 # Skip first and last particle
439 if particle_ix == 0 or particle_ix == n_particles - 1:
440 continue
441 position = particles[particle_id].position
442 axis_position = axis_positions[time_ix][chain_ix][particle_ix - 1]
443 direction = ReaddyPostProcessor._normalize(position - axis_position)
444 result[time_ix].append(
445 np.array(
446 [axis_position, axis_position + normal_length * direction]
447 )
448 )
449 return result
451 @staticmethod
452 def linear_fiber_control_points(
453 axis_positions: list[list[np.ndarray]],
454 n_points: int,
455 ) -> list[list[np.ndarray]]:
456 """
457 Resample the fiber line defined by each array of axis positions to get
458 the requested number of points between x,y,z control points for each
459 linear fiber at each timestep.
461 Parameters
462 ----------
463 axis_positions
464 List of lists of arrays (shape = n x 3) containing the x,y,z
465 positions of the closest point on the fiber axis to the position of
466 each particle in each fiber at each time.
467 n_points
468 Number of control points (spaced evenly) on resulting fibers.
470 Returns
471 -------
472 :
473 List of lists of arrays (shape = n x 3) containing the x,y,z
474 positions of control points for each fiber at each time.
475 """
476 if n_points < 2:
477 raise ValueError("n_points must be > 1 to define a fiber.")
478 result: list[list[np.ndarray]] = []
479 for time_ix in tqdm(range(len(axis_positions))):
480 result.append([])
481 contour_length = get_contour_length_from_trace(axis_positions[time_ix][0])
482 segment_length = contour_length / float(n_points - 1)
483 pt_ix = 1
484 for positions in axis_positions[time_ix]:
485 control_points = np.zeros((n_points, 3))
486 control_points[0] = positions[0]
487 current_position = np.copy(positions[0])
488 leftover_length: float = 0
489 for pos_ix in range(1, len(positions)):
490 v_segment = positions[pos_ix] - positions[pos_ix - 1]
491 direction = ReaddyPostProcessor._normalize(v_segment)
492 remaining_length = (
493 np.linalg.norm(v_segment).item() + leftover_length
494 )
495 # Rounding to 9 decimal places to avoid floating point error
496 # where the remaining length is very close to the segment
497 # length, causeing the final control point to be skipped.
498 while round(remaining_length, 9) >= round(segment_length, 9):
499 current_position += (
500 segment_length - leftover_length
501 ) * direction
502 control_points[pt_ix, :] = current_position[:]
503 pt_ix += 1
504 leftover_length = 0
505 remaining_length -= segment_length
506 current_position += (remaining_length - leftover_length) * direction
507 leftover_length = remaining_length
508 result[time_ix].append(control_points)
509 return result
511 def fiber_bond_energies(
512 self,
513 fiber_chain_ids: list[list[list[int]]],
514 ideal_lengths: dict[int, float],
515 ks: dict[int, float],
516 stride: int = 1,
517 ) -> tuple[dict[int, np.ndarray], np.ndarray]:
518 """
519 Get the strain energy using the harmonic spring equation and the bond
520 distance between particles with a given polymer number offset.
522 Parameters
523 ----------
524 fiber_chain_ids
525 List of lists of lists of particle IDs for particles in each fiber
526 at each time.
527 ideal_lengths
528 Ideal bond length for each of the polymer number offsets.
529 ks
530 Bond energy constant for each of the polymer number offsets.
531 stride
532 Calculate bond energy every stride timesteps.
534 Returns
535 -------
536 bond_energies
537 Map of polymer number offset to array (shape = time x bonds) of bond
538 energy for each bond at each time.
539 filament_positions
540 Array (shape = time x bonds) of position in the filament from the
541 starting end for the first particle in each bond at each time.
542 """
543 energies: dict[int, list[list[float]]] = {}
544 for offset in ideal_lengths:
545 energies[offset] = []
546 filament_positions: list[list[int]] = []
547 for time_ix in range(0, len(self.trajectory), stride):
548 for offset in ideal_lengths:
549 energies[offset].append([])
550 filament_positions.append([])
551 particles = self.trajectory[time_ix].particles
552 new_time = math.floor(time_ix / stride)
553 for fiber_ix in range(len(fiber_chain_ids[time_ix])):
554 fiber_ids = fiber_chain_ids[time_ix][fiber_ix]
555 for ix in range(len(fiber_ids) - 2):
556 particle = particles[fiber_ids[ix]]
557 if "fixed" in particle.type_name:
558 continue
559 for offset in ideal_lengths:
560 offset_particle = particles[fiber_ids[ix + offset]]
561 if "fixed" in offset_particle.type_name:
562 continue
563 offset_pos = self._non_periodic_position(
564 particle.position, offset_particle.position
565 )
566 bond_stretch = (
567 np.linalg.norm(offset_pos - particle.position).item()
568 - ideal_lengths[offset]
569 )
570 energy = 0.5 * ks[offset] * bond_stretch * bond_stretch
571 if math.isnan(energy):
572 energy = 0.0
573 energies[offset][new_time].append(energy)
574 filament_positions[new_time].append(ix)
575 return (
576 {offset: np.array(energy) for offset, energy in energies.items()},
577 np.array(filament_positions),
578 )
580 def edge_positions(self) -> list[list[np.ndarray]]:
581 """
582 Get the edges between particles as start and end positions.
584 Returns
585 -------
586 :
587 List of list of edges as position of each of the two connected
588 particles for each edge at each time.
589 """
590 edges: list[list[np.ndarray]] = []
591 for frame in self.trajectory:
592 edges.append([])
593 for edge in frame.edge_ids:
594 edges[-1].append(
595 np.array(
596 [
597 frame.particles[edge[0]].position,
598 frame.particles[edge[1]].position,
599 ]
600 )
601 )
602 return edges