70 lines
1.7 KiB
Python
70 lines
1.7 KiB
Python
import inspect
|
|
import os
|
|
from functools import wraps
|
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
|
|
from literalai.my_types import Environment
|
|
|
|
if TYPE_CHECKING:
|
|
from literalai.client import BaseLiteralClient
|
|
|
|
|
|
class EnvContextManager:
|
|
def __init__(self, client: "BaseLiteralClient", env: Environment = "prod"):
|
|
self.client = client
|
|
self.env = env
|
|
self.original_env = os.environ.get("LITERAL_ENV", "")
|
|
|
|
def __call__(self, func):
|
|
return env_decorator(
|
|
self.client,
|
|
func=func,
|
|
ctx_manager=self,
|
|
)
|
|
|
|
async def __aenter__(self):
|
|
os.environ["LITERAL_ENV"] = self.env
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
os.environ = self.original_env
|
|
|
|
def __enter__(self):
|
|
os.environ["LITERAL_ENV"] = self.env
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
os.environ["LITERAL_ENV"] = self.original_env
|
|
|
|
|
|
def env_decorator(
|
|
client: "BaseLiteralClient",
|
|
func: Callable,
|
|
env: Environment = "prod",
|
|
ctx_manager: Optional[EnvContextManager] = None,
|
|
**decorator_kwargs,
|
|
):
|
|
if not ctx_manager:
|
|
ctx_manager = EnvContextManager(
|
|
client=client,
|
|
env=env,
|
|
**decorator_kwargs,
|
|
)
|
|
|
|
# Handle async decorator
|
|
if inspect.iscoroutinefunction(func):
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args, **kwargs):
|
|
with ctx_manager:
|
|
result = await func(*args, **kwargs)
|
|
return result
|
|
|
|
return async_wrapper
|
|
else:
|
|
# Handle sync decorator
|
|
@wraps(func)
|
|
def sync_wrapper(*args, **kwargs):
|
|
with ctx_manager:
|
|
return func(*args, **kwargs)
|
|
|
|
return sync_wrapper
|