Source code for bec_lib.dap_plugin_objects

"""
This module contains the base classes for DAP plugin objects. These classes should be used to create custom DAP plugin objects.
"""

from __future__ import annotations

import builtins
import time
import uuid
from typing import TYPE_CHECKING

import lmfit
import numpy as np
from typeguard import typechecked

from bec_lib import messages
from bec_lib.device import DeviceBase
from bec_lib.endpoints import MessageEndpoints
from bec_lib.lmfit_serializer import serialize_param_object
from bec_lib.scan_items import ScanItem
from bec_lib.scan_report import ScanReport

if TYPE_CHECKING:
    from bec_lib.client import BECClient

try:
    import matplotlib.pyplot as plt

    plt.ion()
except ImportError:
    plt = None


[docs] class DAPPluginObjectBase: """ Base class for DAP plugin objects. This class should not be used directly. Instead, use one of the derived classes. """ _result_cls = None def __init__( self, service_name: str, plugin_info: dict, client: BECClient = None, auto_run_supported: bool = False, service_info: dict = None, ) -> None: """ Args: service_name (str): The name of the service. plugin_info (dict): Information about the plugin. client (BECClient, optional): The BEC client. Defaults to None. auto_run_supported (bool, optional): Whether the plugin supports auto run. Defaults to False. service_info (dict, optional): Information about the service. Defaults to None. result_cls (type, optional): The class to use for the result of the plugin. Defaults to None. """ self._service_name = service_name self._plugin_info = plugin_info self._client = client self._auto_run_supported = auto_run_supported self._plugin_config = {} self._service_info = service_info # run must be an anonymous function to allow for multiple doc strings self._user_run = lambda *args, **kwargs: self._run(*args, **kwargs) def _run(self, *args, **kwargs): converted_args = [] for arg in args: if isinstance(arg, ScanItem): converted_args.append(arg.scan_id) elif isinstance(arg, ScanReport): converted_args.append(arg.scan.scan_id) else: converted_args.append(arg) args = converted_args converted_kwargs = {} for key, val in kwargs.items(): if isinstance(val, ScanItem): converted_kwargs[key] = val.scan_id elif isinstance(arg, ScanReport): converted_kwargs[key] = arg.scan.scan_id elif isinstance(val, lmfit.Parameter): converted_kwargs[key] = serialize_param_object(val) else: converted_kwargs[key] = val kwargs = converted_kwargs request_id = str(uuid.uuid4()) self._client.connector.set_and_publish( MessageEndpoints.dap_request(), messages.DAPRequestMessage( dap_cls=self._plugin_info["class"], dap_type="on_demand", config={ "args": args, "kwargs": kwargs, "class_args": self._plugin_info.get("class_args"), "class_kwargs": self._plugin_info.get("class_kwargs"), }, metadata={"RID": request_id}, ), ) response = self._wait_for_dap_response(request_id) return self._convert_result(response) def _convert_result(self, result: messages.BECMessage): if not result.content["data"]: return None if not callable(self._result_cls): return result.content["data"] # pylint: disable=not-callable return self._result_cls(result.content["data"], self._plugin_info["user_friendly_name"]) def _wait_for_dap_response(self, request_id: str, timeout: float = 5.0): start_time = time.time() while True: if time.time() - start_time > timeout: raise TimeoutError("Timeout waiting for DAP response.") response = self._client.connector.get(MessageEndpoints.dap_response(request_id)) if not response: time.sleep(0.005) continue if response.metadata["RID"] != request_id: time.sleep(0.005) continue if response.content["success"]: return response raise RuntimeError(response.content["error"]) def _update_dap_config(self, request_id: str = None): if not self._plugin_config.get("selected_device"): return self._plugin_config["class_args"] = self._plugin_info.get("class_args") self._plugin_config["class_kwargs"] = self._plugin_info.get("class_kwargs") self._client.connector.set_and_publish( MessageEndpoints.dap_request(), messages.DAPRequestMessage( dap_cls=self._plugin_info["class"], dap_type="continuous", config=self._plugin_config, metadata={"RID": request_id}, ), )
[docs] class DAPPluginObject(DAPPluginObjectBase): """ Default DAP plugin object. This class should be used for plugins that do not support auto run. To customize a plugin, create a new class that inherits from this class and override the methods as needed. """
[docs] def get_data(self): """ Get the data from last run. """ msg = self._client.connector.get_last(MessageEndpoints.processed_data(self._service_name)) if not msg: return None return self._convert_result(msg)
[docs] class DAPPluginObjectAutoRun(DAPPluginObject): """ DAP plugin object that supports auto run. This class should be used for plugins that support auto run. To customize a plugin, create a new class that inherits from this class and override the methods as needed. """ @property def auto_run(self): """ Set to True to start a continously running worker. """ return self._plugin_config.get("auto_run", False) @auto_run.setter @typechecked def auto_run(self, val: bool): self._plugin_config["auto_run"] = val request_id = str(uuid.uuid4()) self._update_dap_config(request_id=request_id)
[docs] class LmfitService1DResult: """ Result of fitting 1D data using lmfit. """ def __init__(self, result: list[dict], model_name: str = None, client: BECClient = None): self._data = result[0] self._report = result[1] self._model = model_name if client: self._client = client else: self._client = builtins.__dict__.get("bec") if "amplitude" in self.params: self.amplitude = self.params["amplitude"] if "center" in self.params: self.center = self.params["center"] if "sigma" in self.params: self.sigma = self.params["sigma"] @property def params(self): """ The parameters of the fit. """ return self._report["fit_parameters"] @property def data(self): """ The data from the fit. """ return self._data @property def report(self): """ The report of the fit. """ return self._report
[docs] def eval(self, x: np.ndarray): """ Evaluate the fit at the given x values. Args: x (array_like): The x values to evaluate the fit at. Returns: array_like: The y values of the fit at the given x values. """ if not isinstance(x, np.ndarray): x = np.array(x) model = getattr(lmfit.models, self._model)() params = model.make_params(**self.params) return {"x": x, "y": model.eval(params=params, x=x)}
@property def min(self): """ Get the minimum value of the fit. Returns: float: The minimum value of the fit. """ # get the index of the minimum value min_index = np.argmin(self._data["y"]) return {"x": self._data["x"][min_index], "y": self._data["y"][min_index]} @property def max(self): """ Get the maximum value of the fit. Returns: float: The maximum value of the fit. """ # get the index of the maximum value max_index = np.argmax(self._data["y"]) return {"x": self._data["x"][max_index], "y": self._data["y"][max_index]} @property def input_data(self): """ Get the input data used for the fit. Returns: dict: The input data used for the fit. """ input_data = self._report.get("input") scan_id = input_data.get("scan_id") if not scan_id: return None scan_item = self._client.queue.scan_storage.find_scan_by_ID(scan_id) if not scan_item: return None x = scan_item.data[input_data["device_x"]][input_data["signal_x"]].val y = scan_item.data[input_data["device_y"]][input_data["signal_y"]].val return {"x": x, "y": y}
[docs] def plot(self): """ Plot the fit. """ # move this to BECWidgets once it's available if not plt: raise ImportError( "matplotlib is not installed. Cannot plot. Please install matplotlib using 'pip install matplotlib'." ) input_data = self.input_data plt.figure() plt.plot(input_data["x"], input_data["y"], label="data", color="black", marker="o") plt.plot(self._data["x"], self._data["y"], label=f"{self._model}", color="red") plt.legend() plt.show()
def __str__(self) -> str: return f"{self._model} fit result: \n Params: {self.params} \n Min: {self.min} \n Max: {self.max}"
[docs] class LmfitService1D(DAPPluginObjectAutoRun): """ Plugin for fitting 1D data using lmfit. """ _result_cls = LmfitService1DResult def __init__( self, service_name: str, plugin_info: dict, client: BECClient = None, auto_run_supported: bool = False, service_info: dict = None, ) -> None: super().__init__( service_name, plugin_info, client=client, auto_run_supported=auto_run_supported, service_info=service_info, ) self._params = None
[docs] def select(self, device: DeviceBase | str, signal: str = None): """ Select the device and signal to use for fitting. Args: device (DeviceBase | str): The device to use for fitting. Can be either a DeviceBase object or the name of the device. signal (str, optional): The signal to use for fitting. If not provided, the first signal in the device's hints will be used. """ bec_device = ( device if isinstance(device, DeviceBase) else self._client.device_manager.devices.get(device) ) if not bec_device: raise AttributeError(f"Device {device} not found.") if signal: self._plugin_config["selected_device"] = [bec_device.name, signal] else: # pylint: disable=protected-access hints = bec_device._hints if not hints: raise AttributeError( f"Device {bec_device.name} has no hints. Cannot select device without signal." ) if len(hints) > 1: raise AttributeError( f"Device {bec_device.name} has multiple hints. Please specify a signal." ) self._plugin_config["selected_device"] = [bec_device.name, hints[0]] request_id = str(uuid.uuid4()) self._update_dap_config(request_id=request_id)
[docs] def get_params(self) -> lmfit.Parameters: """ Create a set of parameters for the model. Returns: lmfit.Parameters: The parameters available for the model. """ if not self._params: model = getattr(lmfit.models, self._plugin_info["user_friendly_name"])() self._params = model.make_params() return self._params
[docs] def reset_params(self): """ Reset the parameters to the default values. """ self._params = None
def _run(self, *args, **kwargs): if self._params: return super()._run(*args, **self._params, **kwargs) return super()._run(*args, **kwargs)