# pipeline_redis_mp.py
# Redis + multiprocessing TradingView US market pipeline
#
# Behavior:
# 1) Scans TradingView America screener with pagination until exhausted
# 2) Stores all US symbols in Redis
# 3) Downloads logos locally and stores metadata in Redis
# 4) Splits symbols into 500-symbol websocket bundles
# 5) Starts websocket bundles sequentially:
#       - start bundle 1
#       - wait until connected
#       - start bundle 2
#       - wait until connected
#       - ...
# 6) If one websocket bundle drops later, only that bundle reconnects
#
# Run:
#   python3 pipeline/pipeline_redis_mp.py

import re
import csv
import json
import time
import random
import signal
import string
import requests
import redis
import multiprocessing as mp
from pathlib import Path
from typing import Dict, List, Optional
from websocket import WebSocketApp


# =========================================================
# CONFIG
# =========================================================

REDIS_HOST = "127.0.0.1"
REDIS_PORT = 6379
REDIS_DB = 0
REDIS_PASSWORD = None

SCANNER_URL = "https://scanner.tradingview.com/america/scan"
WS_URL = "wss://data.tradingview.com/socket.io/websocket"

DATA_DIR = Path("data")
LOGOS_DIR = DATA_DIR / "logos"
EXPORT_DIR = DATA_DIR / "exports"
DATA_DIR.mkdir(exist_ok=True)
LOGOS_DIR.mkdir(exist_ok=True)
EXPORT_DIR.mkdir(exist_ok=True)

PAGE_SIZE = 5000

# Websocket bundle behavior
WS_BATCH_SIZE = 500
WS_READY_TIMEOUT_SEC = 300
WS_RECONNECT_MIN_SEC = 180
WS_RECONNECT_MAX_SEC = 1800
WS_INITIAL_DELAY_SEC = 60
WS_MANAGER_STAGGER_SEC = 30
SYMBOL_SUBSCRIBE_DELAY_SEC = 0.05

# Logo and scanner
LOGO_SLEEP_SEC = 0.03
LOGO_TIMEOUT = 30
SCANNER_TIMEOUT = 60

# Optional test limit
TEST_SYMBOL_LIMIT = 0  # set to 0 to disable, or e.g. 1000 for testing
TEST_SYMBOL_LIMIT = 200
WS_BATCH_SIZE = 200
WS_INITIAL_DELAY_SEC = 60
WS_MANAGER_STAGGER_SEC = 30
WS_READY_TIMEOUT_SEC = 300

# Redis keys
KEY_SYMBOLS_SET = "tv:us:symbols:set"
KEY_SYMBOL_META_PREFIX = "tv:us:symbol:"
KEY_PRICE_PREFIX = "tv:us:price:"
KEY_STREAM_RAW = "tv:us:stream:raw"
KEY_LOGO_QUEUE = "tv:us:logo:queue"
KEY_STATUS = "tv:us:status"
KEY_ERRORS = "tv:us:errors"
KEY_HEARTBEAT_PREFIX = "tv:us:heartbeat:"
KEY_WS_READY_PREFIX = "tv:us:ws:ready:"
KEY_WS_STATE_PREFIX = "tv:us:ws:state:"

# Export files
SYMBOLS_CSV = EXPORT_DIR / "us_symbols_full.csv"
SYMBOLS_JSON = EXPORT_DIR / "us_symbols_full.json"

FULL_REFRESH_SYMBOLS = True


# =========================================================
# REDIS
# =========================================================

def get_redis() -> redis.Redis:
    return redis.Redis(
        host=REDIS_HOST,
        port=REDIS_PORT,
        db=REDIS_DB,
        password=REDIS_PASSWORD,
        decode_responses=True,
    )


# =========================================================
# HELPERS
# =========================================================

def random_session(prefix: str) -> str:
    suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=12))
    return f"{prefix}_{suffix}"


def safe_filename(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9._-]+", "_", s).strip("._-")


