"""Google Health API → Reflect health feed.

Pulls the wearer's recent sleep, HRV, resting heart rate and SpO2 from the
**Google Health API** (the successor to the Fitbit Web API, which is retired in
September 2026) and distils it into a short, plain-English block the therapist
reads at the start of a session — so Reflect can open with "you slept rough and
your HRV's down, how are you doing?" rather than starting blind.

Design goals, mirroring therapist.py's memory blocks:
  * NEVER breaks the session. Every failure path returns "" / a cached value.
  * Reads from a local cache (`health_snapshot.json`) so building the agent's
    instructions never blocks on the network. A separate refresh writes it.

Auth: standard OAuth 2.0. You authorise once in a browser (see oauth_setup.py),
store the refresh token as a Fly secret, and this module silently exchanges it
for short-lived access tokens server-side.

Field names and data-type IDs below were confirmed live against James's own
Fitbit Air account (June 2026). Data-type IDs are kebab-case. Re-probe any time:
  `uv run python -m src.health --probe`
"""

import json
import logging
import os
import time
import urllib.parse
import urllib.request
from datetime import date, datetime, timezone

logger = logging.getLogger("health")

PROJECT_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
DATA_DIR = os.getenv("R1_DATA_DIR", PROJECT_ROOT)
SNAPSHOT_FILE = os.path.join(DATA_DIR, "health_snapshot.json")

# --- OAuth / API config (from environment; set as Fly secrets) -----------------
TOKEN_URL = "https://oauth2.googleapis.com/token"
API_BASE = "https://health.googleapis.com/v4"

CLIENT_ID = os.getenv("GOOGLE_HEALTH_CLIENT_ID", "")
CLIENT_SECRET = os.getenv("GOOGLE_HEALTH_CLIENT_SECRET", "")
REFRESH_TOKEN = os.getenv("GOOGLE_HEALTH_REFRESH_TOKEN", "")

# Confirmed against the real consent screen (June 2026): permissions are bundled.
# Sleep is its own scope; HRV, resting HR and SpO2 all fall under
# "health_metrics_and_measurements"; activity_and_fitness covers steps/activity.
DEFAULT_SCOPES = (
    "https://www.googleapis.com/auth/googlehealth.sleep.readonly "
    "https://www.googleapis.com/auth/googlehealth.health_metrics_and_measurements.readonly "
    "https://www.googleapis.com/auth/googlehealth.activity_and_fitness.readonly"
)
SCOPES = os.getenv("GOOGLE_HEALTH_SCOPES", DEFAULT_SCOPES)

# Confirmed kebab-case data-type IDs (see module docstring).
DT_SLEEP = "sleep"
DT_HRV = "daily-heart-rate-variability"
DT_RHR = "daily-resting-heart-rate"
DT_SPO2 = "daily-oxygen-saturation"
DT_TEMP = "daily-sleep-temperature-derivations"
DT_RESP = "daily-respiratory-rate"

HTTP_TIMEOUT = float(os.getenv("GH_HTTP_TIMEOUT", "8"))
SNAPSHOT_MAX_AGE_H = float(os.getenv("GH_SNAPSHOT_MAX_AGE_H", "18"))


def is_configured() -> bool:
    return bool(CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN)


# --- Low-level HTTP -----------------------------------------------------------


def _get_access_token() -> str | None:
    """Exchange the long-lived refresh token for a short-lived access token."""
    if not is_configured():
        return None
    data = urllib.parse.urlencode(
        {
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            "refresh_token": REFRESH_TOKEN,
            "grant_type": "refresh_token",
        }
    ).encode()
    try:
        req = urllib.request.Request(TOKEN_URL, data=data)
        with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT) as resp:
            return json.loads(resp.read().decode()).get("access_token")
    except Exception:
        logger.exception("Google Health: token refresh failed")
        return None


def _list_points(data_type: str, token: str, page_size: int = 30) -> list[dict]:
    """List recent data points for a data type (newest first, no filter needed)."""
    url = (
        f"{API_BASE}/users/me/dataTypes/{data_type}/dataPoints"
        f"?{urllib.parse.urlencode({'pageSize': page_size})}"
    )
    try:
        req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"})
        with urllib.request.urlopen(req, timeout=HTTP_TIMEOUT) as resp:
            return json.loads(resp.read().decode()).get("dataPoints", []) or []
    except Exception:
        logger.exception("Google Health: fetch for %s failed", data_type)
        return []


