# yfinance_stocks.py

import json
import time
from io import StringIO
from pathlib import Path

import pandas as pd
import requests
import yfinance as yf


# =========================================================
# CONFIG
# =========================================================
NASDAQ_URL = "https://www.nasdaqtrader.com/dynamic/symdir/nasdaqlisted.txt"
OTHER_URL = "https://www.nasdaqtrader.com/dynamic/symdir/otherlisted.txt"

OUT_DIR = Path("data")
CACHE_DIR = OUT_DIR / "cache"

OUT_DIR.mkdir(parents=True, exist_ok=True)
CACHE_DIR.mkdir(parents=True, exist_ok=True)

NASDAQ_CACHE = CACHE_DIR / "nasdaqlisted.txt"
OTHER_CACHE = CACHE_DIR / "otherlisted.txt"

SYMBOLS_FILE = OUT_DIR / "us_symbols_all.json"
PRICES_FILE = OUT_DIR / "us_symbols_prices.json"
MISSING_FILE = OUT_DIR / "missing_symbols.json"
SUMMARY_FILE = OUT_DIR / "price_summary.json"

BATCH_SIZE = 100
BATCH_SLEEP_SEC = 2.0
RETRY_SLEEP_SEC = 0.2

DOWNLOAD_RETRIES = 3
DOWNLOAD_TIMEOUT_SEC = 30

YF_PERIOD = "1d"
YF_INTERVAL = "1m"
YF_PREPOST = True

TEST_SYMBOL_LIMIT = 0   # 0 = all symbols


# =========================================================
# HELPERS
# =========================================================
def save_json(path: Path, obj):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)


def chunked(seq, size):
    for i in range(0, len(seq), size):
        yield seq[i:i + size]


def safe_float(x):
    try:
        return float(x)
    except Exception:
        return None


# =========================================================
# SYMBOL DOWNLOAD / CACHE
# =========================================================
def download_text(url: str, cache_path: Path, retries=3, timeout=30) -> str:
    headers = {
        "User-Agent": "Mozilla/5.0",
        "Accept": "text/plain,text/csv,*/*",
    }

    last_err = None

    for attempt in range(1, retries + 1):
        try:
            r = requests.get(url, headers=headers, timeout=timeout)
            r.raise_for_status()

            text = r.text
            cache_path.write_text(text, encoding="utf-8")

            print(f"Downloaded: {url}")
            return text

        except Exception as e:
            last_err = e
            print(f"Attempt {attempt}/{retries} failed for {url}: {e}")
            time.sleep(2 * attempt)

    if cache_path.exists():
        print(f"Using cached file: {cache_path}")
        return cache_path.read_text(encoding="utf-8")

    raise RuntimeError(f"Failed to download {url} and no cache exists. Last error: {last_err}")


def load_us_symbols():
    nasdaq_text = download_text(
        NASDAQ_URL,
        NASDAQ_CACHE,
        retries=DOWNLOAD_RETRIES,
        timeout=DOWNLOAD_TIMEOUT_SEC,
    )
    other_text = download_text(
        OTHER_URL,
        OTHER_CACHE,
        retries=DOWNLOAD_RETRIES,
        timeout=DOWNLOAD_TIMEOUT_SEC,
    )

    nasdaq = pd.read_csv(StringIO(nasdaq_text), sep="|")
    other = pd.read_csv(StringIO(other_text), sep="|")

    # remove footer rows
    nasdaq = nasdaq[nasdaq["Symbol"] != "File Creation Time"]
    other = other[other["ACT Symbol"] != "File Creation Time"]

    nasdaq_symbols = nasdaq["Symbol"].dropna().astype(str).str.strip().tolist()
    other_symbols = other["ACT Symbol"].dropna().astype(str).str.strip().tolist()

    all_symbols = sorted(set(nasdaq_symbols + other_symbols))

    if TEST_SYMBOL_LIMIT > 0:
        all_symbols = all_symbols[:TEST_SYMBOL_LIMIT]

    return all_symbols


# =========================================================
# YFINANCE PRICE EXTRACTION
# =========================================================
def get_price_from_download(data, symbol):
    """
    Supports both:
    - multi-ticker DataFrame with MultiIndex columns
    - single-ticker DataFrame
    """
    try:
        if data is None or len(data) == 0:
            return None

        # Multi-ticker shape
        if isinstance(data.columns, pd.MultiIndex):
            lvl0 = data.columns.get_level_values(0)
            if symbol not in lvl0:
                return None

            sub = data[symbol]
            if "Close" not in sub.columns:
                return None

            close_series = sub["Close"].dropna()
            if len(close_series) == 0:
                return None

            return safe_float(close_series.iloc[-1])

        # Single-ticker shape
        else:
            if "Close" not in data.columns:
                return None

            close_series = data["Close"].dropna()
            if len(close_series) == 0:
                return None

            return safe_float(close_series.iloc[-1])

    except Exception:
        return None


