#!/usr/bin/env python3
"""
Earnings Strategy Backtester
Strategy: Buy pre-earnings stagnant stocks, sell day before, re-enter on beat+dip
"""

import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import yfinance as yf
from datetime import datetime, timedelta
import time

# ── Config ────────────────────────────────────────────────────────────────────
START_DATE = "2021-03-01"
END_DATE   = "2026-03-01"
INITIAL_CAPITAL = 10_000

STOCKS = [
    # Tech
    "GOOGL", "META", "AAPL", "MSFT", "NVDA",
    # Retail
    "AMZN", "WMT", "COST", "TGT",
    # Finance
    "JPM", "BAC", "GS",
    # Healthcare
    "JNJ", "LLY", "UNH",
    # Energy
    "XOM", "CVX",
    # Industrial
    "CAT", "HON", "GE",
]

RESULTS_DIR = os.path.join(os.path.dirname(__file__), "results")
os.makedirs(RESULTS_DIR, exist_ok=True)

# ── Helpers ───────────────────────────────────────────────────────────────────

def fetch_price_data(tickers: list[str]) -> dict[str, pd.DataFrame]:
    """Batch-fetch OHLCV for all tickers."""
    print(f"Fetching price data for {len(tickers)} tickers…")
    raw = yf.download(
        tickers,
        start=START_DATE,
        end=END_DATE,
        auto_adjust=True,
        progress=False,
        group_by="ticker",
        threads=True,
    )
    result = {}
    if len(tickers) == 1:
        result[tickers[0]] = raw.dropna(how="all")
        return result
    for t in tickers:
        try:
            df = raw[t].dropna(how="all").copy()
            result[t] = df
        except Exception:
            print(f"  ⚠  No price data for {t}")
    return result


def fetch_earnings(ticker: str) -> pd.DataFrame:
    """Fetch earnings history for one ticker via yfinance."""
    try:
        tk = yf.Ticker(ticker)
        # earnings_dates gives DataFrame with EPS Est / EPS Actual
        df = tk.earnings_dates
        if df is None or df.empty:
            return pd.DataFrame()
        df = df.copy()
        df.index = pd.to_datetime(df.index).tz_localize(None)
        df = df.sort_index()
        # Keep only rows in our date range that have actual EPS
        df = df[
            (df.index >= pd.Timestamp(START_DATE)) &
            (df.index <= pd.Timestamp(END_DATE))
        ]
        # Rename columns for convenience
        df.columns = [c.strip() for c in df.columns]
        return df
    except Exception as e:
        print(f"  ⚠  Earnings fetch failed for {ticker}: {e}")
        return pd.DataFrame()


def get_trading_day(price_df: pd.DataFrame, ref_date, offset: int = 0) -> pd.Timestamp | None:
    """Return the trading day nearest to ref_date + offset calendar days."""
    target = pd.Timestamp(ref_date) + pd.Timedelta(days=offset)
    idx = price_df.index
    # Find closest date >= target
    future = idx[idx >= target]
    if len(future) == 0:
        return None
    return future[0]


def get_close(price_df: pd.DataFrame, date) -> float | None:
    if date is None or date not in price_df.index:
        return None
    return float(price_df.loc[date, "Close"])


def trading_days_between(price_df: pd.DataFrame, d1, d2) -> int:
    if d1 is None or d2 is None:
        return 0
    idx = price_df.index
    mask = (idx >= d1) & (idx <= d2)
    return int(mask.sum())


# ── Core Strategy ─────────────────────────────────────────────────────────────

