Source code for bec_lib.bec_service

"""
This module provides the BECService class, which is the base class for all BEC services.
"""

from __future__ import annotations

import getpass
import socket
import threading
import time
import uuid
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any

import psutil
from rich.console import Console
from rich.table import Table

from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import BECStatus
from bec_lib.service_config import ServiceConfig

if TYPE_CHECKING:
    from bec_lib.connector import ConnectorBase


logger = bec_logger.logger

SERVICE_CONFIG = None


[docs] class BECService: def __init__( self, config: str | ServiceConfig, connector_cls: ConnectorBase, unique_service=False, wait_for_server=False, ) -> None: super().__init__() self._import_config(config) self._connector_cls = connector_cls self.connector = connector_cls(self.bootstrap_server) self._unique_service = unique_service self.wait_for_server = wait_for_server self.__service_id = str(uuid.uuid4()) self._user = getpass.getuser() self._hostname = socket.gethostname() self._service_info_thread = None self._service_info_event = threading.Event() self._metrics_emitter_thread = None self._metrics_emitter_event = threading.Event() self._services_info = {} self._initialize_logger() self._check_services() self._status = BECStatus.BUSY self._start_update_service_info() self._start_metrics_emitter() self._wait_for_server() @property def _service_name(self): return ( self.__class__.__name__ if self._unique_service else f"{self.__class__.__name__}/{self._service_id}" ) @property def _service_id(self): return self.__service_id def _import_config(self, config: str | ServiceConfig) -> None: if isinstance(config, str): self._service_config = ServiceConfig(config_path=config) elif isinstance(config, ServiceConfig): self._service_config = config else: raise TypeError("config must be of type str or ServiceConfig") global SERVICE_CONFIG SERVICE_CONFIG = self._service_config self.bootstrap_server = self._service_config.redis def _check_services(self, timeout_time=8, sleep_time=0.5) -> None: if not self._unique_service: return elapsed_time = 0 while self._run_service_check(timeout_time, elapsed_time): elapsed_time += sleep_time time.sleep(sleep_time) def _run_service_check(self, timeout_time: float, elapsed_time: float) -> bool: self._update_existing_services() try: for service_name, msg in self._services_info.items(): if service_name == self.__class__.__name__: raise RuntimeError( f"Another instance of {self.__class__.__name__} launched by user" f" {msg.content['info']['user']} is already running on" f" {msg.content['info']['hostname']}" ) return False except RuntimeError as service_error: if elapsed_time > timeout_time: raise RuntimeError from service_error return True def _initialize_logger(self) -> None: bec_logger.configure( self.bootstrap_server, self._connector_cls, service_name=self.__class__.__name__, service_config=self._service_config.config["service_config"], ) def _update_existing_services(self): service_keys = self.connector.keys(MessageEndpoints.service_status("*").endpoint) if not service_keys: return services = [service.decode().split(":", maxsplit=1)[0] for service in service_keys] msgs = [self.connector.get(service) for service in services] self._services_info = {msg.content["name"]: msg for msg in msgs if msg is not None} def _update_service_info(self): while not self._service_info_event.is_set(): logger.trace("Updating service info") try: self._send_service_status() except Exception: # exception is not explicitely specified, # because it depends on the underlying connector pass self._service_info_event.wait(timeout=3) def _send_service_status(self): self.connector.set_and_publish( topic=MessageEndpoints.service_status(self._service_id), msg=messages.StatusMessage( name=self._service_name, status=self.status, info={"user": self._user, "hostname": self._hostname, "timestamp": time.time()}, ), expire=6, ) @property def status(self) -> BECStatus: """get the current BECService status""" return self._status @status.setter def status(self, val: BECStatus): self._status = val self._send_service_status() def _start_update_service_info(self): self._service_info_thread = threading.Thread( target=self._update_service_info, daemon=True, name="update_service_info" ) self._service_info_thread.start() def _start_metrics_emitter(self): self._metrics_emitter_thread = threading.Thread( target=self._get_metrics, daemon=True, name="metrics_emitter" ) self._metrics_emitter_thread.start() def _get_metrics(self): proc = psutil.Process() while not self._metrics_emitter_event.is_set(): res = proc.as_dict( attrs=[ "name", "num_threads", "pid", "cpu_percent", "memory_info", "cmdline", "cpu_times", "create_time", "memory_percent", ] ) data = asdict( ServiceMetric( process_name=res["name"], username=self._user, hostname=self._hostname, cpu_percent=res["cpu_percent"], cpu_times=res["cpu_times"].user, cmdline=res["cmdline"], num_threads=res["num_threads"], pid=res["pid"], memory_in_mb=res["memory_info"].rss / 1024 / 1024, memory_used_percent=res["memory_percent"], create_time=res["create_time"], ) ) msg = messages.ServiceMetricMessage(name=self.__class__.__name__, metrics=data) try: self.connector.send(MessageEndpoints.metrics(self._service_id), msg) except Exception: # exception is not explicitely specified, # because it depends on the underlying connector pass self._metrics_emitter_event.wait(timeout=1)
[docs] def set_global_var(self, name: str, val: Any) -> None: """Set a global variable through Redis Args: name (str): Name of the variable val (Any): Value of the variable """ self.connector.set(MessageEndpoints.global_vars(name), messages.VariableMessage(value=val))
[docs] def get_global_var(self, name: str) -> Any: """Get a global variable from Redis Args: name (str): Name of the variable Returns: Any: Value of the variable """ msg = self.connector.get(MessageEndpoints.global_vars(name)) if msg: return msg.content.get("value") return None
[docs] def delete_global_var(self, name: str) -> None: """Delete a global variable from Redis Args: name (str): Name of the variable """ self.connector.delete(MessageEndpoints.global_vars(name))
[docs] def global_vars(self) -> str: """Get all available global variables""" # sadly, this cannot be a property as it causes side effects with IPython's tab completion available_keys = self.connector.keys(MessageEndpoints.global_vars("*")) def get_endpoint_from_topic(topic: str) -> str: return topic.decode().split(MessageEndpoints.global_vars(""))[-1] endpoints = [get_endpoint_from_topic(k) for k in available_keys] console = Console() table = Table(title="Global variables") table.add_column("Variable", justify="center") table.add_column("Content", justify="center") for endpoint in endpoints: var = str(self.get_global_var(endpoint)) if len(var) > 40: var = var[0:20] + "..., " + var[-20:] table.add_row(endpoint, var) with console.capture() as capture: console.print(table) out = capture.get() logger.info(out) print(out)
[docs] def shutdown(self): """shutdown the BECService""" try: self.connector.shutdown() self._service_info_event.set() if self._service_info_thread: self._service_info_thread.join() self._metrics_emitter_event.set() if self._metrics_emitter_thread: self._metrics_emitter_thread.join() except AttributeError: print("Failed to shutdown BECService.")
@property def service_status(self): """get the status of active services""" self._update_existing_services() return self._services_info def wait_for_service(self, name, status=BECStatus.RUNNING): logger.info(f"Waiting for {name}.") while True: service_status_msg = self.service_status.get(name) if service_status_msg is not None: service_status = BECStatus(service_status_msg.content["status"]) if service_status == status: break time.sleep(0.05) logger.success(f"{name} is running.") def _wait_for_server(self): if not self.wait_for_server: return try: self.wait_for_service("ScanServer", BECStatus.RUNNING) self.wait_for_service("ScanBundler", BECStatus.RUNNING) self.wait_for_service("DeviceServer", BECStatus.RUNNING) self.wait_for_service("SciHub", BECStatus.RUNNING) logger.success("All BEC services are running.") except KeyboardInterrupt: logger.warning("KeyboardInterrupt received. Stopped waiting for BEC services.")
[docs] @dataclass class ServiceMetric: """Container for keeping performance metrics.""" pid: int cmdline: list process_name: str username: str hostname: str cpu_percent: float cpu_times: dict num_threads: int memory_in_mb: float memory_used_percent: float create_time: float