def build_logo_url(logoid: str) -> str:
    return f"https://s3-symbol-logo.tradingview.com/{logoid}.svg"


def chunked(seq: List[str], size: int):
    for i in range(0, len(seq), size):
        yield seq[i:i + size]


def log_error(r: redis.Redis, msg: str):
    print(msg)
    r.rpush(KEY_ERRORS, f"{time.strftime('%Y-%m-%d %H:%M:%S')} | {msg}")


def redis_safe_mapping(data: Dict) -> Dict:
    safe = {}
    for k, v in data.items():
        if v is None:
            safe[k] = ""
        else:
            safe[k] = str(v)
    return safe


def ws_ready_key(batch_id: int) -> str:
    return f"{KEY_WS_READY_PREFIX}{batch_id}"


def ws_state_key(batch_id: int) -> str:
    return f"{KEY_WS_STATE_PREFIX}{batch_id}"


# =========================================================
# SCANNER
# =========================================================

class TradingViewScanner:
    def __init__(self):
        self.session = requests.Session()
        self.session.headers.update({
            "User-Agent": "Mozilla/5.0",
            "Accept": "application/json",
            "Content-Type": "application/json",
            "Origin": "https://www.tradingview.com",
            "Referer": "https://www.tradingview.com/",
        })

    def _payload(self, start: int, end: int) -> Dict:
        return {
            "filter": [
                {
                    "left": "type",
                    "operation": "in_range",
                    "right": ["stock", "dr"]
                }
            ],
            "options": {"lang": "en"},
            "markets": ["america"],
            "symbols": {
                "query": {"types": []},
                "tickers": []
            },
            "columns": [
                "name",
                "description",
                "logoid",
                "type",
                "subtype",
                "exchange",
                "currency",
                "country",
                "sector",
                "industry",
                "market_cap_basic",
                "close"
            ],
            "sort": {"sortBy": "name", "sortOrder": "asc"},
            "range": [start, end]
        }

    def fetch_page(self, start: int, page_size: int = PAGE_SIZE) -> List[Dict]:
        end = start + page_size - 1
        payload = self._payload(start, end)
        resp = self.session.post(SCANNER_URL, json=payload, timeout=SCANNER_TIMEOUT)
        resp.raise_for_status()
        data = resp.json()
        return data.get("data", [])

    def fetch_all(self) -> List[Dict]:
        all_rows = []
        start = 0

        while True:
            rows = self.fetch_page(start, PAGE_SIZE)
            if not rows:
                break

            for row in rows:
                d = row.get("d", [])
                s = row.get("s", "")
                if not s:
                    continue

                item = {
                    "tv_symbol": s or "",
                    "ticker": d[0] if len(d) > 0 and d[0] is not None else "",
                    "description": d[1] if len(d) > 1 and d[1] is not None else "",
                    "logoid": d[2] if len(d) > 2 and d[2] is not None else "",
                    "type": d[3] if len(d) > 3 and d[3] is not None else "",
                    "subtype": d[4] if len(d) > 4 and d[4] is not None else "",
                    "exchange": d[5] if len(d) > 5 and d[5] is not None else "",
                    "currency": d[6] if len(d) > 6 and d[6] is not None else "",
                    "country": d[7] if len(d) > 7 and d[7] is not None else "",
                    "sector": d[8] if len(d) > 8 and d[8] is not None else "",
                    "industry": d[9] if len(d) > 9 and d[9] is not None else "",
                    "market_cap_basic": d[10] if len(d) > 10 and d[10] is not None else "",
                    "close": d[11] if len(d) > 11 and d[11] is not None else "",
                }
                all_rows.append(item)

            print(f"Fetched {len(rows)} rows, total={len(all_rows)}")
            start += PAGE_SIZE

        dedup = {}
        for item in all_rows:
            dedup[item["tv_symbol"]] = item

        return list(dedup.values())


