Source code for bec_lib.callback_handler

"""
This module contains the CallbackHandler class to handle callbacks. 
The CallbackHandler class is used to register and run callbacks for different 
event types. The CallbackRegister class is used to register callbacks 
in a with statement.
"""

import builtins
import enum
import threading
import traceback
from collections import deque
from collections.abc import Callable

from bec_lib.logger import bec_logger
from bec_lib.utils import threadlocked

logger = bec_logger.logger


[docs] class EventType(str, enum.Enum): """Event types""" SCAN_SEGMENT = "scan_segment" SCAN_STATUS = "scan_status" NAMESPACE_UPDATE = "namespace_update"
[docs] class CallbackEntry: """Callback entry class to store callback information""" def __init__(self, id: int, event_type: EventType, func: Callable, sync: bool) -> None: self.id = id self.func = func self.event_type = event_type self.sync = sync self.queue = deque(maxlen=1000) self._lock = threading.RLock()
[docs] @threadlocked def run(self, *args, **kwargs) -> None: """Run the callback function. If sync is True, the callback is run immediately. Otherwise, the callback is added to a queue and exectued in the next poll.""" if not self.sync: self._run_cb(*args, **kwargs) return self.queue.append((args, kwargs))
def _run_cb(self, *args, **kwargs) -> None: """Run the callback function in a safe way.""" try: self.func(*args, **kwargs) except Exception: content = traceback.format_exc() logger.warning(f"Failed to run callback function: {content}") def __str__(self) -> str: return f"<CallbackEntry>: (event_type: {self.event_type}, function: {self.func.__name__}, sync: {self.sync}, pending events: {self.num_pending_events})" @property def num_pending_events(self): """number of pending events""" return len(self.queue)
[docs] @threadlocked def poll(self) -> None: """Run callback. Raises: RuntimeError: Raises if attempt is made to run async callbacks manually. """ if not self.sync: raise RuntimeError("Cannot poll on an async callback.") args, kwargs = self.queue.popleft() self._run_cb(*args, **kwargs)
[docs] class CallbackHandler: """Callback handler class""" def __init__(self) -> None: self.callbacks = {} self.id_counter = 0 self._lock = threading.RLock()
[docs] @threadlocked def register(self, event_type: str, callback: Callable, sync=False) -> int: """Register a callback to an event type Args: event_type (str): Event type callback (Callable): Callback function sync (bool, optional): Synchronous or async callback. Defaults to False. Returns: int: Callback id """ event_type = EventType(event_type) callback_id = self.new_id() self.callbacks[callback_id] = CallbackEntry(callback_id, event_type, callback, sync) return callback_id
[docs] @threadlocked def register_many(self, event_type: str, callbacks: list[Callable], sync=False) -> list[int]: """Register multiple callbacks to an event type Args: event_type (str): Event type callbacks (list[Callable]): List of callback functions sync (bool, optional): Synchronous or async callback. Defaults to False. Returns: list: List of caallback ids """ if not isinstance(callbacks, list): callbacks = [callbacks] ids = [] for cbk in callbacks: if cbk: ids.append(self.register(event_type, cbk, sync)) else: ids.append(-1) return ids
[docs] @threadlocked def remove(self, id: int) -> int: """Remove a registered callback by its id Args: id (int): Callback id Returns: int: Returns the id of the removed callback. -1 if it failed. """ try: self.callbacks.pop(id) return id except KeyError: return -1
[docs] def new_id(self): """Generate a new callback id""" self.id_counter += 1 return self.id_counter
[docs] @threadlocked def run(self, event_type: str, *args, **kwargs): """Run all callbacks for a given event type""" for cb in self.callbacks.values(): if event_type != cb.event_type: continue cb.run(*args, **kwargs)
[docs] @threadlocked def poll(self): """Run all pending callbacks""" for callback in self.callbacks.values(): if not callback.sync: continue while callback.num_pending_events: callback.poll()
[docs] class CallbackRegister: def __init__(self, event_type, callbacks, sync=False, callback_handler=None) -> None: """Callback register class to register callbacks in a with statement Args: callback_handler (CallbackHandler): Callback handler """ if not callback_handler: bec = builtins.__dict__.get("bec") self.callback_handler = bec.callbacks else: self.callback_handler = callback_handler self.event_type = event_type if not isinstance(callbacks, list): callbacks = [callbacks] self.callbacks = callbacks self.sync = sync self.callback_ids = [] def __enter__(self): for callback in self.callbacks: if not callback: continue self.callback_ids.append( self.callback_handler.register(self.event_type, callback, sync=self.sync) ) return self def __exit__(self, *exc): for cb_id in self.callback_ids: self.callback_handler.remove(cb_id)