# --- Parsing (exact schema, confirmed live) -----------------------------------


def _date_key(d: dict) -> tuple:
    return (d.get("year", 0), d.get("month", 0), d.get("day", 0))


def _mean(values: list[float]) -> float | None:
    nums = [v for v in values if isinstance(v, (int, float))]
    return round(sum(nums) / len(nums), 1) if nums else None


def _parse_sleep(points: list[dict]) -> dict:
    """Latest sleep session: minutes asleep + per-stage minutes."""
    sessions = [p["sleep"] for p in points if isinstance(p.get("sleep"), dict)]
    if not sessions:
        return {}
    latest = max(sessions, key=lambda s: s.get("interval", {}).get("endTime", ""))
    summary = latest.get("summary", {})
    stages = {
        s.get("type"): int(s.get("minutes", 0))
        for s in summary.get("stagesSummary", [])
    }

    def _i(v):
        try:
            return int(v)
        except (TypeError, ValueError):
            return None

    return {
        "end_time": latest.get("interval", {}).get("endTime"),
        "asleep_min": _i(summary.get("minutesAsleep")),
        "awake_min": _i(summary.get("minutesAwake")),
        "deep_min": stages.get("DEEP"),
        "rem_min": stages.get("REM"),
        "light_min": stages.get("LIGHT"),
    }


def _parse_daily(points: list[dict], key: str, value_path: list[str]):
    """Return (latest_value, recent_average) for a daily metric, latest by date."""
    rows = []
    for p in points:
        node = p.get(key)
        if not isinstance(node, dict):
            continue
        val = node
        for seg in value_path:
            val = val.get(seg) if isinstance(val, dict) else None
        if val is None:
            continue
        try:
            val = float(val)
        except (TypeError, ValueError):
            continue
        rows.append((_date_key(node.get("date", {})), val))
    if not rows:
        return None, None
    rows.sort()
    latest = rows[-1][1]
    baseline = _mean([v for _, v in rows[-8:-1]])  # ~prior week, excl. today
    return latest, baseline


# --- Multi-day trend / early-warning assessment -------------------------------


def _series(points: list[dict], key: str, value_path: list[str]) -> list[float]:
    """Ascending-by-date list of values for a daily metric (for trend maths)."""
    by_date = _daily_by_date(points, key, value_path)
    return [v for _, v in sorted(by_date.items())]


def _recent_vs_prior(vals: list[float], recent_n=3, prior_n=7):
    """(mean of last `recent_n` days, mean of the ~week before that)."""
    if len(vals) < recent_n + 2:
        return None, None
    recent = _mean(vals[-recent_n:])
    prior = _mean(vals[-(recent_n + prior_n):-recent_n])
    return recent, prior


def _assess_trends(rhr, hrv, temp_delta, resp) -> str | None:
    """Look across the last few days for signals drifting the 'unwell/strained'
    way. Returns a short heads-up sentence, or None if nothing stands out.

    The illness/strain cluster: resting HR up, HRV down, skin temp above
    baseline, breathing rate up. One signal is soft; two or more together is
    worth flagging as a possible early sign of getting run down or unwell."""
    signals = []

    r_recent, r_prior = _recent_vs_prior(rhr)
    if r_recent and r_prior and r_recent - r_prior >= 2.5:  # bpm
        signals.append("resting heart rate has been running higher than usual")

    h_recent, h_prior = _recent_vs_prior(hrv)
    if h_recent and h_prior and (h_prior - h_recent) / h_prior >= 0.12:
        signals.append("HRV has been lower than usual")

    # temp_delta is already (nightly − personal baseline) per recent night
    if temp_delta and _mean(temp_delta[-3:]) is not None and _mean(temp_delta[-3:]) >= 0.3:
        signals.append("skin temperature has been above your baseline")

    p_recent, p_prior = _recent_vs_prior(resp)
    if p_recent and p_prior and p_recent - p_prior >= 1.0:  # breaths/min
        signals.append("breathing rate has crept up")

    if not signals:
        return None
    joined = signals[0] if len(signals) == 1 else (
        ", ".join(signals[:-1]) + " and " + signals[-1]
    )
    if len(signals) >= 2:
        return (f"EARLY-WARNING PATTERN over the last few days — {joined}. Several "
                "signals are pointing the same way, which can mean they're getting "
                "run down, overstrained, or coming down with something. Hold this "
                "gently and, if it fits, check in on how they've been feeling and "
                "whether they need to ease off — do not alarm them or diagnose.")
    return (f"One thing drifting over the last few days: {joined}. Worth keeping in "
            "mind as a soft signal, not a worry.")


