# live_price_updater.py

import json
import time
import random
import string
import threading
from pathlib import Path
from typing import Dict, List, Set
from queue import Queue, Empty

import redis
from websocket import WebSocketApp


# ===============================
# CONFIG
# ===============================

WS_URL = "wss://data.tradingview.com/socket.io/websocket"

EXPORT_DIR = Path("data/exports")
SYMBOLS_FILE = EXPORT_DIR / "us_symbols_full.json"

# Optional TradingView auth token
TV_AUTH_TOKEN = ""

# Redis
REDIS_HOST = "127.0.0.1"
REDIS_PORT = 6379
REDIS_DB = 0
REDIS_PASSWORD = None

# Redis keys
REDIS_PRICES_HASH = "tv:us:prices"
REDIS_UPDATED_AT_KEY = "tv:us:prices:updated_at"
REDIS_SYMBOLS_TOTAL_KEY = "tv:us:symbols_total"

# WS batching
SESSIONS_PER_SOCKET = 1
SYMBOLS_PER_SESSION = 10
SOCKET_SYMBOL_CAP = SESSIONS_PER_SOCKET * SYMBOLS_PER_SESSION  # 100

# Timing
SOCKET_STAGGER_SEC = 120.0
SESSION_CREATE_DELAY = 0.40
SYMBOL_DELAY = 0.06

# Redis flush
REDIS_FLUSH_INTERVAL_SEC = 1.0
REDIS_BATCH_SIZE = 1000

# Health / retry
SYMBOL_STALE_SEC = 300
SOCKET_STALE_SEC = 180
RETRY_BATCH_SIZE = 10
RETRY_REQUEUE_DELAY_SEC = 90
MAX_FAILED_REASONS_PER_SYMBOL = 5

# Reconnect
RECONNECT_MIN_SEC = 600
RECONNECT_MAX_SEC = 1200

# 429 circuit breaker
BASE_429_COOLDOWN_SEC = 300
MAX_429_COOLDOWN_SEC = 1800

# Testing
TEST_SYMBOL_LIMIT = 0


# ===============================
# GLOBAL STATE
# ===============================

symbols: List[str] = []

# Latest cached prices in memory
prices: Dict[str, str] = {}

# Pending Redis updates
redis_dirty: Dict[str, str] = {}

# Symbol metadata
symbol_last_update: Dict[str, float] = {}
symbol_to_socket: Dict[str, int] = {}
symbol_status: Dict[str, str] = {}
symbol_fail_reasons: Dict[str, List[str]] = {}

# Socket metadata kept only in memory
socket_health: Dict[int, dict] = {}

# Retry queue
retry_queue: "Queue[str]" = Queue()
retry_pending: Set[str] = set()

# Global locking / shutdown
lock = threading.RLock()
stop_event = threading.Event()

# Handshake / rate-limit circuit breaker
handshake_gate = threading.Lock()
global_cooldown_until = 0.0
consecutive_429 = 0

# Redis client
redis_client = redis.Redis(
    host=REDIS_HOST,
    port=REDIS_PORT,
    db=REDIS_DB,
    password=REDIS_PASSWORD,
    decode_responses=True,
    socket_timeout=5,
    socket_connect_timeout=5,
    health_check_interval=30,
)


# ===============================
# HELPERS
# ===============================

def now_ts() -> float:
    return time.time()


def random_session(prefix: str) -> str:
    s = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))
    return f"{prefix}_{s}"


def chunked(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i:i + size]


def safe_float_to_str(value):
    try:
        return str(float(value))
    except Exception:
        return str(value)


def add_symbol_failure(symbol: str, reason: str):
    with lock:
        arr = symbol_fail_reasons.setdefault(symbol, [])
        if len(arr) >= MAX_FAILED_REASONS_PER_SYMBOL:
            arr.pop(0)
        arr.append(reason)


def enqueue_retry(symbol: str, reason: str = ""):
    with lock:
        if symbol in retry_pending:
            return
        retry_pending.add(symbol)
        symbol_status[symbol] = "retry"
        if reason:
            add_symbol_failure(symbol, reason)
    retry_queue.put(symbol)


def mark_socket_event(socket_id: int, field: str, value):
    with lock:
        info = socket_health.setdefault(socket_id, {})
        info[field] = value


