/*
 * Snake Game for Arduino Uno R4 WiFi
 * Uses built-in 12x8 LED Matrix and GY-521 (MPU-6050) Accelerometer
 * 
 * Connections:
 * GY-521 VCC -> 5V
 * GY-521 GND -> GND
 * GY-521 SDA -> A4
 * GY-521 SCL -> A5
 */

#include <Arduino_LED_Matrix.h>
#include <Wire.h>

ArduinoLEDMatrix matrix;

// MPU6050 I2C address
const int MPU6050_ADDR = 0x68;

// Game constants
const int MATRIX_WIDTH = 12;
const int MATRIX_HEIGHT = 8;
const int MAX_SNAKE_LENGTH = 64;
const int INITIAL_SNAKE_LENGTH = 8;
const int MOVE_DELAY = 300; // milliseconds between moves
const int GROWTH_INTERVAL = 2000; // grow every 2 seconds

// Direction constants
enum Direction { UP, DOWN, LEFT, RIGHT };

// Snake structure
struct Point {
  int x;
  int y;
};

Point snake[MAX_SNAKE_LENGTH];
int snakeLength = 1;
Direction currentDirection = RIGHT;
Direction nextDirection = RIGHT;

// Game state
bool gameRunning = false;
bool gameOver = false;
unsigned long lastMoveTime = 0;
unsigned long gameStartTime = 0;
unsigned long lastGrowthTime = 0;

// LED matrix buffer (12x8 bits)
uint8_t frame[8][12];

void setup() {
  Serial.begin(115200);
  matrix.begin();
  
  // Initialize MPU6050
  Wire.begin();
  Wire.beginTransmission(MPU6050_ADDR);
  Wire.write(0x6B); // PWR_MGMT_1 register
  Wire.write(0);    // Wake up MPU6050
  Wire.endTransmission(true);
  
  delay(100);
  
  // Start game sequence
  showCountdown();
  initializeGame();
}

void loop() {
  if (gameOver) {
    showGameOver();
    delay(3000);
    // Reset game
    showCountdown();
    initializeGame();
    return;
  }
  
  if (!gameRunning) {
    return;
  }
  
  // Read accelerometer and update direction
  readAccelerometer();
  
  // Move snake at regular intervals
  unsigned long currentTime = millis();
  if (currentTime - lastMoveTime >= MOVE_DELAY) {
    lastMoveTime = currentTime;
    
    // Update direction
    currentDirection = nextDirection;
    
    // Move snake
    moveSnake();
    
    // Check for collision with self
    if (checkSelfCollision()) {
      gameOver = true;
      return;
    }
    
    // Check for growth
    if (currentTime - lastGrowthTime >= GROWTH_INTERVAL) {
      if (snakeLength < MAX_SNAKE_LENGTH) {
        snakeLength++;
        lastGrowthTime = currentTime;
      }
    }
    
    // Update display
    drawGame();
  }
}

void showCountdown() {
  for (int i = 5; i >= 1; i--) {
    clearFrame();
    drawNumber(i, 0, 0);
    updateDisplay();
    delay(1000);
  }
  
  // Show "GO!"
  clearFrame();
  drawGo();
  updateDisplay();
  delay(1000);
}

void initializeGame() {
  // Reset game state
  gameOver = false;
  gameRunning = true;
  snakeLength = INITIAL_SNAKE_LENGTH;
  
  // Place snake in center
  int centerX = MATRIX_WIDTH / 2;
  int centerY = MATRIX_HEIGHT / 2;
  
  for (int i = 0; i < snakeLength; i++) {
    snake[i].x = centerX - i;
    snake[i].y = centerY;
  }
  
  // Random initial direction
  randomSeed(analogRead(0));
  int dir = random(4);
  currentDirection = (Direction)dir;
  nextDirection = currentDirection;
  
  // Reset timers
  lastMoveTime = millis();
  gameStartTime = millis();
  lastGrowthTime = millis();
  
  // Initial draw
  drawGame();
}

void readAccelerometer() {
  Wire.beginTransmission(MPU6050_ADDR);
  Wire.write(0x3B); // Starting register for accelerometer readings
  Wire.endTransmission(false);
  Wire.requestFrom(MPU6050_ADDR, 6, true);
  
  int16_t accelX = Wire.read() << 8 | Wire.read();
  int16_t accelY = Wire.read() << 8 | Wire.read();
  int16_t accelZ = Wire.read() << 8 | Wire.read();
  
  // Tilt threshold
  const int TILT_THRESHOLD = 4000;
  
  // Determine direction based on tilt
  // Prevent 180-degree turns
  if (accelY < -TILT_THRESHOLD && currentDirection != DOWN) {
    nextDirection = UP;
  } else if (accelY > TILT_THRESHOLD && currentDirection != UP) {
    nextDirection = DOWN;
  } else if (accelX < -TILT_THRESHOLD && currentDirection != RIGHT) {
    nextDirection = LEFT;
  } else if (accelX > TILT_THRESHOLD && currentDirection != LEFT) {
    nextDirection = RIGHT;
  }
}

