from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, NamedTuple, Tuple, overload from typing_extensions import Literal, get_args if TYPE_CHECKING: from tree_sitter import Node, Query from textual._cells import cell_len from textual.geometry import Size Newline = Literal["\r\n", "\n", "\r"] """The type representing valid line separators.""" VALID_NEWLINES = set(get_args(Newline)) """The set of valid line separator strings.""" @dataclass class EditResult: """Contains information about an edit that has occurred.""" end_location: Location """The new end Location after the edit is complete.""" replaced_text: str """The text that was replaced.""" @lru_cache(maxsize=1024) def _utf8_encode(text: str) -> bytes: """Encode the input text as utf-8 bytes. The returned encoded bytes may be retrieved from a cache. Args: text: The text to encode. Returns: The utf-8 bytes representing the input string. """ return text.encode("utf-8") def _detect_newline_style(text: str) -> Newline: """Return the newline type used in this document. Args: text: The text to inspect. Returns: The Newline used in the file. """ if "\r\n" in text: # Windows newline return "\r\n" elif "\n" in text: # Unix/Linux/MacOS newline return "\n" elif "\r" in text: # Old MacOS newline return "\r" else: return "\n" # Default to Unix style newline class DocumentBase(ABC): """Describes the minimum functionality a Document implementation must provide in order to be used by the TextArea widget.""" @abstractmethod def replace_range(self, start: Location, end: Location, text: str) -> EditResult: """Replace the text at the given range. Args: start: A tuple (row, column) where the edit starts. end: A tuple (row, column) where the edit ends. text: The text to insert between start and end. Returns: The new end location after the edit is complete. """ @property @abstractmethod def text(self) -> str: """The text from the document as a string.""" @property @abstractmethod def newline(self) -> Newline: """Return the line separator used in the document.""" @property @abstractmethod def lines(self) -> list[str]: """Get the lines of the document as a list of strings. The strings should *not* include newline characters. The newline character used for the document can be retrieved via the newline property. """ @abstractmethod def get_line(self, index: int) -> str: """Returns the line with the given index from the document. This is used in rendering lines, and will be called by the TextArea for each line that is rendered. Args: index: The index of the line in the document. Returns: The str instance representing the line. """ @abstractmethod def get_text_range(self, start: Location, end: Location) -> str: """Get the text that falls between the start and end locations. Args: start: The start location of the selection. end: The end location of the selection. Returns: The text between start (inclusive) and end (exclusive). """ @abstractmethod def get_size(self, indent_width: int) -> Size: """Get the size of the document. The height is generally the number of lines, and the width is generally the maximum cell length of all the lines. Args: indent_width: The width to use for tab characters. Returns: The Size of the document bounding box. """ def query_syntax_tree( self, query: "Query", start_point: tuple[int, int] | None = None, end_point: tuple[int, int] | None = None, ) -> dict[str, list["Node"]]: """Query the tree-sitter syntax tree. The default implementation always returns an empty list. To support querying in a subclass, this must be implemented. Args: query: The tree-sitter Query to perform. start_point: The (row, column byte) to start the query at. end_point: The (row, column byte) to end the query at. Returns: A dict mapping captured node names to lists of Nodes with that name. """ return {} def prepare_query(self, query: str) -> "Query | None": return None @property @abstractmethod def line_count(self) -> int: """Returns the number of lines in the document.""" @property @abstractmethod def start(self) -> Location: """Returns the location of the start of the document (0, 0).""" return (0, 0) @property @abstractmethod def end(self) -> Location: """Returns the location of the end of the document.""" if TYPE_CHECKING: @overload def __getitem__(self, line_index: int) -> str: ... @overload def __getitem__(self, line_index: slice) -> list[str]: ... @abstractmethod def __getitem__(self, line_index: int | slice) -> str | list[str]: """Return the content of a line as a string, excluding newline characters. Args: line_index: The index or slice of the line(s) to retrieve. Returns: The line or list of lines requested. """ class Document(DocumentBase): """A document which can be opened in a TextArea.""" def __init__(self, text: str) -> None: self._newline: Newline = _detect_newline_style(text) """The type of newline used in the text.""" self._lines: list[str] = text.splitlines(keepends=False) """The lines of the document, excluding newline characters. If there's a newline at the end of the file, the final line is an empty string. """ if text.endswith(tuple(VALID_NEWLINES)) or not text: self._lines.append("") @property def lines(self) -> list[str]: """Get the document as a list of strings, where each string represents a line. Newline characters are not included in at the end of the strings. The newline character used in this document can be found via the `Document.newline` property. """ return self._lines @property def text(self) -> str: """Get the text from the document.""" return self._newline.join(self._lines) @property def newline(self) -> Newline: """Get the Newline used in this document (e.g. '\r\n', '\n'. etc.)""" return self._newline def get_size(self, tab_width: int) -> Size: """The Size of the document, taking into account the tab rendering width. Args: tab_width: The width to use for tab indents. Returns: The size (width, height) of the document. """ lines = self._lines cell_lengths = [cell_len(line.expandtabs(tab_width)) for line in lines] max_cell_length = max(cell_lengths, default=0) height = len(lines) return Size(max_cell_length, height) def replace_range(self, start: Location, end: Location, text: str) -> EditResult: """Replace text at the given range. This is the only method by which a document may be updated. Args: start: A tuple (row, column) where the edit starts. end: A tuple (row, column) where the edit ends. text: The text to insert between start and end. Returns: The EditResult containing information about the completed replace operation. """ top, bottom = sorted((start, end)) top_row, top_column = top bottom_row, bottom_column = bottom insert_lines = text.splitlines() if text.endswith(tuple(VALID_NEWLINES)): # Special case where a single newline character is inserted. insert_lines.append("") lines = self._lines replaced_text = self.get_text_range(top, bottom) if bottom_row >= len(lines): after_selection = "" else: after_selection = lines[bottom_row][bottom_column:] if top_row >= len(lines): before_selection = "" else: before_selection = lines[top_row][:top_column] if insert_lines: insert_lines[0] = before_selection + insert_lines[0] destination_column = len(insert_lines[-1]) insert_lines[-1] = insert_lines[-1] + after_selection else: destination_column = len(before_selection) insert_lines = [before_selection + after_selection] lines[top_row : bottom_row + 1] = insert_lines destination_row = top_row + len(insert_lines) - 1 end_location = (destination_row, destination_column) return EditResult(end_location, replaced_text) def get_text_range(self, start: Location, end: Location) -> str: """Get the text that falls between the start and end locations. Returns the text between `start` and `end`, including the appropriate line separator character as specified by `Document._newline`. Note that `_newline` is set automatically to the first line separator character found in the document. Args: start: The start location of the selection. end: The end location of the selection. Returns: The text between start (inclusive) and end (exclusive). """ if start == end: return "" top, bottom = sorted((start, end)) top_row, top_column = top bottom_row, bottom_column = bottom lines = self._lines if top_row == bottom_row: line = lines[top_row] selected_text = line[top_column:bottom_column] else: start_line = lines[top_row] end_line = lines[bottom_row] if bottom_row <= self.line_count - 1 else "" selected_text = start_line[top_column:] for row in range(top_row + 1, bottom_row): selected_text += self._newline + lines[row] if bottom_row < self.line_count: selected_text += self._newline selected_text += end_line[:bottom_column] return selected_text @property def line_count(self) -> int: """Returns the number of lines in the document.""" return len(self._lines) @property def start(self) -> Location: """Returns the location of the start of the document (0, 0).""" return super().start @property def end(self) -> Location: """Returns the location of the end of the document.""" last_line = self._lines[-1] return (self.line_count - 1, len(last_line)) def get_index_from_location(self, location: Location) -> int: """Given a location, returns the index from the document's text. Args: location: The location in the document. Returns: The index in the document's text. """ row, column = location index = row * len(self.newline) + column for line_index in range(row): index += len(self.get_line(line_index)) return index def get_location_from_index(self, index: int) -> Location: """Given a codepoint index in the document's text, returns the corresponding location. Args: index: The index in the document's text. Returns: The corresponding location. Raises: ValueError: If the index is doesn't correspond to a location in the document. """ error_message = ( f"Index {index!r} does not correspond to a location in the document." ) if index < 0 or index > len(self.text): raise ValueError(error_message) column_index = 0 newline_length = len(self.newline) for line_index in range(self.line_count): next_column_index = ( column_index + len(self.get_line(line_index)) + newline_length ) if index < next_column_index: return (line_index, index - column_index) elif index == next_column_index: return (line_index + 1, 0) column_index = next_column_index raise ValueError(error_message) def get_line(self, index: int) -> str: """Returns the line with the given index from the document. Args: index: The index of the line in the document. Returns: The string representing the line. """ line_string = self[index] return line_string @overload def __getitem__(self, line_index: int) -> str: ... @overload def __getitem__(self, line_index: slice) -> list[str]: ... def __getitem__(self, line_index: int | slice) -> str | list[str]: """Return the content of a line as a string, excluding newline characters. Args: line_index: The index or slice of the line(s) to retrieve. Returns: The line or list of lines requested. """ return self._lines[line_index] Location = Tuple[int, int] """A location (row, column) within the document. Indexing starts at 0.""" class Selection(NamedTuple): """A range of characters within a document from a start point to the end point. The location of the cursor is always considered to be the `end` point of the selection. The selection is inclusive of the minimum point and exclusive of the maximum point. """ start: Location = (0, 0) """The start location of the selection. If you were to click and drag a selection inside a text-editor, this is where you *started* dragging. """ end: Location = (0, 0) """The end location of the selection. If you were to click and drag a selection inside a text-editor, this is where you *finished* dragging. """ @classmethod def cursor(cls, location: Location) -> "Selection": """Create a Selection with the same start and end point - a "cursor". Args: location: The location to create the zero-width Selection. """ return cls(location, location) @property def is_empty(self) -> bool: """Return True if the selection has 0 width, i.e. it's just a cursor.""" start, end = self return start == end def contains_line(self, y: int) -> bool: """Check if the given line is within the selection.""" top, bottom = sorted((self.start[0], self.end[0])) return y >= top and y <= bottom