def touch_socket(socket_id: int):
    with lock:
        info = socket_health.setdefault(socket_id, {})
        info["last_message_at"] = now_ts()


def set_global_cooldown(seconds: float):
    global global_cooldown_until
    with lock:
        until = now_ts() + seconds
        if until > global_cooldown_until:
            global_cooldown_until = until


def wait_for_global_cooldown():
    while not stop_event.is_set():
        with lock:
            remaining = global_cooldown_until - now_ts()
        if remaining <= 0:
            return
        sleep_for = min(remaining, 5.0)
        print(f"[gate] cooldown active for {remaining:.1f}s")
        stop_event.wait(sleep_for)


def handle_429(source_name: str):
    global consecutive_429

    with lock:
        consecutive_429 += 1
        hits = consecutive_429

    cooldown = min(MAX_429_COOLDOWN_SEC, BASE_429_COOLDOWN_SEC * hits)
    print(f"[{source_name}] detected 429 -> global cooldown {cooldown}s")
    set_global_cooldown(cooldown)


def reset_429_counter():
    global consecutive_429
    with lock:
        consecutive_429 = 0


# ===============================
# REDIS
# ===============================

def init_redis():
    try:
        redis_client.ping()
        print(f"Connected to Redis at {REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}")
    except Exception as e:
        raise RuntimeError(f"Redis connection failed: {e}") from e


def flush_prices_to_redis(force: bool = False):
    with lock:
        if not redis_dirty and not force:
            return

        payload = dict(redis_dirty)
        redis_dirty.clear()
        updated_at = str(now_ts())
        symbols_total = str(len(symbols))

    try:
        pipe = redis_client.pipeline(transaction=False)

        if payload:
            pipe.hset(REDIS_PRICES_HASH, mapping=payload)

        pipe.set(REDIS_UPDATED_AT_KEY, updated_at)
        pipe.set(REDIS_SYMBOLS_TOTAL_KEY, symbols_total)
        pipe.execute()

        if payload:
            print(f"Flushed {len(payload)} prices to Redis")
    except Exception as e:
        print("redis flush error:", e)
        with lock:
            redis_dirty.update(payload)


def redis_flush_loop():
    while not stop_event.is_set():
        if stop_event.wait(REDIS_FLUSH_INTERVAL_SEC):
            break
        flush_prices_to_redis(force=False)


# ===============================
# FILE IO
# ===============================

def load_symbols():
    global symbols

    with open(SYMBOLS_FILE, "r", encoding="utf-8") as f:
        rows = json.load(f)

    seen = set()
    out = []

    for r in rows:
        sym = r.get("tv_symbol")
        if not sym or sym in seen:
            continue
        seen.add(sym)
        out.append(sym)

    if TEST_SYMBOL_LIMIT:
        out = out[:TEST_SYMBOL_LIMIT]

    symbols = out

    with lock:
        for sym in symbols:
            symbol_status.setdefault(sym, "pending")

    print(f"Loaded {len(symbols)} symbols from file")


# ===============================
# TRADINGVIEW SOCKET CLIENT
# ===============================

