import os
import time
import ssl
import wifi
import adafruit_requests
import socketpool
import json
import neopixel
import board

# Initialize NeoPixel strip (100 LEDs)
strip = neopixel.NeoPixel(board.GP1, 100, brightness=0.2, pixel_order=neopixel.GRB, auto_write=False)

print(os.getenv('test_variable'))  # Confirm .toml file is read correctly

# Define Colors
RED = (0, 255, 0)    # 🚆 Trains
GREEN = (255, 0, 0)  # ✅ Success indicators (WiFi/API working)
BLACK = (0, 0, 0)    # ⚫ Off

# Define Glenmont-bound Red Line stations (LEDs 2-29)
GLENMONT_STATIONS = [
    "Shady Grove", "Rockville", "Twinbrook", "North Bethesda", "Grosvenor-Strathmore",
    "Medical Center", "Bethesda", "Friendship Heights", "Tenleytown-AU", "Van Ness-UDC",
    "Cleveland Park", "Woodley Park-Zoo/Adams Morgan", "Dupont Circle", "Farragut North",
    "Metro Center", "Gallery Pl-Chinatown", "Judiciary Square", "Union Station", "NoMa-Gallaudet U",
    "Rhode Island Ave-Brentwood", "Brookland-CUA", "Fort Totten", "Takoma",
    "Silver Spring", "Forest Glen", "Wheaton", "Glenmont"
]

# Define Shady Grove-bound Red Line stations (LEDs 30-56)
SHADY_GROVE_STATIONS = GLENMONT_STATIONS[::-1]  # Reverse order

# Convert to dictionaries for easy lookup
glenmont_led_map = {station: i+2 for i, station in enumerate(GLENMONT_STATIONS)}
shady_grove_led_map = {station: i+30 for i, station in enumerate(SHADY_GROVE_STATIONS)}

def indicate_wifi_connection():
    """Sets LED #60 to GREEN if WiFi connects successfully."""
    strip[60] = GREEN  # ✅ Green for WiFi success
    strip.show()
    print("✅ LED 60 set to GREEN for successful WiFi connection.")

def indicate_wifi_failure():
    """Sets LED #60 to RED if WiFi fails to connect."""
    strip[60] = RED  # ❌ Red for WiFi failure
    strip.show()
    print("❌ LED 60 set to RED for WiFi failure.")

def indicate_api_status(has_trains):
    """Sets LED #61 to GREEN if API is returning train data, RED if no trains found."""
    strip[61] = GREEN if has_trains else RED
    strip.show()
    print(f"{'✅ API is generating trains - LED 61 GREEN' if has_trains else '❌ No trains detected - LED 61 RED'}")

# Get WiFi credentials from .toml file
ssid = os.getenv("WIFI_SSID")
password = os.getenv("WIFI_PASSWORD")

# Convert SSID and Password to bytes
ssid_bytes = bytes(ssid, "utf-8")
password_bytes = bytes(password, "utf-8")

print(f"SSID: {ssid_bytes}, Password: {password_bytes}")  # Debugging step

# Connect to WiFi
try:
    wifi.radio.connect(ssid_bytes, password_bytes)
    print('Connected to WiFi')
    print(f"Connected to WiFi! IP address: {wifi.radio.ipv4_address}")

    # Indicate successful connection
    indicate_wifi_connection()

except Exception as e:
    print(f"❌ Failed to connect to WiFi: {e}")
    indicate_wifi_failure()

# Set up HTTP session
pool = socketpool.SocketPool(wifi.radio)
requests = adafruit_requests.Session(pool, ssl.create_default_context())

# WMATA API URL
WMATA_URL = "https://api.wmata.com/StationPrediction.svc/json/GetPrediction/All?api_key=011400faedcc4a42abfb572f8ba92f7b"
HEADERS = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)",  # Prevents API blocking
    "Cache-Control": "no-cache"
}

def fetch_wmata_data():
    """Fetches station data and lights up LEDs for both Glenmont and Shady Grove-bound trains."""
    try:
        print(f"\nFetching data from: {WMATA_URL}")
        response = requests.get(WMATA_URL, headers=HEADERS, stream=True)  # Stream data

        print(f"Response Code: {response.status_code}")
        if response.status_code != 200:
            print(f"Failed to retrieve data: {response.status_code}")
            indicate_api_status(False)  # 🚨 API failed, set LED 61 to red
            return

        # Clear all station LEDs (2-56) before updating
        for i in range(2, 57):
            strip[i] = BLACK  # ⚫ Turn off all station LEDs

        # Process JSON data
        json_buffer = ""
        train_found = False  # Track if at least one train is found

        for chunk in response.iter_content(512):  # Read small 512-byte chunks
            json_buffer += chunk.decode("utf-8")  # Append to buffer

            while "}," in json_buffer:  # Find end of a JSON object
                json_object, json_buffer = json_buffer.split("},", 1)  # Split buffer
                json_object += "}"  # Complete the JSON object

                try:
                    train = json.loads(json_object)  # Parse JSON object

                    # Check if train is on the Red Line (RD) and has ARR or BRD status
                    if train.get("Line") == "RD" and train.get("Min") in ["ARR", "BRD"]:
                        train_found = True  # ✅ At least one train detected
                        station_name = train.get("LocationName", "Unknown")
                        destination = train.get("DestinationName", "").strip().lower()

                        # Light up corresponding LED if train is Glenmont-bound
                        if destination in ["glenmont"] and station_name in glenmont_led_map:
                            led_index = glenmont_led_map[station_name]
                            strip[led_index] = RED  # 🚆 Red for trains
                            print(f"🚆 Glenmont-bound train at {station_name}: LED {led_index} → RED")

                        # Light up corresponding LED if train is Shady Grove-bound
                        elif destination in ["shady grove", "shady grv"] and station_name in shady_grove_led_map:
                            led_index = shady_grove_led_map[station_name]
                            strip[led_index] = RED  # 🚆 Red for trains (instead of blue)
                            print(f"🚆 Shady Grove-bound train at {station_name}: LED {led_index} → RED")

                except Exception:
                    pass  # Ignore incomplete JSON objects

        indicate_api_status(train_found)  # 🔴 Update LED 61 based on API data
        strip.show()  # ✅ Update LED strip
        response.close()  # Free memory

    except MemoryError:
        print("❌ Memory allocation failed. The response might still be too large.")
        indicate_api_status(False)
    except Exception as e:
        print(f"❌ Request failed due to: {e}")
        indicate_api_status(False)

# Run the function to fetch WMATA API data and update LEDs
while True:
    fetch_wmata_data()
    time.sleep(5)