# --- Snapshot build + cache ---------------------------------------------------


def _parse_temp(points: list[dict]):
    """(latest nightly−baseline delta in °C, ascending list of recent deltas)."""
    rows = []
    for p in points:
        node = p.get("dailySleepTemperatureDerivations")
        if not isinstance(node, dict):
            continue
        nightly = node.get("nightlyTemperatureCelsius")
        base = node.get("baselineTemperatureCelsius")
        if not isinstance(nightly, (int, float)) or not isinstance(base, (int, float)):
            continue
        rows.append((_date_key(node.get("date", {})), round(nightly - base, 2)))
    if not rows:
        return None, []
    rows.sort()
    return rows[-1][1], [v for _, v in rows]


def refresh_snapshot() -> dict | None:
    """Pull recent metrics and write health_snapshot.json. Returns the snapshot,
    or None if not configured / fetch failed (the old cache is then kept)."""
    token = _get_access_token()
    if not token:
        return None

    hrv_pts = _list_points(DT_HRV, token)
    rhr_pts = _list_points(DT_RHR, token)
    resp_pts = _list_points(DT_RESP, token)
    temp_pts = _list_points(DT_TEMP, token)

    sleep = _parse_sleep(_list_points(DT_SLEEP, token, page_size=5))
    hrv, hrv_base = _parse_daily(
        hrv_pts, "dailyHeartRateVariability",
        ["averageHeartRateVariabilityMilliseconds"],
    )
    rhr, rhr_base = _parse_daily(
        rhr_pts, "dailyRestingHeartRate", ["beatsPerMinute"],
    )
    spo2, _ = _parse_daily(
        _list_points(DT_SPO2, token), "dailyOxygenSaturation", ["averagePercentage"],
    )
    resp, resp_base = _parse_daily(
        resp_pts, "dailyRespiratoryRate", ["breathsPerMinute"],
    )
    temp_delta, temp_series = _parse_temp(temp_pts)

    heads_up = _assess_trends(
        _series(rhr_pts, "dailyRestingHeartRate", ["beatsPerMinute"]),
        _series(hrv_pts, "dailyHeartRateVariability",
                ["averageHeartRateVariabilityMilliseconds"]),
        temp_series,
        _series(resp_pts, "dailyRespiratoryRate", ["breathsPerMinute"]),
    )

    snapshot = {
        "fetched_at": datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC"),
        "ts": time.time(),
        "sleep": sleep,
        "hrv_ms": hrv,
        "hrv_baseline_ms": hrv_base,
        "resting_hr": rhr,
        "resting_hr_baseline": rhr_base,
        "spo2_pct": spo2,
        "resp_rate": resp,
        "resp_baseline": resp_base,
        "temp_delta_c": temp_delta,
        "heads_up": heads_up,
    }
    try:
        with open(SNAPSHOT_FILE, "w") as f:
            json.dump(snapshot, f, indent=2)
        logger.info("Google Health: snapshot refreshed")
    except Exception:
        logger.exception("Google Health: failed to write snapshot")
    return snapshot


def _load_snapshot() -> dict | None:
    if not os.path.exists(SNAPSHOT_FILE):
        return None
    try:
        with open(SNAPSHOT_FILE) as f:
            return json.load(f)
    except Exception:
        return None


# --- On-demand history lookup (any past date the watch recorded) --------------


def _daily_by_date(points: list[dict], key: str, value_path: list[str]) -> dict:
    """Map {YYYY-MM-DD: value} for a daily metric."""
    out = {}
    for p in points:
        node = p.get(key)
        if not isinstance(node, dict):
            continue
        d = node.get("date", {})
        try:
            day = date(d["year"], d["month"], d["day"]).isoformat()
        except (KeyError, TypeError, ValueError):
            continue
        val = node
        for seg in value_path:
            val = val.get(seg) if isinstance(val, dict) else None
        if val is None:
            continue
        try:
            out[day] = float(val)
        except (TypeError, ValueError):
            continue
    return out


