#!/usr/bin/env python3
"""
massive_ws_us_last_prices.py

Purpose:
- Connect to Massive Stocks WebSocket using the official Massive SDK
- Subscribe to delayed stock aggregate updates
- Display any symbol that receives a last price from WS
- Store latest price in Redis as:
    HSET us_market:last_prices <SYMBOL> <PRICE>
- Publish each update to Redis Pub/Sub channel:
    PUBLISH prices {"symbol": "AAPL", "price": "213.45"}
- Track last WS update time per symbol
- If a symbol has no WS update for > 15 minutes, fallback to Massive REST
  daily close for that same symbol using RESTClient.get_daily_open_close_agg()
- Keep running continuously

Required env vars:
  MASSIVE_API_KEY=your_key

Optional env vars:
  REDIS_HOST=127.0.0.1
  REDIS_PORT=6379
  REDIS_DB=0
  REDIS_PASSWORD=
  REDIS_HASH_KEY=us_market:last_prices
  REDIS_CHANNEL=prices
  FALLBACK_AFTER_SECONDS=900
  FALLBACK_CHECK_INTERVAL=30
  PRINT_WS_TICKS=true
  PRINT_FALLBACKS=true
  MONITOR_SYMBOLS=AAPL,MSFT,NVDA,SPY,TSLA
"""

import json
import logging
import os
import signal
import sys
import threading
import time
from datetime import date, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple

from dotenv import load_dotenv
load_dotenv()

import redis
from massive import RESTClient, WebSocketClient
from massive.websocket.models import Feed, Market, WebSocketMessage

# =========================
# Configuration
# =========================
API_KEY = os.getenv("MASSIVE_API_KEY", "").strip()

REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1")
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB = int(os.getenv("REDIS_DB", "0"))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "") or None
REDIS_HASH_KEY = os.getenv("REDIS_HASH_KEY", "us_market:last_prices")
REDIS_CHANNEL = os.getenv("REDIS_CHANNEL", "prices")

FALLBACK_AFTER_SECONDS = int(os.getenv("FALLBACK_AFTER_SECONDS", "900"))
FALLBACK_CHECK_INTERVAL = int(os.getenv("FALLBACK_CHECK_INTERVAL", "30"))
PRINT_WS_TICKS = os.getenv("PRINT_WS_TICKS", "true").strip().lower() == "false"
PRINT_FALLBACKS = os.getenv("PRINT_FALLBACKS", "true").strip().lower() == "false"

MONITOR_SYMBOLS = {
    s.strip().upper()
    for s in os.getenv("MONITOR_SYMBOLS", "").split(",")
    if s.strip()
}

REDIS_PIPELINE_BATCH_SIZE = 1000
REST_LOOKBACK_DAYS = 10

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger("massive_ws_last_prices")

# =========================
# Shared state
# =========================
symbol_last_ws_ts: Dict[str, float] = {}
symbol_last_price: Dict[str, str] = {}
fallback_in_progress: Set[str] = set()
state_lock = threading.Lock()
stop_event = threading.Event()

redis_client: Optional[redis.Redis] = None
rest_client: Optional[RESTClient] = None


def require_env() -> None:
    if not API_KEY:
        raise RuntimeError("Missing MASSIVE_API_KEY")


def get_redis() -> redis.Redis:
    return redis.Redis(
        host=REDIS_HOST,
        port=REDIS_PORT,
        db=REDIS_DB,
        password=REDIS_PASSWORD,
        decode_responses=True,
        socket_timeout=10,
        socket_connect_timeout=10,
        health_check_interval=30,
    )


def test_redis_connection(r: redis.Redis) -> None:
    try:
        r.ping()
        logger.info("Redis connected successfully")

        test_field = "__redis_test__"
        test_value = str(int(time.time()))
        r.hset(REDIS_HASH_KEY, test_field, test_value)
        returned = r.hget(REDIS_HASH_KEY, test_field)

        if returned == test_value:
            logger.info("Redis read/write working")
        else:
            logger.error("Redis read/write mismatch: wrote=%s read=%s", test_value, returned)

        r.hdel(REDIS_HASH_KEY, test_field)

    except Exception as e:
        logger.error("Redis connection failed: %s", e)
        raise


def save_price(symbol: str, price: str, source: str = "ws") -> None:
    """
    Save the latest price in Redis hash and publish the update to Pub/Sub.
    """
    payload = {"symbol": symbol, "price": price}
    redis_client.hset(REDIS_HASH_KEY, symbol, price)
    redis_client.publish(REDIS_CHANNEL, json.dumps(payload))

    with state_lock:
        symbol_last_price[symbol] = price
        symbol_last_ws_ts[symbol] = time.time()

    #if source == "ws" and PRINT_WS_TICKS:
        #logger.info("WS LAST PRICE: %s = %s", symbol, price)
    #elif source != "ws" and PRINT_FALLBACKS:
        #logger.info("%s PRICE: %s = %s", source.upper(), symbol, price)


def flush_to_redis(r: redis.Redis, pairs: Dict[str, str], source: str = "ws") -> int:
    if not pairs:
        return 0

    items = list(pairs.items())
    total = 0

    for i in range(0, len(items), REDIS_PIPELINE_BATCH_SIZE):
        chunk = items[i:i + REDIS_PIPELINE_BATCH_SIZE]

        pipe = r.pipeline(transaction=False)
        mapping = {symbol: price for symbol, price in chunk}
        pipe.hset(REDIS_HASH_KEY, mapping=mapping)

        now_ts = time.time()
        for symbol, price in chunk:
            pipe.publish(REDIS_CHANNEL, json.dumps({"symbol": symbol, "price": price}))
            with state_lock:
                symbol_last_price[symbol] = price
                symbol_last_ws_ts[symbol] = now_ts
            #if source == "ws" and PRINT_WS_TICKS:
                #logger.info("WS LAST PRICE: %s = %s", symbol, price)

        pipe.execute()
        total += len(chunk)

    return total