def scan_and_store_symbols():
    r = get_redis()
    scanner = TradingViewScanner()

    if FULL_REFRESH_SYMBOLS:
        print("Refreshing all symbol keys in Redis...")
        old_symbols = list(r.smembers(KEY_SYMBOLS_SET))
        pipe = r.pipeline()
        for s in old_symbols:
            pipe.delete(f"{KEY_SYMBOL_META_PREFIX}{s}")
            pipe.delete(f"{KEY_PRICE_PREFIX}{s}")
        pipe.delete(KEY_SYMBOLS_SET)
        pipe.delete(KEY_LOGO_QUEUE)
        pipe.execute()

    symbols = scanner.fetch_all()
    print(f"Final unique symbols: {len(symbols)}")

    pipe = r.pipeline(transaction=False)
    for idx, item in enumerate(symbols, start=1):
        tv_symbol = item["tv_symbol"]
        logo_url = build_logo_url(item["logoid"]) if item.get("logoid") else ""
        item["logo_url"] = logo_url
        item["logo_local"] = ""

        safe_item = redis_safe_mapping(item)

        pipe.sadd(KEY_SYMBOLS_SET, tv_symbol)
        pipe.hset(f"{KEY_SYMBOL_META_PREFIX}{tv_symbol}", mapping=safe_item)

        if item.get("logoid"):
            pipe.rpush(KEY_LOGO_QUEUE, tv_symbol)

        if idx % 1000 == 0:
            pipe.execute()

    pipe.execute()

    r.hset(KEY_STATUS, mapping={
        "symbols_total": len(symbols),
        "symbols_last_scan_ts": time.time(),
    })

    export_symbols_to_files()
    print("Symbols stored in Redis and exported.")


def export_symbols_to_files():
    r = get_redis()
    symbols = sorted(list(r.smembers(KEY_SYMBOLS_SET)))

    rows = []
    for sym in symbols:
        meta = r.hgetall(f"{KEY_SYMBOL_META_PREFIX}{sym}")
        if meta:
            rows.append(meta)

    with open(SYMBOLS_JSON, "w", encoding="utf-8") as f:
        json.dump(rows, f, ensure_ascii=False, indent=2)

    with open(SYMBOLS_CSV, "w", encoding="utf-8", newline="") as f:
        fieldnames = [
            "tv_symbol", "ticker", "description", "logoid", "type", "subtype",
            "exchange", "currency", "country", "sector", "industry",
            "market_cap_basic", "close", "logo_url", "logo_local"
        ]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    print(f"Exported {len(rows)} symbols to:")
    print(f"  {SYMBOLS_JSON}")
    print(f"  {SYMBOLS_CSV}")


# =========================================================
# LOGO DOWNLOADER PROCESS
# =========================================================

def logo_downloader_process(worker_id: int):
    r = get_redis()
    session = requests.Session()
    session.headers.update({
        "User-Agent": "Mozilla/5.0",
        "Referer": "https://www.tradingview.com/",
    })

    print(f"[logo-{worker_id}] started")

    while True:
        try:
            item = r.blpop(KEY_LOGO_QUEUE, timeout=5)
            if not item:
                continue

            _, tv_symbol = item
            meta_key = f"{KEY_SYMBOL_META_PREFIX}{tv_symbol}"
            meta = r.hgetall(meta_key)
            if not meta:
                continue

            logoid = meta.get("logoid", "")
            if not logoid:
                continue

            exchange = meta.get("exchange", tv_symbol.split(":")[0])
            ticker = meta.get("ticker", tv_symbol.split(":")[-1])

            url = build_logo_url(logoid)
            filename = safe_filename(f"{exchange}_{ticker}.svg")
            out_path = LOGOS_DIR / filename

            if out_path.exists() and out_path.stat().st_size > 0:
                r.hset(meta_key, "logo_local", str(out_path))
                r.set(f"{KEY_HEARTBEAT_PREFIX}logo-{worker_id}", time.time())
                continue

            resp = session.get(url, timeout=LOGO_TIMEOUT)
            if resp.status_code == 200 and resp.content:
                with open(out_path, "wb") as f:
                    f.write(resp.content)
                r.hset(meta_key, "logo_local", str(out_path))
            else:
                log_error(r, f"[logo-{worker_id}] failed {tv_symbol} status={resp.status_code}")

            r.set(f"{KEY_HEARTBEAT_PREFIX}logo-{worker_id}", time.time())
            time.sleep(LOGO_SLEEP_SEC)

        except Exception as e:
            log_error(r, f"[logo-{worker_id}] exception: {e}")
            time.sleep(1)