def _sleep_by_date(points: list[dict]) -> dict:
    """Map {YYYY-MM-DD (morning it ended): summary dict} for sleep sessions."""
    out = {}
    for p in points:
        s = p.get("sleep")
        if not isinstance(s, dict):
            continue
        end = s.get("interval", {}).get("endTime", "")
        day = end[:10]  # 'YYYY-MM-DD'
        if not day:
            continue
        summary = s.get("summary", {})
        stages = {x.get("type"): int(x.get("minutes", 0))
                  for x in summary.get("stagesSummary", [])}
        try:
            asleep = int(summary.get("minutesAsleep"))
        except (TypeError, ValueError):
            asleep = None
        # Keep the longest session if multiple ended that day (main sleep vs nap).
        if asleep and (day not in out or asleep > out[day]["asleep_min"]):
            out[day] = {"asleep_min": asleep, "stages": stages}
    return out


def summarise_range(start_date: str, end_date: str = "") -> str:
    """Plain-English summary of sleep + daily vitals for a past date or range.
    Dates are 'YYYY-MM-DD'. Returns a sentence the therapist reads aloud."""
    token = _get_access_token()
    if not token:
        return "I can't reach your health data just now."
    try:
        s = datetime.strptime(start_date, "%Y-%m-%d").date()
        e = datetime.strptime(end_date or start_date, "%Y-%m-%d").date()
    except ValueError:
        return "I need the date in YYYY-MM-DD form to look that up."
    if e < s:
        s, e = e, s

    # ~3 months of history covers any realistic 'last Thursday' style question.
    sleep = _sleep_by_date(_list_points(DT_SLEEP, token, page_size=120))
    hrv = _daily_by_date(_list_points(DT_HRV, token, 120),
                         "dailyHeartRateVariability",
                         ["averageHeartRateVariabilityMilliseconds"])
    rhr = _daily_by_date(_list_points(DT_RHR, token, 120),
                         "dailyRestingHeartRate", ["beatsPerMinute"])
    spo2 = _daily_by_date(_list_points(DT_SPO2, token, 120),
                          "dailyOxygenSaturation", ["averagePercentage"])

    days = [s.fromordinal(o) for o in range(s.toordinal(), e.toordinal() + 1)]
    lines = []
    for d in days:
        iso = d.isoformat()
        bits = []
        if iso in sleep:
            st = sleep[iso]["stages"]
            stage_str = ", ".join(
                f"{lbl.lower()} {_fmt_dur(st[t])}"
                for lbl, t in (("deep", "DEEP"), ("REM", "REM"), ("light", "LIGHT"))
                if st.get(t)
            )
            bits.append(f"slept {_fmt_dur(sleep[iso]['asleep_min'])}"
                        + (f" ({stage_str})" if stage_str else ""))
        if iso in hrv:
            bits.append(f"HRV {hrv[iso]:g} ms")
        if iso in rhr:
            bits.append(f"resting HR {int(rhr[iso])} bpm")
        if iso in spo2:
            bits.append(f"blood oxygen {spo2[iso]:g}%")
        if bits:
            label = d.strftime("%A %-d %B")
            lines.append(f"{label}: " + "; ".join(bits) + ".")

    if not lines:
        return (f"No Fitbit data is recorded for "
                f"{'that day' if s == e else 'that period'} — the watch may not "
                "have been worn or synced.")
    header = "" if len(lines) == 1 else f"Here's what the Fitbit recorded:\n"
    return header + "\n".join(lines)


# --- The block injected into the therapist's instructions ---------------------


def _fmt_dur(minutes) -> str:
    if not minutes:
        return ""
    h, m = divmod(int(minutes), 60)
    return f"{h}h {m:02d}m" if h else f"{m}m"


def _trend(value, baseline, *, lower_is_better=False) -> str:
    """A gentle 'vs usual' tag, only when the gap is meaningful (>10%)."""
    if not value or not baseline:
        return ""
    delta = (value - baseline) / baseline
    if abs(delta) < 0.1:
        return f" (about your usual ~{baseline:g})"
    direction = "up on" if delta > 0 else "down on"
    return f" ({direction} your recent ~{baseline:g})"


