456 lines
14 KiB
Python
456 lines
14 KiB
Python
|
|
"""
|
||
|
|
This module contains the `Worker` class and related objects.
|
||
|
|
|
||
|
|
See the guide for how to use [workers](/guide/workers).
|
||
|
|
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import enum
|
||
|
|
import inspect
|
||
|
|
from contextvars import ContextVar
|
||
|
|
from threading import Event
|
||
|
|
from time import monotonic
|
||
|
|
from typing import (
|
||
|
|
TYPE_CHECKING,
|
||
|
|
Awaitable,
|
||
|
|
Callable,
|
||
|
|
Coroutine,
|
||
|
|
Generic,
|
||
|
|
TypeVar,
|
||
|
|
Union,
|
||
|
|
cast,
|
||
|
|
)
|
||
|
|
|
||
|
|
import rich.repr
|
||
|
|
from typing_extensions import TypeAlias
|
||
|
|
|
||
|
|
from textual.message import Message
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from textual.app import App
|
||
|
|
from textual.dom import DOMNode
|
||
|
|
|
||
|
|
|
||
|
|
active_worker: ContextVar[Worker] = ContextVar("active_worker")
|
||
|
|
"""Currently active worker context var."""
|
||
|
|
|
||
|
|
|
||
|
|
class NoActiveWorker(Exception):
|
||
|
|
"""There is no active worker."""
|
||
|
|
|
||
|
|
|
||
|
|
class WorkerError(Exception):
|
||
|
|
"""A worker related error."""
|
||
|
|
|
||
|
|
|
||
|
|
class WorkerFailed(WorkerError):
|
||
|
|
"""The worker raised an exception and did not complete."""
|
||
|
|
|
||
|
|
def __init__(self, error: BaseException) -> None:
|
||
|
|
self.error = error
|
||
|
|
super().__init__(f"Worker raised exception: {error!r}")
|
||
|
|
|
||
|
|
|
||
|
|
class DeadlockError(WorkerError):
|
||
|
|
"""The operation would result in a deadlock."""
|
||
|
|
|
||
|
|
|
||
|
|
class WorkerCancelled(WorkerError):
|
||
|
|
"""The worker was cancelled and did not complete."""
|
||
|
|
|
||
|
|
|
||
|
|
def get_current_worker() -> Worker:
|
||
|
|
"""Get the currently active worker.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
NoActiveWorker: If there is no active worker.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A Worker instance.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
return active_worker.get()
|
||
|
|
except LookupError:
|
||
|
|
raise NoActiveWorker(
|
||
|
|
"There is no active worker in this task or thread."
|
||
|
|
) from None
|
||
|
|
|
||
|
|
|
||
|
|
class WorkerState(enum.Enum):
|
||
|
|
"""A description of the worker's current state."""
|
||
|
|
|
||
|
|
PENDING = 1
|
||
|
|
"""Worker is initialized, but not running."""
|
||
|
|
RUNNING = 2
|
||
|
|
"""Worker is running."""
|
||
|
|
CANCELLED = 3
|
||
|
|
"""Worker is not running, and was cancelled."""
|
||
|
|
ERROR = 4
|
||
|
|
"""Worker is not running, and exited with an error."""
|
||
|
|
SUCCESS = 5
|
||
|
|
"""Worker is not running, and completed successfully."""
|
||
|
|
|
||
|
|
|
||
|
|
ResultType = TypeVar("ResultType")
|
||
|
|
|
||
|
|
|
||
|
|
WorkType: TypeAlias = Union[
|
||
|
|
Callable[[], Coroutine[None, None, ResultType]],
|
||
|
|
Callable[[], ResultType],
|
||
|
|
Awaitable[ResultType],
|
||
|
|
]
|
||
|
|
"""Type used for [workers](/guide/workers/)."""
|
||
|
|
|
||
|
|
|
||
|
|
class _ReprText:
|
||
|
|
"""Shim to insert a word into the Worker's repr."""
|
||
|
|
|
||
|
|
def __init__(self, text: str) -> None:
|
||
|
|
self.text = text
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
return self.text
|
||
|
|
|
||
|
|
|
||
|
|
@rich.repr.auto(angular=True)
|
||
|
|
class Worker(Generic[ResultType]):
|
||
|
|
"""A class to manage concurrent work (either a task or a thread)."""
|
||
|
|
|
||
|
|
@rich.repr.auto
|
||
|
|
class StateChanged(Message, bubble=False, namespace="worker"):
|
||
|
|
"""The worker state changed."""
|
||
|
|
|
||
|
|
def __init__(self, worker: Worker, state: WorkerState) -> None:
|
||
|
|
"""Initialize the StateChanged message.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
worker: The worker object.
|
||
|
|
state: New state.
|
||
|
|
"""
|
||
|
|
self.worker = worker
|
||
|
|
self.state = state
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
def __rich_repr__(self) -> rich.repr.Result:
|
||
|
|
yield self.worker
|
||
|
|
yield self.state
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
node: DOMNode,
|
||
|
|
work: WorkType,
|
||
|
|
*,
|
||
|
|
name: str = "",
|
||
|
|
group: str = "default",
|
||
|
|
description: str = "",
|
||
|
|
exit_on_error: bool = True,
|
||
|
|
thread: bool = False,
|
||
|
|
) -> None:
|
||
|
|
"""Initialize a Worker.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
node: The widget, screen, or App that initiated the work.
|
||
|
|
work: A callable, coroutine, or other awaitable object to run in the worker.
|
||
|
|
name: Name of the worker (short string to help identify when debugging).
|
||
|
|
group: The worker group.
|
||
|
|
description: Description of the worker (longer string with more details).
|
||
|
|
exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions.
|
||
|
|
thread: Mark the worker as a thread worker.
|
||
|
|
"""
|
||
|
|
self._node = node
|
||
|
|
self._work = work
|
||
|
|
self.name = name
|
||
|
|
self.group = group
|
||
|
|
self.description = (
|
||
|
|
description if len(description) <= 1000 else description[:1000] + "..."
|
||
|
|
)
|
||
|
|
self.exit_on_error = exit_on_error
|
||
|
|
self.cancelled_event: Event = Event()
|
||
|
|
"""A threading event set when the worker is cancelled."""
|
||
|
|
self._thread_worker = thread
|
||
|
|
self._state = WorkerState.PENDING
|
||
|
|
self.state = self._state
|
||
|
|
self._error: BaseException | None = None
|
||
|
|
self._completed_steps: int = 0
|
||
|
|
self._total_steps: int | None = None
|
||
|
|
self._cancelled: bool = False
|
||
|
|
self._created_time = monotonic()
|
||
|
|
self._result: ResultType | None = None
|
||
|
|
self._task: asyncio.Task | None = None
|
||
|
|
self._node.post_message(self.StateChanged(self, self._state))
|
||
|
|
|
||
|
|
def __rich_repr__(self) -> rich.repr.Result:
|
||
|
|
yield _ReprText(self.state.name)
|
||
|
|
yield "name", self.name, ""
|
||
|
|
yield "group", self.group, "default"
|
||
|
|
yield "description", self.description, ""
|
||
|
|
yield "progress", round(self.progress, 1), 0.0
|
||
|
|
|
||
|
|
@property
|
||
|
|
def node(self) -> DOMNode:
|
||
|
|
"""The node where this worker was run from."""
|
||
|
|
return self._node
|
||
|
|
|
||
|
|
@property
|
||
|
|
def state(self) -> WorkerState:
|
||
|
|
"""The current state of the worker."""
|
||
|
|
return self._state
|
||
|
|
|
||
|
|
@state.setter
|
||
|
|
def state(self, state: WorkerState) -> None:
|
||
|
|
"""Set the state, and send a message."""
|
||
|
|
changed = state != self._state
|
||
|
|
self._state = state
|
||
|
|
if changed:
|
||
|
|
self._node.post_message(self.StateChanged(self, state))
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_cancelled(self) -> bool:
|
||
|
|
"""Has the work been cancelled?
|
||
|
|
|
||
|
|
Note that cancelled work may still be running.
|
||
|
|
"""
|
||
|
|
return self._cancelled
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_running(self) -> bool:
|
||
|
|
"""Is the task running?"""
|
||
|
|
return self.state == WorkerState.RUNNING
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_finished(self) -> bool:
|
||
|
|
"""Has the task finished (cancelled, error, or success)?"""
|
||
|
|
return self.state in (
|
||
|
|
WorkerState.CANCELLED,
|
||
|
|
WorkerState.ERROR,
|
||
|
|
WorkerState.SUCCESS,
|
||
|
|
)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def completed_steps(self) -> int:
|
||
|
|
"""The number of completed steps."""
|
||
|
|
return self._completed_steps
|
||
|
|
|
||
|
|
@property
|
||
|
|
def total_steps(self) -> int | None:
|
||
|
|
"""The number of total steps, or None if indeterminate."""
|
||
|
|
return self._total_steps
|
||
|
|
|
||
|
|
@property
|
||
|
|
def progress(self) -> float:
|
||
|
|
"""Progress as a percentage.
|
||
|
|
|
||
|
|
If the total steps is None, then this will return 0. The percentage will be clamped between 0 and 100.
|
||
|
|
"""
|
||
|
|
if not self._total_steps:
|
||
|
|
return 0.0
|
||
|
|
return max(0, min(100, (self._completed_steps / self._total_steps) * 100.0))
|
||
|
|
|
||
|
|
@property
|
||
|
|
def result(self) -> ResultType | None:
|
||
|
|
"""The result of the worker, or `None` if there is no result."""
|
||
|
|
return self._result
|
||
|
|
|
||
|
|
@property
|
||
|
|
def error(self) -> BaseException | None:
|
||
|
|
"""The exception raised by the worker, or `None` if there was no error."""
|
||
|
|
return self._error
|
||
|
|
|
||
|
|
def update(
|
||
|
|
self, completed_steps: int | None = None, total_steps: int | None = -1
|
||
|
|
) -> None:
|
||
|
|
"""Update the number of completed steps.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
completed_steps: The number of completed seps, or `None` to not change.
|
||
|
|
total_steps: The total number of steps, `None` for indeterminate, or -1 to leave unchanged.
|
||
|
|
"""
|
||
|
|
if completed_steps is not None:
|
||
|
|
self._completed_steps += completed_steps
|
||
|
|
if total_steps != -1:
|
||
|
|
self._total_steps = None if total_steps is None else max(0, total_steps)
|
||
|
|
|
||
|
|
def advance(self, steps: int = 1) -> None:
|
||
|
|
"""Advance the number of completed steps.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
steps: Number of steps to advance.
|
||
|
|
"""
|
||
|
|
self._completed_steps += steps
|
||
|
|
|
||
|
|
async def _run_threaded(self) -> ResultType:
|
||
|
|
"""Run a threaded worker.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Return value of the work.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def run_awaitable(work: Awaitable[ResultType]) -> ResultType:
|
||
|
|
"""Set the active worker and await the awaitable."""
|
||
|
|
|
||
|
|
async def do_work() -> ResultType:
|
||
|
|
active_worker.set(self)
|
||
|
|
return await work
|
||
|
|
|
||
|
|
return asyncio.run(do_work())
|
||
|
|
|
||
|
|
def run_coroutine(
|
||
|
|
work: Callable[[], Coroutine[None, None, ResultType]],
|
||
|
|
) -> ResultType:
|
||
|
|
"""Set the active worker and await coroutine."""
|
||
|
|
return run_awaitable(work())
|
||
|
|
|
||
|
|
def run_callable(work: Callable[[], ResultType]) -> ResultType:
|
||
|
|
"""Set the active worker, and call the callable."""
|
||
|
|
active_worker.set(self)
|
||
|
|
return work()
|
||
|
|
|
||
|
|
if (
|
||
|
|
inspect.iscoroutinefunction(self._work)
|
||
|
|
or hasattr(self._work, "func")
|
||
|
|
and inspect.iscoroutinefunction(self._work.func)
|
||
|
|
):
|
||
|
|
runner = run_coroutine
|
||
|
|
elif inspect.isawaitable(self._work):
|
||
|
|
runner = run_awaitable
|
||
|
|
elif callable(self._work):
|
||
|
|
runner = run_callable
|
||
|
|
else:
|
||
|
|
raise WorkerError("Unsupported attempt to run a thread worker")
|
||
|
|
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
assert loop is not None
|
||
|
|
return await loop.run_in_executor(None, runner, self._work)
|
||
|
|
|
||
|
|
async def _run_async(self) -> ResultType:
|
||
|
|
"""Run an async worker.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Return value of the work.
|
||
|
|
"""
|
||
|
|
if (
|
||
|
|
inspect.iscoroutinefunction(self._work)
|
||
|
|
or hasattr(self._work, "func")
|
||
|
|
and inspect.iscoroutinefunction(self._work.func)
|
||
|
|
):
|
||
|
|
return await self._work()
|
||
|
|
elif inspect.isawaitable(self._work):
|
||
|
|
return await self._work
|
||
|
|
elif callable(self._work):
|
||
|
|
raise WorkerError("Request to run a non-async function as an async worker")
|
||
|
|
raise WorkerError("Unsupported attempt to run an async worker")
|
||
|
|
|
||
|
|
async def run(self) -> ResultType:
|
||
|
|
"""Run the work.
|
||
|
|
|
||
|
|
Implement this method in a subclass, or pass a callable to the constructor.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Return value of the work.
|
||
|
|
"""
|
||
|
|
return await (
|
||
|
|
self._run_threaded() if self._thread_worker else self._run_async()
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _run(self, app: App) -> None:
|
||
|
|
"""Run the worker.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
app: App instance.
|
||
|
|
"""
|
||
|
|
with app._context():
|
||
|
|
active_worker.set(self)
|
||
|
|
|
||
|
|
self.state = WorkerState.RUNNING
|
||
|
|
app.log.worker(self)
|
||
|
|
try:
|
||
|
|
self._result = await self.run()
|
||
|
|
except asyncio.CancelledError as error:
|
||
|
|
self.state = WorkerState.CANCELLED
|
||
|
|
self._error = error
|
||
|
|
app.log.worker(self)
|
||
|
|
except Exception as error:
|
||
|
|
self.state = WorkerState.ERROR
|
||
|
|
self._error = error
|
||
|
|
app.log.worker(self, "failed", repr(error))
|
||
|
|
from rich.traceback import Traceback
|
||
|
|
|
||
|
|
app.log.worker(Traceback())
|
||
|
|
if self.exit_on_error:
|
||
|
|
worker_failed = WorkerFailed(self._error)
|
||
|
|
app._handle_exception(worker_failed)
|
||
|
|
else:
|
||
|
|
self.state = WorkerState.SUCCESS
|
||
|
|
app.log.worker(self)
|
||
|
|
|
||
|
|
def _start(
|
||
|
|
self, app: App, done_callback: Callable[[Worker], None] | None = None
|
||
|
|
) -> None:
|
||
|
|
"""Start the worker.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
app: An app instance.
|
||
|
|
done_callback: A callback to call when the task is done.
|
||
|
|
"""
|
||
|
|
if self._task is not None:
|
||
|
|
return
|
||
|
|
self.state = WorkerState.RUNNING
|
||
|
|
self._task = asyncio.create_task(self._run(app))
|
||
|
|
|
||
|
|
def task_done_callback(_task: asyncio.Task) -> None:
|
||
|
|
"""Run the callback.
|
||
|
|
|
||
|
|
Called by `Task.add_done_callback`.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
The worker's task.
|
||
|
|
"""
|
||
|
|
if done_callback is not None:
|
||
|
|
done_callback(self)
|
||
|
|
|
||
|
|
self._task.add_done_callback(task_done_callback)
|
||
|
|
|
||
|
|
def cancel(self) -> None:
|
||
|
|
"""Cancel the task."""
|
||
|
|
self._cancelled = True
|
||
|
|
if self._task is not None:
|
||
|
|
self._task.cancel()
|
||
|
|
self.cancelled_event.set()
|
||
|
|
|
||
|
|
async def wait(self) -> ResultType:
|
||
|
|
"""Wait for the work to complete.
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
WorkerFailed: If the Worker raised an exception.
|
||
|
|
WorkerCancelled: If the Worker was cancelled before it completed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The return value of the work.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
if active_worker.get() is self:
|
||
|
|
raise DeadlockError(
|
||
|
|
"Can't call worker.wait from within the worker function!"
|
||
|
|
)
|
||
|
|
except LookupError:
|
||
|
|
# Not in a worker
|
||
|
|
pass
|
||
|
|
|
||
|
|
if self.state == WorkerState.PENDING:
|
||
|
|
raise WorkerError("Worker must be started before calling this method.")
|
||
|
|
if self._task is not None:
|
||
|
|
try:
|
||
|
|
await self._task
|
||
|
|
except asyncio.CancelledError as error:
|
||
|
|
self.state = WorkerState.CANCELLED
|
||
|
|
self._error = error
|
||
|
|
if self.state == WorkerState.ERROR:
|
||
|
|
assert self._error is not None
|
||
|
|
raise WorkerFailed(self._error)
|
||
|
|
elif self.state == WorkerState.CANCELLED:
|
||
|
|
raise WorkerCancelled("Worker was cancelled, and did not complete.")
|
||
|
|
return cast("ResultType", self._result)
|