def backtest_stock(ticker: str, price_df: pd.DataFrame, earnings_df: pd.DataFrame) -> list[dict]:
    """Run the earnings strategy on a single stock. Returns list of trade dicts."""
    trades = []

    if price_df.empty or earnings_df.empty:
        return trades

    price_idx = price_df.index

    for earn_date in earnings_df.index:
        # ── Earnings metadata ──────────────────────────────────────────────
        row = earnings_df.loc[earn_date]
        eps_actual = row.get("EPS Actual", np.nan)
        eps_est    = row.get("EPS Estimate", np.nan)

        beat = False
        if pd.notna(eps_actual) and pd.notna(eps_est) and eps_est != 0:
            beat = eps_actual > eps_est

        # ── Reference dates ────────────────────────────────────────────────
        # T-1: last trading day before earnings
        pre_earn_days = price_idx[price_idx < earn_date]
        if len(pre_earn_days) < 31:
            continue
        t_minus_1 = pre_earn_days[-1]           # day before earnings
        t_minus_14_approx = get_trading_day(price_df, earn_date, offset=-14)
        t_minus_30_approx = get_trading_day(price_df, earn_date, offset=-30)

        if t_minus_14_approx is None or t_minus_30_approx is None:
            continue

        price_t30 = get_close(price_df, t_minus_30_approx)
        price_t14 = get_close(price_df, t_minus_14_approx)
        price_t1  = get_close(price_df, t_minus_1)

        if price_t30 is None or price_t14 is None or price_t1 is None:
            continue

        # ── Window 1: T-30 to T-14 stagnant check ─────────────────────────
        change_30_to_14 = (price_t14 - price_t30) / price_t30
        stagnant_window1 = abs(change_30_to_14) <= 0.03

        if not stagnant_window1:
            continue  # Skip if not stagnant entering window 1

        # Entry 1: buy at T-30 close (50% of position)
        entry1_date  = t_minus_30_approx
        entry1_price = price_t30

        # ── Window 2: T-14 to T-1 stagnant or rising from T-30 ────────────
        change_30_to_1 = (price_t1 - price_t30) / price_t30
        eligible_window2 = change_30_to_1 >= -0.03  # stagnant or rising

        entry2_date  = t_minus_14_approx
        entry2_price = price_t14 if eligible_window2 else None

        # ── Pre-Earnings Exit at T-1 close ─────────────────────────────────
        exit_pre_date  = t_minus_1
        exit_pre_price = price_t1

        # Blended entry price
        if entry2_price is not None:
            avg_entry = (entry1_price + entry2_price) / 2
            shares_entry = 1.0  # normalised to 1 unit
        else:
            avg_entry = entry1_price
            shares_entry = 0.5  # only half position

        pre_earn_return = (exit_pre_price - avg_entry) / avg_entry

        trade = {
            "ticker":           ticker,
            "earnings_date":    earn_date.date(),
            "eps_actual":       round(eps_actual, 4) if pd.notna(eps_actual) else np.nan,
            "eps_estimate":     round(eps_est, 4)    if pd.notna(eps_est)    else np.nan,
            "beat":             beat,
            "entry1_date":      entry1_date.date(),
            "entry1_price":     round(entry1_price, 4),
            "entry2_date":      entry2_date.date() if eligible_window2 else None,
            "entry2_price":     round(entry2_price, 4) if entry2_price else None,
            "avg_entry_price":  round(avg_entry, 4),
            "exit_pre_date":    exit_pre_date.date(),
            "exit_pre_price":   round(exit_pre_price, 4),
            "pre_earn_return":  round(pre_earn_return * 100, 4),
            # Post-earnings fields
            "post_reentry":     False,
            "reentry_price":    None,
            "reentry_date":     None,
            "rebound_exit_date":None,
            "rebound_exit_price":None,
            "post_earn_return": None,
            "hold_days_post":   None,
            # Combined
            "trade_return_pct": round(pre_earn_return * 100, 4),
            "hold_days_total":  trading_days_between(price_df, entry1_date, exit_pre_date),
            "max_drawdown_pct": _max_drawdown(price_df, entry1_date, exit_pre_date, avg_entry),
        }

        # ── Post-Earnings Re-entry ─────────────────────────────────────────
        if beat:
            earn_trading_day = get_trading_day(price_df, earn_date, offset=0)
            if earn_trading_day is not None and earn_trading_day in price_df.index:
                earn_day_close = get_close(price_df, earn_trading_day)
                if earn_day_close is not None:
                    earn_day_drop = (earn_day_close - exit_pre_price) / exit_pre_price
                    if earn_day_drop <= -0.02:
                        # Re-enter at earn day close
                        reentry_price = earn_day_close
                        reentry_date  = earn_trading_day

                        # Rebound exit: recover to exit_pre_price OR 5 trading days
                        rebound_exit_price = None
                        rebound_exit_date  = None
                        post_return = None

                        future_days = price_idx[price_idx > reentry_date][:6]  # up to 5 days after
                        for fd in future_days:
                            fd_close = get_close(price_df, fd)
                            if fd_close is None:
                                continue
                            if fd_close >= exit_pre_price:
                                rebound_exit_price = fd_close
                                rebound_exit_date  = fd
                                break
                        else:
                            # 5 days elapsed — exit at last available
                            if len(future_days) > 0:
                                last_day = future_days[-1]
                                rebound_exit_price = get_close(price_df, last_day)
                                rebound_exit_date  = last_day

                        if rebound_exit_price is not None:
                            post_return = (rebound_exit_price - reentry_price) / reentry_price

                        hold_post = trading_days_between(price_df, reentry_date, rebound_exit_date) if rebound_exit_date else 0

                        trade.update({
                            "post_reentry":      True,
                            "reentry_price":     round(reentry_price, 4),
                            "reentry_date":      reentry_date.date(),
                            "rebound_exit_date": rebound_exit_date.date() if rebound_exit_date else None,
                            "rebound_exit_price":round(rebound_exit_price, 4) if rebound_exit_price else None,
                            "post_earn_return":  round(post_return * 100, 4) if post_return is not None else None,
                            "hold_days_post":    hold_post,
                        })
                        # Combined return: weight by position time
                        combined = pre_earn_return + (post_return or 0)
                        trade["trade_return_pct"] = round(combined * 100, 4)
                        trade["hold_days_total"] += hold_post

        trades.append(trade)

    return trades


