851 lines
30 KiB
Python
851 lines
30 KiB
Python
|
|
import base64
|
||
|
|
import os
|
||
|
|
import urllib.parse
|
||
|
|
from typing import Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from fastapi import HTTPException
|
||
|
|
|
||
|
|
from chainlit.secret import random_secret
|
||
|
|
from chainlit.user import User
|
||
|
|
|
||
|
|
ACCESS_TOKEN_MISSING = "Access token missing in the response"
|
||
|
|
|
||
|
|
|
||
|
|
class OAuthProvider:
|
||
|
|
id: str
|
||
|
|
env: List[str]
|
||
|
|
client_id: str
|
||
|
|
client_secret: str
|
||
|
|
authorize_url: str
|
||
|
|
authorize_params: Dict[str, str]
|
||
|
|
default_prompt: Optional[str] = None
|
||
|
|
|
||
|
|
def is_configured(self):
|
||
|
|
return all([os.environ.get(env) for env in self.env])
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str) -> str:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str) -> Tuple[Dict[str, str], User]:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def get_env_prefix(self) -> str:
|
||
|
|
"""Return environment prefix, like AZURE_AD."""
|
||
|
|
|
||
|
|
return self.id.replace("-", "_").upper()
|
||
|
|
|
||
|
|
def get_prompt(self) -> Optional[str]:
|
||
|
|
"""Return OAuth prompt param."""
|
||
|
|
if prompt := os.environ.get(f"OAUTH_{self.get_env_prefix()}_PROMPT"):
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
if prompt := os.environ.get("OAUTH_PROMPT"):
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
return self.default_prompt
|
||
|
|
|
||
|
|
|
||
|
|
class GithubOAuthProvider(OAuthProvider):
|
||
|
|
id = "github"
|
||
|
|
env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"]
|
||
|
|
authorize_url = os.environ.get(
|
||
|
|
"OAUTH_GITHUB_AUTH_URL", "https://github.com/login/oauth/authorize"
|
||
|
|
)
|
||
|
|
token_url = os.environ.get(
|
||
|
|
"OAUTH_GITHUB_TOKEN_URL", "https://github.com/login/oauth/access_token"
|
||
|
|
)
|
||
|
|
user_info_url = os.environ.get(
|
||
|
|
"OAUTH_GITHUB_USER_INFO_URL", "https://api.github.com/user"
|
||
|
|
)
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET")
|
||
|
|
self.authorize_params = {
|
||
|
|
"scope": "user:email",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> Dict[str, List[str]]:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
self.token_url,
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return urllib.parse.parse_qs(response.text)
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
content = await self.get_raw_token_response(code, url)
|
||
|
|
token = content.get("access_token", [""])[0]
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
user_response = await client.get(
|
||
|
|
self.user_info_url,
|
||
|
|
headers={"Authorization": f"token {token}"},
|
||
|
|
)
|
||
|
|
user_response.raise_for_status()
|
||
|
|
github_user = user_response.json()
|
||
|
|
|
||
|
|
emails_response = await client.get(
|
||
|
|
urllib.parse.urljoin(self.user_info_url + "/", "emails"),
|
||
|
|
headers={"Authorization": f"token {token}"},
|
||
|
|
)
|
||
|
|
emails_response.raise_for_status()
|
||
|
|
emails = emails_response.json()
|
||
|
|
|
||
|
|
github_user.update({"emails": emails})
|
||
|
|
user = User(
|
||
|
|
identifier=github_user["login"],
|
||
|
|
metadata={"image": github_user["avatar_url"], "provider": "github"},
|
||
|
|
)
|
||
|
|
return (github_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class GoogleOAuthProvider(OAuthProvider):
|
||
|
|
id = "google"
|
||
|
|
env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"]
|
||
|
|
authorize_url = "https://accounts.google.com/o/oauth2/v2/auth"
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET")
|
||
|
|
self.authorize_params = {
|
||
|
|
"scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email",
|
||
|
|
"response_type": "code",
|
||
|
|
"access_type": "offline",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
"https://oauth2.googleapis.com/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
token = json.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
"https://www.googleapis.com/userinfo/v2/me",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
google_user = response.json()
|
||
|
|
user = User(
|
||
|
|
identifier=google_user["email"],
|
||
|
|
metadata={"image": google_user["picture"], "provider": "google"},
|
||
|
|
)
|
||
|
|
return (google_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class AzureADOAuthProvider(OAuthProvider):
|
||
|
|
id = "azure-ad"
|
||
|
|
env = [
|
||
|
|
"OAUTH_AZURE_AD_CLIENT_ID",
|
||
|
|
"OAUTH_AZURE_AD_CLIENT_SECRET",
|
||
|
|
"OAUTH_AZURE_AD_TENANT_ID",
|
||
|
|
]
|
||
|
|
authorize_url = (
|
||
|
|
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize"
|
||
|
|
if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
|
||
|
|
else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||
|
|
)
|
||
|
|
token_url = (
|
||
|
|
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token"
|
||
|
|
if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT")
|
||
|
|
else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||
|
|
)
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET")
|
||
|
|
self.authorize_params = {
|
||
|
|
"tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"),
|
||
|
|
"response_type": "code",
|
||
|
|
"scope": "https://graph.microsoft.com/User.Read offline_access",
|
||
|
|
"response_mode": "query",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
self.token_url,
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
|
||
|
|
token = json["access_token"]
|
||
|
|
refresh_token = json.get("refresh_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
self._refresh_token = refresh_token
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
"https://graph.microsoft.com/v1.0/me",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
|
||
|
|
azure_user = response.json()
|
||
|
|
|
||
|
|
try:
|
||
|
|
photo_response = await client.get(
|
||
|
|
"https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
photo_data = await photo_response.aread()
|
||
|
|
base64_image = base64.b64encode(photo_data)
|
||
|
|
azure_user["image"] = (
|
||
|
|
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
# Ignore errors getting the photo
|
||
|
|
pass
|
||
|
|
|
||
|
|
user = User(
|
||
|
|
identifier=azure_user["userPrincipalName"],
|
||
|
|
metadata={
|
||
|
|
"image": azure_user.get("image"),
|
||
|
|
"provider": "azure-ad",
|
||
|
|
"refresh_token": getattr(self, "_refresh_token", None),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (azure_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class AzureADHybridOAuthProvider(OAuthProvider):
|
||
|
|
id = "azure-ad-hybrid"
|
||
|
|
env = [
|
||
|
|
"OAUTH_AZURE_AD_HYBRID_CLIENT_ID",
|
||
|
|
"OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET",
|
||
|
|
"OAUTH_AZURE_AD_HYBRID_TENANT_ID",
|
||
|
|
]
|
||
|
|
authorize_url = (
|
||
|
|
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/authorize"
|
||
|
|
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
||
|
|
else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||
|
|
)
|
||
|
|
token_url = (
|
||
|
|
f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_HYBRID_TENANT_ID', '')}/oauth2/v2.0/token"
|
||
|
|
if os.environ.get("OAUTH_AZURE_AD_HYBRID_ENABLE_SINGLE_TENANT")
|
||
|
|
else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||
|
|
)
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_AZURE_AD_HYBRID_CLIENT_SECRET")
|
||
|
|
nonce = random_secret(16)
|
||
|
|
self.authorize_params = {
|
||
|
|
"tenant": os.environ.get("OAUTH_AZURE_AD_HYBRID_TENANT_ID"),
|
||
|
|
"response_type": "code id_token",
|
||
|
|
"scope": "https://graph.microsoft.com/User.Read https://graph.microsoft.com/openid offline_access",
|
||
|
|
"response_mode": "form_post",
|
||
|
|
"nonce": nonce,
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
self.token_url,
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
|
||
|
|
token = json["access_token"]
|
||
|
|
refresh_token = json.get("refresh_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
self._refresh_token = refresh_token
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
"https://graph.microsoft.com/v1.0/me",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
|
||
|
|
azure_user = response.json()
|
||
|
|
|
||
|
|
try:
|
||
|
|
photo_response = await client.get(
|
||
|
|
"https://graph.microsoft.com/v1.0/me/photos/48x48/$value",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
photo_data = await photo_response.aread()
|
||
|
|
base64_image = base64.b64encode(photo_data)
|
||
|
|
azure_user["image"] = (
|
||
|
|
f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}"
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
# Ignore errors getting the photo
|
||
|
|
pass
|
||
|
|
|
||
|
|
user = User(
|
||
|
|
identifier=azure_user["userPrincipalName"],
|
||
|
|
metadata={
|
||
|
|
"image": azure_user.get("image"),
|
||
|
|
"provider": "azure-ad",
|
||
|
|
"refresh_token": getattr(self, "_refresh_token", None),
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (azure_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class OktaOAuthProvider(OAuthProvider):
|
||
|
|
id = "okta"
|
||
|
|
env = [
|
||
|
|
"OAUTH_OKTA_CLIENT_ID",
|
||
|
|
"OAUTH_OKTA_CLIENT_SECRET",
|
||
|
|
"OAUTH_OKTA_DOMAIN",
|
||
|
|
]
|
||
|
|
# Avoid trailing slash in domain if supplied
|
||
|
|
domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}"
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET")
|
||
|
|
self.authorization_server_id = os.environ.get(
|
||
|
|
"OAUTH_OKTA_AUTHORIZATION_SERVER_ID", ""
|
||
|
|
)
|
||
|
|
self.authorize_url = (
|
||
|
|
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize"
|
||
|
|
)
|
||
|
|
self.authorize_params = {
|
||
|
|
"response_type": "code",
|
||
|
|
"scope": "openid profile email",
|
||
|
|
"response_mode": "query",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
def get_authorization_server_path(self):
|
||
|
|
if not self.authorization_server_id:
|
||
|
|
return "/default"
|
||
|
|
if self.authorization_server_id == "false":
|
||
|
|
return ""
|
||
|
|
return f"/{self.authorization_server_id}"
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json_data = await self.get_raw_token_response(code, url)
|
||
|
|
token = json_data.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
okta_user = response.json()
|
||
|
|
|
||
|
|
user = User(
|
||
|
|
identifier=okta_user.get("email"),
|
||
|
|
metadata={"image": "", "provider": "okta"},
|
||
|
|
)
|
||
|
|
return (okta_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class Auth0OAuthProvider(OAuthProvider):
|
||
|
|
id = "auth0"
|
||
|
|
env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"]
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET")
|
||
|
|
# Ensure that the domain does not have a trailing slash
|
||
|
|
self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}"
|
||
|
|
self.original_domain = (
|
||
|
|
f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}"
|
||
|
|
if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN")
|
||
|
|
else self.domain
|
||
|
|
)
|
||
|
|
|
||
|
|
self.authorize_url = f"{self.domain}/authorize"
|
||
|
|
|
||
|
|
self.authorize_params = {
|
||
|
|
"response_type": "code",
|
||
|
|
"scope": "openid profile email",
|
||
|
|
"audience": f"{self.original_domain}/userinfo",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.domain}/oauth/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json_content = await self.get_raw_token_response(code, url)
|
||
|
|
token = json_content.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.original_domain}/userinfo",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
auth0_user = response.json()
|
||
|
|
user = User(
|
||
|
|
identifier=auth0_user.get("email"),
|
||
|
|
metadata={
|
||
|
|
"image": auth0_user.get("picture", ""),
|
||
|
|
"provider": "auth0",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (auth0_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class DescopeOAuthProvider(OAuthProvider):
|
||
|
|
id = "descope"
|
||
|
|
env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"]
|
||
|
|
# Ensure that the domain does not have a trailing slash
|
||
|
|
domain = "https://api.descope.com/oauth2/v1"
|
||
|
|
|
||
|
|
authorize_url = f"{domain}/authorize"
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET")
|
||
|
|
self.authorize_params = {
|
||
|
|
"response_type": "code",
|
||
|
|
"scope": "openid profile email",
|
||
|
|
"audience": f"{self.domain}/userinfo",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.domain}/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json_content = await self.get_raw_token_response(code, url)
|
||
|
|
token = json_content.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"}
|
||
|
|
)
|
||
|
|
response.raise_for_status() # This will raise an exception for 4xx/5xx responses
|
||
|
|
descope_user = response.json()
|
||
|
|
|
||
|
|
user = User(
|
||
|
|
identifier=descope_user.get("email"),
|
||
|
|
metadata={"image": "", "provider": "descope"},
|
||
|
|
)
|
||
|
|
return (descope_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class AWSCognitoOAuthProvider(OAuthProvider):
|
||
|
|
id = "aws-cognito"
|
||
|
|
env = [
|
||
|
|
"OAUTH_COGNITO_CLIENT_ID",
|
||
|
|
"OAUTH_COGNITO_CLIENT_SECRET",
|
||
|
|
"OAUTH_COGNITO_DOMAIN",
|
||
|
|
]
|
||
|
|
authorize_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/login"
|
||
|
|
token_url = f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/token"
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_COGNITO_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_COGNITO_CLIENT_SECRET")
|
||
|
|
self.scopes = os.environ.get("OAUTH_COGNITO_SCOPE", "openid profile email")
|
||
|
|
self.authorize_params = {
|
||
|
|
"response_type": "code",
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"scope": self.scopes,
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
self.token_url,
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
token = json.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
user_info_url = (
|
||
|
|
f"https://{os.environ.get('OAUTH_COGNITO_DOMAIN')}/oauth2/userInfo"
|
||
|
|
)
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
user_info_url,
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
|
||
|
|
cognito_user = response.json()
|
||
|
|
|
||
|
|
# Customize user metadata as needed
|
||
|
|
user = User(
|
||
|
|
identifier=cognito_user["email"],
|
||
|
|
metadata={
|
||
|
|
"image": cognito_user.get("picture", ""),
|
||
|
|
"provider": "aws-cognito",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (cognito_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class GitlabOAuthProvider(OAuthProvider):
|
||
|
|
id = "gitlab"
|
||
|
|
env = [
|
||
|
|
"OAUTH_GITLAB_CLIENT_ID",
|
||
|
|
"OAUTH_GITLAB_CLIENT_SECRET",
|
||
|
|
"OAUTH_GITLAB_DOMAIN",
|
||
|
|
]
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_GITLAB_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_GITLAB_CLIENT_SECRET")
|
||
|
|
# Ensure that the domain does not have a trailing slash
|
||
|
|
self.domain = f"https://{os.environ.get('OAUTH_GITLAB_DOMAIN', '').rstrip('/')}"
|
||
|
|
|
||
|
|
self.authorize_url = f"{self.domain}/oauth/authorize"
|
||
|
|
|
||
|
|
self.authorize_params = {
|
||
|
|
"scope": "openid profile email",
|
||
|
|
"response_type": "code",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.domain}/oauth/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json_content = await self.get_raw_token_response(code, url)
|
||
|
|
token = json_content.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.domain}/oauth/userinfo",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
gitlab_user = response.json()
|
||
|
|
user = User(
|
||
|
|
identifier=gitlab_user.get("email"),
|
||
|
|
metadata={
|
||
|
|
"image": gitlab_user.get("picture", ""),
|
||
|
|
"provider": "gitlab",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (gitlab_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class KeycloakOAuthProvider(OAuthProvider):
|
||
|
|
env = [
|
||
|
|
"OAUTH_KEYCLOAK_CLIENT_ID",
|
||
|
|
"OAUTH_KEYCLOAK_CLIENT_SECRET",
|
||
|
|
"OAUTH_KEYCLOAK_REALM",
|
||
|
|
"OAUTH_KEYCLOAK_BASE_URL",
|
||
|
|
]
|
||
|
|
id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.refresh_token = None
|
||
|
|
self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
|
||
|
|
self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
|
||
|
|
self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
|
||
|
|
self.authorize_url = (
|
||
|
|
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
|
||
|
|
)
|
||
|
|
|
||
|
|
self.authorize_params = {
|
||
|
|
"scope": "profile email openid",
|
||
|
|
"response_type": "code",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(
|
||
|
|
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
|
||
|
|
data=payload,
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str):
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
token = json.get("access_token")
|
||
|
|
refresh_token = json.get("refresh_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
self.refresh_token = refresh_token
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
kc_user = response.json()
|
||
|
|
user = User(
|
||
|
|
identifier=kc_user["email"],
|
||
|
|
metadata={"provider": "keycloak"},
|
||
|
|
)
|
||
|
|
return (kc_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
class GenericOAuthProvider(OAuthProvider):
|
||
|
|
env = [
|
||
|
|
"OAUTH_GENERIC_CLIENT_ID",
|
||
|
|
"OAUTH_GENERIC_CLIENT_SECRET",
|
||
|
|
"OAUTH_GENERIC_AUTH_URL",
|
||
|
|
"OAUTH_GENERIC_TOKEN_URL",
|
||
|
|
"OAUTH_GENERIC_USER_INFO_URL",
|
||
|
|
"OAUTH_GENERIC_SCOPES",
|
||
|
|
]
|
||
|
|
id = os.environ.get("OAUTH_GENERIC_NAME", "generic")
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.client_id = os.environ.get("OAUTH_GENERIC_CLIENT_ID")
|
||
|
|
self.client_secret = os.environ.get("OAUTH_GENERIC_CLIENT_SECRET")
|
||
|
|
self.authorize_url = os.environ.get("OAUTH_GENERIC_AUTH_URL")
|
||
|
|
self.token_url = os.environ.get("OAUTH_GENERIC_TOKEN_URL")
|
||
|
|
self.user_info_url = os.environ.get("OAUTH_GENERIC_USER_INFO_URL")
|
||
|
|
self.scopes = os.environ.get("OAUTH_GENERIC_SCOPES")
|
||
|
|
self.user_identifier = os.environ.get("OAUTH_GENERIC_USER_IDENTIFIER", "email")
|
||
|
|
|
||
|
|
self.authorize_params = {
|
||
|
|
"scope": self.scopes,
|
||
|
|
"response_type": "code",
|
||
|
|
}
|
||
|
|
|
||
|
|
if prompt := self.get_prompt():
|
||
|
|
self.authorize_params["prompt"] = prompt
|
||
|
|
|
||
|
|
async def get_raw_token_response(self, code: str, url: str) -> dict:
|
||
|
|
payload = {
|
||
|
|
"client_id": self.client_id,
|
||
|
|
"client_secret": self.client_secret,
|
||
|
|
"code": code,
|
||
|
|
"grant_type": "authorization_code",
|
||
|
|
"redirect_uri": url,
|
||
|
|
}
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.post(self.token_url, data=payload)
|
||
|
|
response.raise_for_status()
|
||
|
|
return response.json()
|
||
|
|
|
||
|
|
async def get_token(self, code: str, url: str) -> str:
|
||
|
|
json = await self.get_raw_token_response(code, url)
|
||
|
|
token = json.get("access_token")
|
||
|
|
if not token:
|
||
|
|
raise HTTPException(status_code=400, detail=ACCESS_TOKEN_MISSING)
|
||
|
|
return token
|
||
|
|
|
||
|
|
async def get_user_info(self, token: str):
|
||
|
|
async with httpx.AsyncClient() as client:
|
||
|
|
response = await client.get(
|
||
|
|
self.user_info_url,
|
||
|
|
headers={"Authorization": f"Bearer {token}"},
|
||
|
|
)
|
||
|
|
response.raise_for_status()
|
||
|
|
server_user = response.json()
|
||
|
|
user = User(
|
||
|
|
identifier=server_user.get(self.user_identifier),
|
||
|
|
metadata={
|
||
|
|
"provider": self.id,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return (server_user, user)
|
||
|
|
|
||
|
|
|
||
|
|
providers = [
|
||
|
|
GithubOAuthProvider(),
|
||
|
|
GoogleOAuthProvider(),
|
||
|
|
AzureADOAuthProvider(),
|
||
|
|
AzureADHybridOAuthProvider(),
|
||
|
|
OktaOAuthProvider(),
|
||
|
|
Auth0OAuthProvider(),
|
||
|
|
DescopeOAuthProvider(),
|
||
|
|
AWSCognitoOAuthProvider(),
|
||
|
|
GitlabOAuthProvider(),
|
||
|
|
KeycloakOAuthProvider(),
|
||
|
|
GenericOAuthProvider(),
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def get_oauth_provider(provider: str) -> Optional[OAuthProvider]:
|
||
|
|
for p in providers:
|
||
|
|
if p.id == provider:
|
||
|
|
return p
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def get_configured_oauth_providers():
|
||
|
|
return [p.id for p in providers if p.is_configured()]
|