def get_attr(msg: Any, *names: str) -> Any:
    for name in names:
        if hasattr(msg, name):
            value = getattr(msg, name)
            if value is not None:
                return value
    return None


def extract_symbol_and_price(m: WebSocketMessage) -> Tuple[Optional[str], Optional[str]]:
    """
    Massive aggregate messages are identified by event type A.
    We try both SDK-style attrs and short attrs.
    """
    event_type = get_attr(m, "event_type", "ev")
    if event_type != "A":
        return None, None

    symbol = get_attr(m, "symbol", "sym")
    close_price = get_attr(m, "close", "c")

    if not symbol or close_price is None:
        return None, None

    return str(symbol).upper(), str(close_price)


def fetch_latest_daily_close(symbol: str) -> Optional[Tuple[str, str]]:
    """
    Walk backward until a valid daily close is found.
    Uses official Massive RESTClient.get_daily_open_close_agg().
    """
    today = date.today()

    for i in range(REST_LOOKBACK_DAYS):
        dt = today - timedelta(days=i)
        dt_str = dt.isoformat()

        try:
            result = rest_client.get_daily_open_close_agg(
                symbol,
                dt_str,
                adjusted="true",
            )

            if result is None:
                continue

            close_price = None
            from_date = dt_str

            if hasattr(result, "close"):
                close_price = getattr(result, "close", None)
            elif isinstance(result, dict):
                close_price = result.get("close")
                from_date = result.get("from", dt_str)

            if close_price is None:
                continue

            return from_date, str(close_price)

        except Exception as e:
            return None
            #logger.warning("REST fallback failed for %s on %s: %s", symbol, dt_str, e)

    return None


def fallback_symbol_to_redis(r: redis.Redis, symbol: str) -> bool:
    result = fetch_latest_daily_close(symbol)
    if not result:
        #logger.warning("No REST fallback close found for %s", symbol)
        return False

    from_date, close_str = result

    try:
        save_price(symbol, close_str, source="rest_fallback")
        read_back = r.hget(REDIS_HASH_KEY, symbol)
        #logger.info("Redis fallback check: %s = %s (date=%s)", symbol, read_back, from_date)
        return True

    except Exception as e:
        #logger.error("Redis fallback write failed for %s: %s", symbol, e)
        return False


def fallback_monitor() -> None:
    global redis_client

    while not stop_event.is_set():
        stop_event.wait(FALLBACK_CHECK_INTERVAL)
        if stop_event.is_set():
            break

        now = time.time()

        with state_lock:
            tracked_symbols = set(symbol_last_ws_ts.keys()) | set(MONITOR_SYMBOLS)

        stale_symbols = []

        with state_lock:
            for symbol in tracked_symbols:
                last_ts = symbol_last_ws_ts.get(symbol)

                if last_ts is None or last_ts == 0.0:
                    stale_symbols.append(symbol)
                    continue

                if (now - last_ts) > FALLBACK_AFTER_SECONDS:
                    stale_symbols.append(symbol)

        for symbol in stale_symbols:
            with state_lock:
                if symbol in fallback_in_progress:
                    continue
                fallback_in_progress.add(symbol)

            try:
                fallback_symbol_to_redis(redis_client, symbol)
            finally:
                with state_lock:
                    fallback_in_progress.discard(symbol)


def handle_msg(msgs: List[WebSocketMessage]) -> None:
    global redis_client

    redis_updates: Dict[str, str] = {}

    for m in msgs:
        symbol, price_str = extract_symbol_and_price(m)
        if not symbol or price_str is None:
            continue
        redis_updates[symbol] = price_str

    if redis_updates:
        try:
            flush_to_redis(redis_client, redis_updates, source="ws")
        except Exception as exc:
            logger.error("Redis write failed: %s", exc)


def run_ws_forever() -> None:
    while not stop_event.is_set():
        try:
            #logger.info("Starting Massive WebSocket client with feed=delayed market=stocks")

            ws_client = WebSocketClient(
                api_key=API_KEY,
                feed=Feed.Delayed,
                market=Market.Stocks,
            )

            ws_client.subscribe("A.*")
            ws_client.run(handle_msg)

        except Exception as exc:
            logger.exception("WebSocket client error: %s", exc)

        if not stop_event.is_set():
            logger.info("Reconnecting WebSocket client in 5 seconds...")
            stop_event.wait(5)


def handle_shutdown(signum, frame) -> None:
    logger.info("Shutdown signal received")
    stop_event.set()


def main() -> int:
    global redis_client, rest_client

    try:
        require_env()

        redis_client = get_redis()
        test_redis_connection(redis_client)

        rest_client = RESTClient(API_KEY)

        if MONITOR_SYMBOLS:
            with state_lock:
                for sym in MONITOR_SYMBOLS:
                    symbol_last_ws_ts.setdefault(sym, 0.0)
                    symbol_last_price.setdefault(sym, "")

        signal.signal(signal.SIGINT, handle_shutdown)
        signal.signal(signal.SIGTERM, handle_shutdown)

        fallback_thread = threading.Thread(target=fallback_monitor, daemon=True)
        fallback_thread.start()

        run_ws_forever()
        return 0

    except KeyboardInterrupt:
        logger.info("Stopped by user")
        return 0
    except Exception as exc:
        logger.error("Fatal error: %s", exc)
        return 1


if __name__ == "__main__":
    sys.exit(main())