class TVSocket:
    def __init__(self, socket_id: int, socket_symbols: List[str]):
        self.socket_id = socket_id
        self.symbols = list(socket_symbols)
        self.name = f"ws-{socket_id}"

        self.session_chunks = [list(x) for x in chunked(self.symbols, SYMBOLS_PER_SESSION)]
        self.quote_sessions = [random_session("qs") for _ in self.session_chunks]

        self.ws = None
        self.subscribed_symbols: Set[str] = set()

        with lock:
            socket_health[self.socket_id] = {
                "socket_id": self.socket_id,
                "name": self.name,
                "started_at": None,
                "connected_at": None,
                "last_message_at": None,
                "last_error": "",
                "last_close_code": None,
                "last_close_msg": "",
                "status": "created",
                "connect_count": 0,
                "disconnect_count": 0,
                "symbol_target_count": len(self.symbols),
                "subscribed_count": 0,
                "active_count": 0,
                "stale_count": 0,
                "retry_injected_count": 0,
            }

    def _auth_token(self) -> str:
        token = TV_AUTH_TOKEN.strip()
        return token if token else "unauthorized_user_token"

    def send(self, ws, method: str, params: list):
        msg = json.dumps({"m": method, "p": params}, separators=(",", ":"))
        frame = f"~m~{len(msg)}~m~{msg}"
        ws.send(frame)

    def extract_frames(self, raw: str):
        msgs = []
        i = 0
        n = len(raw)

        while i < n:
            if raw.startswith("~h~", i):
                break
            if not raw.startswith("~m~", i):
                break

            i += 3
            j = raw.find("~m~", i)
            if j == -1:
                break

            length_str = raw[i:j]
            try:
                msg_len = int(length_str)
            except ValueError:
                break

            start = j + 3
            end = start + msg_len
            if end > n:
                break

            body = raw[start:end]
            i = end

            try:
                msgs.append(json.loads(body))
            except Exception:
                pass

        return msgs

    def _subscribe_symbol(self, ws, quote_session: str, sym: str):
        self.send(ws, "quote_add_symbols", [quote_session, sym])
        self.subscribed_symbols.add(sym)

        with lock:
            symbol_to_socket[sym] = self.socket_id
            if symbol_status.get(sym) in ("pending", "retry", "stale", "failed"):
                symbol_status[sym] = "active"
            socket_health[self.socket_id]["subscribed_count"] = len(self.subscribed_symbols)

    def _subscribe_session(self, ws, quote_session: str, session_symbols: List[str]):
        self.send(ws, "quote_create_session", [quote_session])

        self.send(ws, "quote_set_fields", [
            quote_session,
            "lp",
            "ch",
            "chp",
            "rtc",
            "rch",
            "rchp",
            "current_session",
            "currency_code",
            "volume",
        ])

        time.sleep(SESSION_CREATE_DELAY)

        for sym in session_symbols:
            if stop_event.is_set():
                return
            try:
                self._subscribe_symbol(ws, quote_session, sym)
            except Exception as e:
                enqueue_retry(sym, f"subscribe failed on open: {e}")
            time.sleep(SYMBOL_DELAY)

        try:
            self.send(ws, "quote_fast_symbols", [quote_session] + session_symbols)
        except Exception:
            pass

    def _pick_least_loaded_session(self):
        counts = []
        for idx, chunk in enumerate(self.session_chunks):
            qs = self.quote_sessions[idx]
            counts.append((len(chunk), idx, qs))
        counts.sort(key=lambda x: x[0])
        _, idx, qs = counts[0]
        return idx, qs

    def inject_retry_symbols(self, limit=RETRY_BATCH_SIZE):
        if not self.ws:
            return

        injected = 0

        while injected < limit and not stop_event.is_set():
            try:
                sym = retry_queue.get_nowait()
            except Empty:
                break

            try:
                with lock:
                    already_here = sym in self.subscribed_symbols

                if already_here:
                    with lock:
                        retry_pending.discard(sym)
                        symbol_status[sym] = "active"
                    continue

                idx, qs = self._pick_least_loaded_session()
                self.session_chunks[idx].append(sym)

                self._subscribe_symbol(self.ws, qs, sym)

                try:
                    self.send(self.ws, "quote_fast_symbols", [qs, sym])
                except Exception:
                    pass

                with lock:
                    retry_pending.discard(sym)
                    symbol_status[sym] = "active"
                    socket_health[self.socket_id]["retry_injected_count"] += 1

                injected += 1
                time.sleep(SYMBOL_DELAY)

            except Exception as e:
                with lock:
                    retry_pending.discard(sym)
                enqueue_retry(sym, f"retry inject failed: {e}")
                time.sleep(0.05)

    def on_open(self, ws):
        self.ws = ws
        reset_429_counter()

        mark_socket_event(self.socket_id, "status", "connected")
        mark_socket_event(self.socket_id, "connected_at", now_ts())

        with lock:
            socket_health[self.socket_id]["connect_count"] += 1

        print(f"[{self.name}] connected ({len(self.symbols)} symbols, {len(self.quote_sessions)} sessions)")

        self.send(ws, "set_auth_token", [self._auth_token()])

        try:
            self.send(ws, "set_locale", ["en", "US"])
        except Exception:
            pass

        for idx, (quote_session, session_symbols) in enumerate(zip(self.quote_sessions, self.session_chunks), start=1):
            if stop_event.is_set():
                return
            print(f"[{self.name}] opening session {idx}/{len(self.quote_sessions)} with {len(session_symbols)} symbols")
            self._subscribe_session(ws, quote_session, session_symbols)

    def on_message(self, ws, raw: str):
        touch_socket(self.socket_id)

        if raw.startswith("~h~"):
            try:
                ws.send(raw)
            except Exception:
                pass
            return

        msgs = self.extract_frames(raw)

        for m in msgs:
            if m.get("m") != "qsd":
                continue

            p = m.get("p", [])
            if len(p) < 2 or not isinstance(p[1], dict):
                continue

            payload = p[1]
            symbol = payload.get("n")
            values = payload.get("v", {})

            if not symbol or not isinstance(values, dict):
                continue

            lp = values.get("lp")
            if isinstance(lp, dict):
                lp = lp.get("v")

            if lp is None:
                continue

            t = now_ts()
            lp_str = safe_float_to_str(lp)

            with lock:
                prices[symbol] = lp_str
                redis_dirty[symbol] = lp_str
                symbol_last_update[symbol] = t
                symbol_status[symbol] = "active"

    def on_error(self, ws, err):
        msg = str(err)
        mark_socket_event(self.socket_id, "last_error", msg)
        mark_socket_event(self.socket_id, "status", "error")
        print(f"[{self.name}] error: {msg}")

        if "429" in msg or "Too Many Requests" in msg:
            handle_429(self.name)

        with lock:
            target_symbols = list(self.subscribed_symbols)

        cutoff = now_ts() - SYMBOL_STALE_SEC
        for sym in target_symbols:
            last = symbol_last_update.get(sym, 0)
            if last == 0 or last < cutoff:
                enqueue_retry(sym, f"socket error: {msg}")

    def on_close(self, ws, code, msg):
        mark_socket_event(self.socket_id, "last_close_code", code)
        mark_socket_event(self.socket_id, "last_close_msg", str(msg))
        mark_socket_event(self.socket_id, "status", "closed")

        with lock:
            socket_health[self.socket_id]["disconnect_count"] += 1

        print(f"[{self.name}] closed code={code} msg={msg}")

        with lock:
            target_symbols = list(self.subscribed_symbols)

        cutoff = now_ts() - SYMBOL_STALE_SEC
        for sym in target_symbols:
            last = symbol_last_update.get(sym, 0)
            if last == 0 or last < cutoff:
                enqueue_retry(sym, f"socket closed: {code} {msg}")

    def run(self):
        mark_socket_event(self.socket_id, "started_at", now_ts())
        time.sleep(random.uniform(1.0, 3.0))

        while not stop_event.is_set():
            wait_for_global_cooldown()

            try:
                with handshake_gate:
                    wait_for_global_cooldown()

                    self.ws = WebSocketApp(
                        WS_URL,
                        on_open=self.on_open,
                        on_message=self.on_message,
                        on_error=self.on_error,
                        on_close=self.on_close,
                        header=[
                            "Origin: https://www.tradingview.com",
                            "User-Agent: Mozilla/5.0",
                        ],
                    )

                    self.ws.run_forever(
                        ping_interval=20,
                        ping_timeout=10,
                        origin="https://www.tradingview.com",
                    )

            except Exception as e:
                err_msg = str(e)
                mark_socket_event(self.socket_id, "last_error", err_msg)
                mark_socket_event(self.socket_id, "status", "exception")
                print(f"[{self.name}] restart after exception: {err_msg}")

                if "429" in err_msg or "Too Many Requests" in err_msg:
                    handle_429(self.name)

            if stop_event.is_set():
                break

            wait_for_global_cooldown()

            sleep_for = random.uniform(RECONNECT_MIN_SEC, RECONNECT_MAX_SEC)
            print(f"[{self.name}] reconnecting in {sleep_for:.1f}s")
            stop_event.wait(sleep_for)


