Coverage for subcell_pipeline/visualization/spatial_annotator.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods for adding spatial annotations to visualizations.""" 

2 

3import numpy as np 

4from simulariumio import DISPLAY_TYPE, DimensionData, DisplayData, TrajectoryData 

5from simulariumio.constants import VIZ_TYPE 

6 

7 

8def _added_dimensions_for_fibers( 

9 traj_data: TrajectoryData, data: list[list[np.ndarray]] 

10) -> DimensionData: 

11 """ 

12 Get a DimensionData with deltas for each dimension of AgentData. 

13 

14 Used when adding fiber annotation data. 

15 

16 Data shape = [timesteps, fibers, np.array(points, 3)] (assumed to be jagged) 

17 """ 

18 total_steps = len(data) 

19 max_fibers = 0 

20 max_points = 0 

21 for time_ix in range(total_steps): 

22 n_fibers = len(data[time_ix]) 

23 if n_fibers > max_fibers: 

24 max_fibers = n_fibers 

25 for fiber_ix in range(n_fibers): 

26 n_points = len(data[time_ix][fiber_ix]) 

27 if n_points > max_points: 

28 max_points = n_points 

29 current_dimensions = traj_data.agent_data.get_dimensions() 

30 return DimensionData( 

31 total_steps=0, 

32 max_agents=max_fibers, 

33 max_subpoints=(3 * max_points) - current_dimensions.max_subpoints, 

34 ) 

35 

36 

37def add_fiber_annotation_agents( 

38 traj_data: TrajectoryData, 

39 fiber_points: list[list[np.ndarray]], 

40 type_name: str = "fiber", 

41 fiber_width: float = 0.5, 

42 color: str = "#eaeaea", 

43) -> TrajectoryData: 

44 """ 

45 Add agent data for fiber annotations. 

46 

47 Parameters 

48 ---------- 

49 traj_data 

50 Trajectory data to add the fibers to. 

51 fiber_points 

52 List of lists of arrays (shape = n x 3) containing the x,y,z positions 

53 of control points for each fiber at each time. 

54 type_name 

55 Agent type name to use for the new fibers. 

56 fiber_width 

57 Width to draw the fibers. 

58 color 

59 Color for the new fibers. 

60 

61 Returns 

62 ------- 

63 : 

64 Updated trajectory data. 

65 """ 

66 

67 total_steps = len(fiber_points) 

68 new_agent_data = traj_data.agent_data.get_copy_with_increased_buffer_size( 

69 _added_dimensions_for_fibers(traj_data, fiber_points) 

70 ) 

71 max_used_uid = max(list(np.unique(traj_data.agent_data.unique_ids))) 

72 for time_ix in range(total_steps): 

73 start_ix = int(traj_data.agent_data.n_agents[time_ix]) 

74 n_fibers = len(fiber_points[time_ix]) 

75 end_ix = start_ix + n_fibers 

76 for fiber_ix in range(n_fibers): 

77 agent_ix = start_ix + fiber_ix 

78 new_agent_data.unique_ids[time_ix][agent_ix] = max_used_uid + fiber_ix + 1 

79 new_agent_data.n_subpoints[time_ix][agent_ix] = 3 * len( 

80 fiber_points[time_ix][fiber_ix] 

81 ) 

82 new_agent_data.subpoints[time_ix][agent_ix] = fiber_points[time_ix][ 

83 fiber_ix 

84 ].flatten() 

85 new_agent_data.n_agents[time_ix] += n_fibers 

86 new_agent_data.viz_types[time_ix][start_ix:end_ix] = n_fibers * [VIZ_TYPE.FIBER] 

87 new_agent_data.types[time_ix] += n_fibers * [type_name] 

88 new_agent_data.radii[time_ix][start_ix:end_ix] = n_fibers * [fiber_width] 

89 new_agent_data.display_data[type_name] = DisplayData( 

90 name=type_name, 

91 display_type=DISPLAY_TYPE.FIBER, 

92 color=color, 

93 ) 

94 traj_data.agent_data = new_agent_data 

95 return traj_data 

96 

97 

98def _added_dimensions_for_spheres(data: list[np.ndarray]) -> DimensionData: 

99 """ 

