from typing import Optional, List, Dict, Any
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import JSONResponse
import redis
import os
import json

app = FastAPI(title="TradingView US Market Redis API", version="1.0.0")

# =========================================================
# Redis Config
# =========================================================

REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1")
REDIS_PORT = int(os.getenv("REDIS_PORT", "6379"))
REDIS_DB = int(os.getenv("REDIS_DB", "0"))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)

KEY_SYMBOLS_SET = "tv:us:symbols:set"
KEY_SYMBOL_META_PREFIX = "tv:us:symbol:"
KEY_PRICE_PREFIX = "tv:us:price:"
KEY_STREAM_RAW = "tv:us:stream:raw"
KEY_STATUS = "tv:us:status"

r = redis.Redis(
    host=REDIS_HOST,
    port=REDIS_PORT,
    db=REDIS_DB,
    password=REDIS_PASSWORD,
    decode_responses=True,
)

# =========================================================
# Helpers
# =========================================================

def get_symbol_meta(tv_symbol: str) -> Dict[str, Any]:
    return r.hgetall(f"{KEY_SYMBOL_META_PREFIX}{tv_symbol}")

def get_symbol_price(tv_symbol: str) -> Dict[str, Any]:
    return r.hgetall(f"{KEY_PRICE_PREFIX}{tv_symbol}")

def parse_float(v):
    try:
        if v is None or v == "":
            return None
        return float(v)
    except Exception:
        return None

def parse_int(v):
    try:
        if v is None or v == "":
            return None
        return int(float(v))
    except Exception:
        return None

def merge_symbol_and_price(tv_symbol: str) -> Dict[str, Any]:
    meta = get_symbol_meta(tv_symbol)
    if not meta:
        return {}

    price = get_symbol_price(tv_symbol)

    return {
        "tv_symbol": tv_symbol,
        "ticker": meta.get("ticker"),
        "description": meta.get("description"),
        "exchange": meta.get("exchange"),
        "country": meta.get("country"),
        "currency": meta.get("currency"),
        "sector": meta.get("sector"),
        "industry": meta.get("industry"),
        "type": meta.get("type"),
        "subtype": meta.get("subtype"),
        "market_cap_basic": parse_float(meta.get("market_cap_basic")),
        "scanner_close": parse_float(meta.get("close")),
        "logoid": meta.get("logoid"),
        "logo_url": meta.get("logo_url"),
        "logo_local": meta.get("logo_local"),
        "price": {
            "last_price": parse_float(price.get("last_price")),
            "change": parse_float(price.get("change")),
            "change_percent": parse_float(price.get("change_percent")),
            "regular_change": parse_float(price.get("regular_change")),
            "regular_change_percent": parse_float(price.get("regular_change_percent")),
            "volume": parse_float(price.get("volume")),
            "session": price.get("session"),
            "status": price.get("status"),
            "ts": parse_float(price.get("ts")),
        }
    }

def list_all_symbols() -> List[str]:
    return sorted(list(r.smembers(KEY_SYMBOLS_SET)))

# =========================================================
# Routes
# =========================================================