# =========================================================
# TRADINGVIEW WS PROCESS
# =========================================================

class TVSocketClient:
    def __init__(self, batch_id: int, symbols: List[str]):
        self.batch_id = batch_id
        self.symbols = symbols
        self.r = get_redis()
        self.ws = None
        self.quote_session = random_session("qs")
        self.chart_session = random_session("cs")
        self.proc_name = f"ws-{batch_id}"
        self.consecutive_429 = 0

    @staticmethod
    def _frame_message(payload: Dict) -> str:
        body = json.dumps(payload, separators=(",", ":"))
        return f"~m~{len(body)}~m~{body}"

    def _send(self, method: str, params: List):
        msg = {"m": method, "p": params}
        self.ws.send(self._frame_message(msg))

    def _mark_connected(self):
        now = time.time()
        self.r.set(ws_ready_key(self.batch_id), "1")
        self.r.hset(ws_state_key(self.batch_id), mapping={
            "state": "connected",
            "connected_at": now,
            "last_event_at": now,
            "symbols_count": len(self.symbols),
        })
        self.r.set(f"{KEY_HEARTBEAT_PREFIX}{self.proc_name}", now)

    def _mark_disconnected(self, reason: str):
        now = time.time()
        self.r.delete(ws_ready_key(self.batch_id))
        self.r.hset(ws_state_key(self.batch_id), mapping={
            "state": "disconnected",
            "last_event_at": now,
            "reason": reason,
            "symbols_count": len(self.symbols),
        })

    def _mark_connecting(self):
        now = time.time()
        self.r.delete(ws_ready_key(self.batch_id))
        self.r.hset(ws_state_key(self.batch_id), mapping={
            "state": "connecting",
            "last_event_at": now,
            "symbols_count": len(self.symbols),
        })

    def on_open(self, ws):
        self.consecutive_429 = 0
        print(f"[{self.proc_name}] connected with {len(self.symbols)} symbols")

        self._send("set_auth_token", ["unauthorized_user_token"])
        self._send("set_locale", ["en", "US"])
        self._send("quote_create_session", [self.quote_session])
        self._send("chart_create_session", [self.chart_session, ""])

        self._send("quote_set_fields", [
            self.quote_session,
            "lp",
            "ch",
            "chp",
            "rtc",
            "rch",
            "rchp",
            "current_session",
            "description",
            "exchange",
            "currency_code",
            "volume",
            "logoid",
            "status",
            "type"
        ])

        # Subscribe symbols sequentially
        for idx, sym in enumerate(self.symbols, start=1):

            self._send("quote_add_symbols", [self.quote_session, sym])

            if idx % 50 == 0:
                time.sleep(SYMBOL_SUBSCRIBE_DELAY_SEC * 2)
            else:
                time.sleep(SYMBOL_SUBSCRIBE_DELAY_SEC)

        # Only mark ready AFTER subscriptions are sent
        self._mark_connected()

    def _extract_json_messages(self, raw: str) -> List[Dict]:
        messages = []
        idx = 0

        while idx < len(raw):
            if raw.startswith("~m~", idx):
                idx += 3
                len_end = raw.find("~m~", idx)
                if len_end == -1:
                    break

                try:
                    msg_len = int(raw[idx:len_end])
                except ValueError:
                    break

                start = len_end + 3
                body = raw[start:start + msg_len]
                idx = start + msg_len

                if body.startswith("~h~"):
                    continue

                try:
                    messages.append(json.loads(body))
                except json.JSONDecodeError:
                    continue
            else:
                break

        return messages

    def _normalize_quote_update(self, msg: Dict) -> Optional[Dict]:
        method = msg.get("m")
        params = msg.get("p", [])

        if method != "qsd" or len(params) < 2:
            return None

        raw_payload = None
        if len(params) > 1 and isinstance(params[1], dict):
            raw_payload = params[1]
        elif len(params) > 2 and isinstance(params[2], dict):
            raw_payload = params[2]

        if not raw_payload or not isinstance(raw_payload, dict):
            return None

        tv_symbol = raw_payload.get("n")
        if not tv_symbol or not isinstance(tv_symbol, str):
            return None

        payload = raw_payload.get("v", {})
        if not isinstance(payload, dict):
            payload = {}

        def val(k):
            x = payload.get(k)
            if isinstance(x, dict) and "v" in x:
                return x["v"]
            return x

        data = {
            "ts": time.time(),
            "tv_symbol": tv_symbol,
            "status": raw_payload.get("s", ""),
            "type": val("type"),
            "currency": val("currency_code"),
            "exchange": val("exchange"),
            "description": val("description"),
            "logoid": val("logoid"),
            "session": val("current_session"),
            "change": val("ch"),
            "change_percent": val("chp"),
            "regular_change": val("rch"),
            "regular_change_percent": val("rchp"),
            "volume": val("volume"),
        }

        lp = val("lp")
        if lp is not None:
            data["last_price"] = lp

        return data

    def on_message(self, ws, raw_message: str):
        try:
            if raw_message.startswith("~h~"):
                ws.send(raw_message)
                return

            messages = self._extract_json_messages(raw_message)
            pipe = self.r.pipeline(transaction=False)
            updates = 0

            for msg in messages:
                normalized = self._normalize_quote_update(msg)
                if not normalized:
                    continue

                tv_symbol = normalized["tv_symbol"]
                if not isinstance(tv_symbol, str) or ":" not in tv_symbol:
                    continue

                price_key = f"{KEY_PRICE_PREFIX}{tv_symbol}"
                old_data = self.r.hgetall(price_key)

                update_data = {
                    "ts": normalized.get("ts"),
                    "tv_symbol": tv_symbol,
                    "status": normalized.get("status"),
                    "type": normalized.get("type"),
                    "currency": normalized.get("currency"),
                    "exchange": normalized.get("exchange"),
                    "description": normalized.get("description"),
                    "logoid": normalized.get("logoid"),
                    "session": normalized.get("session"),
                }

                for field in [
                    "last_price",
                    "change",
                    "change_percent",
                    "regular_change",
                    "regular_change_percent",
                    "volume",
                ]:
                    value = normalized.get(field)
                    if value is not None and value != "":
                        update_data[field] = value

                safe_update_data = {}
                for k, v in update_data.items():
                    if v is None:
                        continue
                    safe_update_data[k] = str(v)

                if safe_update_data:
                    pipe.hset(price_key, mapping=safe_update_data)

                    stream_data = dict(safe_update_data)
                    if "last_price" not in stream_data:
                        existing_last_price = old_data.get("last_price")
                        if existing_last_price not in (None, ""):
                            stream_data["last_price"] = existing_last_price

                    pipe.xadd(
                        KEY_STREAM_RAW,
                        stream_data,
                        maxlen=500000,
                        approximate=True
                    )

                    updates += 1

            if updates:
                now = time.time()
                pipe.set(f"{KEY_HEARTBEAT_PREFIX}{self.proc_name}", now)
                pipe.hincrby(KEY_STATUS, "price_updates_total", updates)
                pipe.hset(ws_state_key(self.batch_id), mapping={
                    "state": "connected",
                    "last_event_at": now,
                    "symbols_count": len(self.symbols),
                })
                pipe.execute()

        except Exception as e:
            log_error(self.r, f"[{self.proc_name}] on_message exception: {e}")

    def on_error(self, ws, error):
        error_text = str(error)
        log_error(self.r, f"[{self.proc_name}] error: {error_text}")

        if "429" in error_text:
            self.consecutive_429 += 1
        else:
            self.consecutive_429 = 0

        self._mark_disconnected(error_text)

    def on_close(self, ws, status_code, msg):
        reason = f"closed code={status_code} msg={msg}"
        log_error(self.r, f"[{self.proc_name}] {reason}")
        self._mark_disconnected(reason)

    def run_forever(self):
        reconnect_delay = WS_RECONNECT_MIN_SEC

        print(f"[{self.proc_name}] waiting {WS_INITIAL_DELAY_SEC}s before first websocket attempt...")
        time.sleep(WS_INITIAL_DELAY_SEC)

        while True:
            try:
                self._mark_connecting()
                self.quote_session = random_session("qs")
                self.chart_session = random_session("cs")

                self.ws = WebSocketApp(
                    WS_URL,
                    on_open=self.on_open,
                    on_message=self.on_message,
                    on_error=self.on_error,
                    on_close=self.on_close,
                    header=[
                        "Origin: https://www.tradingview.com",
                        "User-Agent: Mozilla/5.0"
                    ],
                )

                self.ws.run_forever(
                    ping_interval=20,
                    ping_timeout=10,
                    sslopt={"check_hostname": True}
                )

                if self.consecutive_429 > 0:
                    sleep_for = min(
                        WS_RECONNECT_MAX_SEC,
                        reconnect_delay * (2 ** min(self.consecutive_429, 4))
                    ) + random.uniform(60, 180)
                else:
                    sleep_for = reconnect_delay + random.uniform(20, 60)

                print(f"[{self.proc_name}] reconnecting its own bundle after {sleep_for:.1f}s")
                time.sleep(sleep_for)
                reconnect_delay = min(max(reconnect_delay * 2, WS_RECONNECT_MIN_SEC), WS_RECONNECT_MAX_SEC)

            except Exception as e:
                self._mark_disconnected(str(e))
                log_error(self.r, f"[{self.proc_name}] restart exception: {e}")
                sleep_for = reconnect_delay + random.uniform(20, 60)
                print(f"[{self.proc_name}] restart sleep {sleep_for:.1f}s")
                time.sleep(sleep_for)
                reconnect_delay = min(max(reconnect_delay * 2, WS_RECONNECT_MIN_SEC), WS_RECONNECT_MAX_SEC)


