520 lines
16 KiB
Python
520 lines
16 KiB
Python
|
|
import json
|
||
|
|
import logging
|
||
|
|
import numbers
|
||
|
|
import re
|
||
|
|
import time
|
||
|
|
from collections import defaultdict
|
||
|
|
from dataclasses import asdict, is_dataclass
|
||
|
|
from datetime import date, datetime, timezone
|
||
|
|
from decimal import Decimal
|
||
|
|
from typing import Any, Optional
|
||
|
|
from uuid import UUID
|
||
|
|
import sys
|
||
|
|
import platform
|
||
|
|
import distro # For Linux OS detection
|
||
|
|
|
||
|
|
import six
|
||
|
|
from dateutil.tz import tzlocal, tzutc
|
||
|
|
|
||
|
|
log = logging.getLogger("posthog")
|
||
|
|
|
||
|
|
|
||
|
|
def is_naive(dt):
|
||
|
|
"""Determines if a given datetime.datetime is naive."""
|
||
|
|
return dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None
|
||
|
|
|
||
|
|
|
||
|
|
def total_seconds(delta):
|
||
|
|
"""Determines total seconds with python < 2.7 compat."""
|
||
|
|
# http://stackoverflow.com/questions/3694835/python-2-6-5-divide-timedelta-with-timedelta
|
||
|
|
return (delta.microseconds + (delta.seconds + delta.days * 24 * 3600) * 1e6) / 1e6
|
||
|
|
|
||
|
|
|
||
|
|
def guess_timezone(dt):
|
||
|
|
"""Attempts to convert a naive datetime to an aware datetime."""
|
||
|
|
if is_naive(dt):
|
||
|
|
# attempts to guess the datetime.datetime.now() local timezone
|
||
|
|
# case, and then defaults to utc
|
||
|
|
delta = datetime.now() - dt
|
||
|
|
if total_seconds(delta) < 5:
|
||
|
|
# this was created using datetime.datetime.now()
|
||
|
|
# so we are in the local timezone
|
||
|
|
return dt.replace(tzinfo=tzlocal())
|
||
|
|
else:
|
||
|
|
# at this point, the best we can do is guess UTC
|
||
|
|
return dt.replace(tzinfo=tzutc())
|
||
|
|
|
||
|
|
return dt
|
||
|
|
|
||
|
|
|
||
|
|
def remove_trailing_slash(host):
|
||
|
|
if host.endswith("/"):
|
||
|
|
return host[:-1]
|
||
|
|
return host
|
||
|
|
|
||
|
|
|
||
|
|
def clean(item):
|
||
|
|
if isinstance(item, Decimal):
|
||
|
|
return float(item)
|
||
|
|
if isinstance(item, UUID):
|
||
|
|
return str(item)
|
||
|
|
if isinstance(
|
||
|
|
item, (six.string_types, bool, numbers.Number, datetime, date, type(None))
|
||
|
|
):
|
||
|
|
return item
|
||
|
|
if isinstance(item, (set, list, tuple)):
|
||
|
|
return _clean_list(item)
|
||
|
|
# Pydantic model
|
||
|
|
try:
|
||
|
|
# v2+
|
||
|
|
if hasattr(item, "model_dump") and callable(item.model_dump):
|
||
|
|
item = item.model_dump()
|
||
|
|
# v1
|
||
|
|
elif hasattr(item, "dict") and callable(item.dict):
|
||
|
|
item = item.dict()
|
||
|
|
except TypeError as e:
|
||
|
|
log.debug(f"Could not serialize Pydantic-like model: {e}")
|
||
|
|
pass
|
||
|
|
if isinstance(item, dict):
|
||
|
|
return _clean_dict(item)
|
||
|
|
if is_dataclass(item) and not isinstance(item, type):
|
||
|
|
return _clean_dataclass(item)
|
||
|
|
return _coerce_unicode(item)
|
||
|
|
|
||
|
|
|
||
|
|
def _clean_list(list_):
|
||
|
|
return [clean(item) for item in list_]
|
||
|
|
|
||
|
|
|
||
|
|
def _clean_dict(dict_):
|
||
|
|
data = {}
|
||
|
|
for k, v in six.iteritems(dict_):
|
||
|
|
try:
|
||
|
|
data[k] = clean(v)
|
||
|
|
except TypeError:
|
||
|
|
log.warning(
|
||
|
|
'Dictionary values must be serializeable to JSON "%s" value %s of type %s is unsupported.',
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
type(v),
|
||
|
|
)
|
||
|
|
return data
|
||
|
|
|
||
|
|
|
||
|
|
def _clean_dataclass(dataclass_):
|
||
|
|
data = asdict(dataclass_)
|
||
|
|
data = _clean_dict(data)
|
||
|
|
return data
|
||
|
|
|
||
|
|
|
||
|
|
def _coerce_unicode(cmplx: Any) -> Optional[str]:
|
||
|
|
"""
|
||
|
|
In theory, this method is only called
|
||
|
|
after many isinstance checks are carried out in `utils.clean`.
|
||
|
|
When we supported Python 2 it was safe to call `decode` on a `str`
|
||
|
|
but in Python 3 that will throw.
|
||
|
|
So, we check if the input is bytes and only call `decode` in that case.
|
||
|
|
|
||
|
|
Previously we would always call `decode` on the input
|
||
|
|
That would throw an error.
|
||
|
|
Then we would call `decode` on the stringified error
|
||
|
|
That would throw an error.
|
||
|
|
And then we would return `None`
|
||
|
|
|
||
|
|
To avoid a breaking change, we can maintain the behavior
|
||
|
|
that anything which did not have `decode` in Python 2
|
||
|
|
returns None.
|
||
|
|
"""
|
||
|
|
item = None
|
||
|
|
try:
|
||
|
|
if isinstance(cmplx, bytes):
|
||
|
|
item = cmplx.decode("utf-8", "strict")
|
||
|
|
elif isinstance(cmplx, str):
|
||
|
|
item = cmplx
|
||
|
|
except Exception as exception:
|
||
|
|
item = ":".join(map(str, exception.args))
|
||
|
|
log.warning("Error decoding: %s", item)
|
||
|
|
return None
|
||
|
|
|
||
|
|
return item
|
||
|
|
|
||
|
|
|
||
|
|
def is_valid_regex(value) -> bool:
|
||
|
|
try:
|
||
|
|
re.compile(value)
|
||
|
|
return True
|
||
|
|
except re.error:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
class SizeLimitedDict(defaultdict):
|
||
|
|
def __init__(self, max_size, *args, **kwargs):
|
||
|
|
super().__init__(*args, **kwargs)
|
||
|
|
self.max_size = max_size
|
||
|
|
|
||
|
|
def __setitem__(self, key, value):
|
||
|
|
if len(self) >= self.max_size:
|
||
|
|
self.clear()
|
||
|
|
|
||
|
|
super().__setitem__(key, value)
|
||
|
|
|
||
|
|
|
||
|
|
class FlagCacheEntry:
|
||
|
|
def __init__(self, flag_result, flag_definition_version, timestamp=None):
|
||
|
|
self.flag_result = flag_result
|
||
|
|
self.flag_definition_version = flag_definition_version
|
||
|
|
self.timestamp = timestamp or time.time()
|
||
|
|
|
||
|
|
def is_valid(self, current_time, ttl, current_flag_version):
|
||
|
|
time_valid = (current_time - self.timestamp) < ttl
|
||
|
|
version_valid = self.flag_definition_version == current_flag_version
|
||
|
|
return time_valid and version_valid
|
||
|
|
|
||
|
|
def is_stale_but_usable(self, current_time, max_stale_age=3600):
|
||
|
|
return (current_time - self.timestamp) < max_stale_age
|
||
|
|
|
||
|
|
|
||
|
|
class FlagCache:
|
||
|
|
def __init__(self, max_size=10000, default_ttl=300):
|
||
|
|
self.cache = {} # distinct_id -> {flag_key: FlagCacheEntry}
|
||
|
|
self.access_times = {} # distinct_id -> last_access_time
|
||
|
|
self.max_size = max_size
|
||
|
|
self.default_ttl = default_ttl
|
||
|
|
|
||
|
|
def get_cached_flag(self, distinct_id, flag_key, current_flag_version):
|
||
|
|
current_time = time.time()
|
||
|
|
|
||
|
|
if distinct_id not in self.cache:
|
||
|
|
return None
|
||
|
|
|
||
|
|
user_flags = self.cache[distinct_id]
|
||
|
|
if flag_key not in user_flags:
|
||
|
|
return None
|
||
|
|
|
||
|
|
entry = user_flags[flag_key]
|
||
|
|
if entry.is_valid(current_time, self.default_ttl, current_flag_version):
|
||
|
|
self.access_times[distinct_id] = current_time
|
||
|
|
return entry.flag_result
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_stale_cached_flag(self, distinct_id, flag_key, max_stale_age=3600):
|
||
|
|
current_time = time.time()
|
||
|
|
|
||
|
|
if distinct_id not in self.cache:
|
||
|
|
return None
|
||
|
|
|
||
|
|
user_flags = self.cache[distinct_id]
|
||
|
|
if flag_key not in user_flags:
|
||
|
|
return None
|
||
|
|
|
||
|
|
entry = user_flags[flag_key]
|
||
|
|
if entry.is_stale_but_usable(current_time, max_stale_age):
|
||
|
|
return entry.flag_result
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def set_cached_flag(
|
||
|
|
self, distinct_id, flag_key, flag_result, flag_definition_version
|
||
|
|
):
|
||
|
|
current_time = time.time()
|
||
|
|
|
||
|
|
# Evict LRU users if we're at capacity
|
||
|
|
if distinct_id not in self.cache and len(self.cache) >= self.max_size:
|
||
|
|
self._evict_lru()
|
||
|
|
|
||
|
|
# Initialize user cache if needed
|
||
|
|
if distinct_id not in self.cache:
|
||
|
|
self.cache[distinct_id] = {}
|
||
|
|
|
||
|
|
# Store the flag result
|
||
|
|
self.cache[distinct_id][flag_key] = FlagCacheEntry(
|
||
|
|
flag_result, flag_definition_version, current_time
|
||
|
|
)
|
||
|
|
self.access_times[distinct_id] = current_time
|
||
|
|
|
||
|
|
def invalidate_version(self, old_version):
|
||
|
|
users_to_remove = []
|
||
|
|
|
||
|
|
for distinct_id, user_flags in self.cache.items():
|
||
|
|
flags_to_remove = []
|
||
|
|
for flag_key, entry in user_flags.items():
|
||
|
|
if entry.flag_definition_version == old_version:
|
||
|
|
flags_to_remove.append(flag_key)
|
||
|
|
|
||
|
|
# Remove invalidated flags
|
||
|
|
for flag_key in flags_to_remove:
|
||
|
|
del user_flags[flag_key]
|
||
|
|
|
||
|
|
# Remove user entirely if no flags remain
|
||
|
|
if not user_flags:
|
||
|
|
users_to_remove.append(distinct_id)
|
||
|
|
|
||
|
|
# Clean up empty users
|
||
|
|
for distinct_id in users_to_remove:
|
||
|
|
del self.cache[distinct_id]
|
||
|
|
if distinct_id in self.access_times:
|
||
|
|
del self.access_times[distinct_id]
|
||
|
|
|
||
|
|
def _evict_lru(self):
|
||
|
|
if not self.access_times:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Remove 20% of least recently used entries
|
||
|
|
sorted_users = sorted(self.access_times.items(), key=lambda x: x[1])
|
||
|
|
to_remove = max(1, len(sorted_users) // 5)
|
||
|
|
|
||
|
|
for distinct_id, _ in sorted_users[:to_remove]:
|
||
|
|
if distinct_id in self.cache:
|
||
|
|
del self.cache[distinct_id]
|
||
|
|
if distinct_id in self.access_times:
|
||
|
|
del self.access_times[distinct_id]
|
||
|
|
|
||
|
|
def clear(self):
|
||
|
|
self.cache.clear()
|
||
|
|
self.access_times.clear()
|
||
|
|
|
||
|
|
|
||
|
|
class RedisFlagCache:
|
||
|
|
def __init__(
|
||
|
|
self, redis_client, default_ttl=300, stale_ttl=3600, key_prefix="posthog:flags:"
|
||
|
|
):
|
||
|
|
self.redis = redis_client
|
||
|
|
self.default_ttl = default_ttl
|
||
|
|
self.stale_ttl = stale_ttl
|
||
|
|
self.key_prefix = key_prefix
|
||
|
|
self.version_key = f"{key_prefix}version"
|
||
|
|
|
||
|
|
def _get_cache_key(self, distinct_id, flag_key):
|
||
|
|
return f"{self.key_prefix}{distinct_id}:{flag_key}"
|
||
|
|
|
||
|
|
def _serialize_entry(self, flag_result, flag_definition_version, timestamp=None):
|
||
|
|
if timestamp is None:
|
||
|
|
timestamp = time.time()
|
||
|
|
|
||
|
|
# Use clean to make flag_result JSON-serializable for cross-platform compatibility
|
||
|
|
serialized_result = clean(flag_result)
|
||
|
|
|
||
|
|
entry = {
|
||
|
|
"flag_result": serialized_result,
|
||
|
|
"flag_version": flag_definition_version,
|
||
|
|
"timestamp": timestamp,
|
||
|
|
}
|
||
|
|
return json.dumps(entry)
|
||
|
|
|
||
|
|
def _deserialize_entry(self, data):
|
||
|
|
try:
|
||
|
|
entry = json.loads(data)
|
||
|
|
flag_result = entry["flag_result"]
|
||
|
|
return FlagCacheEntry(
|
||
|
|
flag_result=flag_result,
|
||
|
|
flag_definition_version=entry["flag_version"],
|
||
|
|
timestamp=entry["timestamp"],
|
||
|
|
)
|
||
|
|
except (json.JSONDecodeError, KeyError, ValueError):
|
||
|
|
# If deserialization fails, treat as cache miss
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_cached_flag(self, distinct_id, flag_key, current_flag_version):
|
||
|
|
try:
|
||
|
|
cache_key = self._get_cache_key(distinct_id, flag_key)
|
||
|
|
data = self.redis.get(cache_key)
|
||
|
|
|
||
|
|
if data:
|
||
|
|
entry = self._deserialize_entry(data)
|
||
|
|
if entry and entry.is_valid(
|
||
|
|
time.time(), self.default_ttl, current_flag_version
|
||
|
|
):
|
||
|
|
return entry.flag_result
|
||
|
|
|
||
|
|
return None
|
||
|
|
except Exception:
|
||
|
|
# Redis error - return None to fall back to normal evaluation
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_stale_cached_flag(self, distinct_id, flag_key, max_stale_age=None):
|
||
|
|
try:
|
||
|
|
if max_stale_age is None:
|
||
|
|
max_stale_age = self.stale_ttl
|
||
|
|
|
||
|
|
cache_key = self._get_cache_key(distinct_id, flag_key)
|
||
|
|
data = self.redis.get(cache_key)
|
||
|
|
|
||
|
|
if data:
|
||
|
|
entry = self._deserialize_entry(data)
|
||
|
|
if entry and entry.is_stale_but_usable(time.time(), max_stale_age):
|
||
|
|
return entry.flag_result
|
||
|
|
|
||
|
|
return None
|
||
|
|
except Exception:
|
||
|
|
# Redis error - return None
|
||
|
|
return None
|
||
|
|
|
||
|
|
def set_cached_flag(
|
||
|
|
self, distinct_id, flag_key, flag_result, flag_definition_version
|
||
|
|
):
|
||
|
|
try:
|
||
|
|
cache_key = self._get_cache_key(distinct_id, flag_key)
|
||
|
|
serialized_entry = self._serialize_entry(
|
||
|
|
flag_result, flag_definition_version
|
||
|
|
)
|
||
|
|
|
||
|
|
# Set with TTL for automatic cleanup (use stale_ttl for total lifetime)
|
||
|
|
self.redis.setex(cache_key, self.stale_ttl, serialized_entry)
|
||
|
|
|
||
|
|
# Update the current version
|
||
|
|
self.redis.set(self.version_key, flag_definition_version)
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
# Redis error - silently fail, don't break flag evaluation
|
||
|
|
pass
|
||
|
|
|
||
|
|
def invalidate_version(self, old_version):
|
||
|
|
try:
|
||
|
|
# For Redis, we use a simple approach: scan for keys with old version
|
||
|
|
# and delete them. This could be expensive with many keys, but it's
|
||
|
|
# necessary for correctness.
|
||
|
|
|
||
|
|
cursor = 0
|
||
|
|
pattern = f"{self.key_prefix}*"
|
||
|
|
|
||
|
|
while True:
|
||
|
|
cursor, keys = self.redis.scan(cursor, match=pattern, count=100)
|
||
|
|
|
||
|
|
for key in keys:
|
||
|
|
if key.decode() == self.version_key:
|
||
|
|
continue
|
||
|
|
|
||
|
|
try:
|
||
|
|
data = self.redis.get(key)
|
||
|
|
if data:
|
||
|
|
entry_dict = json.loads(data)
|
||
|
|
if entry_dict.get("flag_version") == old_version:
|
||
|
|
self.redis.delete(key)
|
||
|
|
except (json.JSONDecodeError, KeyError):
|
||
|
|
# If we can't parse the entry, delete it to be safe
|
||
|
|
self.redis.delete(key)
|
||
|
|
|
||
|
|
if cursor == 0:
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
# Redis error - silently fail
|
||
|
|
pass
|
||
|
|
|
||
|
|
def clear(self):
|
||
|
|
try:
|
||
|
|
# Delete all keys matching our pattern
|
||
|
|
cursor = 0
|
||
|
|
pattern = f"{self.key_prefix}*"
|
||
|
|
|
||
|
|
while True:
|
||
|
|
cursor, keys = self.redis.scan(cursor, match=pattern, count=100)
|
||
|
|
if keys:
|
||
|
|
self.redis.delete(*keys)
|
||
|
|
if cursor == 0:
|
||
|
|
break
|
||
|
|
except Exception:
|
||
|
|
# Redis error - silently fail
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
def convert_to_datetime_aware(date_obj):
|
||
|
|
if date_obj.tzinfo is None:
|
||
|
|
date_obj = date_obj.replace(tzinfo=timezone.utc)
|
||
|
|
return date_obj
|
||
|
|
|
||
|
|
|
||
|
|
def str_icontains(source, search):
|
||
|
|
"""
|
||
|
|
Check if a string contains another string, ignoring case.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
source: The string to search within
|
||
|
|
search: The substring to search for
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
bool: True if search is a substring of source (case-insensitive), False otherwise
|
||
|
|
|
||
|
|
Examples:
|
||
|
|
>>> str_icontains("Hello World", "WORLD")
|
||
|
|
True
|
||
|
|
>>> str_icontains("Hello World", "python")
|
||
|
|
False
|
||
|
|
"""
|
||
|
|
return str(search).casefold() in str(source).casefold()
|
||
|
|
|
||
|
|
|
||
|
|
def str_iequals(value, comparand):
|
||
|
|
"""
|
||
|
|
Check if a string equals another string, ignoring case.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
value: The string to compare
|
||
|
|
comparand: The string to compare with
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
bool: True if value and comparand are equal (case-insensitive), False otherwise
|
||
|
|
|
||
|
|
Examples:
|
||
|
|
>>> str_iequals("Hello World", "hello world")
|
||
|
|
True
|
||
|
|
>>> str_iequals("Hello World", "hello")
|
||
|
|
False
|
||
|
|
"""
|
||
|
|
return str(value).casefold() == str(comparand).casefold()
|
||
|
|
|
||
|
|
|
||
|
|
def get_os_info():
|
||
|
|
"""
|
||
|
|
Returns standardized OS name and version information.
|
||
|
|
Similar to how user agent parsing works in JS.
|
||
|
|
"""
|
||
|
|
os_name = ""
|
||
|
|
os_version = ""
|
||
|
|
|
||
|
|
platform_name = sys.platform
|
||
|
|
|
||
|
|
if platform_name.startswith("win"):
|
||
|
|
os_name = "Windows"
|
||
|
|
if hasattr(platform, "win32_ver"):
|
||
|
|
win_version = platform.win32_ver()[0]
|
||
|
|
if win_version:
|
||
|
|
os_version = win_version
|
||
|
|
|
||
|
|
elif platform_name == "darwin":
|
||
|
|
os_name = "Mac OS X"
|
||
|
|
if hasattr(platform, "mac_ver"):
|
||
|
|
mac_version = platform.mac_ver()[0]
|
||
|
|
if mac_version:
|
||
|
|
os_version = mac_version
|
||
|
|
|
||
|
|
elif platform_name.startswith("linux"):
|
||
|
|
os_name = "Linux"
|
||
|
|
linux_info = distro.info()
|
||
|
|
if linux_info["version"]:
|
||
|
|
os_version = linux_info["version"]
|
||
|
|
|
||
|
|
elif platform_name.startswith("freebsd"):
|
||
|
|
os_name = "FreeBSD"
|
||
|
|
if hasattr(platform, "release"):
|
||
|
|
os_version = platform.release()
|
||
|
|
|
||
|
|
else:
|
||
|
|
os_name = platform_name
|
||
|
|
if hasattr(platform, "release"):
|
||
|
|
os_version = platform.release()
|
||
|
|
|
||
|
|
return os_name, os_version
|
||
|
|
|
||
|
|
|
||
|
|
def system_context() -> dict[str, Any]:
|
||
|
|
os_name, os_version = get_os_info()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"$python_runtime": platform.python_implementation(),
|
||
|
|
"$python_version": "%s.%s.%s" % (sys.version_info[:3]),
|
||
|
|
"$os": os_name,
|
||
|
|
"$os_version": os_version,
|
||
|
|
}
|