# ===============================
# HEALTH / RETRY LOOPS
# ===============================

def health_monitor_loop():
    while not stop_event.is_set():
        if stop_event.wait(HEALTH_CHECK_INTERVAL_SEC):
            break

        t = now_ts()
        stale_cutoff = t - SYMBOL_STALE_SEC
        socket_cutoff = t - SOCKET_STALE_SEC

        stale_to_retry = []

        with lock:
            for sym in symbols:
                last = symbol_last_update.get(sym, 0)
                st = symbol_status.get(sym, "pending")

                if last == 0:
                    if st == "pending":
                        continue
                    if st not in ("retry", "failed"):
                        symbol_status[sym] = "stale"
                        stale_to_retry.append((sym, "never updated"))
                    continue

                if last < stale_cutoff:
                    if st != "retry":
                        symbol_status[sym] = "stale"
                        stale_to_retry.append((sym, "stale price"))

            for sid, info in socket_health.items():
                last_msg = info.get("last_message_at") or 0
                if last_msg and last_msg < socket_cutoff and info.get("status") == "connected":
                    info["status"] = "degraded"

                active_count = 0
                stale_count = 0
                for sym, sock_id in symbol_to_socket.items():
                    if sock_id != sid:
                        continue
                    st = symbol_status.get(sym)
                    if st == "active":
                        active_count += 1
                    elif st in ("stale", "retry", "failed"):
                        stale_count += 1

                info["active_count"] = active_count
                info["stale_count"] = stale_count

        for sym, reason in stale_to_retry:
            enqueue_retry(sym, reason)