def ws_worker_process(batch_id: int, symbols: List[str]):
    client = TVSocketClient(batch_id=batch_id, symbols=symbols)
    client.run_forever()


# =========================================================
# ORCHESTRATION
# =========================================================

def load_all_symbols_from_redis() -> List[str]:
    r = get_redis()
    symbols = sorted(list(r.smembers(KEY_SYMBOLS_SET)))
    if TEST_SYMBOL_LIMIT and TEST_SYMBOL_LIMIT > 0:
        symbols = symbols[:TEST_SYMBOL_LIMIT]
        print(f"TEST_SYMBOL_LIMIT enabled: using only {len(symbols)} symbols")
    return symbols


def clear_ws_manager_state(total_batches: int):
    r = get_redis()
    pipe = r.pipeline()
    for batch_id in range(1, total_batches + 1):
        pipe.delete(ws_ready_key(batch_id))
        pipe.delete(ws_state_key(batch_id))
    pipe.execute()


def wait_until_logo_queue_empty(check_every_sec: int = 3, stable_rounds: int = 3):
    """
    Wait until the Redis logo queue is empty for a few consecutive checks.
    This avoids racing on the final few items.
    """
    r = get_redis()
    stable = 0

    while True:
        remaining = r.llen(KEY_LOGO_QUEUE)
        print(f"Remaining logos in queue: {remaining}")

        if remaining == 0:
            stable += 1
            if stable >= stable_rounds:
                print("Logo queue is fully empty.")
                return
        else:
            stable = 0

        time.sleep(check_every_sec)


