import os import sys import time import argparse import serial import struct from pathlib import Path class FlashDownloader: def __init__(self): self.port = None self.baudrate = 750000 # 默认750Kbps self.filename = None self.address = 0x00010000 # 默认地址 self.check = False self.ser = None def parse_arguments(self): parser = argparse.ArgumentParser(description='UART Flash Downloader') parser.add_argument('--COMX', required=True, help='COM port (e.g. COM3)') parser.add_argument('--BAUDRATE', type=int, help='Baud rate (default: 750000)') parser.add_argument('--file', help='File to download') parser.add_argument('--ADDR', type=lambda x: int(x, 0), help='Download address (hex/dec)') parser.add_argument('--CHECK', action='store_true', help='Enable readback verification') args = parser.parse_args() self.port = args.COMX if args.BAUDRATE: self.baudrate = args.BAUDRATE if args.file: self.filename = args.file else: # 默认文件名为当前目录/build/当前文件夹名.bin current_dir = Path.cwd().name self.filename = f"build/{current_dir}.bin" if args.ADDR: self.address = args.ADDR self.check = args.CHECK def connect_serial(self): try: self.ser = serial.Serial( port=self.port, baudrate=self.baudrate, bytesize=serial.EIGHTBITS, parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, timeout=1 ) print(f"Connected to {self.port} at {self.baudrate} baud") # 初始化信号状态 self.ser.rts = True # 初始拉低 self.ser.dtr = False # 初始拉高 except serial.SerialException as e: print(f"Error opening serial port: {e}") sys.exit(1) def control_signals_pre_erase(self): """擦除前的信号控制""" print("Setting pre-erase signals: RTS low, DTR high...") self.ser.rts = True # RTS拉低 self.ser.dtr = False # DTR拉高 time.sleep(0.1) # 等待100ms self.ser.rts = False # RTS拉高 # time.sleep(0.5) # 等待500ms print("Pre-erase signals set") shakehand = self.ser.read(4) print("shakehand=", shakehand) return shakehand == b'\x55\x55\x55\x55' def control_signals_post_download(self): """下载完成后的信号控制""" print("Setting post-download signals: RTS low, DTR low...") self.ser.rts = True # RTS拉低 self.ser.dtr = True # DTR拉低 time.sleep(0.1) # 等待100ms self.ser.rts = False # RTS拉高 print("Post-download signals set") def calculate_checksum(self, data): return sum(data) & 0xFFFFFFFF def send_command(self, cmd, cmd_inv, data_length, data_length_inv, checksum, *args): # 构建命令头 header = struct.pack('>HHHH', cmd, cmd_inv, data_length, data_length_inv) # 构建数据部分 data_parts = [] for arg in args: if isinstance(arg, int): if arg <= 0xFF: data_parts.append(struct.pack('>B', arg)) elif arg <= 0xFFFF: data_parts.append(struct.pack('>H', arg)) else: data_parts.append(struct.pack('>I', arg)) elif isinstance(arg, bytes): data_parts.append(arg) data = b''.join(data_parts) # 构建完整命令 full_cmd = header + struct.pack('>I', checksum) + data # 发送命令 self.ser.write(full_cmd) # 等待响应 response = self.ser.read(4) return response == b'OKAY' def erase_flash(self, addr, length): print(f"Erasing {length} bytes at 0x{addr:08X}...") cmd = 0x0001 cmd_inv = 0xFFFE data_length = 0x000C data_length_inv = 0xFFF3 # 准备数据 addr_bytes = struct.pack('>I', addr) length_bytes = struct.pack('>I', length) data = addr_bytes + length_bytes checksum = self.calculate_checksum(data) if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes): print("Erase failed!") return False print("Erase successful") return True def write_flash(self, addr, data): chunk_size = 256 total_length = len(data) offset = 0 while offset < total_length: chunk = data[offset:offset+chunk_size] actual_length = len(chunk) progress = (offset / total_length) * 100 print(f"Writing {actual_length} bytes at 0x{addr+offset:08X}... [{progress:.1f}%]") cmd = 0x0002 cmd_inv = 0xFFFD data_length = 0x000C + actual_length data_length_inv = 0xFFFF - data_length # 准备数据 addr_bytes = struct.pack('>I', addr + offset) length_bytes = struct.pack('>I', actual_length) header_data = addr_bytes + length_bytes checksum_data = header_data + chunk checksum = self.calculate_checksum(checksum_data) if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes, chunk): print(f"Write failed at offset 0x{offset:08X}!") return False offset += actual_length print("Write completed successfully [100.0%]") return True def read_flash(self, addr, length): chunk_size = 256 remaining = length offset = 0 read_data = bytearray() while remaining > 0: actual_length = min(chunk_size, remaining) progress = (offset / length) * 100 print(f"Reading {actual_length} bytes from 0x{addr+offset:08X}... [{progress:.1f}%]") cmd = 0x0003 cmd_inv = 0xFFFC data_length = 0x000C data_length_inv = 0xFFF3 # 准备数据 addr_bytes = struct.pack('>I', addr + offset) length_bytes = struct.pack('>I', actual_length) data = addr_bytes + length_bytes checksum = self.calculate_checksum(data) # 发送命令 if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes): print(f"Read failed at offset 0x{offset:08X}!") return False # 读取数据 chunk = self.ser.read(actual_length) if len(chunk) != actual_length: print(f"Read incomplete at offset 0x{offset:08X}!") return None read_data.extend(chunk) offset += actual_length remaining -= actual_length print("Read completed successfully [100.0%]") return read_data def verify_flash(self, addr, original_data): print("Starting verification...") read_data = self.read_flash(addr, len(original_data)) if read_data is None: print("Verification failed - could not read data") return False if len(read_data) != len(original_data): print(f"Verification failed - length mismatch (expected {len(original_data)}, got {len(read_data)})") return False for i in range(len(original_data)): if read_data[i] != original_data[i]: print(f"Verification failed at offset 0x{i:08X} (expected 0x{original_data[i]:02X}, got 0x{read_data[i]:02X})") return False print("Verification successful") return True def run(self): self.parse_arguments() self.connect_serial() # 读取文件 try: with open(self.filename, 'rb') as f: file_data = f.read() except IOError as e: print(f"Error opening file: {e}") sys.exit(1) print(f"Preparing to download {len(file_data)} bytes from {self.filename} to 0x{self.address:08X}") # 擦除前的信号控制 if not self.control_signals_pre_erase(): print("shake hands failed") sys.exit(1) # 擦除Flash (擦除足够的空间) erase_size = ((len(file_data) + 4095) // 4096) * 4096 # 假设擦除块大小为4KB if not self.erase_flash(self.address, erase_size): sys.exit(1) # 写入数据 if not self.write_flash(self.address, file_data): sys.exit(1) # 校验 if self.check: if not self.verify_flash(self.address, file_data): sys.exit(1) print("Flash download completed successfully") # 下载完成后的信号控制 self.control_signals_post_download() if __name__ == "__main__": downloader = FlashDownloader() downloader.run()