Source code for bec_widgets.widgets.plots.heatmap.heatmap

from __future__ import annotations

import json
from dataclasses import dataclass
from typing import Literal

import numpy as np
import pyqtgraph as pg
from bec_lib import bec_logger, messages
from bec_lib.endpoints import MessageEndpoints
from pydantic import BaseModel, Field, field_validator
from qtpy.QtCore import QObject, Qt, QThread, QTimer, Signal
from qtpy.QtGui import QTransform
from scipy.interpolate import (
    CloughTocher2DInterpolator,
    LinearNDInterpolator,
    NearestNDInterpolator,
)
from scipy.spatial import cKDTree
from toolz import partition

from bec_widgets.utils import Colors
from bec_widgets.utils.bec_connector import ConnectionConfig
from bec_widgets.utils.error_popups import SafeProperty, SafeSlot
from bec_widgets.utils.settings_dialog import SettingsDialog
from bec_widgets.utils.toolbars.actions import MaterialIconAction
from bec_widgets.widgets.plots.heatmap.settings.heatmap_setting import HeatmapSettings
from bec_widgets.widgets.plots.image.image_base import ImageBase
from bec_widgets.widgets.plots.image.image_item import ImageItem
from bec_widgets.widgets.plots.plot_base import PlotBase

logger = bec_logger.logger


class HeatmapDeviceSignal(BaseModel):
    """The configuration of a signal in the scatter waveform widget."""

    device: str
    signal: str

    model_config: dict = {"validate_assignment": True}


class HeatmapConfig(ConnectionConfig):
    parent_id: str | None = Field(None, description="The parent plot of the curve.")
    color_map: str | None = Field(
        "plasma", description="The color palette of the heatmap widget.", validate_default=True
    )
    color_bar: Literal["full", "simple"] | None = Field(
        None, description="The type of the color bar."
    )
    interpolation: Literal["linear", "nearest", "clough"] = Field(
        "linear", description="The interpolation method for the heatmap."
    )
    oversampling_factor: float = Field(
        1.0,
        description="Factor to oversample the grid resolution (1.0 = no oversampling, 2.0 = 2x resolution).",
    )
    show_config_label: bool = Field(
        True, description="Whether to show the configuration label in the heatmap."
    )
    enforce_interpolation: bool = Field(
        False, description="Whether to use the interpolation mode even for grid scans."
    )
    lock_aspect_ratio: bool = Field(
        False, description="Whether to lock the aspect ratio of the image."
    )
    device_x: HeatmapDeviceSignal | None = Field(
        None, description="The x device signal of the heatmap."
    )
    device_y: HeatmapDeviceSignal | None = Field(
        None, description="The y device signal of the heatmap."
    )
    device_z: HeatmapDeviceSignal | None = Field(
        None, description="The z device signal of the heatmap."
    )

    model_config: dict = {"validate_assignment": True}
    _validate_color_palette = field_validator("color_map")(Colors.validate_color_map)


@dataclass
class _InterpolationRequest:
    """Immutable payload describing an interpolation request for the worker thread.

    Args:
        x_data: X coordinates collected so far.
        y_data: Y coordinates collected so far.
        z_data: Z values associated with x/y.
        data_version: Number of points at request time (len(z_data)); used to reject stale results.
        scan_id: Identifier for the scan that produced the data.
        interpolation: Interpolation method to apply.
        oversampling_factor: Oversampling factor for the interpolation grid.
    """

    x_data: list[float]
    y_data: list[float]
    z_data: list[float]
    data_version: int
    scan_id: str
    interpolation: str
    oversampling_factor: float


class _StepInterpolationWorker(QObject):
    """Worker for performing step-scan interpolation in a background thread.

    This worker computes the interpolated heatmap image using the provided data
    and settings, then emits the result or a failure signal.

    Signals:
        finished(image, transform, data_version, scan_id):
            Emitted when interpolation is successful.
            - image: The resulting image (numpy array or similar).
            - transform: The QTransform for the image.
            - data_version: The data version for the request.
            - scan_id: The scan identifier.
        failed(error_message, data_version, scan_id):
            Emitted when interpolation fails.
            - error_message: The error message string.
            - data_version: The data version for the request.
            - scan_id: The scan identifier.
    """

    finished = Signal(object, object, int, str)
    failed = Signal(str, int, str)

    def __init__(self, parent: QObject | None = None):
        super().__init__(parent=parent)
        self._active_request: _InterpolationRequest | None = None
        self._processing = False

    @property
    def is_processing(self) -> bool:
        """Return whether the worker is currently processing a request."""
        return self._processing

    @SafeSlot(object, int)
    def process(self, request: _InterpolationRequest, data_version: int):
        """
        Process an interpolation request in the worker thread.

        Args:
            request(_InterpolationRequest): The interpolation request payload.
            data_version(int): The data version for the request.
        """
        self._active_request = request
        self._processing = True
        try:
            image, transform = Heatmap.compute_step_scan_image(
                x_data=np.asarray(request.x_data, dtype=float),
                y_data=np.asarray(request.y_data, dtype=float),
                z_data=np.asarray(request.z_data, dtype=float),
                oversampling_factor=request.oversampling_factor,
                interpolation_method=request.interpolation,
            )
        except Exception as exc:  # pragma: no cover - defensive
            logger.warning(f"Step-scan interpolation failed with: {exc}")
            self.failed.emit(str(exc), data_version, request.scan_id)
            self._processing = False
            return
        self._processing = False
        self.finished.emit(image, transform, data_version, request.scan_id)