def wait_for_ws_ready(batch_id: int, timeout_sec: int) -> bool:
    r = get_redis()
    ready_key = ws_ready_key(batch_id)
    start = time.time()

    while time.time() - start < timeout_sec:
        if r.get(ready_key) == "1":
            return True
        time.sleep(1)

    return False


def start_logo_workers(n: int) -> List[mp.Process]:
    procs = []
    for i in range(1, n + 1):
        p = mp.Process(target=logo_downloader_process, args=(i,), daemon=True)
        p.start()
        procs.append(p)
    return procs


def start_ws_workers_sequential(symbols: List[str], batch_size: int) -> List[mp.Process]:
    procs = []
    batches = list(chunked(symbols, batch_size))

    clear_ws_manager_state(len(batches))

    print(f"Preparing {len(batches)} websocket bundles with size up to {batch_size}")

    for batch_id, batch in enumerate(batches, start=1):
        while True:
            print(f"Starting ws bundle {batch_id}/{len(batches)} with {len(batch)} symbols...")

            p = mp.Process(target=ws_worker_process, args=(batch_id, batch), daemon=True)
            p.start()

            ok = wait_for_ws_ready(batch_id, WS_READY_TIMEOUT_SEC)

            if ok:
                print(f"ws bundle {batch_id} connected successfully.")
                procs.append(p)
                time.sleep(WS_MANAGER_STAGGER_SEC)
                break

            print(f"ws bundle {batch_id} did not report ready within {WS_READY_TIMEOUT_SEC}s.")
            print(f"Stopping failed ws bundle {batch_id} and retrying the same bundle later...")

            if p.is_alive():
                p.terminate()
                p.join(timeout=5)

            # Important: do not move to next bundle
            retry_sleep = random.uniform(180, 420)
            print(f"Retrying ws bundle {batch_id} after {retry_sleep:.1f}s")
            time.sleep(retry_sleep)

    return procs