void moveSnake() {
  // Calculate new head position
  Point newHead = snake[0];
  
  switch (currentDirection) {
    case UP:
      newHead.y--;
      break;
    case DOWN:
      newHead.y++;
      break;
    case LEFT:
      newHead.x--;
      break;
    case RIGHT:
      newHead.x++;
      break;
  }
  
  // Wrap around edges
  if (newHead.x < 0) newHead.x = MATRIX_WIDTH - 1;
  if (newHead.x >= MATRIX_WIDTH) newHead.x = 0;
  if (newHead.y < 0) newHead.y = MATRIX_HEIGHT - 1;
  if (newHead.y >= MATRIX_HEIGHT) newHead.y = 0;
  
  // Shift body
  for (int i = snakeLength - 1; i > 0; i--) {
    snake[i] = snake[i - 1];
  }
  
  // Place new head
  snake[0] = newHead;
}

bool checkSelfCollision() {
  // Check if head collides with body
  for (int i = 1; i < snakeLength; i++) {
    if (snake[0].x == snake[i].x && snake[0].y == snake[i].y) {
      return true;
    }
  }
  return false;
}

void drawGame() {
  clearFrame();
  
  // Draw snake
  for (int i = 0; i < snakeLength; i++) {
    setPixel(snake[i].x, snake[i].y, true);
  }
  
  updateDisplay();
}

void clearFrame() {
  for (int y = 0; y < 8; y++) {
    for (int x = 0; x < 12; x++) {
      frame[y][x] = 0;
    }
  }
}

void setPixel(int x, int y, bool on) {
  if (x >= 0 && x < MATRIX_WIDTH && y >= 0 && y < MATRIX_HEIGHT) {
    frame[y][x] = on ? 1 : 0;
  }
}

void updateDisplay() {
  matrix.renderBitmap(frame, 8, 12);
}

void drawNumber(int num, int offsetX, int offsetY) {
  // Simple 3x5 number patterns
  const uint8_t numbers[10][5] = {
    {0b111, 0b101, 0b101, 0b101, 0b111}, // 0
    {0b010, 0b110, 0b010, 0b010, 0b111}, // 1
    {0b111, 0b001, 0b111, 0b100, 0b111}, // 2
    {0b111, 0b001, 0b111, 0b001, 0b111}, // 3
    {0b101, 0b101, 0b111, 0b001, 0b001}, // 4
    {0b111, 0b100, 0b111, 0b001, 0b111}, // 5
    {0b111, 0b100, 0b111, 0b101, 0b111}, // 6
    {0b111, 0b001, 0b001, 0b001, 0b001}, // 7
    {0b111, 0b101, 0b111, 0b101, 0b111}, // 8
    {0b111, 0b101, 0b111, 0b001, 0b111}  // 9
  };
  
  if (num >= 0 && num <= 9) {
    for (int y = 0; y < 5; y++) {
      for (int x = 0; x < 3; x++) {
        if (numbers[num][y] & (1 << (2 - x))) {
          setPixel(offsetX + x, offsetY + y, true);
        }
      }
    }
  }
}

void drawGo() {
  // "GO" pattern - simple representation
  // G
  setPixel(1, 2, true);
  setPixel(1, 3, true);
  setPixel(1, 4, true);
  setPixel(1, 5, true);
  setPixel(2, 2, true);
  setPixel(2, 5, true);
  setPixel(3, 2, true);
  setPixel(3, 4, true);
  setPixel(3, 5, true);
  
  // O
  setPixel(5, 2, true);
  setPixel(5, 3, true);
  setPixel(5, 4, true);
  setPixel(5, 5, true);
  setPixel(6, 2, true);
  setPixel(6, 5, true);
  setPixel(7, 2, true);
  setPixel(7, 3, true);
  setPixel(7, 4, true);
  setPixel(7, 5, true);
  
  // Exclamation mark
  setPixel(9, 2, true);
  setPixel(9, 3, true);
  setPixel(9, 4, true);
  setPixel(9, 6, true);
}

void showGameOver() {
  // Flash the screen
  for (int i = 0; i < 3; i++) {
    clearFrame();
    updateDisplay();
    delay(200);
    
    // Fill screen
    for (int y = 0; y < MATRIX_HEIGHT; y++) {
      for (int x = 0; x < MATRIX_WIDTH; x++) {
        setPixel(x, y, true);
      }
    }
    updateDisplay();
    delay(200);
  }
}