ai-station/.venv/lib/python3.12/site-packages/textual/worker.py

456 lines
14 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
"""
This module contains the `Worker` class and related objects.
See the guide for how to use [workers](/guide/workers).
"""
from __future__ import annotations
import asyncio
import enum
import inspect
from contextvars import ContextVar
from threading import Event
from time import monotonic
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Coroutine,
Generic,
TypeVar,
Union,
cast,
)
import rich.repr
from typing_extensions import TypeAlias
from textual.message import Message
if TYPE_CHECKING:
from textual.app import App
from textual.dom import DOMNode
active_worker: ContextVar[Worker] = ContextVar("active_worker")
"""Currently active worker context var."""
class NoActiveWorker(Exception):
"""There is no active worker."""
class WorkerError(Exception):
"""A worker related error."""
class WorkerFailed(WorkerError):
"""The worker raised an exception and did not complete."""
def __init__(self, error: BaseException) -> None:
self.error = error
super().__init__(f"Worker raised exception: {error!r}")
class DeadlockError(WorkerError):
"""The operation would result in a deadlock."""
class WorkerCancelled(WorkerError):
"""The worker was cancelled and did not complete."""
def get_current_worker() -> Worker:
"""Get the currently active worker.
Raises:
NoActiveWorker: If there is no active worker.
Returns:
A Worker instance.
"""
try:
return active_worker.get()
except LookupError:
raise NoActiveWorker(
"There is no active worker in this task or thread."
) from None
class WorkerState(enum.Enum):
"""A description of the worker's current state."""
PENDING = 1
"""Worker is initialized, but not running."""
RUNNING = 2
"""Worker is running."""
CANCELLED = 3
"""Worker is not running, and was cancelled."""
ERROR = 4
"""Worker is not running, and exited with an error."""
SUCCESS = 5
"""Worker is not running, and completed successfully."""
ResultType = TypeVar("ResultType")
WorkType: TypeAlias = Union[
Callable[[], Coroutine[None, None, ResultType]],
Callable[[], ResultType],
Awaitable[ResultType],
]
"""Type used for [workers](/guide/workers/)."""
class _ReprText:
"""Shim to insert a word into the Worker's repr."""
def __init__(self, text: str) -> None:
self.text = text
def __repr__(self) -> str:
return self.text
@rich.repr.auto(angular=True)
class Worker(Generic[ResultType]):
"""A class to manage concurrent work (either a task or a thread)."""
@rich.repr.auto
class StateChanged(Message, bubble=False, namespace="worker"):
"""The worker state changed."""
def __init__(self, worker: Worker, state: WorkerState) -> None:
"""Initialize the StateChanged message.
Args:
worker: The worker object.
state: New state.
"""
self.worker = worker
self.state = state
super().__init__()
def __rich_repr__(self) -> rich.repr.Result:
yield self.worker
yield self.state
def __init__(
self,
node: DOMNode,
work: WorkType,
*,
name: str = "",
group: str = "default",
description: str = "",
exit_on_error: bool = True,
thread: bool = False,
) -> None:
"""Initialize a Worker.
Args:
node: The widget, screen, or App that initiated the work.
work: A callable, coroutine, or other awaitable object to run in the worker.
name: Name of the worker (short string to help identify when debugging).
group: The worker group.
description: Description of the worker (longer string with more details).
exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions.
thread: Mark the worker as a thread worker.
"""
self._node = node
self._work = work
self.name = name
self.group = group
self.description = (
description if len(description) <= 1000 else description[:1000] + "..."
)
self.exit_on_error = exit_on_error
self.cancelled_event: Event = Event()
"""A threading event set when the worker is cancelled."""
self._thread_worker = thread
self._state = WorkerState.PENDING
self.state = self._state
self._error: BaseException | None = None
self._completed_steps: int = 0
self._total_steps: int | None = None
self._cancelled: bool = False
self._created_time = monotonic()
self._result: ResultType | None = None
self._task: asyncio.Task | None = None
self._node.post_message(self.StateChanged(self, self._state))
def __rich_repr__(self) -> rich.repr.Result:
yield _ReprText(self.state.name)
yield "name", self.name, ""
yield "group", self.group, "default"
yield "description", self.description, ""
yield "progress", round(self.progress, 1), 0.0
@property
def node(self) -> DOMNode:
"""The node where this worker was run from."""
return self._node
@property
def state(self) -> WorkerState:
"""The current state of the worker."""
return self._state
@state.setter
def state(self, state: WorkerState) -> None:
"""Set the state, and send a message."""
changed = state != self._state
self._state = state
if changed:
self._node.post_message(self.StateChanged(self, state))
@property
def is_cancelled(self) -> bool:
"""Has the work been cancelled?
Note that cancelled work may still be running.
"""
return self._cancelled
@property
def is_running(self) -> bool:
"""Is the task running?"""
return self.state == WorkerState.RUNNING
@property
def is_finished(self) -> bool:
"""Has the task finished (cancelled, error, or success)?"""
return self.state in (
WorkerState.CANCELLED,
WorkerState.ERROR,
WorkerState.SUCCESS,
)
@property
def completed_steps(self) -> int:
"""The number of completed steps."""
return self._completed_steps
@property
def total_steps(self) -> int | None:
"""The number of total steps, or None if indeterminate."""
return self._total_steps
@property
def progress(self) -> float:
"""Progress as a percentage.
If the total steps is None, then this will return 0. The percentage will be clamped between 0 and 100.
"""
if not self._total_steps:
return 0.0
return max(0, min(100, (self._completed_steps / self._total_steps) * 100.0))
@property
def result(self) -> ResultType | None:
"""The result of the worker, or `None` if there is no result."""
return self._result
@property
def error(self) -> BaseException | None:
"""The exception raised by the worker, or `None` if there was no error."""
return self._error
def update(
self, completed_steps: int | None = None, total_steps: int | None = -1
) -> None:
"""Update the number of completed steps.
Args:
completed_steps: The number of completed seps, or `None` to not change.
total_steps: The total number of steps, `None` for indeterminate, or -1 to leave unchanged.
"""
if completed_steps is not None:
self._completed_steps += completed_steps
if total_steps != -1:
self._total_steps = None if total_steps is None else max(0, total_steps)
def advance(self, steps: int = 1) -> None:
"""Advance the number of completed steps.
Args:
steps: Number of steps to advance.
"""
self._completed_steps += steps
async def _run_threaded(self) -> ResultType:
"""Run a threaded worker.
Returns:
Return value of the work.
"""
def run_awaitable(work: Awaitable[ResultType]) -> ResultType:
"""Set the active worker and await the awaitable."""
async def do_work() -> ResultType:
active_worker.set(self)
return await work
return asyncio.run(do_work())
def run_coroutine(
work: Callable[[], Coroutine[None, None, ResultType]],
) -> ResultType:
"""Set the active worker and await coroutine."""
return run_awaitable(work())
def run_callable(work: Callable[[], ResultType]) -> ResultType:
"""Set the active worker, and call the callable."""
active_worker.set(self)
return work()
if (
inspect.iscoroutinefunction(self._work)
or hasattr(self._work, "func")
and inspect.iscoroutinefunction(self._work.func)
):
runner = run_coroutine
elif inspect.isawaitable(self._work):
runner = run_awaitable
elif callable(self._work):
runner = run_callable
else:
raise WorkerError("Unsupported attempt to run a thread worker")
loop = asyncio.get_running_loop()
assert loop is not None
return await loop.run_in_executor(None, runner, self._work)
async def _run_async(self) -> ResultType:
"""Run an async worker.
Returns:
Return value of the work.
"""
if (
inspect.iscoroutinefunction(self._work)
or hasattr(self._work, "func")
and inspect.iscoroutinefunction(self._work.func)
):
return await self._work()
elif inspect.isawaitable(self._work):
return await self._work
elif callable(self._work):
raise WorkerError("Request to run a non-async function as an async worker")
raise WorkerError("Unsupported attempt to run an async worker")
async def run(self) -> ResultType:
"""Run the work.
Implement this method in a subclass, or pass a callable to the constructor.
Returns:
Return value of the work.
"""
return await (
self._run_threaded() if self._thread_worker else self._run_async()
)
async def _run(self, app: App) -> None:
"""Run the worker.
Args:
app: App instance.
"""
with app._context():
active_worker.set(self)
self.state = WorkerState.RUNNING
app.log.worker(self)
try:
self._result = await self.run()
except asyncio.CancelledError as error:
self.state = WorkerState.CANCELLED
self._error = error
app.log.worker(self)
except Exception as error:
self.state = WorkerState.ERROR
self._error = error
app.log.worker(self, "failed", repr(error))
from rich.traceback import Traceback
app.log.worker(Traceback())
if self.exit_on_error:
worker_failed = WorkerFailed(self._error)
app._handle_exception(worker_failed)
else:
self.state = WorkerState.SUCCESS
app.log.worker(self)
def _start(
self, app: App, done_callback: Callable[[Worker], None] | None = None
) -> None:
"""Start the worker.
Args:
app: An app instance.
done_callback: A callback to call when the task is done.
"""
if self._task is not None:
return
self.state = WorkerState.RUNNING
self._task = asyncio.create_task(self._run(app))
def task_done_callback(_task: asyncio.Task) -> None:
"""Run the callback.
Called by `Task.add_done_callback`.
Args:
The worker's task.
"""
if done_callback is not None:
done_callback(self)
self._task.add_done_callback(task_done_callback)
def cancel(self) -> None:
"""Cancel the task."""
self._cancelled = True
if self._task is not None:
self._task.cancel()
self.cancelled_event.set()
async def wait(self) -> ResultType:
"""Wait for the work to complete.
Raises:
WorkerFailed: If the Worker raised an exception.
WorkerCancelled: If the Worker was cancelled before it completed.
Returns:
The return value of the work.
"""
try:
if active_worker.get() is self:
raise DeadlockError(
"Can't call worker.wait from within the worker function!"
)
except LookupError:
# Not in a worker
pass
if self.state == WorkerState.PENDING:
raise WorkerError("Worker must be started before calling this method.")
if self._task is not None:
try:
await self._task
except asyncio.CancelledError as error:
self.state = WorkerState.CANCELLED
self._error = error
if self.state == WorkerState.ERROR:
assert self._error is not None
raise WorkerFailed(self._error)
elif self.state == WorkerState.CANCELLED:
raise WorkerCancelled("Worker was cancelled, and did not complete.")
return cast("ResultType", self._result)