from __future__ import annotations from collections import defaultdict from itertools import product from typing import Generic, Iterable, TypeVar from typing_extensions import TypeAlias from textual.geometry import Offset, Region ValueType = TypeVar("ValueType") GridCoordinate: TypeAlias = "tuple[int, int]" class SpatialMap(Generic[ValueType]): """A spatial map allows for data to be associated with rectangular regions in Euclidean space, and efficiently queried. When the SpatialMap is populated, a reference to each value is placed into one or more buckets associated with a regular grid that covers 2D space. The SpatialMap is able to quickly retrieve the values under a given "window" region by combining the values in the grid squares under the visible area. """ def __init__(self, grid_width: int = 100, grid_height: int = 20) -> None: """Create a spatial map with the given grid size. Args: grid_width: Width of a grid square. grid_height: Height of a grid square. """ self._grid_size = (grid_width, grid_height) self.total_region = Region() self._map: defaultdict[GridCoordinate, list[ValueType]] = defaultdict(list) self._fixed: list[ValueType] = [] def _region_to_grid_coordinates(self, region: Region) -> Iterable[GridCoordinate]: """Get the grid squares under a region. Args: region: A region. Returns: Iterable of grid coordinates (tuple of 2 values). """ # (x1, y1) is the coordinate of the top left cell # (x2, y2) is the coordinate of the bottom right cell x1, y1, width, height = region x2 = x1 + width - 1 y2 = y1 + height - 1 grid_width, grid_height = self._grid_size return product( range(x1 // grid_width, x2 // grid_width + 1), range(y1 // grid_height, y2 // grid_height + 1), ) def insert( self, regions_and_values: Iterable[tuple[Region, Offset, bool, bool, ValueType]] ) -> None: """Insert values into the Spatial map. Values are associated with their region in Euclidean space, and a boolean that indicates fixed regions. Fixed regions don't scroll and are always visible. Args: regions_and_values: An iterable of (REGION, OFFSET, FIXED, OVERLAY, VALUE). """ append_fixed = self._fixed.append get_grid_list = self._map.__getitem__ _region_to_grid = self._region_to_grid_coordinates total_region = self.total_region for region, offset, fixed, overlay, value in regions_and_values: if fixed: append_fixed(value) else: if not overlay: total_region = total_region.union(region) for grid in _region_to_grid(region + offset): get_grid_list(grid).append(value) self.total_region = total_region def get_values_in_region(self, region: Region) -> list[ValueType]: """Get a superset of all the values that intersect with a given region. Note that this may return false positives. Args: region: A region. Returns: Values under the region. """ results: list[ValueType] = self._fixed.copy() add_results = results.extend get_grid_values = self._map.get for grid_coordinate in self._region_to_grid_coordinates(region): grid_values = get_grid_values(grid_coordinate) if grid_values is not None: add_results(grid_values) unique_values = list(dict.fromkeys(results)) return unique_values