Coverage for subcell_pipeline/visualization/spatial_annotator.py: 0%
68 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Methods for adding spatial annotations to visualizations."""
3import numpy as np
4from simulariumio import DISPLAY_TYPE, DimensionData, DisplayData, TrajectoryData
5from simulariumio.constants import VIZ_TYPE
8def _added_dimensions_for_fibers(
9 traj_data: TrajectoryData, data: list[list[np.ndarray]]
10) -> DimensionData:
11 """
12 Get a DimensionData with deltas for each dimension of AgentData.
14 Used when adding fiber annotation data.
16 Data shape = [timesteps, fibers, np.array(points, 3)] (assumed to be jagged)
17 """
18 total_steps = len(data)
19 max_fibers = 0
20 max_points = 0
21 for time_ix in range(total_steps):
22 n_fibers = len(data[time_ix])
23 if n_fibers > max_fibers:
24 max_fibers = n_fibers
25 for fiber_ix in range(n_fibers):
26 n_points = len(data[time_ix][fiber_ix])
27 if n_points > max_points:
28 max_points = n_points
29 current_dimensions = traj_data.agent_data.get_dimensions()
30 return DimensionData(
31 total_steps=0,
32 max_agents=max_fibers,
33 max_subpoints=(3 * max_points) - current_dimensions.max_subpoints,
34 )
37def add_fiber_annotation_agents(
38 traj_data: TrajectoryData,
39 fiber_points: list[list[np.ndarray]],
40 type_name: str = "fiber",
41 fiber_width: float = 0.5,
42 color: str = "#eaeaea",
43) -> TrajectoryData:
44 """
45 Add agent data for fiber annotations.
47 Parameters
48 ----------
49 traj_data
50 Trajectory data to add the fibers to.
51 fiber_points
52 List of lists of arrays (shape = n x 3) containing the x,y,z positions
53 of control points for each fiber at each time.
54 type_name
55 Agent type name to use for the new fibers.
56 fiber_width
57 Width to draw the fibers.
58 color
59 Color for the new fibers.
61 Returns
62 -------
63 :
64 Updated trajectory data.
65 """
67 total_steps = len(fiber_points)
68 new_agent_data = traj_data.agent_data.get_copy_with_increased_buffer_size(
69 _added_dimensions_for_fibers(traj_data, fiber_points)
70 )
71 max_used_uid = max(list(np.unique(traj_data.agent_data.unique_ids)))
72 for time_ix in range(total_steps):
73 start_ix = int(traj_data.agent_data.n_agents[time_ix])
74 n_fibers = len(fiber_points[time_ix])
75 end_ix = start_ix + n_fibers
76 for fiber_ix in range(n_fibers):
77 agent_ix = start_ix + fiber_ix
78 new_agent_data.unique_ids[time_ix][agent_ix] = max_used_uid + fiber_ix + 1
79 new_agent_data.n_subpoints[time_ix][agent_ix] = 3 * len(
80 fiber_points[time_ix][fiber_ix]
81 )
82 new_agent_data.subpoints[time_ix][agent_ix] = fiber_points[time_ix][
83 fiber_ix
84 ].flatten()
85 new_agent_data.n_agents[time_ix] += n_fibers
86 new_agent_data.viz_types[time_ix][start_ix:end_ix] = n_fibers * [VIZ_TYPE.FIBER]
87 new_agent_data.types[time_ix] += n_fibers * [type_name]
88 new_agent_data.radii[time_ix][start_ix:end_ix] = n_fibers * [fiber_width]
89 new_agent_data.display_data[type_name] = DisplayData(
90 name=type_name,
91 display_type=DISPLAY_TYPE.FIBER,
92 color=color,
93 )
94 traj_data.agent_data = new_agent_data
95 return traj_data
98def _added_dimensions_for_spheres(data: list[np.ndarray]) -> DimensionData:
99 """
100 Get a DimensionData with deltas for each dimension of AgentData.
102 Used when adding sphere annotation data.
104 Data shape = [timesteps, np.array(spheres, 3)] (assumed to be jagged)
105 """
106 total_steps = len(data)
107 max_spheres = 0
108 for time_ix in range(total_steps):
109 n_spheres = len(data[time_ix])
110 if n_spheres > max_spheres:
111 max_spheres = n_spheres
112 return DimensionData(
113 total_steps=0,
114 max_agents=max_spheres,
115 )
118def add_sphere_annotation_agents(
119 traj_data: TrajectoryData,
120 sphere_positions: list[np.ndarray],
121 type_name: str = "sphere",
122 radius: float = 1.0,
123 rainbow_colors: bool = False,
124 color: str = "#eaeaea",
125) -> TrajectoryData:
126 """
127 Add agent data for sphere annotations.
129 Parameters
130 ----------
131 traj_data
132 Trajectory data to add the spheres to.
133 sphere_positions
134 List of x,y,z positions of spheres to visualize at each time.
135 type_name
136 Agent type name to use for the new spheres.
137 radius
138 Radius to draw the spheres
139 rainbow_colors
140 True to color new spheres in rainbow order, False otherwise.
141 color
142 Color for the new fibers (if rainbow_colors is False).
144 Returns
145 -------
146 :
147 Updated trajectory data.
148 """
150 total_steps = len(sphere_positions)
151 new_agent_data = traj_data.agent_data.get_copy_with_increased_buffer_size(
152 _added_dimensions_for_spheres(sphere_positions)
153 )
154 max_used_uid = max(list(np.unique(traj_data.agent_data.unique_ids)))
155 max_spheres = 0
156 for time_ix in range(total_steps):
157 start_ix = int(traj_data.agent_data.n_agents[time_ix])
158 n_spheres = len(sphere_positions[time_ix])
159 if n_spheres > max_spheres:
160 max_spheres = n_spheres
161 end_ix = start_ix + n_spheres
162 new_agent_data.unique_ids[time_ix][start_ix:end_ix] = np.arange(
163 max_used_uid + 1, max_used_uid + 1 + n_spheres
164 )
165 new_agent_data.n_agents[time_ix] += n_spheres
166 new_agent_data.viz_types[time_ix][start_ix:end_ix] = n_spheres * [
167 VIZ_TYPE.DEFAULT
168 ]
169 new_agent_data.types[time_ix] += [
170 f"{type_name}#{ix}" for ix in range(n_spheres)
171 ]
172 new_agent_data.positions[time_ix][start_ix:end_ix] = sphere_positions[time_ix][
173 :n_spheres
174 ]
175 new_agent_data.radii[time_ix][start_ix:end_ix] = n_spheres * [radius]
177 colors = ["#0000ff", "#00ff00", "#ffff00", "#ff0000", "#ff00ff"]
178 for ix in range(max_spheres):
179 tn = f"{type_name}#{ix}"
180 new_agent_data.display_data[tn] = DisplayData(
181 name=tn,
182 display_type=DISPLAY_TYPE.SPHERE,
183 color=colors[ix % len(colors)] if rainbow_colors else color,
184 )
185 traj_data.agent_data = new_agent_data
186 return traj_data