@app.get("/health")
def health():
    try:
        pong = r.ping()
        status = r.hgetall(KEY_STATUS)
        return {
            "ok": bool(pong),
            "redis": "connected",
            "status": status
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Redis error: {e}")

@app.get("/symbols")
def symbols(
    limit: int = Query(100, ge=1, le=5000),
    offset: int = Query(0, ge=0),
    exchange: Optional[str] = None,
    sector: Optional[str] = None,
    with_price: bool = Query(False)
):
    all_symbols = list_all_symbols()
    rows = []

    for tv_symbol in all_symbols:
        meta = get_symbol_meta(tv_symbol)
        if not meta:
            continue

        if exchange and (meta.get("exchange") or "").upper() != exchange.upper():
            continue

        if sector and (meta.get("sector") or "").lower() != sector.lower():
            continue

        if with_price:
            rows.append(merge_symbol_and_price(tv_symbol))
        else:
            rows.append({
                "tv_symbol": tv_symbol,
                "ticker": meta.get("ticker"),
                "description": meta.get("description"),
                "exchange": meta.get("exchange"),
                "sector": meta.get("sector"),
                "industry": meta.get("industry"),
                "logo_url": meta.get("logo_url"),
                "logo_local": meta.get("logo_local"),
            })

    total = len(rows)
    paged = rows[offset:offset + limit]

    return {
        "total": total,
        "offset": offset,
        "limit": limit,
        "count": len(paged),
        "results": paged
    }

@app.get("/symbol/{tv_symbol:path}")
def symbol_detail(tv_symbol: str):
    data = merge_symbol_and_price(tv_symbol)
    if not data:
        raise HTTPException(status_code=404, detail="Symbol not found")
    return data

@app.get("/price/{tv_symbol:path}")
def price_detail(tv_symbol: str):
    price = get_symbol_price(tv_symbol)
    if not price:
        raise HTTPException(status_code=404, detail="Price not found for symbol")

    return {
        "tv_symbol": tv_symbol,
        "last_price": parse_float(price.get("last_price")),
        "change": parse_float(price.get("change")),
        "change_percent": parse_float(price.get("change_percent")),
        "regular_change": parse_float(price.get("regular_change")),
        "regular_change_percent": parse_float(price.get("regular_change_percent")),
        "volume": parse_float(price.get("volume")),
        "currency": price.get("currency"),
        "exchange": price.get("exchange"),
        "description": price.get("description"),
        "session": price.get("session"),
        "status": price.get("status"),
        "ts": parse_float(price.get("ts")),
    }

@app.get("/top-movers")
def top_movers(
    direction: str = Query("gainers", pattern="^(gainers|losers)$"),
    limit: int = Query(50, ge=1, le=500)
):
    all_symbols = list_all_symbols()
    rows = []

    for tv_symbol in all_symbols:
        meta = get_symbol_meta(tv_symbol)
        price = get_symbol_price(tv_symbol)

        if not meta or not price:
            continue

        chg_pct = parse_float(price.get("change_percent"))
        last_price = parse_float(price.get("last_price"))

        if chg_pct is None or last_price is None:
            continue

        rows.append({
            "tv_symbol": tv_symbol,
            "ticker": meta.get("ticker"),
            "description": meta.get("description"),
            "exchange": meta.get("exchange"),
            "sector": meta.get("sector"),
            "industry": meta.get("industry"),
            "logo_url": meta.get("logo_url"),
            "logo_local": meta.get("logo_local"),
            "last_price": last_price,
            "change": parse_float(price.get("change")),
            "change_percent": chg_pct,
            "volume": parse_float(price.get("volume")),
            "ts": parse_float(price.get("ts")),
        })

    rows.sort(key=lambda x: (x["change_percent"] if x["change_percent"] is not None else -999999), reverse=(direction == "gainers"))

    if direction == "losers":
        rows.sort(key=lambda x: (x["change_percent"] if x["change_percent"] is not None else 999999))

    return {
        "direction": direction,
        "count": min(limit, len(rows)),
        "results": rows[:limit]
    }

@app.get("/search")
def search(
    q: str = Query(..., min_length=1),
    limit: int = Query(50, ge=1, le=500),
    with_price: bool = Query(False)
):
    q_lower = q.lower().strip()
    all_symbols = list_all_symbols()
    results = []

    for tv_symbol in all_symbols:
        meta = get_symbol_meta(tv_symbol)
        if not meta:
            continue

        haystack = " ".join([
            tv_symbol,
            meta.get("ticker", ""),
            meta.get("description", ""),
            meta.get("exchange", ""),
            meta.get("sector", ""),
            meta.get("industry", ""),
        ]).lower()

        if q_lower in haystack:
            if with_price:
                results.append(merge_symbol_and_price(tv_symbol))
            else:
                results.append({
                    "tv_symbol": tv_symbol,
                    "ticker": meta.get("ticker"),
                    "description": meta.get("description"),
                    "exchange": meta.get("exchange"),
                    "sector": meta.get("sector"),
                    "industry": meta.get("industry"),
                    "logo_url": meta.get("logo_url"),
                    "logo_local": meta.get("logo_local"),
                })

        if len(results) >= limit:
            break

    return {
        "query": q,
        "count": len(results),
        "results": results
    }

@app.get("/stream/latest")
def stream_latest(count: int = Query(20, ge=1, le=500)):
    try:
        items = r.xrevrange(KEY_STREAM_RAW, count=count)
        results = []

        for stream_id, fields in items:
            item = dict(fields)
            item["stream_id"] = stream_id
            item["last_price"] = parse_float(item.get("last_price"))
            item["change"] = parse_float(item.get("change"))
            item["change_percent"] = parse_float(item.get("change_percent"))
            item["regular_change"] = parse_float(item.get("regular_change"))
            item["regular_change_percent"] = parse_float(item.get("regular_change_percent"))
            item["volume"] = parse_float(item.get("volume"))
            item["ts"] = parse_float(item.get("ts"))
            results.append(item)

        return {
            "count": len(results),
            "results": results
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))