def retry_worker_loop(clients: List[TVSocket]):
    rr = 0

    while not stop_event.is_set():
        if stop_event.wait(RETRY_REQUEUE_DELAY_SEC):
            break

        if not clients:
            continue

        active_clients = [
            c for c in clients
            if socket_health.get(c.socket_id, {}).get("status") in ("connected", "degraded")
        ]

        if not active_clients:
            continue

        batch_per_client = max(1, RETRY_BATCH_SIZE // max(1, len(active_clients)))

        for _ in range(len(active_clients)):
            client = active_clients[rr % len(active_clients)]
            rr += 1

            try:
                client.inject_retry_symbols(limit=batch_per_client)
            except Exception as e:
                print(f"[retry-worker] inject failed on {client.name}: {e}")


# ===============================
# MAIN
# ===============================

def start_ws():
    threads = []
    clients = []

    socket_batches = list(chunked(symbols, SOCKET_SYMBOL_CAP))
    print("Total sockets:", len(socket_batches))
    print("Socket symbol cap:", SOCKET_SYMBOL_CAP)
    print("Sessions per socket:", SESSIONS_PER_SOCKET)
    print("Symbols per session:", SYMBOLS_PER_SESSION)

    for i, batch in enumerate(socket_batches, 1):
        wait_for_global_cooldown()

        client = TVSocket(i, batch)
        clients.append(client)

        t = threading.Thread(target=client.run, daemon=True, name=f"tv-ws-{i}")
        t.start()
        threads.append(t)

        print(f"Started socket {i} with {len(batch)} symbols")
        stop_event.wait(SOCKET_STAGGER_SEC)

    return clients, threads


def main():
    init_redis()
    load_symbols()

    if not symbols:
        print("No symbols found.")
        return

    try:
        redis_client.set(REDIS_SYMBOLS_TOTAL_KEY, len(symbols))
        redis_client.set(REDIS_UPDATED_AT_KEY, now_ts())
    except Exception as e:
        print("redis init write error:", e)

    redis_flusher = threading.Thread(target=redis_flush_loop, daemon=True, name="redisflusher")
    redis_flusher.start()

    clients, threads = start_ws()

    healthmon = threading.Thread(target=health_monitor_loop, daemon=True, name="healthmon")
    healthmon.start()

    retryworker = threading.Thread(
        target=retry_worker_loop,
        args=(clients,),
        daemon=True,
        name="retryworker",
    )
    retryworker.start()

    try:
        while True:
            time.sleep(10)

            alive = sum(t.is_alive() for t in threads)
            with lock:
                prices_count = len(prices)
                retry_size = retry_queue.qsize()
                active_count = sum(1 for v in symbol_status.values() if v == "active")
                stale_count = sum(1 for v in symbol_status.values() if v == "stale")
                retry_count = sum(1 for v in symbol_status.values() if v == "retry")
                pending_count = sum(1 for v in symbol_status.values() if v == "pending")
                cooldown_left = max(0.0, global_cooldown_until - now_ts())

            print(
                f"alive sockets: {alive}/{len(threads)} | "
                f"prices cached: {prices_count} | "
                f"active: {active_count} | pending: {pending_count} | "
                f"stale: {stale_count} | retry: {retry_count} | "
                f"retry_queue: {retry_size} | cooldown: {cooldown_left:.1f}s"
            )

    except KeyboardInterrupt:
        stop_event.set()
        flush_prices_to_redis(force=True)
        print("stopped")


if __name__ == "__main__":
    main()