def _max_drawdown(price_df: pd.DataFrame, start, end, entry_price: float) -> float:
    """Max intra-trade drawdown as % from entry price."""
    try:
        mask = (price_df.index >= start) & (price_df.index <= end)
        lows = price_df.loc[mask, "Low"]
        if lows.empty:
            return 0.0
        worst = float(lows.min())
        dd = (worst - entry_price) / entry_price * 100
        return round(dd, 4)
    except Exception:
        return 0.0


# ── Summary Stats ─────────────────────────────────────────────────────────────

def compute_summary(trades_df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for ticker, grp in trades_df.groupby("ticker"):
        valid = grp.dropna(subset=["trade_return_pct"])
        n = len(valid)
        if n == 0:
            continue
        wins = (valid["trade_return_pct"] > 0).sum()
        avg_ret = valid["trade_return_pct"].mean()
        avg_hold = valid["hold_days_total"].mean()
        max_dd = valid["max_drawdown_pct"].min()

        # Cumulative return: compound each trade sequentially on $10k
        capital = INITIAL_CAPITAL
        for ret in valid["trade_return_pct"]:
            capital *= (1 + ret / 100)
        cum_return_pct = (capital - INITIAL_CAPITAL) / INITIAL_CAPITAL * 100

        rows.append({
            "ticker":             ticker,
            "total_trades":       n,
            "win_rate_pct":       round(wins / n * 100, 2),
            "avg_return_pct":     round(avg_ret, 4),
            "avg_hold_days":      round(avg_hold, 1),
            "max_drawdown_pct":   round(max_dd, 4),
            "final_capital":      round(capital, 2),
            "cumulative_return_pct": round(cum_return_pct, 2),
        })

    summary = pd.DataFrame(rows)

    # Overall row
    if not trades_df.empty:
        valid_all = trades_df.dropna(subset=["trade_return_pct"])
        n_all = len(valid_all)
        wins_all = (valid_all["trade_return_pct"] > 0).sum()
        capital = INITIAL_CAPITAL
        for ret in valid_all.sort_values(["ticker","earnings_date"])["trade_return_pct"]:
            capital *= (1 + ret / 100)
        overall = {
            "ticker":             "OVERALL",
            "total_trades":       n_all,
            "win_rate_pct":       round(wins_all / n_all * 100, 2) if n_all else 0,
            "avg_return_pct":     round(valid_all["trade_return_pct"].mean(), 4),
            "avg_hold_days":      round(valid_all["hold_days_total"].mean(), 1),
            "max_drawdown_pct":   round(valid_all["max_drawdown_pct"].min(), 4),
            "final_capital":      round(capital, 2),
            "cumulative_return_pct": round((capital - INITIAL_CAPITAL) / INITIAL_CAPITAL * 100, 2),
        }
        summary = pd.concat([summary, pd.DataFrame([overall])], ignore_index=True)

    return summary


def write_summary_txt(summary_df: pd.DataFrame, path: str):
    lines = [
        "=" * 80,
        "  EARNINGS STRATEGY BACKTEST — SUMMARY",
        f"  Period: {START_DATE} → {END_DATE}  |  Initial Capital: ${INITIAL_CAPITAL:,}",
        "=" * 80,
        "",
    ]
    cols = ["ticker","total_trades","win_rate_pct","avg_return_pct",
            "avg_hold_days","max_drawdown_pct","final_capital","cumulative_return_pct"]
    header = (
        f"{'Ticker':<8} {'Trades':>7} {'Win%':>7} {'AvgRet%':>9} "
        f"{'HoldDays':>10} {'MaxDD%':>8} {'FinalCap':>12} {'CumRet%':>10}"
    )
    lines.append(header)
    lines.append("-" * 80)

    for _, row in summary_df.iterrows():
        if row["ticker"] == "OVERALL":
            lines.append("-" * 80)
        lines.append(
            f"{row['ticker']:<8} {int(row['total_trades']):>7} "
            f"{row['win_rate_pct']:>7.2f} {row['avg_return_pct']:>9.4f} "
            f"{row['avg_hold_days']:>10.1f} {row['max_drawdown_pct']:>8.4f} "
            f"${row['final_capital']:>11,.2f} {row['cumulative_return_pct']:>9.2f}%"
        )

    lines += ["", "Strategy:", "  • Entry W1 (T-30): Buy 50% if stagnant ±3% from T-30",
              "  • Entry W2 (T-14): Buy remaining 50% if stagnant/rising from T-30",
              "  • Exit: Sell all at T-1 close",
              "  • Re-entry: If earnings beat AND price drops ≥2%, buy back",
              "  • Rebound exit: Recover to pre-earnings level OR 5 trading days",
              ""]
    with open(path, "w") as f:
        f.write("\n".join(lines))


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    print("=" * 60)
    print("  EARNINGS STRATEGY BACKTESTER")
    print(f"  {START_DATE} → {END_DATE}")
    print("=" * 60)

    # 1. Batch fetch all price data
    price_data = fetch_price_data(STOCKS)

    # 2. Fetch earnings per stock and run backtest
    all_trades = []
    for ticker in STOCKS:
        print(f"\n[{ticker}] Fetching earnings…", end=" ", flush=True)
        earnings_df = fetch_earnings(ticker)
        n_earn = len(earnings_df)
        print(f"{n_earn} earnings events found.")

        if ticker not in price_data or price_data[ticker].empty:
            print(f"  ⚠  Skipping {ticker} — no price data")
            continue

        trades = backtest_stock(ticker, price_data[ticker], earnings_df)
        print(f"  → {len(trades)} trades generated")
        all_trades.extend(trades)
        time.sleep(0.3)  # polite delay between earnings fetches

    if not all_trades:
        print("\n⚠  No trades generated. Check data availability.")
        return

    # 3. Build DataFrames
    trades_df = pd.DataFrame(all_trades)
    trades_df = trades_df.sort_values(["ticker", "earnings_date"]).reset_index(drop=True)

    summary_df = compute_summary(trades_df)

    # 4. Save outputs
    csv_path = os.path.join(RESULTS_DIR, "earnings_backtest_results.csv")
    txt_path = os.path.join(RESULTS_DIR, "earnings_backtest_summary.txt")

    trades_df.to_csv(csv_path, index=False)
    write_summary_txt(summary_df, txt_path)

    print("\n" + "=" * 60)
    print("  RESULTS")
    print("=" * 60)
    print(summary_df.to_string(index=False))
    print(f"\n✅  Saved:\n  {csv_path}\n  {txt_path}")


if __name__ == "__main__":
    main()