def format_health_block() -> str:
    """Plain-English read of the wearer's recent body data for the system prompt.
    Returns "" when there's nothing usable, so Reflect is wholly unaffected."""
    snap = _load_snapshot()
    if not snap:
        return ""
    if (time.time() - snap.get("ts", 0)) / 3600 > SNAPSHOT_MAX_AGE_H:
        return ""  # stale (e.g. watch not synced) — better to say nothing

    sleep = snap.get("sleep") or {}
    parts: list[str] = []

    if sleep.get("asleep_min"):
        stages = []
        for label, key in (("deep", "deep_min"), ("REM", "rem_min"), ("light", "light_min")):
            if sleep.get(key):
                stages.append(f"{label} {_fmt_dur(sleep[key])}")
        if sleep.get("awake_min"):
            stages.append(f"awake {_fmt_dur(sleep['awake_min'])}")
        stage_str = f" ({', '.join(stages)})" if stages else ""
        parts.append(f"Slept {_fmt_dur(sleep['asleep_min'])} last night{stage_str}.")

    if snap.get("hrv_ms"):
        parts.append(
            f"Overnight HRV {snap['hrv_ms']:g} ms"
            + _trend(snap["hrv_ms"], snap.get("hrv_baseline_ms")) + "."
        )
    if snap.get("resting_hr"):
        parts.append(
            f"Resting heart rate {int(snap['resting_hr'])} bpm"
            + _trend(snap["resting_hr"], snap.get("resting_hr_baseline"),
                     lower_is_better=True) + "."
        )
    if snap.get("resp_rate"):
        parts.append(
            f"Breathing rate {snap['resp_rate']:g}/min"
            + _trend(snap["resp_rate"], snap.get("resp_baseline")) + "."
        )
    if snap.get("temp_delta_c") is not None:
        d = snap["temp_delta_c"]
        if abs(d) < 0.2:
            parts.append("Skin temperature about your usual baseline.")
        else:
            parts.append(f"Skin temperature {d:+.1f} °C vs your baseline"
                         + (" (notably warm)" if d >= 0.4 else "") + ".")
    if snap.get("spo2_pct"):
        parts.append(f"Blood oxygen {snap['spo2_pct']:g}%.")

    if not parts:
        return ""

    guidance = (
        "RECENT BODY DATA from their Fitbit. Two ways to use it:\n"
        "1) If they ASK directly about their sleep, heart rate, HRV, blood oxygen "
        "or recovery, answer the specific question plainly and accurately FIRST — "
        "give the actual figure (e.g. 'your resting heart rate was 73, right around "
        "your usual') — then you can reflect on what it means.\n"
        "2) Otherwise, use it like a perceptive friend who can tell they're "
        "run-down, NOT a dashboard: weave it in gently and only if it fits, without "
        "reciting numbers unprompted. Lower HRV, a higher resting heart rate, "
        "short/broken sleep, a raised breathing rate or a skin temperature above "
        "their baseline usually mean they're tired, depleted, more stressed or "
        "possibly coming down with something; the opposite means well-rested.\n"
        "Below are last night's sleep and today's latest daily readings. For any "
        "OTHER date or a multi-day trend, use your health-lookup tool. You do NOT "
        "have live/right-now readings, so say so plainly if asked. Never diagnose — "
        "you notice and gently reflect, you don't medically assess.\n"
        "The readings:"
    )
    block = guidance + "\n" + " ".join(parts)
    if snap.get("heads_up"):
        block += "\n" + snap["heads_up"]
    return block


# --- Probe / CLI --------------------------------------------------------------


def _probe() -> None:
    logging.basicConfig(level=logging.INFO)
    if not is_configured():
        print("Not configured. Set GOOGLE_HEALTH_CLIENT_ID / _SECRET / "
              "_REFRESH_TOKEN (see oauth_setup.py).")
        return
    print("Refreshing snapshot from Google Health...\n")
    snap = refresh_snapshot()
    print(json.dumps(snap, indent=2))
    print("\nFormatted block Reflect will see:\n" + (format_health_block() or "(empty)"))


if __name__ == "__main__":
    import sys

    if "--probe" in sys.argv:
        _probe()
    elif "--refresh" in sys.argv:
        logging.basicConfig(level=logging.INFO)
        refresh_snapshot()
    else:
        print(format_health_block() or "(no health block)")
