Source code for bec_lib.user_scripts_mixin
"""
This module provides a mixin class for the BEC class that allows the user to load and unload scripts from the `scripts` directory.
"""
from __future__ import annotations
import builtins
import glob
import importlib
import inspect
import os
import pathlib
from typing import TYPE_CHECKING
from pylint import lint
from pylint.reporters import CollectingReporter
from rich.console import Console
from rich.table import Table
from bec_lib.callback_handler import EventType
from bec_lib.logger import bec_logger
if TYPE_CHECKING:
from pylint.message import Message
logger = bec_logger.logger
try:
from bec_plugins import bec_ipython_client as client_plugins
except ImportError:
client_plugins = None
[docs]
class UserScriptsMixin:
def __init__(self) -> None:
super().__init__()
self._scripts = {}
[docs]
def load_all_user_scripts(self) -> None:
"""Load all scripts from the `scripts` directory.
Runs a callback of type `EventType.NAMESPACE_UPDATE`
to inform clients about added objects in the namesapce.
"""
self.forget_all_user_scripts()
# load all scripts from the scripts directory
current_path = pathlib.Path(__file__).parent.resolve()
script_files = glob.glob(os.path.abspath(os.path.join(current_path, "../scripts/*.py")))
# load all scripts from the user's script directory in the home directory
user_script_dir = os.path.join(os.path.expanduser("~"), "bec", "scripts")
if os.path.exists(user_script_dir):
script_files.extend(glob.glob(os.path.abspath(os.path.join(user_script_dir, "*.py"))))
# load scripts from the plugins
if client_plugins:
plugin_scripts_dir = os.path.join(client_plugins.__path__[0], "scripts")
if os.path.exists(plugin_scripts_dir):
script_files.extend(
glob.glob(os.path.abspath(os.path.join(plugin_scripts_dir, "*.py")))
)
for file in script_files:
self.load_user_script(file)
builtins.__dict__.update({name: v["cls"] for name, v in self._scripts.items()})
[docs]
def forget_all_user_scripts(self) -> None:
"""unload / remove loaded user scripts from builtins. Files will remain untouched.
Runs a callback of type `EventType.NAMESPACE_UPDATE`
to inform clients about removing objects from the namesapce.
"""
for name, obj in self._scripts.items():
builtins.__dict__.pop(name)
self.callbacks.run(
EventType.NAMESPACE_UPDATE, action="remove", ns_objects={name: obj["cls"]}
)
self._scripts.clear()
[docs]
def load_user_script(self, file: str) -> None:
"""load a user script file and import all its definitions
Args:
file (str): Full path to the script file.
"""
self._run_linter_on_file(file)
module_members = self._load_script_module(file)
for name, cls in module_members:
if not callable(cls):
continue
# ignore imported classes
if cls.__module__ != "scripts":
continue
if name in self._scripts:
logger.warning(f"Conflicting definitions for {name}.")
logger.info(f"Importing {name}")
self._scripts[name] = {"cls": cls, "fname": file}
self.callbacks.run(EventType.NAMESPACE_UPDATE, action="add", ns_objects={name: cls})
[docs]
def forget_user_script(self, name: str) -> None:
"""unload / remove a user scripts. The file will remain on disk."""
if name not in self._scripts:
logger.error(f"{name} is not a known user script.")
return
self.callbacks.run(
EventType.NAMESPACE_UPDATE,
action="remove",
ns_objects={name: self._scripts[name]["cls"]},
)
builtins.__dict__.pop(name)
self._scripts.pop(name)
[docs]
def list_user_scripts(self):
"""display all currently loaded user functions"""
console = Console()
table = Table(title="User scripts")
table.add_column("Name", justify="center")
table.add_column("Location", justify="center", overflow="fold")
for name, content in self._scripts.items():
table.add_row(name, content.get("fname"))
console.print(table)
def _load_script_module(self, file) -> list:
module_spec = importlib.util.spec_from_file_location("scripts", file)
plugin_module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(plugin_module)
module_members = inspect.getmembers(plugin_module)
return module_members
def _run_linter_on_file(self, file) -> None:
accepted_vars = ",".join([key for key in builtins.__dict__ if not key.startswith("_")])
reporter = CollectingReporter()
print(f"{accepted_vars}")
lint.Run(
[file, "--errors-only", f"--additional-builtins={accepted_vars}"],
exit=False,
reporter=reporter,
)
if not reporter.messages:
return
def _format_pylint_output(msg: Message):
return f"Line {msg.line}, column {msg.column}: {msg.msg}."
for msg in reporter.messages:
logger.error(
f"During the import of {file}, the following error was detected: \n{_format_pylint_output(msg)}.\nThe script was imported but may not work as expected."
)