def retry_single_symbol(symbol):
    """
    Retry one symbol using:
    1) fast_info lastPrice
    2) 1d history fallback
    """
    try:
        t = yf.Ticker(symbol)

        try:
            fi = t.fast_info
            if fi:
                price = safe_float(fi.get("lastPrice"))
                if price is not None:
                    return price
        except Exception:
            pass

        try:
            hist = t.history(period="5d", interval="1d", auto_adjust=False)
            if hist is not None and len(hist) > 0 and "Close" in hist.columns:
                close_series = hist["Close"].dropna()
                if len(close_series) > 0:
                    return safe_float(close_series.iloc[-1])
        except Exception:
            pass

    except Exception:
        pass

    return None


# =========================================================
# BATCH FETCH
# =========================================================
def fetch_latest_prices(symbols):
    prices = {}
    valid = []
    missing = []

    total = len(symbols)

    for idx, batch in enumerate(chunked(symbols, BATCH_SIZE), start=1):
        batch_start = (idx - 1) * BATCH_SIZE + 1
        batch_end = min(idx * BATCH_SIZE, total)

        tickers_str = " ".join(batch)

        data = None
        try:
            data = yf.download(
                tickers=tickers_str,
                period=YF_PERIOD,
                interval=YF_INTERVAL,
                group_by="ticker",
                auto_adjust=False,
                progress=False,
                prepost=YF_PREPOST,
                threads=True,
            )
        except Exception as e:
            print(f"Batch {batch_start}-{batch_end} download error: {e}")

        for sym in batch:
            price = get_price_from_download(data, sym)

            if price is not None:
                prices[sym] = price
                valid.append(sym)
            else:
                missing.append(sym)

        print(
            f"Processed {batch_end}/{total} | "
            f"found: {len(valid)} | missing: {len(missing)}"
        )

        time.sleep(BATCH_SLEEP_SEC)

    return prices, valid, missing


# =========================================================
# RETRY MISSING
# =========================================================
def retry_missing_symbols(missing):
    recovered = {}
    still_missing = []

    total = len(missing)

    for i, sym in enumerate(missing, start=1):
        price = retry_single_symbol(sym)

        if price is not None:
            recovered[sym] = price
        else:
            still_missing.append(sym)

        if i % 50 == 0 or i == total:
            print(
                f"Retry progress: {i}/{total} | "
                f"recovered: {len(recovered)} | still missing: {len(still_missing)}"
            )

        time.sleep(RETRY_SLEEP_SEC)

    return recovered, still_missing


# =========================================================
# MAIN
# =========================================================
def main():
    print("Downloading US symbol lists...")
    symbols = load_us_symbols()

    print(f"Total US symbols loaded: {len(symbols)}")
    save_json(SYMBOLS_FILE, symbols)

    print("\nFetching latest prices in batches...")
    prices, valid, missing = fetch_latest_prices(symbols)

    print("\nRetrying missing symbols individually...")
    recovered, still_missing = retry_missing_symbols(missing)

    prices.update(recovered)

    prices_sorted = dict(sorted(prices.items()))

    total = len(symbols)
    initial_found = len(valid)
    missing_after_batch = len(missing)
    recovered_count = len(recovered)
    final_found = len(prices_sorted)
    final_missing = len(still_missing)
    coverage = round((final_found / total) * 100, 2) if total else 0.0

    save_json(PRICES_FILE, prices_sorted)
    save_json(MISSING_FILE, still_missing)

    summary = {
        "total_symbols": total,
        "prices_found_initial_batch": initial_found,
        "missing_after_batch": missing_after_batch,
        "recovered_on_retry": recovered_count,
        "final_prices_found": final_found,
        "final_missing": final_missing,
        "coverage_percent": coverage,
        "symbols_file": str(SYMBOLS_FILE),
        "prices_file": str(PRICES_FILE),
        "missing_file": str(MISSING_FILE),
    }
    save_json(SUMMARY_FILE, summary)

    print("\n==================== SUMMARY ====================")
    print(f"Total symbols         : {total}")
    print(f"Initial prices found  : {initial_found}")
    print(f"Missing after batch   : {missing_after_batch}")
    print(f"Recovered on retry    : {recovered_count}")
    print(f"Final prices found    : {final_found}")
    print(f"Final missing         : {final_missing}")
    print(f"Coverage              : {coverage}%")
    print(f"Saved symbols to      : {SYMBOLS_FILE}")
    print(f"Saved prices to       : {PRICES_FILE}")
    print(f"Saved missing to      : {MISSING_FILE}")
    print(f"Saved summary to      : {SUMMARY_FILE}")


if __name__ == "__main__":
    main()