import serial
import sys
import time
import glob
import os
import threading

def list_serial_ports():
    if sys.platform.startswith('win'):
        ports = ['COM%s' % (i + 1) for i in range(256)]
    elif sys.platform.startswith('linux') or sys.platform.startswith('cygwin'):
        ports = glob.glob('/dev/tty[A-Za-z]*')
    elif sys.platform.startswith('darwin'):
        ports = glob.glob('/dev/tty.*')
    else:
        raise EnvironmentError('Unsupported platform')
    result = []
    for port in ports:
        try:
            s = serial.Serial(port)
            s.close()
            result.append(port)
        except (OSError, serial.SerialException):
            pass
    return result

def choose_port():
    ports = list_serial_ports()
    if not ports:
        print("No serial ports found.")
        sys.exit(1)
    print("Available serial ports:")
    for i, port in enumerate(ports):
        print(f"{i+1}: {port}")
    while True:
        choice = input("Select port number: ")
        try:
            idx = int(choice) - 1
            if 0 <= idx < len(ports):
                return ports[idx]
        except ValueError:
            pass
        print("Invalid selection.")

def send_file(ser, filename):
    with open(filename, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    total = len(lines) + 2  # +2 for |sof| and |eof|
    print(f"Uploading {filename} ({len(lines)} lines)...")
    sent = 0
    time.sleep(1)
    def send_line(line):
        ser.write(line.encode('utf-8'))
        ser.flush()
        time.sleep(0.1)  # 100 ms

    send_line('|sof|\n')
    sent += 1
    print(f"Progress: {sent}/{total}", end='\r')
    for line in lines:
        send_line(line)
        sent += 1
        print(f"Progress: {sent}/{total}", end='\r')
    send_line('|eof|\n')
    sent += 1
    print(f"Progress: {sent}/{total} - Done.")

def serial_reader(ser, stop_event):
    try:
        while not stop_event.is_set():
            if ser.in_waiting:
                try:
                    data = ser.read(ser.in_waiting).decode('utf-8', errors='replace')
                    if data:
                        print(data, end='', flush=True)
                except Exception:
                    pass
            time.sleep(0.05)
    except Exception:
        pass

def main():
    if len(sys.argv) < 2:
        print(f"Usage: {os.path.basename(sys.argv[0])} <file> [port]")
        sys.exit(1)
    filename = sys.argv[1]
    port = sys.argv[2] if len(sys.argv) > 2 else choose_port()
    while True:
        try:
            ser = serial.Serial(port, 115200, timeout=0.1)
        except Exception as e:
            print(f"Failed to open port {port}: {e}")
            sys.exit(1)

        stop_event = threading.Event()
        reader_thread = threading.Thread(target=serial_reader, args=(ser, stop_event), daemon=True)
        reader_thread.start()

        try:
            send_file(ser, filename)
            print("\n--- Serial output (Ctrl+C for menu) ---")
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            print("\nInterrupted.")
            while True:
                print("\nOptions:")
                print("1: Reupload")
                print("2: Disconnect and restart device")
                print("3: Continue debugging")
                choice = input("Select option [1/2/3]: ").strip()
                if choice == '1':
                    stop_event.set()
                    reader_thread.join()
                    print("Reuploading...")
                    send_file(ser, filename)
                    stop_event.clear()
                    reader_thread = threading.Thread(target=serial_reader, args=(ser, stop_event), daemon=True)
                    reader_thread.start()
                    print("\n--- Serial output (Ctrl+C for menu) ---")
                    try:
                        while True:
                            time.sleep(1)
                    except KeyboardInterrupt:
                        continue  
                    break 
                elif choice == '2':
                    stop_event.set()
                    reader_thread.join()
                    ser.close()
                    input("Connection closed. Press Enter to exit (device will restart).")
                    sys.exit(0)
                elif choice == '3':
                    print("Continuing debugging. Press Ctrl+C for menu.")

                    if not reader_thread.is_alive():
                        stop_event.clear()
                        reader_thread = threading.Thread(target=serial_reader, args=(ser, stop_event), daemon=True)
                        reader_thread.start()
                    try:
                        while True:
                            time.sleep(1)
                    except KeyboardInterrupt:
                        continue  
                else:
                    print("Invalid selection.")
        finally:
            stop_event.set()
            reader_thread.join()
            ser.close()


if __name__ == "__main__":
    main()