Source code for simularium_metrics_calculator.metrics_service

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import json
from typing import Any, Dict, List

from simulariumio import HistogramPlotData, ScatterPlotData, TrajectoryData
from simulariumio.constants import CURRENT_VERSION
from simulariumio.plot_readers import HistogramPlotReader, ScatterPlotReader

from .exceptions import MetricNotFoundError
from .metric_info import MetricInfo
from .metrics_registry import metrics_list
from .plot_info import PlotInfo


[docs] class MetricsService: def __init__(self) -> None: """ This object lists and calculates metrics that can be plotted in the Simularium Viewer. """ self._create_metrics_registry() def _create_metrics_registry(self) -> None: """ Get a dict mapping metric index (as per-session unique ID) to info about each available metric. """ self.metrics_registry = {} for index, metric_info in enumerate(metrics_list): self.metrics_registry[index] = metric_info
[docs] def metric_info_for_id(self, metric_id: int) -> MetricInfo: """ Get a MetricInfo for a given metric's session id. Raise an error if the metric_id is not found in the registry. Parameters ---------- metric_id: int The session ID for the requested metric. Returns ------- MetricInfo Info about the requested metric. """ if metric_id not in self.metrics_registry: raise MetricNotFoundError(metric_id) return self.metrics_registry[metric_id]
[docs] def available_metrics(self) -> List[Dict[str, Any]]: """ Get the IDs and display names for the metrics that are compatible with the given type of data. Returns ------- List[Dict[str, Any]] A list of info about each available metric, including session ID, display name, metric type, and excluded axes. """ result = [] for uid, metric_info in self.metrics_registry.items(): info = metric_info.to_dict() info["uid"] = uid result.append(info) return result
[docs] def plot_data(self, traj_data: TrajectoryData, plots: List[PlotInfo]) -> str: """ Add plots with the given configuration. Parameters ---------- traj_data : TrajectoryData, A Simularium trajectory. plots: List[PlotInfo] A list of PlotInfo configuration for each plot. Returns ------- str A JSON string of the plot(s) in simularium format. """ plot_dicts = self._plot_dicts(traj_data, plots) return json.dumps( { "version": CURRENT_VERSION.PLOT_DATA, "data": plot_dicts, } )
def _plot_dicts( self, traj_data: TrajectoryData, plots: List[PlotInfo] ) -> List[Dict[str, Any]]: """ Calculate each plot and get a dict for each. """ result = [] for plot_info in plots: result.append(self._calculate_plot(traj_data, plot_info)) return result def _calculate_plot( self, traj_data: TrajectoryData, plot_info: PlotInfo ) -> Dict[str, Any]: """ Calculate a plot with the given configuration. """ # get metric info x_metric_info = self.metric_info_for_id(plot_info.metric_id_x) y_metric_info = None if plot_info.metric_id_y >= 0: y_metric_info = self.metric_info_for_id(plot_info.metric_id_y) # validate and setup title plot_info.validate_plot_configuration(x_metric_info, y_metric_info) plot_info.set_display_title(x_metric_info, y_metric_info) # X axis metric x_calculator = x_metric_info.calculator() x_traces, x_units = x_calculator.calculate(traj_data) x_metric_title = x_metric_info.display_name # create and add plots if y_metric_info is None: # HISTOGRAM plot_data = HistogramPlotData( title=plot_info.display_title, xaxis_title=f"{x_metric_title}{x_units}", traces=x_traces, ) return HistogramPlotReader().read(plot_data) else: # SCATTER PLOT # only use the first trace for X axis since there can only be one x_trace = x_traces[list(x_traces.keys())[0]] # Y axis metric y_calculator = y_metric_info.calculator() y_traces, y_units = y_calculator.calculate(traj_data) y_metric_title = y_metric_info.display_name # create and add scatter plot plot_data = ScatterPlotData( title=plot_info.display_title, xaxis_title=f"{x_metric_title}{x_units}", yaxis_title=f"{y_metric_title}{y_units}", xtrace=x_trace, ytraces=y_traces, render_mode=plot_info.scatter_plot_mode.value, ) return ScatterPlotReader().read(plot_data)