""" A `MessagePump` is a base class for any object which processes messages, which includes Widget, Screen, and App. !!! tip Most of the method here are useful in general app development. """ from __future__ import annotations import asyncio import threading from asyncio import CancelledError, QueueEmpty, Task, create_task from contextlib import contextmanager from functools import partial from time import perf_counter from typing import ( TYPE_CHECKING, Any, Awaitable, Callable, Generator, Iterable, Type, TypeVar, cast, ) from weakref import WeakSet from textual import Logger, events, log, messages from textual._callback import invoke from textual._compat import cached_property from textual._context import NoActiveAppError, active_app, active_message_pump from textual._context import message_hook as message_hook_context_var from textual._context import prevent_message_types_stack from textual._on import OnNoWidget from textual._queue import Queue from textual._time import time from textual.constants import SLOW_THRESHOLD from textual.css.match import match from textual.events import Event from textual.message import Message from textual.reactive import Reactive, TooManyComputesError from textual.signal import Signal from textual.timer import Timer, TimerCallback if TYPE_CHECKING: from typing_extensions import TypeAlias from textual.app import App from textual.css.model import SelectorSet Callback: TypeAlias = "Callable[..., Any] | Callable[..., Awaitable[Any]]" class CallbackError(Exception): pass class MessagePumpClosed(Exception): pass _MessagePumpMetaSub = TypeVar("_MessagePumpMetaSub", bound="_MessagePumpMeta") class _MessagePumpMeta(type): """Metaclass for message pump. This exists to populate a Message inner class of a Widget with the parent classes' name. """ def __new__( cls: Type[_MessagePumpMetaSub], name: str, bases: tuple[type, ...], class_dict: dict[str, Any], **kwargs: Any, ) -> _MessagePumpMetaSub: handlers: dict[ type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]] ] = class_dict.get("_decorated_handlers", {}) class_dict["_decorated_handlers"] = handlers for value in class_dict.values(): if callable(value) and hasattr(value, "_textual_on"): textual_on: list[ tuple[type[Message], dict[str, tuple[SelectorSet, ...]]] ] = getattr(value, "_textual_on") for message_type, selectors in textual_on: handlers.setdefault(message_type, []).append((value, selectors)) # Look for reactives with public AND private compute methods. prefix = "compute_" prefix_len = len(prefix) for attr_name, value in class_dict.items(): if attr_name.startswith(prefix) and callable(value): reactive_name = attr_name[prefix_len:] if ( reactive_name in class_dict and isinstance(class_dict[reactive_name], Reactive) and f"_{attr_name}" in class_dict ): raise TooManyComputesError( f"reactive {reactive_name!r} can't have two computes." ) class_obj = super().__new__(cls, name, bases, class_dict, **kwargs) return class_obj class MessagePump(metaclass=_MessagePumpMeta): """Base class which supplies a message pump.""" def __init__(self, parent: MessagePump | None = None) -> None: self._parent = parent self._running: bool = False self._closing: bool = False self._closed: bool = False self._disabled_messages: set[type[Message]] = set() self._pending_message: Message | None = None self._task: Task | None = None self._timers: WeakSet[Timer] = WeakSet() self._last_idle: float = time() self._max_idle: float | None = None self._is_mounted = False """Having this explicit Boolean is an optimization. The same information could be retrieved from `self._mounted_event.is_set()`, but we need to access this frequently in the compositor and the attribute with the explicit Boolean value is faster than the two lookups and the function call. """ self._next_callbacks: list[events.Callback] = [] self._thread_id: int = threading.get_ident() self._prevented_messages_on_mount = self._prevent_message_types_stack[-1] self.message_signal: Signal[Message] = Signal(self, "messages") """Subscribe to this signal to be notified of all messages sent to this widget. This is a fairly low-level mechanism, and shouldn't replace regular message handling. """ @cached_property def _message_queue(self) -> Queue[Message | None]: return Queue() @cached_property def _mounted_event(self) -> asyncio.Event: return asyncio.Event() @property def _prevent_message_types_stack(self) -> list[set[type[Message]]]: """The stack that manages prevented messages.""" try: stack = prevent_message_types_stack.get() except LookupError: stack = [set()] prevent_message_types_stack.set(stack) return stack def _thread_init(self): """Initialize threading primitives for the current thread. Require for Python3.8 https://github.com/Textualize/textual/issues/5845 """ self._message_queue self._mounted_event def _get_prevented_messages(self) -> set[type[Message]]: """A set of all the prevented message types.""" return self._prevent_message_types_stack[-1] def _is_prevented(self, message_type: type[Message]) -> bool: """Check if a message type has been prevented via the [prevent][textual.message_pump.MessagePump.prevent] context manager. Args: message_type: A message type. Returns: `True` if the message has been prevented from sending, or `False` if it will be sent as normal. """ return message_type in self._prevent_message_types_stack[-1] @contextmanager def prevent(self, *message_types: type[Message]) -> Generator[None, None, None]: """A context manager to *temporarily* prevent the given message types from being posted. Example: ```python input = self.query_one(Input) with self.prevent(Input.Changed): input.value = "foo" ``` """ if message_types: prevent_stack = self._prevent_message_types_stack prevent_stack.append(prevent_stack[-1].union(message_types)) try: yield finally: prevent_stack.pop() else: yield @property def task(self) -> Task: assert self._task is not None return self._task @property def has_parent(self) -> bool: """Does this object have a parent?""" return self._parent is not None @property def message_queue_size(self) -> int: """The current size of the message queue.""" return self._message_queue.qsize() @property def is_dom_root(self): """Is this a root node (i.e. the App)?""" return False if TYPE_CHECKING: from textual import getters app = getters.app(App) else: @property def app(self) -> "App[object]": """ Get the current app. Returns: The current app. Raises: NoActiveAppError: if no active app could be found for the current asyncio context """ try: return active_app.get() except LookupError: from textual.app import App node: MessagePump | None = self while not isinstance(node, App): if node is None: raise NoActiveAppError() node = node._parent return node @property def is_attached(self) -> bool: """Is this node linked to the app through the DOM?""" try: if self.app._exit: return False except NoActiveAppError: return False node: MessagePump | None = self while (node := node._parent) is not None: if node.is_dom_root: return True return False @property def is_parent_active(self) -> bool: """Is the parent active?""" parent = self._parent return bool(parent is not None and not parent._closed and not parent._closing) @property def is_running(self) -> bool: """Is the message pump running (potentially processing messages)?""" return self._running @property def log(self) -> Logger: """Get a logger for this object. Returns: A logger. """ return self.app._logger def _attach(self, parent: MessagePump) -> None: """Set the parent, and therefore attach this node to the tree. Args: parent: Parent node. """ self._parent = parent def _detach(self) -> None: """Set the parent to None to remove the node from the tree.""" self._parent = None def check_message_enabled(self, message: Message) -> bool: """Check if a given message is enabled (allowed to be sent). Args: message: A message object. Returns: `True` if the message will be sent, or `False` if it is disabled. """ return type(message) not in self._disabled_messages def disable_messages(self, *messages: type[Message]) -> None: """Disable message types from being processed.""" self._disabled_messages.update(messages) def enable_messages(self, *messages: type[Message]) -> None: """Enable processing of messages types.""" self._disabled_messages.difference_update(messages) async def _get_message(self) -> Message: """Get the next event on the queue, or None if queue is closed. Returns: Event object or None. """ if self._closed: raise MessagePumpClosed("The message pump is closed") if self._pending_message is not None: try: return self._pending_message finally: self._pending_message = None message = await self._message_queue.get() if message is None: self._closed = True raise MessagePumpClosed("The message pump is now closed") return message def _peek_message(self) -> Message | None: """Peek the message at the head of the queue (does not remove it from the queue), or return None if the queue is empty. Returns: The message or None. """ if self._pending_message is None: try: message = self._message_queue.get_nowait() except QueueEmpty: pass else: if message is None: self._closed = True raise MessagePumpClosed("The message pump is now closed") self._pending_message = message if self._pending_message is not None: return self._pending_message return None def set_timer( self, delay: float, callback: TimerCallback | None = None, *, name: str | None = None, pause: bool = False, ) -> Timer: """Call a function after a delay. Example: ```python def ready(): self.notify("Your soft boiled egg is ready!") # Call ready() after 3 minutes self.set_timer(3 * 60, ready) ``` Args: delay: Time (in seconds) to wait before invoking callback. callback: Callback to call after time has expired. name: Name of the timer (for debug). pause: Start timer paused. Returns: A timer object. """ timer = Timer( self, delay, name=name or f"set_timer#{Timer._timer_count}", callback=None if callback is None else partial(self.call_next, callback), repeat=0, pause=pause, ) timer._start() self._timers.add(timer) return timer def set_interval( self, interval: float, callback: TimerCallback | None = None, *, name: str | None = None, repeat: int = 0, pause: bool = False, ) -> Timer: """Call a function at periodic intervals. Args: interval: Time (in seconds) between calls. callback: Function to call. name: Name of the timer object. repeat: Number of times to repeat the call or 0 for continuous. pause: Start the timer paused. Returns: A timer object. """ timer = Timer( self, interval, name=name or f"set_interval#{Timer._timer_count}", callback=callback, repeat=repeat or None, pause=pause, ) timer._start() self._timers.add(timer) return timer def call_after_refresh(self, callback: Callback, *args: Any, **kwargs: Any) -> bool: """Schedule a callback to run after all messages are processed and the screen has been refreshed. Positional and keyword arguments are passed to the callable. Args: callback: A callable. Returns: `True` if the callback was scheduled, or `False` if the callback could not be scheduled (may occur if the message pump was closed or closing). """ # We send the InvokeLater message to ourselves first, to ensure we've cleared # out anything already pending in our own queue. message = messages.InvokeLater(partial(callback, *args, **kwargs)) return self.post_message(message) async def wait_for_refresh(self) -> bool: """Wait for the next refresh. This method should only be called from a task other than the one running this widget. If called from the same task, it will return immediately to avoid blocking the event loop. Returns: `True` if waiting for refresh was successful, or `False` if the call was a null-op due to calling it within the node's own task. """ assert ( self._task is not None ), "Node must be running before calling wait_for_refresh" if asyncio.current_task() is self._task: return False refreshed_event = asyncio.Event() self.call_after_refresh(refreshed_event.set) await refreshed_event.wait() return True def call_later(self, callback: Callback, *args: Any, **kwargs: Any) -> bool: """Schedule a callback to run after all messages are processed in this object. Positional and keywords arguments are passed to the callable. Args: callback: Callable to call next. *args: Positional arguments to pass to the callable. **kwargs: Keyword arguments to pass to the callable. Returns: `True` if the callback was scheduled, or `False` if the callback could not be scheduled (may occur if the message pump was closed or closing). """ message = events.Callback(callback=partial(callback, *args, **kwargs)) return self.post_message(message) def call_next(self, callback: Callback, *args: Any, **kwargs: Any) -> None: """Schedule a callback to run immediately after processing the current message. Args: callback: Callable to run after current event. *args: Positional arguments to pass to the callable. **kwargs: Keyword arguments to pass to the callable. """ assert callback is not None, "Callback must not be None" callback_message = events.Callback(callback=partial(callback, *args, **kwargs)) callback_message._prevent.update(self._get_prevented_messages()) self._next_callbacks.append(callback_message) self.check_idle() def _on_invoke_later(self, message: messages.InvokeLater) -> None: # Forward InvokeLater message to the Screen if self.app._running: self.app.screen._invoke_later( message.callback, message._sender or active_message_pump.get() ) async def _close_messages(self, wait: bool = True) -> None: """Close message queue, and optionally wait for queue to finish processing.""" if self._closed or self._closing: return self._closing = True if self._timers: await Timer._stop_all(self._timers) self._timers.clear() Reactive._reset_object(self) self._message_queue.put_nowait(None) if wait and self._task is not None and asyncio.current_task() != self._task: try: running_widget = active_message_pump.get() except LookupError: running_widget = None if running_widget is None or running_widget is not self: try: await self._task except CancelledError: pass def _start_messages(self) -> None: """Start messages task.""" self._thread_init() if self.app._running: self._task = create_task( self._process_messages(), name=f"message pump {self}" ) else: self._closing = True self._closed = True async def _process_messages(self) -> None: self._running = True with self._context(): if not await self._pre_process(): self._running = False return try: await self._process_messages_loop() except CancelledError: pass finally: self._running = False try: if self._timers: await Timer._stop_all(self._timers) self._timers.clear() Reactive._clear_watchers(self) finally: await self._message_loop_exit() self._task = None async def _message_loop_exit(self) -> None: """Called when the message loop has completed.""" async def _pre_process(self) -> bool: """Procedure to run before processing messages. Returns: `True` if successful, or `False` if any exception occurred. """ # Dispatch compose and mount messages without going through loop # These events must occur in this order, and at the start. try: await self._dispatch_message(events.Compose()) if self._prevented_messages_on_mount: with self.prevent(*self._prevented_messages_on_mount): await self._dispatch_message(events.Mount()) else: await self._dispatch_message(events.Mount()) self._post_mount() except Exception as error: self.app._handle_exception(error) return False finally: # This is critical, mount may be waiting self._mounted_event.set() self._is_mounted = True return True def _post_mount(self): """Called after the object has been mounted.""" def _close_messages_no_wait(self) -> None: """Request the message queue to immediately exit.""" self._message_queue.put_nowait(messages.CloseMessages()) @contextmanager def _context(self) -> Generator[None, None, None]: """Context manager to set ContextVars.""" reset_token = active_message_pump.set(self) try: yield finally: active_message_pump.reset(reset_token) async def _on_close_messages(self, message: messages.CloseMessages) -> None: await self._close_messages() async def _process_messages_loop(self) -> None: """Process messages until the queue is closed.""" _rich_traceback_guard = True self._thread_id = threading.get_ident() await asyncio.sleep(0) while not self._closed: try: message = await self._get_message() except MessagePumpClosed: break except CancelledError: raise except Exception as error: raise error from None # Combine any pending messages that may supersede this one while not (self._closed or self._closing): try: pending = self._peek_message() except MessagePumpClosed: break if pending is None or not message.can_replace(pending): break try: message = await self._get_message() except MessagePumpClosed: break try: await self._dispatch_message(message) except CancelledError: raise except Exception as error: self._mounted_event.set() self._is_mounted = True self.app._handle_exception(error) break finally: self.message_signal.publish(message) self._message_queue.task_done() current_time = time() # Insert idle events if self._message_queue.empty() or ( self._max_idle is not None and current_time - self._last_idle > self._max_idle ): self._last_idle = current_time if not self._closed: event = events.Idle() for _cls, method in self._get_dispatch_methods( "on_idle", event ): try: await invoke(method, event) except Exception as error: self.app._handle_exception(error) break await self._flush_next_callbacks() async def _flush_next_callbacks(self) -> None: """Invoke pending callbacks in next callbacks queue.""" callbacks = self._next_callbacks.copy() self._next_callbacks.clear() for callback in callbacks: try: with self.prevent(*callback._prevent): await invoke(callback.callback) except Exception as error: self.app._handle_exception(error) break async def _dispatch_message(self, message: Message) -> None: """Dispatch a message received from the message queue. Args: message: A message object """ _rich_traceback_guard = True if message.no_dispatch: return try: message_hook = message_hook_context_var.get() except LookupError: pass else: message_hook(message) with self.prevent(*message._prevent): # Allow apps to treat events and messages separately if isinstance(message, Event): await self.on_event(message) elif "debug" in self.app.features: start = perf_counter() await self._on_message(message) if perf_counter() - start > SLOW_THRESHOLD / 1000: log.warning( f"method=<{self.__class__.__name__}." f"{message.handler_name}>", f"Took over {SLOW_THRESHOLD}ms to process.", "\nTo avoid screen freezes, consider using a worker.", ) else: await self._on_message(message) if self._next_callbacks: await self._flush_next_callbacks() def _get_dispatch_methods( self, method_name: str, message: Message ) -> Iterable[tuple[type, Callable[[Message], Awaitable]]]: """Gets handlers from the MRO Args: method_name: Handler method name. message: Message object. """ from textual.widget import Widget methods_dispatched: set[Callable] = set() message_mro = [ _type for _type in message.__class__.__mro__ if issubclass(_type, Message) ] for cls in self.__class__.__mro__: if message._no_default_action: break # Try decorated handlers first decorated_handlers = cast( "dict[type[Message], list[tuple[Callable, dict[str, tuple[SelectorSet, ...]]]]] | None", cls.__dict__.get("_decorated_handlers"), ) if decorated_handlers: for message_class in message_mro: handlers = decorated_handlers.get(message_class, []) for method, selectors in handlers: if method in methods_dispatched: continue if not selectors: yield cls, method.__get__(self, cls) methods_dispatched.add(method) else: if not message._sender: continue for attribute, selector in selectors.items(): node = getattr(message, attribute) if node is None: break if not isinstance(node, Widget): raise OnNoWidget( f"on decorator can't match against {attribute!r} as it is not a widget." ) if not match(selector, node): break else: yield cls, method.__get__(self, cls) methods_dispatched.add(method) # Fall back to the naming convention # But avoid calling the handler if it was decorated method = cls.__dict__.get(f"_{method_name}") or cls.__dict__.get( method_name ) if method is not None and not getattr(method, "_textual_on", None): yield cls, method.__get__(self, cls) async def on_event(self, event: events.Event) -> None: """Called to process an event. Args: event: An Event object. """ await self._on_message(event) async def _on_message(self, message: Message) -> None: """Called to process a message. Args: message: A Message object. """ _rich_traceback_guard = True handler_name = message.handler_name # Look through the MRO to find a handler dispatched = False for cls, method in self._get_dispatch_methods(handler_name, message): log.event.verbosity(message.verbose)( message, ">>>", self, f"method=<{cls.__name__}.{handler_name}>", ) dispatched = True await invoke(method, message) if not dispatched: log.event.verbosity(message.verbose)(message, ">>>", self, "method=None") # Bubble messages up the DOM (if enabled on the message) if message.bubble and self._parent and not message._stop_propagation: if message._sender is not None and message._sender == self._parent: # parent is sender, so we stop propagation after parent message.stop() if self.is_parent_active and self.is_attached: message._bubble_to(self._parent) def check_idle(self) -> None: """Prompt the message pump to call idle if the queue is empty.""" if self._running and self._message_queue.empty(): self.post_message(messages.Prompt()) async def _post_message(self, message: Message) -> bool: """Post a message or an event to this message pump. This is an internal method for use where a coroutine is required. Args: message: A message object. Returns: True if the messages was posted successfully, False if the message was not posted (because the message pump was in the process of closing). """ return self.post_message(message) def post_message(self, message: Message) -> bool: """Posts a message on to this widget's queue. Args: message: A message (including Event). Returns: `True` if the message was queued for processing, otherwise `False`. """ _rich_traceback_omit = True if not hasattr(message, "_prevent"): # Catch a common error (forgetting to call super) raise RuntimeError( "Message is missing attributes; did you forget to call super().__init__() ?" ) if self._closing or self._closed: return False if not self.check_message_enabled(message): return False # Add a copy of the prevented message types to the message # This is so that prevented messages are honoured by the event's handler message._prevent.update(self._get_prevented_messages()) if self._thread_id != threading.get_ident() and self.app._loop is not None: # If we're not calling from the same thread, make it threadsafe loop = self.app._loop loop.call_soon_threadsafe(self._message_queue.put_nowait, message) else: self._message_queue.put_nowait(message) return True async def on_callback(self, event: events.Callback) -> None: if self.app._closing: return try: self.app.screen except Exception: self.log.warning( f"Not invoking timer callback {event.callback!r} because there is no screen." ) return await invoke(event.callback) async def on_timer(self, event: events.Timer) -> None: if not self.app._running: return event.prevent_default() event.stop() if event.callback is not None: try: self.app.screen except Exception: self.log.warning( f"Not invoking timer callback {event.callback!r} because there is no screen." ) return try: await invoke(event.callback) except Exception as error: raise CallbackError( f"unable to run callback {event.callback!r}; {error}" )