[docs] class Heatmap(ImageBase): """ Heatmap widget for visualizing 2d grid data with color mapping for the z-axis. """ USER_ACCESS = [ *PlotBase.USER_ACCESS, # ImageView Specific Settings "color_map", "color_map.setter", "v_range", "v_range.setter", "v_min", "v_min.setter", "v_max", "v_max.setter", "autorange", "autorange.setter", "autorange_mode", "autorange_mode.setter", "enable_colorbar", "enable_simple_colorbar", "enable_simple_colorbar.setter", "enable_full_colorbar", "enable_full_colorbar.setter", "interpolation_method", "interpolation_method.setter", "oversampling_factor", "oversampling_factor.setter", "enforce_interpolation", "enforce_interpolation.setter", "fft", "fft.setter", "log", "log.setter", "main_image", "add_roi", "remove_roi", "rois", "plot", # Device properties "device_x", "device_x.setter", "signal_x", "signal_x.setter", "device_y", "device_y.setter", "signal_y", "signal_y.setter", "device_z", "device_z.setter", "signal_z", "signal_z.setter", ] PLUGIN = True RPC = True ICON_NAME = "dataset" new_scan = Signal() new_scan_id = Signal(str) sync_signal_update = Signal() heatmap_property_changed = Signal() interpolation_requested = Signal(object, int) def __init__(self, parent=None, config: HeatmapConfig | None = None, **kwargs): if config is None: config = HeatmapConfig( widget_class=self.__class__.__name__, parent_id=None, color_map="plasma", color_bar=None, interpolation="linear", oversampling_factor=1.0, lock_aspect_ratio=False, device_x=None, device_y=None, device_z=None, ) super().__init__(parent=parent, config=config, theme_update=True, **kwargs) self._image_config = config self.scan_id = None self.old_scan_id = None self.scan_item = None self.status_message = None self._grid_index = None # Highest data_version we have dispatched for the current scan; used to drop stale results. # Initialized to -1 so the first real request (len(z_data) >= 0) always supersedes it. self._latest_interpolation_version = -1 self._interpolation_thread: QThread | None = None self._interpolation_worker: _StepInterpolationWorker | None = None self._pending_interpolation_request: _InterpolationRequest | None = None self.heatmap_dialog = None bg_color = pg.mkColor((240, 240, 240, 150)) self.config_label = pg.LegendItem( labelTextColor=(0, 0, 0), offset=(-30, 1), brush=pg.mkBrush(bg_color), horSpacing=0 ) self.config_label.setParentItem(self.plot_item.vb) self.config_label.setVisible(False) self.reload = False self.bec_dispatcher.connect_slot(self.on_scan_status, MessageEndpoints.scan_status()) self.bec_dispatcher.connect_slot(self.on_scan_progress, MessageEndpoints.scan_progress()) self.heatmap_property_changed.connect(lambda: self.sync_signal_update.emit()) self.proxy_update_sync = pg.SignalProxy( self.sync_signal_update, rateLimit=5, slot=self.update_plot ) self._init_toolbar_heatmap() self.toolbar.show_bundles( [ "heatmap_settings", "plot_export", "image_crosshair", "mouse_interaction", "image_autorange", "image_colorbar", "image_processing", "axis_popup", "interpolation_info", ] ) @property def main_image(self) -> ImageItem: """Access the main image item.""" return self.layer_manager["main"].image ################################################################################ # Widget Specific GUI interactions ################################################################################
[docs] @SafeSlot(str) def apply_theme(self, theme: str): """ Apply the current theme to the heatmap widget. """ super().apply_theme(theme) if theme == "dark": brush = pg.mkBrush(pg.mkColor(50, 50, 50, 150)) color = pg.mkColor(255, 255, 255) else: brush = pg.mkBrush(pg.mkColor(240, 240, 240, 150)) color = pg.mkColor(0, 0, 0) if hasattr(self, "config_label"): self.config_label.setBrush(brush) self.config_label.setLabelTextColor(color) self.redraw_config_label()
[docs] @SafeSlot(popup_error=True) def plot( self, device_x: str, device_y: str, device_z: str, signal_x: None | str = None, signal_y: None | str = None, signal_z: None | str = None, color_map: str | None = "plasma", validate_bec: bool = True, interpolation: Literal["linear", "nearest"] | None = None, enforce_interpolation: bool | None = None, oversampling_factor: float | None = None, lock_aspect_ratio: bool | None = None, show_config_label: bool | None = None, reload: bool = False, ): """ Plot the heatmap with the given x, y, and z data. Args: device_x (str): The name of the x-axis device signal. device_y (str): The name of the y-axis device signal. device_z (str): The name of the z-axis device signal. signal_x (str | None): The entry for the x-axis device signal. signal_y (str | None): The entry for the y-axis device signal. signal_z (str | None): The entry for the z-axis device signal. color_map (str | None): The color map to use for the heatmap. validate_bec (bool): Whether to validate the entries against BEC signals. interpolation (Literal["linear", "nearest"] | None): The interpolation method to use. enforce_interpolation (bool | None): Whether to enforce interpolation even for grid scans. oversampling_factor (float | None): Factor to oversample the grid resolution. lock_aspect_ratio (bool | None): Whether to lock the aspect ratio of the image. show_config_label (bool | None): Whether to show the configuration label in the heatmap. reload (bool): Whether to reload the heatmap with new data. """ if validate_bec: signal_x = self.entry_validator.validate_signal(device_x, signal_x) signal_y = self.entry_validator.validate_signal(device_y, signal_y) signal_z = self.entry_validator.validate_signal(device_z, signal_z) if signal_x is None or signal_y is None or signal_z is None: raise ValueError("x, y, and z entries must be provided.") if device_x is None or device_y is None or device_z is None: raise ValueError("x, y, and z names must be provided.") if interpolation is None: interpolation = self._image_config.interpolation if oversampling_factor is None: oversampling_factor = self._image_config.oversampling_factor if enforce_interpolation is None: enforce_interpolation = self._image_config.enforce_interpolation if lock_aspect_ratio is None: lock_aspect_ratio = self._image_config.lock_aspect_ratio if show_config_label is None: show_config_label = self._image_config.show_config_label def _device_key(device: HeatmapDeviceSignal | None) -> tuple[str | None, str | None]: return (device.device if device else None, device.signal if device else None) prev_cfg = getattr(self, "_image_config", None) config_changed = False if prev_cfg and prev_cfg.device_x and prev_cfg.device_y and prev_cfg.device_z: config_changed = any( ( _device_key(prev_cfg.device_x) != (device_x, signal_x), _device_key(prev_cfg.device_y) != (device_y, signal_y), _device_key(prev_cfg.device_z) != (device_z, signal_z), ) ) self._image_config = HeatmapConfig( parent_id=self.gui_id, device_x=HeatmapDeviceSignal(device=device_x, signal=signal_x), device_y=HeatmapDeviceSignal(device=device_y, signal=signal_y), device_z=HeatmapDeviceSignal(device=device_z, signal=signal_z), color_map=color_map, color_bar=None, interpolation=interpolation, oversampling_factor=oversampling_factor, enforce_interpolation=enforce_interpolation, lock_aspect_ratio=lock_aspect_ratio, show_config_label=show_config_label, ) self.color_map = color_map self.reload = reload or config_changed if config_changed: self._grid_index = None self.main_image.clear() self.update_labels() self._fetch_running_scan() self.sync_signal_update.emit()
def _fetch_running_scan(self): scan = self.client.queue.scan_storage.current_scan if scan is not None: self.scan_item = scan self.scan_id = scan.scan_id elif self.client.history and len(self.client.history) > 0: self.scan_item = self.client.history[-1] self.scan_id = self.client.history._scan_ids[-1] self.old_scan_id = None
[docs] def update_labels(self): """ Update the labels of the x, y, and z axes. """ if self._image_config is None: return # Safely get device names (might be None if not yet configured) device_x = self._image_config.device_x device_y = self._image_config.device_y device_z = self._image_config.device_z device_x_name = device_x.device if device_x else None device_y_name = device_y.device if device_y else None device_z_name = device_z.device if device_z else None if device_x_name is not None: self.x_label = device_x_name # type: ignore x_dev = self.dev.get(device_x_name) if x_dev and hasattr(x_dev, "egu"): self.x_label_units = x_dev.egu() if device_y_name is not None: self.y_label = device_y_name # type: ignore y_dev = self.dev.get(device_y_name) if y_dev and hasattr(y_dev, "egu"): self.y_label_units = y_dev.egu() if device_z_name is not None: self.title = device_z_name
def _init_toolbar_heatmap(self): """ Initialize the toolbar for the heatmap widget, adding actions for heatmap settings. """ self.toolbar.add_action( "heatmap_settings", MaterialIconAction( icon_name="scatter_plot", tooltip="Show Heatmap Settings", checkable=True, parent=self, ), ) self.toolbar.components.get_action("heatmap_settings").action.triggered.connect( self.show_heatmap_settings ) # disable all processing actions except for the fft and log bundle = self.toolbar.get_bundle("image_processing") for name, action in bundle.bundle_actions.items(): if name not in ["image_processing_fft", "image_processing_log"]: action().action.setVisible(False) self.toolbar.add_action( "interpolation_info", MaterialIconAction( icon_name="info", tooltip="Show Interpolation Info", checkable=True, parent=self ), ) self.toolbar.components.get_action("interpolation_info").action.triggered.connect( self.toggle_interpolation_info ) self.toolbar.components.get_action("interpolation_info").action.setChecked( self._image_config.show_config_label )
[docs] def show_heatmap_settings(self): """ Show the heatmap settings dialog. """ heatmap_settings_action = self.toolbar.components.get_action("heatmap_settings").action if self.heatmap_dialog is None or not self.heatmap_dialog.isVisible(): heatmap_settings = HeatmapSettings(parent=self, target_widget=self, popup=True) self.heatmap_dialog = SettingsDialog( self, settings_widget=heatmap_settings, window_title="Heatmap Settings", modal=False ) self.heatmap_dialog.resize(700, 350) # When the dialog is closed, update the toolbar icon and clear the reference self.heatmap_dialog.finished.connect(self._heatmap_dialog_closed) self.heatmap_dialog.show() heatmap_settings_action.setChecked(True) else: # If already open, bring it to the front self.heatmap_dialog.raise_() self.heatmap_dialog.activateWindow() heatmap_settings_action.setChecked(True) # keep it toggled
[docs] def toggle_interpolation_info(self): """ Toggle the visibility of the interpolation info label. """ self._image_config.show_config_label = not self._image_config.show_config_label self.toolbar.components.get_action("interpolation_info").action.setChecked( self._image_config.show_config_label ) self.redraw_config_label()
def _heatmap_dialog_closed(self): """ Slot for when the heatmap settings dialog is closed. """ self.heatmap_dialog = None self.toolbar.components.get_action("heatmap_settings").action.setChecked(False)
[docs] @SafeSlot(dict, dict) def on_scan_status(self, msg: dict, meta: dict): """ Initial scan status message handler, which is triggered at the begging and end of scan. Args: msg(dict): The message content. meta(dict): The message metadata. """ current_scan_id = msg.get("scan_id", None) if current_scan_id is None: return if current_scan_id != self.scan_id: self._invalidate_interpolation_generation() # Invalidate any pending interpolation work when a new scan starts self.reset() self.new_scan.emit() self.new_scan_id.emit(current_scan_id) self.old_scan_id = self.scan_id self.scan_id = current_scan_id self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # type: ignore # First trigger to update the scan curves self.sync_signal_update.emit()
@SafeSlot(dict, dict) def on_scan_progress(self, msg: dict, meta: dict): self.sync_signal_update.emit() status = msg.get("done") if status: QTimer.singleShot(100, self.update_plot) QTimer.singleShot(300, self.update_plot)
[docs] @SafeSlot(verify_sender=True) def update_plot(self, _=None) -> None: """ Update the plot with the current data. """ if self.scan_item is None: logger.info("No scan executed so far; skipping update.") return data, access_key = self._fetch_scan_data_and_access() if data == "none": logger.info("No scan executed so far; skipping update.") return if self._image_config is None: return try: device_x = self._image_config.device_x.device signal_x = self._image_config.device_x.signal device_y = self._image_config.device_y.device signal_y = self._image_config.device_y.signal device_z = self._image_config.device_z.device signal_z = self._image_config.device_z.signal except AttributeError: return if access_key == "val": x_data = data.get(device_x, {}).get(signal_x, {}).get(access_key, None) y_data = data.get(device_y, {}).get(signal_y, {}).get(access_key, None) z_data = data.get(device_z, {}).get(signal_z, {}).get(access_key, None) else: x_data = data.get(device_x, {}).get(signal_x, {}).read().get("value", None) y_data = data.get(device_y, {}).get(signal_y, {}).read().get("value", None) z_data = data.get(device_z, {}).get(signal_z, {}).read().get("value", None) if not isinstance(x_data, list): x_data = x_data.tolist() if isinstance(x_data, np.ndarray) else None if not isinstance(y_data, list): y_data = y_data.tolist() if isinstance(y_data, np.ndarray) else None if not isinstance(z_data, list): z_data = z_data.tolist() if isinstance(z_data, np.ndarray) else None if x_data is None or y_data is None or z_data is None: logger.warning("x, y, or z data is None; skipping update.") return if len(x_data) != len(y_data) or len(x_data) != len(z_data): logger.warning( "x, y, and z data lengths do not match; skipping update. " f"Lengths: x={len(x_data)}, y={len(y_data)}, z={len(z_data)}" ) return if hasattr(self.scan_item, "status_message"): scan_msg = self.scan_item.status_message elif hasattr(self.scan_item, "metadata"): metadata = self.scan_item.metadata["bec"] status = metadata["exit_status"] scan_id = metadata["scan_id"] scan_name = metadata["scan_name"] scan_type = metadata["scan_type"] scan_number = metadata["scan_number"] request_inputs = metadata["request_inputs"] if "arg_bundle" in request_inputs and isinstance(request_inputs["arg_bundle"], str): # Convert the arg_bundle from a JSON string to a dictionary request_inputs["arg_bundle"] = json.loads(request_inputs["arg_bundle"]) positions = metadata.get("positions", []) positions = positions.tolist() if isinstance(positions, np.ndarray) else positions scan_msg = messages.ScanStatusMessage( status=status, scan_id=scan_id, scan_name=scan_name, scan_number=scan_number, scan_type=scan_type, request_inputs=request_inputs, info={"positions": positions}, ) else: scan_msg = None if scan_msg is None: logger.warning("Scan message is None; skipping update.") return self.status_message = scan_msg if self._image_config.show_config_label: self.redraw_config_label() if self._is_grid_scan_supported(scan_msg): img, transform = self.get_grid_scan_image(z_data, scan_msg) self._apply_image_update(img, transform) return if len(z_data) < 4: # LinearNDInterpolator requires at least 4 points to interpolate logger.warning("Not enough data points to interpolate; skipping update.") return self._request_step_scan_interpolation(x_data, y_data, z_data, scan_msg)
def _apply_image_update(self, img: np.ndarray | None, transform: QTransform | None): """Apply interpolated image and transform to the heatmap display. This method updates the main image with the computed data and emits the image_updated signal. Color bar signals are temporarily blocked during the update to prevent cascading updates. Args: img(np.ndarray): The interpolated image data, or None if unavailable transform(QTransform): QTransform mapping pixel to world coordinates, or None if unavailable """ if img is None or transform is None: logger.warning("Image data is None; skipping update.") return if self._color_bar is not None: self._color_bar.blockSignals(True) if self.main_image is None: logger.warning("Main image item is None; cannot update image.") return self.main_image.set_data(img, transform=transform) if self._color_bar is not None: self._color_bar.blockSignals(False) self.image_updated.emit() if self.crosshair is not None: self.crosshair.update_markers_on_image_change() def _request_step_scan_interpolation( self, x_data: list[float], y_data: list[float], z_data: list[float], msg: messages.ScanStatusMessage, ): """Request step-scan interpolation in a background thread. If a thread is already running, the request is queued as a pending request and will be processed when the current interpolation completes. Args: x_data(list[float]): X coordinates of data points y_data(list[float]): Y coordinates of data points z_data(list[float]): Z values at each point msg(messages.ScanStatusMessage): Scan status message containing scan metadata """ request = _InterpolationRequest( x_data=list(x_data), y_data=list(y_data), z_data=list(z_data), data_version=len(z_data), scan_id=msg.scan_id, interpolation=self._image_config.interpolation, oversampling_factor=self._image_config.oversampling_factor, ) if self._interpolation_worker is not None and self._interpolation_worker.is_processing: self._pending_interpolation_request = request return self._start_step_scan_interpolation(request) def _ensure_interpolation_thread(self): if self._interpolation_thread is None: self._interpolation_thread = QThread() self._interpolation_worker = _StepInterpolationWorker() self._interpolation_worker.moveToThread(self._interpolation_thread) self.interpolation_requested.connect( self._interpolation_worker.process, Qt.ConnectionType.QueuedConnection ) self._interpolation_worker.finished.connect( self._on_interpolation_finished, Qt.ConnectionType.QueuedConnection ) self._interpolation_worker.failed.connect( self._on_interpolation_failed, Qt.ConnectionType.QueuedConnection ) if self._interpolation_thread is not None and not self._interpolation_thread.isRunning(): self._interpolation_thread.start() def _start_step_scan_interpolation(self, request: _InterpolationRequest): # data_version = len(z_data) at the time of the request; keep the latest to gate results. self._ensure_interpolation_thread() if self._interpolation_thread is not None and not self._interpolation_thread.isRunning(): self._interpolation_thread.start() self._latest_interpolation_version = request.data_version self.interpolation_requested.emit(request, request.data_version) def _on_interpolation_finished( self, img: np.ndarray, transform: QTransform, data_version: int, scan_id: str ): # Only accept results that match the latest dispatched version for the active scan. if data_version == self._latest_interpolation_version and scan_id == self.scan_id: self._apply_image_update(img, transform) else: logger.info("Discarding outdated interpolation result.") self._maybe_start_pending_interpolation() def _on_interpolation_failed(self, error: str, data_version: int, scan_id: str): logger.warning(f"Interpolation failed for scan {scan_id} (version {data_version}): {error}") self._maybe_start_pending_interpolation() def _finish_interpolation_thread(self): self._pending_interpolation_request = None if self._interpolation_worker is not None: try: self.interpolation_requested.disconnect(self._interpolation_worker.process) except (TypeError, RuntimeError) as ext: logger.warning(f"Processing thread already disconnected: {ext}") pass self._interpolation_worker.deleteLater() self._interpolation_worker = None if self._interpolation_thread is not None: if self._interpolation_thread.isRunning(): self._interpolation_thread.quit() if not self._interpolation_thread.wait(3000): # 3s timeout logger.error( f"Interpolation thread of widget {self.gui_id} did not stop within timeout 3s; leaving it dangling." ) self._interpolation_thread.deleteLater() self._interpolation_thread = None logger.info(f"Interpolation thread finished of widget {self.gui_id}") def _maybe_start_pending_interpolation(self): if self._pending_interpolation_request is None: return if self._pending_interpolation_request.scan_id != self.scan_id: self._pending_interpolation_request = None return if self._interpolation_worker is not None and self._interpolation_worker.is_processing: return pending = self._pending_interpolation_request self._pending_interpolation_request = None self._start_step_scan_interpolation(pending) def _cancel_interpolation(self): """Cancel any pending interpolation request without invalidating in-flight work. This clears the pending request queue but does not invalidate in-flight work, allowing any currently running interpolation to complete and update the display if it matches the current scan. """ self._pending_interpolation_request = None # Do not change the active data version so an in-flight worker can still deliver. def _invalidate_interpolation_generation(self): """Invalidate all pending interpolation results and ignore in-flight updates.""" self._pending_interpolation_request = None self._latest_interpolation_version = -1 def redraw_config_label(self): scan_msg = self.status_message if scan_msg is None: return if not self._image_config.show_config_label: self.config_label.setVisible(False) return self.config_label.setOffset((-30, 1)) self.config_label.setVisible(True) self.config_label.clear() self.config_label.addItem(self.plot_item, f"Scan: {scan_msg.scan_number}") self.config_label.addItem(self.plot_item, f"Scan Name: {scan_msg.scan_name}") if scan_msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation: self.config_label.addItem( self.plot_item, f"Interpolation: {self._image_config.interpolation}" ) self.config_label.addItem( self.plot_item, f"Oversampling: {self._image_config.oversampling_factor}x" )
[docs] def get_image_data( self, x_data: list[float] | None = None, y_data: list[float] | None = None, z_data: list[float] | None = None, ) -> tuple[np.ndarray | None, QTransform | None]: """ Get the image data for the heatmap. Depending on the scan type, it will either pre-allocate the grid (grid_scan) or interpolate the data (step scan). Args: x_data (np.ndarray): The x data. y_data (np.ndarray): The y data. z_data (np.ndarray): The z data. Returns: tuple[np.ndarray, QTransform]: The image data and the QTransform. """ msg = self.status_message if x_data is None or y_data is None or z_data is None or msg is None: logger.warning("x, y, or z data is None; skipping update.") return None, None if self._is_grid_scan_supported(msg): return self.get_grid_scan_image(z_data, msg) if len(z_data) < 4: # LinearNDInterpolator requires at least 4 points to interpolate return None, None return self.get_step_scan_image(x_data, y_data, z_data)
def _is_grid_scan_supported(self, msg: messages.ScanStatusMessage) -> bool: """Check if the scan can use optimized grid_scan rendering. Grid scans can avoid interpolation if both X and Y devices match the configured devices and interpolation is not enforced. Args: msg(messages.ScanStatusMessage): Scan status message containing scan metadata Returns: True if grid_scan optimization is applicable, False otherwise """ if msg.scan_name != "grid_scan" or self._image_config.enforce_interpolation: return False signal_x = self._image_config.device_x.signal signal_y = self._image_config.device_y.signal return ( signal_x in msg.request_inputs["arg_bundle"] and signal_y in msg.request_inputs["arg_bundle"] )
[docs] def get_grid_scan_image( self, z_data: list[float], msg: messages.ScanStatusMessage ) -> tuple[np.ndarray, QTransform]: """ Get the image data for a grid scan. Args: z_data (np.ndarray): The z data. msg (messages.ScanStatusMessage): The scan status message. Returns: tuple[np.ndarray, QTransform]: The image data and the QTransform. """ args = self.arg_bundle_to_dict(4, msg.request_inputs["arg_bundle"]) signal_x = self._image_config.device_x.signal signal_y = self._image_config.device_y.signal shape = (args[signal_x][-1], args[signal_y][-1]) data = self.main_image.raw_data if data is None or data.shape != shape: data = np.empty(shape) data.fill(np.nan) elif self.reload: data.fill(np.nan) snaked = msg.request_inputs["kwargs"].get("snaked", True) slow_entry, fast_entry = ( msg.request_inputs["arg_bundle"][0], msg.request_inputs["arg_bundle"][4], ) scan_pos = np.asarray(msg.info["positions"], dtype=float) relative = bool(msg.request_inputs["kwargs"].get("relative", False)) def _axis_column(entry: str) -> int: return 0 if entry == slow_entry else 1 def _axis_levels(entry: str, npts: int) -> np.ndarray: start, stop = args[entry][:2] if relative: origin = float(scan_pos[0, _axis_column(entry)] - start) return origin + np.linspace(start, stop, npts) return np.linspace(start, stop, npts) x_levels = _axis_levels(signal_x, shape[0]) y_levels = _axis_levels(signal_y, shape[1]) pixel_size_x = ( float(x_levels[-1] - x_levels[0]) / max(shape[0] - 1, 1) if shape[0] > 1 else 1.0 ) pixel_size_y = ( float(y_levels[-1] - y_levels[0]) / max(shape[1] - 1, 1) if shape[1] > 1 else 1.0 ) transform = QTransform() transform.scale(pixel_size_x, pixel_size_y) transform.translate(x_levels[0] / pixel_size_x - 0.5, y_levels[0] / pixel_size_y - 0.5) # Fill the data array with the z values if self._grid_index is None or self.reload: self._grid_index = 0 self.reload = False for i in range(self._grid_index, len(z_data)): slow_i, fast_i = divmod(i, args[fast_entry][-1]) if snaked and (slow_i % 2 == 1): fast_i = args[fast_entry][-1] - 1 - fast_i if signal_x == fast_entry: x_i, y_i = fast_i, slow_i else: x_i, y_i = slow_i, fast_i data[x_i, y_i] = z_data[i] self._grid_index = len(z_data) return data, transform
[docs] def get_step_scan_image( self, x_data: list[float], y_data: list[float], z_data: list[float] ) -> tuple[np.ndarray, QTransform]: """ Get the image data for an arbitrary step scan. Args: x_data (list[float]): The x data. y_data (list[float]): The y data. z_data (list[float]): The z data. Returns: tuple[np.ndarray, QTransform]: The image data and the QTransform. """ return self.compute_step_scan_image( x_data=x_data, y_data=y_data, z_data=z_data, oversampling_factor=self._image_config.oversampling_factor, interpolation_method=self._image_config.interpolation, )
[docs] @staticmethod def compute_step_scan_image( x_data: list[float] | np.ndarray, y_data: list[float] | np.ndarray, z_data: list[float] | np.ndarray, oversampling_factor: float, interpolation_method: str, ) -> tuple[np.ndarray, QTransform]: """Compute interpolated heatmap image from step-scan data. This static method is suitable for execution in a background thread as it doesn't access any instance state. Args: x_data(list[float]): X coordinates of data points y_data(list[float]): Y coordinates of data points z_data(list[float]): Z values at each point oversampling_factor(float): Grid resolution multiplier (>1.0 for higher resolution) interpolation_method(str): One of 'linear', 'nearest', or 'clough' Returns: (tuple[np.ndarray, QTransform]):Tuple of (interpolated_grid, transform) where transform maps pixel to world coordinates """ xy_data = np.column_stack((x_data, y_data)) grid_x, grid_y, transform = Heatmap.build_image_grid( positions=xy_data, oversampling_factor=oversampling_factor ) if interpolation_method == "linear": interp = LinearNDInterpolator(xy_data, z_data) elif interpolation_method == "nearest": interp = NearestNDInterpolator(xy_data, z_data) elif interpolation_method == "clough": interp = CloughTocher2DInterpolator(xy_data, z_data) else: # pragma: no cover - guarded by validation raise ValueError( "Interpolation method must be either 'linear', 'nearest', or 'clough'." ) grid_z = interp(grid_x, grid_y) return grid_z, transform
[docs] def get_image_grid(self, positions) -> tuple[np.ndarray, np.ndarray, QTransform]: """ LRU-cached calculation of the grid for the image. The lru cache is indexed by the scan_id to avoid recalculating the grid for the same scan. Args: positions: positions of the data points. Returns: tuple[np.ndarray, np.ndarray, QTransform]: The grid x and y coordinates and the QTransform. """ return self.build_image_grid( positions=positions, oversampling_factor=self._image_config.oversampling_factor )
[docs] @staticmethod def build_image_grid( positions: np.ndarray, oversampling_factor: float ) -> tuple[np.ndarray, np.ndarray, QTransform]: """Build an interpolation grid covering the data positions. Args: positions: (N, 2) array of (x, y) coordinates oversampling_factor: Grid resolution multiplier (>1.0 for higher resolution) Returns: Tuple of (grid_x, grid_y, transform) where grid_x/grid_y are meshgrids for interpolation and transform maps pixel to world coordinates """ base_width, base_height = Heatmap.estimate_image_resolution(positions) width = max(1, int(base_width * oversampling_factor)) height = max(1, int(base_height * oversampling_factor)) grid_x, grid_y = np.mgrid[ min(positions[:, 0]) : max(positions[:, 0]) : width * 1j, min(positions[:, 1]) : max(positions[:, 1]) : height * 1j, ] x_min, x_max = min(positions[:, 0]), max(positions[:, 0]) y_min, y_max = min(positions[:, 1]), max(positions[:, 1]) x_range = x_max - x_min y_range = y_max - y_min x_scale = x_range / width y_scale = y_range / height transform = QTransform() transform.scale(x_scale, y_scale) transform.translate(x_min / x_scale - 0.5, y_min / y_scale - 0.5) return grid_x, grid_y, transform
[docs] @staticmethod def estimate_image_resolution(coords: np.ndarray) -> tuple[int, int]: """ Estimate the number of pixels needed for the image based on the coordinates. Args: coords (np.ndarray): The coordinates of the points. Returns: tuple[int, int]: The estimated width and height of the image.""" if coords.ndim != 2 or coords.shape[1] != 2: raise ValueError("Input must be an (m x 2) array of (x, y) coordinates.") x_min, x_max = coords[:, 0].min(), coords[:, 0].max() y_min, y_max = coords[:, 1].min(), coords[:, 1].max() tree = cKDTree(coords) distances, _ = tree.query(coords, k=2) distances = distances[:, 1] # Get the second nearest neighbor distance avg_distance = np.mean(distances) width_extent = x_max - x_min height_extent = y_max - y_min # Calculate the number of pixels needed based on the average distance width_pixels = int(np.ceil(width_extent / avg_distance)) height_pixels = int(np.ceil(height_extent / avg_distance)) return max(1, width_pixels), max(1, height_pixels)
[docs] @staticmethod def arg_bundle_to_dict(bundle_size: int, args: list) -> dict: """ Convert the argument bundle to a dictionary. Args: bundle_size (int): The size of each argument bundle. args (list): The argument bundle. Returns: dict: The dictionary representation of the argument bundle. """ params = {} for cmds in partition(bundle_size, args): params[cmds[0]] = list(cmds[1:]) return params
def _fetch_scan_data_and_access(self): """ Decide whether the widget is in live or historical mode and return the appropriate data dict and access key. Returns: data_dict (dict): The data structure for the current scan. access_key (str): Either 'val' (live) or 'value' (history). """ if self.scan_item is None: # Optionally fetch the latest from history if nothing is set # self.update_with_scan_history(-1) if self.scan_item is None: logger.info("No scan executed so far; skipping update.") return "none", "none" if hasattr(self.scan_item, "live_data"): # Live scan return self.scan_item.live_data, "val" # Historical scan_devices = self.scan_item.devices return scan_devices, "value"
[docs] def reset(self): self._cancel_interpolation() self._grid_index = None self.main_image.clear() if self.crosshair is not None: self.crosshair.reset() super().reset()
################################################################################ # Widget Specific Properties ################################################################################ @SafeProperty(str) def device_x(self) -> str: """Device name for the X axis.""" if self._image_config.device_x is None: return "" return self._image_config.device_x.device or "" @device_x.setter def device_x(self, device_name: str) -> None: """ Set the X device name. Args: device_name(str): Device name for the X axis """ device_name = device_name or "" # Get current entry or validate if device_name: try: signal = self.entry_validator.validate_signal(device_name, None) self._image_config.device_x = HeatmapDeviceSignal(device=device_name, signal=signal) self.property_changed.emit("device_x", device_name) self.update_labels() # Update axis labels self._try_auto_plot() except Exception: pass # Silently fail if device is not available yet else: self._image_config.device_x = None self.property_changed.emit("device_x", "") self.update_labels() # Clear axis labels @SafeProperty(str) def signal_x(self) -> str: """Signal entry for the X axis device.""" if self._image_config.device_x is None: return "" return self._image_config.device_x.signal or "" @signal_x.setter def signal_x(self, entry: str) -> None: """ Set the X device entry. Args: entry(str): Signal entry for the X axis device """ if not entry: return if self._image_config.device_x is None: logger.warning("Cannot set signal_x without device_x set first.") return device_name = self._image_config.device_x.device try: # Validate the entry for this device validated_signal = self.entry_validator.validate_signal(device_name, entry) self._image_config.device_x = HeatmapDeviceSignal( device=device_name, signal=validated_signal ) self.property_changed.emit("signal_x", validated_signal) self.update_labels() # Update axis labels self._try_auto_plot() except Exception: pass # Silently fail if validation fails @SafeProperty(str) def device_y(self) -> str: """Device name for the Y axis.""" if self._image_config.device_y is None: return "" return self._image_config.device_y.device or "" @device_y.setter def device_y(self, device_name: str) -> None: """ Set the Y device name. Args: device_name(str): Device name for the Y axis """ device_name = device_name or "" # Get current entry or validate if device_name: try: signal = self.entry_validator.validate_signal(device_name, None) self._image_config.device_y = HeatmapDeviceSignal(device=device_name, signal=signal) self.property_changed.emit("device_y", device_name) self.update_labels() # Update axis labels self._try_auto_plot() except Exception: pass # Silently fail if device is not available yet else: self._image_config.device_y = None self.property_changed.emit("device_y", "") self.update_labels() # Clear axis labels @SafeProperty(str) def signal_y(self) -> str: """Signal entry for the Y axis device.""" if self._image_config.device_y is None: return "" return self._image_config.device_y.signal or "" @signal_y.setter def signal_y(self, entry: str) -> None: """ Set the Y device entry. Args: entry(str): Signal entry for the Y axis device """ if not entry: return if self._image_config.device_y is None: logger.warning("Cannot set signal_y without device_y set first.") return device_name = self._image_config.device_y.device try: # Validate the entry for this device validated_signal = self.entry_validator.validate_signal(device_name, entry) self._image_config.device_y = HeatmapDeviceSignal( device=device_name, signal=validated_signal ) self.property_changed.emit("signal_y", validated_signal) self.update_labels() # Update axis labels self._try_auto_plot() except Exception as e: logger.debug(f"Y device entry validation failed: {e}") pass # Silently fail if validation fails @SafeProperty(str) def device_z(self) -> str: """Device name for the Z (color) axis.""" if self._image_config.device_z is None: return "" return self._image_config.device_z.device or "" @device_z.setter def device_z(self, device_name: str) -> None: """ Set the Z device name. Args: device_name(str): Device name for the Z axis """ device_name = device_name or "" # Get current entry or validate if device_name: try: signal = self.entry_validator.validate_signal(device_name, None) self._image_config.device_z = HeatmapDeviceSignal(device=device_name, signal=signal) self.property_changed.emit("device_z", device_name) self.update_labels() # Update axis labels (title) self._try_auto_plot() except Exception as e: logger.debug(f"Z device name validation failed: {e}") pass # Silently fail if device is not available yet else: self._image_config.device_z = None self.property_changed.emit("device_z", "") self.update_labels() # Clear axis labels @SafeProperty(str) def signal_z(self) -> str: """Signal entry for the Z (color) axis device.""" if self._image_config.device_z is None: return "" return self._image_config.device_z.signal or "" @signal_z.setter def signal_z(self, entry: str) -> None: """ Set the Z device entry. Args: entry(str): Signal entry for the Z axis device """ if not entry: return if self._image_config.device_z is None: logger.warning("Cannot set signal_z without device_z set first.") return device_name = self._image_config.device_z.device try: # Validate the entry for this device validated_signal = self.entry_validator.validate_signal(device_name, entry) self._image_config.device_z = HeatmapDeviceSignal( device=device_name, signal=validated_signal ) self.property_changed.emit("signal_z", validated_signal) self.update_labels() # Update axis labels (title) self._try_auto_plot() except Exception as e: logger.debug(f"Z device entry validation failed: {e}") pass # Silently fail if validation fails def _try_auto_plot(self) -> None: """ Attempt to automatically call plot() if all three devices are set. Similar to waveform's approach but requires all three devices. """ has_x = self._image_config.device_x is not None has_y = self._image_config.device_y is not None has_z = self._image_config.device_z is not None if has_x and has_y and has_z: device_x = self._image_config.device_x.device signal_x = self._image_config.device_x.signal device_y = self._image_config.device_y.device signal_y = self._image_config.device_y.signal device_z = self._image_config.device_z.device signal_z = self._image_config.device_z.signal try: self.plot( device_x=device_x, device_y=device_y, device_z=device_z, signal_x=signal_x, signal_y=signal_y, signal_z=signal_z, validate_bec=False, # Don't validate - entries already validated ) except Exception as e: logger.debug(f"Auto-plot failed: {e}") pass # Silently fail if plot cannot be called yet @SafeProperty(str) def interpolation_method(self) -> str: """ The interpolation method used for the heatmap. """ return self._image_config.interpolation @interpolation_method.setter def interpolation_method(self, value: str): """ Set the interpolation method for the heatmap. Args: value(str): The interpolation method, either 'linear' or 'nearest'. """ if value not in ["linear", "nearest"]: raise ValueError("Interpolation method must be either 'linear' or 'nearest'.") self._image_config.interpolation = value self.heatmap_property_changed.emit() @SafeProperty(float) def oversampling_factor(self) -> float: """ The oversampling factor for grid resolution. """ return self._image_config.oversampling_factor @oversampling_factor.setter def oversampling_factor(self, value: float): """ Set the oversampling factor for grid resolution. Args: value(float): The oversampling factor (1.0 = no oversampling, 2.0 = 2x resolution). """ if value <= 0: raise ValueError("Oversampling factor must be greater than 0.") self._image_config.oversampling_factor = value self.heatmap_property_changed.emit() @SafeProperty(bool) def enforce_interpolation(self) -> bool: """ Whether to enforce interpolation even for grid scans. """ return self._image_config.enforce_interpolation @enforce_interpolation.setter def enforce_interpolation(self, value: bool): """ Set whether to enforce interpolation even for grid scans. Args: value(bool): Whether to enforce interpolation. """ self._image_config.enforce_interpolation = value self.heatmap_property_changed.emit() ################################################################################ # Post Processing ################################################################################ @SafeProperty(bool, auto_emit=True) def fft(self) -> bool: """ Whether FFT postprocessing is enabled. """ return self.main_image.fft @fft.setter def fft(self, enable: bool): """ Set FFT postprocessing. Args: enable(bool): Whether to enable FFT postprocessing. """ self.main_image.fft = enable @SafeProperty(bool, auto_emit=True) def log(self) -> bool: """ Whether logarithmic scaling is applied. """ return self.main_image.log @log.setter def log(self, enable: bool): """ Set logarithmic scaling. Args: enable(bool): Whether to enable logarithmic scaling. """ self.main_image.log = enable @SafeProperty(int) def num_rotation_90(self) -> int: """ The number of 90° rotations to apply counterclockwise. """ return self.main_image.num_rotation_90 @num_rotation_90.setter def num_rotation_90(self, value: int): """ Set the number of 90° rotations to apply counterclockwise. Args: value(int): The number of 90° rotations to apply. """ self.main_image.num_rotation_90 = value @SafeProperty(bool, auto_emit=True) def transpose(self) -> bool: """ Whether the image is transposed. """ return self.main_image.transpose @transpose.setter def transpose(self, enable: bool): """ Set the image to be transposed. Args: enable(bool): Whether to enable transposing the image. """ self.main_image.transpose = enable
[docs] def cleanup(self): self._finish_interpolation_thread() super().cleanup()
if __name__ == "__main__": # pragma: no cover import sys from qtpy.QtWidgets import QApplication app = QApplication(sys.argv) heatmap = Heatmap() heatmap.plot(device_x="samx", device_y="samy", device_z="bpm4i", oversampling_factor=5.0) heatmap.show() sys.exit(app.exec_())