ai-station/.venv/lib/python3.12/site-packages/textual/widgets/_tree.py

1599 lines
50 KiB
Python

"""Provides a tree widget."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, Iterable, NewType, TypeVar, cast
import rich.repr
from rich.style import NULL_STYLE, Style
from rich.text import Text, TextType
from textual import events, on
from textual._immutable_sequence_view import ImmutableSequenceView
from textual._loop import loop_last
from textual._segment_tools import line_pad
from textual.binding import Binding, BindingType
from textual.cache import LRUCache
from textual.geometry import Region, Size, clamp
from textual.message import Message
from textual.reactive import reactive, var
from textual.scroll_view import ScrollView
from textual.strip import Strip
if TYPE_CHECKING:
from typing_extensions import Self, TypeAlias
NodeID = NewType("NodeID", int)
"""The type of an ID applied to a [TreeNode][textual.widgets._tree.TreeNode]."""
TreeDataType = TypeVar("TreeDataType")
"""The type of the data for a given instance of a [Tree][textual.widgets.Tree]."""
EventTreeDataType = TypeVar("EventTreeDataType")
"""The type of the data for a given instance of a [Tree][textual.widgets.Tree].
Similar to [TreeDataType][textual.widgets._tree.TreeDataType] but used for
``Tree`` messages.
"""
LineCacheKey: TypeAlias = "tuple[int | tuple, ...]"
TOGGLE_STYLE = Style.from_meta({"toggle": True})
class RemoveRootError(Exception):
"""Exception raised when trying to remove the root of a [`TreeNode`][textual.widgets.tree.TreeNode]."""
class UnknownNodeID(Exception):
"""Exception raised when referring to an unknown [`TreeNode`][textual.widgets.tree.TreeNode] ID."""
class AddNodeError(Exception):
"""Exception raised when there is an error with a request to add a node."""
@dataclass
class _TreeLine(Generic[TreeDataType]):
path: list[TreeNode[TreeDataType]]
last: bool
@property
def node(self) -> TreeNode[TreeDataType]:
"""The node associated with this line."""
return self.path[-1]
def _get_guide_width(self, guide_depth: int, show_root: bool) -> int:
"""Get the cell width of the line as rendered.
Args:
guide_depth: The guide depth (cells in the indentation).
Returns:
Width in cells.
"""
if show_root:
width = (max(0, len(self.path) - 1)) * guide_depth
else:
width = 0
if len(self.path) > 1:
width += (len(self.path) - 1) * guide_depth
return width
class TreeNodes(ImmutableSequenceView["TreeNode[TreeDataType]"]):
"""An immutable collection of `TreeNode`."""
@rich.repr.auto
class TreeNode(Generic[TreeDataType]):
"""An object that represents a "node" in a tree control."""
def __init__(
self,
tree: Tree[TreeDataType],
parent: TreeNode[TreeDataType] | None,
id: NodeID,
label: Text,
data: TreeDataType | None = None,
*,
expanded: bool = True,
allow_expand: bool = True,
) -> None:
"""Initialise the node.
Args:
tree: The tree that the node is being attached to.
parent: The parent node that this node is being attached to.
id: The ID of the node.
label: The label for the node.
data: Optional data to associate with the node.
expanded: Should the node be attached in an expanded state?
allow_expand: Should the node allow being expanded by the user?
"""
self._tree = tree
self._parent = parent
self._id = id
self._label = tree.process_label(label)
self.data = data
"""Optional data associated with the tree node."""
self._expanded = expanded
self._children: list[TreeNode[TreeDataType]] = []
self._hover_ = False
self._selected_ = False
self._allow_expand = allow_expand
self._updates: int = 0
self._line: int = -1
def __rich_repr__(self) -> rich.repr.Result:
yield self._label.plain
yield self.data
def _reset(self) -> None:
self._hover_ = False
self._selected_ = False
self._updates += 1
@property
def tree(self) -> Tree[TreeDataType]:
"""The tree that this node is attached to."""
return self._tree
@property
def children(self) -> TreeNodes[TreeDataType]:
"""The child nodes of a TreeNode."""
return TreeNodes(self._children)
@property
def siblings(self) -> TreeNodes[TreeDataType]:
"""The siblings of this node (includes self)."""
if self.parent is None:
return TreeNodes([self])
else:
return self.parent.children
@property
def line(self) -> int:
"""The line number for this node, or -1 if it is not displayed."""
return self._line
@property
def _hover(self) -> bool:
"""Check if the mouse is over the node."""
return self._hover_
@_hover.setter
def _hover(self, hover: bool) -> None:
self._updates += 1
self._hover_ = hover
@property
def _selected(self) -> bool:
"""Check if the node is selected."""
return self._selected_
@_selected.setter
def _selected(self, selected: bool) -> None:
self._updates += 1
self._selected_ = selected
@property
def id(self) -> NodeID:
"""The ID of the node."""
return self._id
@property
def parent(self) -> TreeNode[TreeDataType] | None:
"""The parent of the node."""
return self._parent
@property
def next_sibling(self) -> TreeNode[TreeDataType] | None:
"""The next sibling below the node."""
siblings = self.siblings
index = siblings.index(self) + 1
try:
return siblings[index]
except IndexError:
return None
@property
def previous_sibling(self) -> TreeNode[TreeDataType] | None:
"""The previous sibling below the node."""
siblings = self.siblings
index = siblings.index(self) - 1
if index < 0:
return None
try:
return siblings[index]
except IndexError:
return None
@property
def is_expanded(self) -> bool:
"""Is the node expanded?"""
return self._expanded
@property
def is_collapsed(self) -> bool:
"""Is the node collapsed?"""
return not self._expanded
@property
def is_last(self) -> bool:
"""Is this the last child node of its parent?"""
if self._parent is None:
return True
return bool(
self._parent._children and self._parent._children[-1] == self,
)
@property
def is_root(self) -> bool:
"""Is this node the root of the tree?"""
return self == self._tree.root
@property
def allow_expand(self) -> bool:
"""Is this node allowed to expand?"""
return self._allow_expand
@allow_expand.setter
def allow_expand(self, allow_expand: bool) -> None:
self._allow_expand = allow_expand
self._updates += 1
def _expand(self, expand_all: bool) -> None:
"""Mark the node as expanded (its children are shown).
Args:
expand_all: If `True` expand all offspring at all depths.
"""
self._expanded = True
self._updates += 1
self._tree.post_message(Tree.NodeExpanded(self).set_sender(self._tree))
if expand_all:
for child in self.children:
child._expand(expand_all)
def expand(self) -> Self:
"""Expand the node (show its children).
Returns:
The `TreeNode` instance.
"""
self._expand(False)
self._tree._invalidate()
return self
def expand_all(self) -> Self:
"""Expand the node (show its children) and all those below it.
Returns:
The `TreeNode` instance.
"""
self._expand(True)
self._tree._invalidate()
return self
def _collapse(self, collapse_all: bool) -> None:
"""Mark the node as collapsed (its children are hidden).
Args:
collapse_all: If `True` collapse all offspring at all depths.
"""
self._expanded = False
self._updates += 1
self._tree.post_message(Tree.NodeCollapsed(self).set_sender(self._tree))
if collapse_all:
for child in self.children:
child._collapse(collapse_all)
def collapse(self) -> Self:
"""Collapse the node (hide its children).
Returns:
The `TreeNode` instance.
"""
self._collapse(False)
self._tree._invalidate()
return self
def collapse_all(self) -> Self:
"""Collapse the node (hide its children) and all those below it.
Returns:
The `TreeNode` instance.
"""
self._collapse(True)
self._tree._invalidate()
return self
def toggle(self) -> Self:
"""Toggle the node's expanded state.
Returns:
The `TreeNode` instance.
"""
if self._expanded:
self.collapse()
else:
self.expand()
return self
def toggle_all(self) -> Self:
"""Toggle the node's expanded state and make all those below it match.
Returns:
The `TreeNode` instance.
"""
if self._expanded:
self.collapse_all()
else:
self.expand_all()
return self
@property
def label(self) -> TextType:
"""The label for the node."""
return self._label
@label.setter
def label(self, new_label: TextType) -> None:
self.set_label(new_label)
def set_label(self, label: TextType) -> None:
"""Set a new label for the node.
Args:
label: A ``str`` or ``Text`` object with the new label.
"""
self._updates += 1
text_label = self._tree.process_label(label)
self._label = text_label
self._tree.call_later(self._tree._refresh_node, self)
def add(
self,
label: TextType,
data: TreeDataType | None = None,
*,
before: int | TreeNode[TreeDataType] | None = None,
after: int | TreeNode[TreeDataType] | None = None,
expand: bool = False,
allow_expand: bool = True,
) -> TreeNode[TreeDataType]:
"""Add a node to the sub-tree.
Args:
label: The new node's label.
data: Data associated with the new node.
before: Optional index or `TreeNode` to add the node before.
after: Optional index or `TreeNode` to add the node after.
expand: Node should be expanded.
allow_expand: Allow user to expand the node via keyboard or mouse.
Returns:
A new Tree node
Raises:
AddNodeError: If there is a problem with the addition request.
Note:
Only one of `before` or `after` can be provided. If both are
provided a `AddNodeError` will be raised.
"""
if before is not None and after is not None:
raise AddNodeError("Unable to add a node both before and after a node")
insert_index: int = len(self.children)
if before is not None:
if isinstance(before, int):
insert_index = before
elif isinstance(before, TreeNode):
try:
insert_index = self.children.index(before)
except ValueError:
raise AddNodeError(
"The node specified for `before` is not a child of this node"
)
else:
raise TypeError(
"`before` argument must be an index or a TreeNode object to add before"
)
if after is not None:
if isinstance(after, int):
insert_index = after + 1
if after < 0:
insert_index += len(self.children)
elif isinstance(after, TreeNode):
try:
insert_index = self.children.index(after) + 1
except ValueError:
raise AddNodeError(
"The node specified for `after` is not a child of this node"
)
else:
raise TypeError(
"`after` argument must be an index or a TreeNode object to add after"
)
text_label = self._tree.process_label(label)
node = self._tree._add_node(self, text_label, data)
node._expanded = expand
node._allow_expand = allow_expand
self._updates += 1
self._children.insert(insert_index, node)
self._tree._invalidate()
return node
def add_leaf(
self,
label: TextType,
data: TreeDataType | None = None,
*,
before: int | TreeNode[TreeDataType] | None = None,
after: int | TreeNode[TreeDataType] | None = None,
) -> TreeNode[TreeDataType]:
"""Add a 'leaf' node (a node that can not expand).
Args:
label: Label for the node.
data: Optional data.
before: Optional index or `TreeNode` to add the node before.
after: Optional index or `TreeNode` to add the node after.
Returns:
New node.
Raises:
AddNodeError: If there is a problem with the addition request.
Note:
Only one of `before` or `after` can be provided. If both are
provided a `AddNodeError` will be raised.
"""
node = self.add(
label,
data,
before=before,
after=after,
expand=False,
allow_expand=False,
)
return node
def _remove_children(self) -> None:
"""Remove child nodes of this node.
Note:
This is the internal support method for `remove_children`. Call
`remove_children` to ensure the tree gets refreshed.
"""
for child in reversed(self._children):
child._remove()
def _remove(self) -> None:
"""Remove the current node and all its children.
Note:
This is the internal support method for `remove`. Call `remove`
to ensure the tree gets refreshed.
"""
self._remove_children()
assert self._parent is not None
del self._parent._children[self._parent._children.index(self)]
del self._tree._tree_nodes[self.id]
def remove(self) -> None:
"""Remove this node from the tree.
Raises:
RemoveRootError: If there is an attempt to remove the root.
"""
if self.is_root:
raise RemoveRootError("Attempt to remove the root node of a Tree.")
self._remove()
self._tree._invalidate()
def remove_children(self) -> None:
"""Remove any child nodes of this node."""
self._remove_children()
self._tree._invalidate()
def refresh(self) -> None:
"""Initiate a refresh (repaint) of this node."""
self._updates += 1
self._tree._refresh_line(self._line)
class Tree(Generic[TreeDataType], ScrollView, can_focus=True):
"""A widget for displaying and navigating data in a tree."""
ICON_NODE = ""
"""Unicode 'icon' to use for an expandable node."""
ICON_NODE_EXPANDED = ""
"""Unicode 'icon' to use for an expanded node."""
BINDINGS: ClassVar[list[BindingType]] = [
Binding("shift+left", "cursor_parent", "Cursor to parent", show=False),
Binding(
"shift+right",
"cursor_parent_next_sibling",
"Cursor to next ancestor",
show=False,
),
Binding(
"shift+up",
"cursor_previous_sibling",
"Cursor to previous sibling",
show=False,
),
Binding(
"shift+down",
"cursor_next_sibling",
"Cursor to next sibling",
show=False,
),
Binding("enter", "select_cursor", "Select", show=False),
Binding("space", "toggle_node", "Toggle", show=False),
Binding(
"shift+space", "toggle_expand_all", "Expand or collapse all", show=False
),
Binding("up", "cursor_up", "Cursor Up", show=False),
Binding("down", "cursor_down", "Cursor Down", show=False),
]
"""
| Key(s) | Description |
| :- | :- |
| enter | Select the current item. |
| space | Toggle the expand/collapsed state of the current item. |
| up | Move the cursor up. |
| down | Move the cursor down. |
"""
COMPONENT_CLASSES: ClassVar[set[str]] = {
"tree--cursor",
"tree--guides",
"tree--guides-hover",
"tree--guides-selected",
"tree--highlight",
"tree--highlight-line",
"tree--label",
}
"""
| Class | Description |
| :- | :- |
| `tree--cursor` | Targets the cursor. |
| `tree--guides` | Targets the indentation guides. |
| `tree--guides-hover` | Targets the indentation guides under the cursor. |
| `tree--guides-selected` | Targets the indentation guides that are selected. |
| `tree--highlight` | Targets the highlighted items. |
| `tree--highlight-line` | Targets the lines under the cursor. |
| `tree--label` | Targets the (text) labels of the items. |
"""
DEFAULT_CSS = """
Tree {
background: $surface;
color: $foreground;
& > .tree--label {}
& > .tree--guides {
color: $surface-lighten-2;
}
& > .tree--guides-hover {
color: $surface-lighten-2;
}
& > .tree--guides-selected {
color: $block-cursor-blurred-background;
}
& > .tree--cursor {
text-style: $block-cursor-blurred-text-style;
background: $block-cursor-blurred-background;
}
& > .tree--highlight {}
& > .tree--highlight-line {
background: $block-hover-background;
}
&:focus {
background-tint: $foreground 5%;
& > .tree--cursor {
color: $block-cursor-foreground;
background: $block-cursor-background;
text-style: $block-cursor-text-style;
}
& > .tree--guides {
color: $surface-lighten-3;
}
& > .tree--guides-hover {
color: $surface-lighten-3;
}
& > .tree--guides-selected {
color: $block-cursor-background;
}
}
&:light {
/* In light mode the guides are darker*/
& > .tree--guides {
color: $surface-darken-1;
}
& > .tree--guides-hover {
color: $block-cursor-background;
}
& > .tree--guides-selected {
color: $block-cursor-background;
}
}
&:ansi {
color: ansi_default;
& > .tree--guides {
color: ansi_green;
}
&:nocolor > .tree--cursor{
text-style: reverse;
}
}
}
"""
show_root = reactive(True)
"""Show the root of the tree."""
hover_line = var(-1)
"""The line number under the mouse pointer, or -1 if not under the mouse pointer."""
cursor_line = var(-1, always_update=True)
"""The line with the cursor, or -1 if no cursor."""
show_guides = reactive(True)
"""Enable display of tree guide lines."""
guide_depth = reactive(4, init=False)
"""The indent depth of tree nodes."""
auto_expand = var(True)
"""Auto expand tree nodes when they are selected."""
center_scroll = var(False)
"""Keep selected node in the center of the control, where possible."""
LINES: dict[str, tuple[str, str, str, str]] = {
"default": (
" ",
"",
"└─",
"├─",
),
"bold": (
" ",
"",
"┗━",
"┣━",
),
"double": (
" ",
"",
"╚═",
"╠═",
),
}
class NodeCollapsed(Generic[EventTreeDataType], Message):
"""Event sent when a node is collapsed.
Can be handled using `on_tree_node_collapsed` in a subclass of `Tree` or in a
parent node in the DOM.
"""
def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
self.node: TreeNode[EventTreeDataType] = node
"""The node that was collapsed."""
super().__init__()
@property
def control(self) -> Tree[EventTreeDataType]:
"""The tree that sent the message."""
return self.node.tree
class NodeExpanded(Generic[EventTreeDataType], Message):
"""Event sent when a node is expanded.
Can be handled using `on_tree_node_expanded` in a subclass of `Tree` or in a
parent node in the DOM.
"""
def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
self.node: TreeNode[EventTreeDataType] = node
"""The node that was expanded."""
super().__init__()
@property
def control(self) -> Tree[EventTreeDataType]:
"""The tree that sent the message."""
return self.node.tree
class NodeHighlighted(Generic[EventTreeDataType], Message):
"""Event sent when a node is highlighted.
Can be handled using `on_tree_node_highlighted` in a subclass of `Tree` or in a
parent node in the DOM.
"""
def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
self.node: TreeNode[EventTreeDataType] = node
"""The node that was highlighted."""
super().__init__()
@property
def control(self) -> Tree[EventTreeDataType]:
"""The tree that sent the message."""
return self.node.tree
class NodeSelected(Generic[EventTreeDataType], Message):
"""Event sent when a node is selected.
Can be handled using `on_tree_node_selected` in a subclass of `Tree` or in a
parent node in the DOM.
"""
def __init__(self, node: TreeNode[EventTreeDataType]) -> None:
self.node: TreeNode[EventTreeDataType] = node
"""The node that was selected."""
super().__init__()
@property
def control(self) -> Tree[EventTreeDataType]:
"""The tree that sent the message."""
return self.node.tree
def __init__(
self,
label: TextType,
data: TreeDataType | None = None,
*,
name: str | None = None,
id: str | None = None,
classes: str | None = None,
disabled: bool = False,
) -> None:
"""Initialise a Tree.
Args:
label: The label of the root node of the tree.
data: The optional data to associate with the root node of the tree.
name: The name of the Tree.
id: The ID of the tree in the DOM.
classes: The CSS classes of the tree.
disabled: Whether the tree is disabled or not.
"""
text_label = self.process_label(label)
self._updates = 0
self._tree_nodes: dict[NodeID, TreeNode[TreeDataType]] = {}
self._current_id = 0
self.root = self._add_node(None, text_label, data)
"""The root node of the tree."""
self._line_cache: LRUCache[LineCacheKey, Strip] = LRUCache(1024)
self._tree_lines_cached: list[_TreeLine[TreeDataType]] | None = None
self._cursor_node: TreeNode[TreeDataType] | None = None
super().__init__(name=name, id=id, classes=classes, disabled=disabled)
def add_json(self, json_data: object, node: TreeNode | None = None) -> None:
"""Adds JSON data to a node.
Args:
json_data: An object decoded from JSON.
node: Node to add data to.
"""
if node is None:
node = self.root
from rich.highlighter import ReprHighlighter
highlighter = ReprHighlighter()
def add_node(name: str, node: TreeNode, data: object) -> None:
"""Adds a node to the tree.
Args:
name: Name of the node.
node: Parent node.
data: Data associated with the node.
"""
if isinstance(data, dict):
node.set_label(Text(f"{{}} {name}"))
for key, value in data.items():
new_node = node.add("")
add_node(key, new_node, value)
elif isinstance(data, list):
node.set_label(Text(f"[] {name}"))
for index, value in enumerate(data):
new_node = node.add("")
add_node(str(index), new_node, value)
else:
node.allow_expand = False
if name:
label = Text.assemble(
Text.from_markup(f"[b]{name}[/b]="), highlighter(repr(data))
)
else:
label = Text(repr(data))
node.set_label(label)
add_node("", node, json_data)
@property
def cursor_node(self) -> TreeNode[TreeDataType] | None:
"""The currently selected node, or ``None`` if no selection."""
return self._cursor_node
@property
def last_line(self) -> int:
"""The index of the last line."""
return len(self._tree_lines) - 1
def process_label(self, label: TextType) -> Text:
"""Process a `str` or `Text` value into a label.
May be overridden in a subclass to change how labels are rendered.
Args:
label: Label.
Returns:
A Rich Text object.
"""
if isinstance(label, str):
text_label = Text.from_markup(label)
else:
text_label = label
first_line = text_label.split()[0]
return first_line
def _add_node(
self,
parent: TreeNode[TreeDataType] | None,
label: Text,
data: TreeDataType | None,
expand: bool = False,
) -> TreeNode[TreeDataType]:
node = TreeNode(self, parent, self._new_id(), label, data, expanded=expand)
self._tree_nodes[node._id] = node
self._updates += 1
return node
def render_label(
self, node: TreeNode[TreeDataType], base_style: Style, style: Style
) -> Text:
"""Render a label for the given node. Override this to modify how labels are rendered.
Args:
node: A tree node.
base_style: The base style of the widget.
style: The additional style for the label.
Returns:
A Rich Text object containing the label.
"""
node_label = node._label.copy()
node_label.stylize(style)
if node._allow_expand:
prefix = (
self.ICON_NODE_EXPANDED if node.is_expanded else self.ICON_NODE,
base_style + TOGGLE_STYLE,
)
else:
prefix = ("", base_style)
text = Text.assemble(prefix, node_label)
return text
def get_label_width(self, node: TreeNode[TreeDataType]) -> int:
"""Get the width of the nodes label.
The default behavior is to call `render_label` and return the cell length. This method may be
overridden in a sub-class if it can be done more efficiently.
Args:
node: A node.
Returns:
Width in cells.
"""
label = self.render_label(node, NULL_STYLE, NULL_STYLE)
return label.cell_len
def _clear_line_cache(self) -> None:
"""Clear line cache."""
self._line_cache.clear()
self._tree_lines_cached = None
def clear(self) -> Self:
"""Clear all nodes under root.
Returns:
The `Tree` instance.
"""
self._clear_line_cache()
self._current_id = 0
root_label = self.root._label
root_data = self.root.data
root_expanded = self.root.is_expanded
self.root = TreeNode(
self,
None,
self._new_id(),
root_label,
root_data,
expanded=root_expanded,
)
self._updates += 1
self.refresh()
return self
def reset(self, label: TextType, data: TreeDataType | None = None) -> Self:
"""Clear the tree and reset the root node.
Args:
label: The label for the root node.
data: Optional data for the root node.
Returns:
The `Tree` instance.
"""
self.clear()
self.root.label = label
self.root.data = data
return self
def move_cursor(
self, node: TreeNode[TreeDataType] | None, animate: bool = False
) -> None:
"""Move the cursor to the given node, or reset cursor.
Args:
node: A tree node, or None to reset cursor.
animate: Enable animation
"""
previous_cursor_line = self.cursor_line
self.cursor_line = -1 if node is None else node._line
if node is not None and self.cursor_node is not None:
self.scroll_to_node(
self.cursor_node,
animate=animate and abs(self.cursor_line - previous_cursor_line) > 1,
)
def move_cursor_to_line(self, line: int, animate=False) -> None:
"""Move the cursor to the given line.
Args:
line: The line number (negative indexes are offsets from the last line).
animate: Enable scrolling animation.
Raises:
IndexError: If the line doesn't exist.
"""
if self.cursor_line == line:
return
try:
node = self._tree_lines[line].node
except IndexError:
raise IndexError(f"No line no. {line} in the tree")
self.move_cursor(node, animate=animate)
def select_node(self, node: TreeNode[TreeDataType] | None) -> None:
"""Move the cursor to the given node and select it, or reset cursor.
Args:
node: A tree node to move the cursor to and select, or None to reset cursor.
"""
self.move_cursor(node)
if node is not None:
self.post_message(Tree.NodeSelected(node))
def unselect(self) -> None:
"""Hide and reset the cursor."""
self.set_reactive(Tree.cursor_line, -1)
self._invalidate()
@on(NodeSelected)
def _expand_node_on_select(self, event: NodeSelected[TreeDataType]) -> None:
"""When the node is selected, expand the node if `auto_expand` is True."""
node = event.node
if self.auto_expand:
self._toggle_node(node)
def get_node_at_line(self, line_no: int) -> TreeNode[TreeDataType] | None:
"""Get the node for a given line.
Args:
line_no: A line number.
Returns:
A tree node, or ``None`` if there is no node at that line.
"""
try:
line = self._tree_lines[line_no]
except IndexError:
return None
else:
return line.node
def get_node_by_id(self, node_id: NodeID) -> TreeNode[TreeDataType]:
"""Get a tree node by its ID.
Args:
node_id: The ID of the node to get.
Returns:
The node associated with that ID.
Raises:
UnknownNodeID: Raised if the `TreeNode` ID is unknown.
"""
try:
return self._tree_nodes[node_id]
except KeyError:
raise UnknownNodeID(f"Unknown NodeID ({node_id}) in tree") from None
def validate_cursor_line(self, value: int) -> int:
"""Prevent cursor line from going outside of range.
Args:
value: The value to test.
Return:
A valid version of the given value.
"""
return clamp(value, 0, len(self._tree_lines) - 1)
def validate_guide_depth(self, value: int) -> int:
"""Restrict guide depth to reasonable range.
Args:
value: The value to test.
Return:
A valid version of the given value.
"""
return clamp(value, 2, 10)
def _invalidate(self) -> None:
"""Invalidate caches."""
self._clear_line_cache()
self._updates += 1
self.root._reset()
self.refresh(layout=True)
def _on_mouse_move(self, event: events.MouseMove) -> None:
meta = event.style.meta
if meta and "line" in meta:
self.hover_line = meta["line"]
else:
self.hover_line = -1
def _on_leave(self, _: events.Leave) -> None:
# Ensure the hover effect doesn't linger after the mouse leaves.
self.hover_line = -1
def _new_id(self) -> NodeID:
"""Create a new node ID.
Returns:
A unique node ID.
"""
id = self._current_id
self._current_id += 1
return NodeID(id)
def _get_node(self, line: int) -> TreeNode[TreeDataType] | None:
if line == -1:
return None
try:
tree_line = self._tree_lines[line]
except IndexError:
return None
else:
return tree_line.node
def _get_label_region(self, line: int) -> Region | None:
"""Returns the region occupied by the label of the node at line `line`.
This can be used, e.g., when scrolling to that line such that the label
is visible after the scroll.
Args:
line: A line number.
Returns:
The region occupied by the label, or `None` if the
line is not in the tree.
"""
try:
tree_line = self._tree_lines[line]
except IndexError:
return None
region_x = tree_line._get_guide_width(self.guide_depth, self.show_root)
region_width = self.get_label_width(tree_line.node)
return Region(region_x, line, region_width, 1)
def watch_hover_line(self, previous_hover_line: int, hover_line: int) -> None:
previous_node = self._get_node(previous_hover_line)
if previous_node is not None:
self._refresh_node(previous_node)
previous_node._hover = False
node = self._get_node(hover_line)
if node is not None:
self._refresh_node(node)
node._hover = True
def watch_cursor_line(self, previous_line: int, line: int) -> None:
previous_node = self._get_node(previous_line)
node = self._get_node(line)
if self.cursor_node is not None:
self.cursor_node._selected = False
if previous_node is not None:
previous_node._selected = False
if node is not None:
node._selected = True
self._cursor_node = node
else:
self._cursor_node = None
if previous_line == line:
# No change, so no need for refresh
return
# Refresh previous cursor node
if previous_node is not None:
self._refresh_node(previous_node)
# Refresh new node
if node is not None:
self._refresh_node(node)
if previous_node != node:
self.post_message(self.NodeHighlighted(node))
def watch_guide_depth(self, guide_depth: int) -> None:
self._invalidate()
def watch_show_root(self, show_root: bool) -> None:
self.cursor_line = -1
self._invalidate()
def scroll_to_line(self, line: int, animate: bool = True) -> None:
"""Scroll to the given line.
Args:
line: A line number.
animate: Enable animation.
"""
region = self._get_label_region(line)
if region is not None:
self.scroll_to_region(
region,
animate=animate,
force=True,
center=self.center_scroll,
origin_visible=False,
x_axis=False, # Scrolling the X axis is quite jarring, and rarely necessary
)
def scroll_to_node(
self, node: TreeNode[TreeDataType], animate: bool = True
) -> None:
"""Scroll to the given node.
Args:
node: Node to scroll into view.
animate: Animate scrolling.
"""
line = node._line
if line != -1:
self.scroll_to_line(line, animate=animate)
def _refresh_line(self, line: int) -> None:
"""Refresh (repaint) a given line in the tree.
Args:
line: Line number.
"""
region = Region(0, line - self.scroll_offset.y, self.size.width, 1)
self.refresh(region)
def _refresh_node_line(self, line: int) -> None:
node = self._get_node(line)
if node is not None:
self._refresh_node(node)
def _refresh_node(self, node: TreeNode[TreeDataType]) -> None:
"""Refresh a node and all its children.
Args:
node: A tree node.
"""
scroll_y = self.scroll_offset.y
height = self.size.height
visible_lines = self._tree_lines[scroll_y : scroll_y + height]
for line_no, line in enumerate(visible_lines, scroll_y):
if node in line.path:
self._refresh_line(line_no)
@property
def _tree_lines(self) -> list[_TreeLine[TreeDataType]]:
if self._tree_lines_cached is None:
self._build()
assert self._tree_lines_cached is not None
return self._tree_lines_cached
async def _on_idle(self, event: events.Idle) -> None:
"""Check tree needs a rebuild on idle."""
# Property calls build if required
async with self.lock:
self._tree_lines
def _build(self) -> None:
"""Builds the tree by traversing nodes, and creating tree lines."""
TreeLine = _TreeLine
lines: list[_TreeLine[TreeDataType]] = []
add_line = lines.append
root = self.root
def add_node(
path: list[TreeNode[TreeDataType]], node: TreeNode[TreeDataType], last: bool
) -> None:
child_path = [*path, node]
node._line = len(lines)
add_line(TreeLine(child_path, last))
if node._expanded:
for last, child in loop_last(node._children):
add_node(child_path, child, last)
if self.show_root:
add_node([], root, True)
else:
for node in self.root._children:
add_node([], node, True)
self._tree_lines_cached = lines
guide_depth = self.guide_depth
show_root = self.show_root
get_label_width = self.get_label_width
def get_line_width(line: _TreeLine[TreeDataType]) -> int:
return get_label_width(line.node) + line._get_guide_width(
guide_depth, show_root
)
if lines:
width = max([get_line_width(line) for line in lines])
else:
width = self.size.width
self.virtual_size = Size(width, len(lines))
if self.cursor_line != -1:
if self.cursor_node is not None:
self.cursor_line = self.cursor_node._line
if self.cursor_line >= len(lines):
self.cursor_line = -1
def render_lines(self, crop: Region) -> list[Strip]:
self._pseudo_class_state = self.get_pseudo_class_state()
return super().render_lines(crop)
def render_line(self, y: int) -> Strip:
width = self.size.width
scroll_x, scroll_y = self.scroll_offset
style = self.rich_style
return self._render_line(
y + scroll_y,
scroll_x,
scroll_x + width,
style,
)
def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
tree_lines = self._tree_lines
width = self.size.width
if y >= len(tree_lines):
return Strip.blank(width, base_style)
line = tree_lines[y]
is_hover = self.hover_line >= 0 and any(node._hover for node in line.path)
cache_key = (
y,
is_hover,
width,
self._updates,
self._pseudo_class_state,
tuple(node._updates for node in line.path),
)
if cache_key in self._line_cache:
strip = self._line_cache[cache_key]
else:
# Allow tree guides to be explicitly disabled by setting color to transparent
base_hidden = self.get_component_styles("tree--guides").color.a == 0
hover_hidden = self.get_component_styles("tree--guides-hover").color.a == 0
selected_hidden = (
self.get_component_styles("tree--guides-selected").color.a == 0
)
base_guide_style = self.get_component_rich_style(
"tree--guides", partial=True
)
guide_hover_style = base_guide_style + self.get_component_rich_style(
"tree--guides-hover", partial=True
)
guide_selected_style = base_guide_style + self.get_component_rich_style(
"tree--guides-selected", partial=True
)
hover = line.path[0]._hover
selected = line.path[0]._selected and self.has_focus
def get_guides(style: Style, hidden: bool) -> tuple[str, str, str, str]:
"""Get the guide strings for a given style.
Args:
style: A Style object.
hidden: Switch to hide guides (make them invisible).
Returns:
Strings for space, vertical, terminator and cross.
"""
lines: tuple[Iterable[str], Iterable[str], Iterable[str], Iterable[str]]
if self.show_guides and not hidden:
lines = self.LINES["default"]
if style.bold:
lines = self.LINES["bold"]
elif style.underline2:
lines = self.LINES["double"]
else:
lines = (" ", " ", " ", " ")
guide_depth = max(0, self.guide_depth - 2)
guide_lines = tuple(
f"{characters[0]}{characters[1] * guide_depth} "
for characters in lines
)
return cast("tuple[str, str, str, str]", guide_lines)
if is_hover:
line_style = self.get_component_rich_style("tree--highlight-line")
else:
line_style = base_style
line_style += Style(meta={"line": y})
guides = Text(style=line_style)
guides_append = guides.append
guide_style = base_guide_style
hidden = True
for node in line.path[1:]:
hidden = base_hidden
if hover:
guide_style = guide_hover_style
hidden = hover_hidden
if selected:
guide_style = guide_selected_style
hidden = selected_hidden
space, vertical, _, _ = get_guides(guide_style, hidden)
guide = space if node.is_last else vertical
if node != line.path[-1]:
guides_append(guide, style=guide_style)
hover = hover or node._hover
selected = (selected or node._selected) and self.has_focus
if len(line.path) > 1:
_, _, terminator, cross = get_guides(guide_style, hidden)
if line.last:
guides.append(terminator, style=guide_style)
else:
guides.append(cross, style=guide_style)
label_style = self.get_component_rich_style("tree--label", partial=True)
if self.hover_line == y:
label_style += self.get_component_rich_style(
"tree--highlight", partial=True
)
if self.cursor_line == y:
label_style += self.get_component_rich_style(
"tree--cursor", partial=False
)
label = self.render_label(line.path[-1], line_style, label_style).copy()
label.stylize(Style(meta={"node": line.node._id}))
guides.append(label)
segments = list(guides.render(self.app.console))
pad_width = max(self.virtual_size.width, width)
segments = line_pad(segments, 0, pad_width - guides.cell_len, line_style)
strip = self._line_cache[cache_key] = Strip(segments)
strip = strip.crop(x1, x2)
return strip
def _on_resize(self, event: events.Resize) -> None:
self._line_cache.grow(event.size.height)
self._invalidate()
def _toggle_node(self, node: TreeNode[TreeDataType]) -> None:
if not node.allow_expand:
return
if node.is_expanded:
node.collapse()
else:
node.expand()
async def _on_click(self, event: events.Click) -> None:
async with self.lock:
meta = event.style.meta
if "line" in meta:
cursor_line = meta["line"]
if meta.get("toggle", False):
node = self.get_node_at_line(cursor_line)
if node is not None:
self._toggle_node(node)
else:
self.cursor_line = cursor_line
await self.run_action("select_cursor")
def notify_style_update(self) -> None:
super().notify_style_update()
self._invalidate()
def action_cursor_up(self) -> None:
"""Move the cursor up one node."""
if self.cursor_line == -1:
self.cursor_line = self.last_line
else:
self.cursor_line -= 1
self.scroll_to_line(self.cursor_line, animate=False)
def action_cursor_down(self) -> None:
"""Move the cursor down one node."""
if self.cursor_line == -1:
self.cursor_line = 0
else:
self.cursor_line += 1
self.scroll_to_line(self.cursor_line, animate=False)
def action_page_down(self) -> None:
"""Move the cursor down a page's-worth of nodes."""
if self.cursor_line == -1:
self.cursor_line = 0
self.cursor_line += self.scrollable_content_region.height - 1
self.scroll_to_line(self.cursor_line)
def action_page_up(self) -> None:
"""Move the cursor up a page's-worth of nodes."""
if self.cursor_line == -1:
self.cursor_line = self.last_line
self.cursor_line -= self.scrollable_content_region.height - 1
self.scroll_to_line(self.cursor_line)
def action_scroll_home(self) -> None:
"""Move the cursor to the top of the tree."""
self.cursor_line = 0
self.scroll_to_line(self.cursor_line)
def action_scroll_end(self) -> None:
"""Move the cursor to the bottom of the tree.
Note:
Here bottom means vertically, not branch depth.
"""
self.cursor_line = self.last_line
self.scroll_to_line(self.cursor_line)
def action_toggle_node(self) -> None:
"""Toggle the expanded state of the target node."""
try:
line = self._tree_lines[self.cursor_line]
except IndexError:
pass
else:
self._toggle_node(line.path[-1])
def action_select_cursor(self) -> None:
"""Cause a select event for the target node.
Note:
If `auto_expand` is `True` use of this action on a non-leaf node
will cause both an expand/collapse event to occur, as well as a
selected event.
"""
if self.cursor_line < 0:
return
try:
line = self._tree_lines[self.cursor_line]
except IndexError:
pass
else:
node = line.path[-1]
self.post_message(Tree.NodeSelected(node))
def action_cursor_parent(self) -> None:
"""Move the cursor to the parent node."""
cursor_node = self.cursor_node
if cursor_node is not None and cursor_node.parent is not None:
self.move_cursor(cursor_node.parent, animate=True)
def action_cursor_parent_next_sibling(self) -> None:
"""Move the cursor to the parent's next sibling."""
cursor_node = self.cursor_node
if cursor_node is not None and cursor_node.parent is not None:
self.move_cursor(cursor_node.parent.next_sibling, animate=True)
def action_cursor_previous_sibling(self) -> None:
"""Move the cursor to previous sibling, or to the parent if there are no more siblings."""
cursor_node = self.cursor_node
if cursor_node is not None:
previous_sibling = cursor_node.previous_sibling
if previous_sibling is None:
self.move_cursor(cursor_node.parent, animate=True)
else:
self.move_cursor(previous_sibling, animate=True)
def action_cursor_next_sibling(self) -> None:
"""Move the cursor to the next sibling, or to the paren't sibling if there are no more siblings."""
cursor_node = self.cursor_node
if cursor_node is not None:
next_sibling = cursor_node.next_sibling
if next_sibling is None:
if cursor_node.parent is not None:
parent_sibling = cursor_node.parent.next_sibling
self.move_cursor(parent_sibling, animate=True)
else:
self.move_cursor(next_sibling, animate=True)
def action_toggle_expand_all(self) -> None:
"""Expand or collapse all siblings.
If all the siblings are collapsed then they will be expanded.
Otherwise they will all be collapsed.
"""
if self.cursor_node is None or self.cursor_node.parent is None:
return
siblings = self.cursor_node.siblings
cursor_node = self.cursor_node
# If all siblings are collapsed we want to expand them all
if all(child.is_collapsed for child in siblings):
for child in siblings:
if child.allow_expand:
child.expand()
# Otherwise we want to collapse them all
else:
for child in siblings:
if child.allow_expand:
child.collapse()
self.call_after_refresh(self.move_cursor, cursor_node, animate=False)