def stop_processes(procs: List[mp.Process]):
    for p in procs:
        if p.is_alive():
            p.terminate()
    for p in procs:
        p.join(timeout=5)


def print_summary():
    r = get_redis()
    symbol_count = r.scard(KEY_SYMBOLS_SET)
    status = r.hgetall(KEY_STATUS)

    print("\n===== SUMMARY =====")
    print(f"Symbols in Redis: {symbol_count}")
    print(f"Status: {json.dumps(status, indent=2)}")
    print("===================\n")


def main():
    mp.set_start_method("spawn", force=True)

    r = get_redis()
    r.hset(KEY_STATUS, mapping={
        "started_at": time.time(),
        "price_updates_total": 0,
    })

    print("Step 1: scanning and storing symbols...")
    scan_and_store_symbols()

    symbols = load_all_symbols_from_redis()
    print(f"Total symbols loaded from Redis: {len(symbols)}")

    #logo_workers_count = 8
    #print(f"Step 2: starting {logo_workers_count} logo workers...")
    #logo_procs = start_logo_workers(logo_workers_count)

    #print("Step 3: waiting for all logos to finish downloading...")
    #wait_until_logo_queue_empty(check_every_sec=3, stable_rounds=3)

    #print("Step 4: stopping logo workers...")
    #stop_processes(logo_procs)

    print("Step 5: starting websocket bundles sequentially...")
    ws_procs = start_ws_workers_sequential(symbols, WS_BATCH_SIZE)

    all_procs = ws_procs
    print_summary()

    def handle_exit(signum, frame):
        print("\nStopping all worker processes...")
        stop_processes(all_procs)
        print("Stopped.")
        raise SystemExit(0)

    signal.signal(signal.SIGINT, handle_exit)
    signal.signal(signal.SIGTERM, handle_exit)

    while True:
        time.sleep(10)
        alive = sum(1 for p in all_procs if p.is_alive())
        r.hset(KEY_STATUS, "alive_processes", alive)
        print(f"Alive worker processes: {alive}/{len(all_procs)}")


if __name__ == "__main__":
    main()
