Coverage for subcell_pipeline/analysis/dimensionality_reduction/pacmap_dim_reduction.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods for dimensionality reduction using PaCMAP.""" 

2 

3from typing import Optional 

4 

5import matplotlib.pyplot as plt 

6import pandas as pd 

7from io_collection.save.save_figure import save_figure 

8from pacmap import PaCMAP 

9 

10from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import reshape_fibers 

11 

12 

13def run_pacmap(data: pd.DataFrame) -> tuple[pd.DataFrame, PaCMAP]: 

14 """ 

15 Run Pairwise Controlled Manifold Approximation (PaCMAP) on simulation data. 

16 

17 Parameters 

18 ---------- 

19 data 

20 Simulated fiber data. 

21 

22 Returns 

23 ------- 

24 : 

25 Dataframe with PaCMAP emebdding appended and the PaCMAP object. 

26 """ 

27 

28 all_fibers, all_features = reshape_fibers(data) 

29 

30 pacmap = PaCMAP(n_components=2, n_neighbors=None, MN_ratio=0.5, FP_ratio=2.0) 

31 transform = pacmap.fit_transform(all_fibers) 

32 

33 pacmap_results = pd.concat( 

34 [ 

35 pd.DataFrame(transform, columns=["PACMAP1", "PACMAP2"]), 

36 pd.DataFrame(all_features), 

37 ], 

38 axis=1, 

39 ) 

40 pacmap_results = pacmap_results.sample(frac=1, random_state=1) 

41 

42 return pacmap_results, pacmap 

43 

44 

45def plot_pacmap_feature_scatter( 

46 data: pd.DataFrame, 

47 features: dict, 

48 save_location: Optional[str] = None, 

49 save_key: str = "pacmap_feature_scatter.png", 

50) -> None: 

51 """ 

52 Plot scatter of PaCMAP embedding colored by the given features. 

53 

54 Parameters 

55 ---------- 

56 data 

57 PaCMAP results data. 

58 features 

59 Map of feature name to coloring. 

60 save_location 

61 Location for output file (local path or S3 bucket). 

62 save_key 

63 Name key for output file. 

64 """ 

65 

66 figure, ax = plt.subplots( 

67 1, len(features), figsize=(10, 3), sharey=True, sharex=True 

68 ) 

69 

70 for index, (feature, colors) in enumerate(features.items()): 

71 if isinstance(colors, dict): 

72 ax[index].scatter( 

73 data["PACMAP1"], 

74 data["PACMAP2"], 

75 s=2, 

76 c=data[feature].map(colors), 

77 ) 

78 elif isinstance(colors, tuple): 

79 ax[index].scatter( 

80 data["PACMAP1"], 

81 data["PACMAP2"], 

82 s=2, 

83 c=data[feature].map(colors[0]), 

84 cmap=colors[1], 

85 ) 

86 else: 

87 ax[index].scatter( 

88 data["PACMAP1"], 

89 data["PACMAP2"], 

90 s=2, 

91 c=data[feature], 

92 cmap=colors, 

93 ) 

94 

95 ax[index].set_title(feature) 

96 ax[index].set_xlabel("PACMAP1") 

97 ax[index].set_ylabel("PACMAP2") 

98 

99 plt.tight_layout() 

100 plt.show() 

101 

102 if save_location is not None: 

103 save_figure(save_location, save_key, figure)