100 Get a DimensionData with deltas for each dimension of AgentData. 

101 

102 Used when adding sphere annotation data. 

103 

104 Data shape = [timesteps, np.array(spheres, 3)] (assumed to be jagged) 

105 """ 

106 total_steps = len(data) 

107 max_spheres = 0 

108 for time_ix in range(total_steps): 

109 n_spheres = len(data[time_ix]) 

110 if n_spheres > max_spheres: 

111 max_spheres = n_spheres 

112 return DimensionData( 

113 total_steps=0, 

114 max_agents=max_spheres, 

115 ) 

116 

117 

118def add_sphere_annotation_agents( 

119 traj_data: TrajectoryData, 

120 sphere_positions: list[np.ndarray], 

121 type_name: str = "sphere", 

122 radius: float = 1.0, 

123 rainbow_colors: bool = False, 

124 color: str = "#eaeaea", 

125) -> TrajectoryData: 

126 """ 

127 Add agent data for sphere annotations. 

128 

129 Parameters 

130 ---------- 

131 traj_data 

132 Trajectory data to add the spheres to. 

133 sphere_positions 

134 List of x,y,z positions of spheres to visualize at each time. 

135 type_name 

136 Agent type name to use for the new spheres. 

137 radius 

138 Radius to draw the spheres 

139 rainbow_colors 

140 True to color new spheres in rainbow order, False otherwise. 

141 color 

142 Color for the new fibers (if rainbow_colors is False). 

143 

144 Returns 

145 ------- 

146 : 

147 Updated trajectory data. 

148 """ 

149 

150 total_steps = len(sphere_positions) 

151 new_agent_data = traj_data.agent_data.get_copy_with_increased_buffer_size( 

152 _added_dimensions_for_spheres(sphere_positions) 

153 ) 

154 max_used_uid = max(list(np.unique(traj_data.agent_data.unique_ids))) 

155 max_spheres = 0 

156 for time_ix in range(total_steps): 

157 start_ix = int(traj_data.agent_data.n_agents[time_ix]) 

158 n_spheres = len(sphere_positions[time_ix]) 

159 if n_spheres > max_spheres: 

160 max_spheres = n_spheres 

161 end_ix = start_ix + n_spheres 

162 new_agent_data.unique_ids[time_ix][start_ix:end_ix] = np.arange( 

163 max_used_uid + 1, max_used_uid + 1 + n_spheres 

164 ) 

165 new_agent_data.n_agents[time_ix] += n_spheres 

166 new_agent_data.viz_types[time_ix][start_ix:end_ix] = n_spheres * [ 

167 VIZ_TYPE.DEFAULT 

168 ] 

169 new_agent_data.types[time_ix] += [ 

170 f"{type_name}#{ix}" for ix in range(n_spheres) 

171 ] 

172 new_agent_data.positions[time_ix][start_ix:end_ix] = sphere_positions[time_ix][ 

173 :n_spheres 

174 ] 

175 new_agent_data.radii[time_ix][start_ix:end_ix] = n_spheres * [radius] 

176 

177 colors = ["#0000ff", "#00ff00", "#ffff00", "#ff0000", "#ff00ff"] 

178 for ix in range(max_spheres): 

179 tn = f"{type_name}#{ix}" 

180 new_agent_data.display_data[tn] = DisplayData( 

181 name=tn, 

182 display_type=DISPLAY_TYPE.SPHERE, 

183 color=colors[ix % len(colors)] if rainbow_colors else color, 

184 ) 

185 traj_data.agent_data = new_agent_data 

186 return traj_data