From 30fe7c43779ba10febd14a178f978524781b4c2a Mon Sep 17 00:00:00 2001 From: zhji Date: Mon, 5 May 2025 21:42:38 +0800 Subject: [PATCH] [feat] add auto_flash --- example/boot2/boot2.ld | 16 +-- example/boot2/main.c | 46 ++++--- tools/flash_download.py | 262 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 293 insertions(+), 31 deletions(-) create mode 100644 tools/flash_download.py diff --git a/example/boot2/boot2.ld b/example/boot2/boot2.ld index b1cce04..60f96ad 100644 --- a/example/boot2/boot2.ld +++ b/example/boot2/boot2.ld @@ -15,8 +15,6 @@ PHDRS { boot2_pre PT_LOAD FLAGS(5); /* R + X */ text PT_LOAD FLAGS(5); /* R + X */ - rodata PT_LOAD FLAGS(5); /* R + W */ - data PT_LOAD FLAGS(6); /* R + W */ bss PT_LOAD FLAGS(6); /* R + W */ } @@ -40,22 +38,10 @@ SECTIONS . = ALIGN(4); _boot2_copy_self_start_addr = .; *(.text*) - . = ALIGN(4); - } >RAM AT > FLASH :text - - .rodata : - { - . = ALIGN(4); *(.rodata*) - . = ALIGN(4); - } >RAM AT > FLASH :rodata - - .data : - { - . = ALIGN(4); *(.data*) . = ALIGN(4); - } >RAM AT > FLASH :data + } >RAM AT > FLASH :text _boot2_copy_self_end_addr = .; diff --git a/example/boot2/main.c b/example/boot2/main.c index 6293ada..90e08d2 100644 --- a/example/boot2/main.c +++ b/example/boot2/main.c @@ -62,8 +62,6 @@ void __attribute__((section(".text.boot2_pre"))) boot2_copy_self(void) uint8_t uart_tx_buffer[512]; uint8_t uart_rx_buffer[512]; -uint8_t flash_tx_buffer[512]; -uint8_t flash_rx_buffer[512]; struct uart_cfg_s uart_cfg = { .baudrate = 2 * 1000 * 1000, @@ -181,7 +179,7 @@ void uart_state_machine(uint8_t id) uint16_t code, code_inv, length, length_inv; if (uart_get_flags(id) & UART_FLAG_RXFE) { - if (timer_count_read() - time_fifo_empty < 1000) { + if (timer_count_read() - time_fifo_empty < 100) { return; } time_fifo_empty = timer_count_read(); @@ -202,15 +200,25 @@ void uart_state_machine(uint8_t id) } else { uart_rx_buffer[uart_rx_length++] = uart0_hw->dr & 0xFF; time_fifo_empty = timer_count_read(); - if (uart_rx_length >= sizeof(flash_rx_buffer)) { + if (uart_rx_length >= sizeof(uart_rx_buffer)) { uart_rx_length = 0; return; } } } +__attribute__((naked)) void jump_to_address(uint32_t address) { + __asm volatile ( + "bx %0\n" + : + : "r" (address) + ); +} + int main(void) { + uint32_t boot_pin_low, boot_pin_high; + clock_ref_set_src(CLOCK_REF_SRC_XOSC_GLITCHLESS); clock_sys_set_src(CLOCK_SYS_SRC_REF_GLITCHLESS); /* refdiv >= 5MHz, VCO=[750:1600]MHz, fbdiv=[16:320], postdiv=[1:7] */ @@ -223,20 +231,26 @@ int main(void) gpio_init_simple(0, GPIO_FUNC_UART, DISABLE, ENABLE); gpio_init_simple(1, GPIO_FUNC_UART, DISABLE, ENABLE); uart_init(UART_ID_0, &uart_cfg); + *(volatile uint32_t *)(WATCHDOG_BASE + 0x2C) = ((1 << 9) | (12 << 0)); timer_start(); - printf("boot2 system clock = 60MHz\r\n"); - printf("boot2 peripheral clock = 120MHz\r\n"); - - while (1) { - // int c = uart_get_char(UART_ID_0); - // if (c >= 0) { - // uart_put_char(UART_ID_0, c); - // } - uart_state_machine(UART_ID_0); + gpio_init_simple(2, GPIO_FUNC_SIO, ENABLE, DISABLE); + boot_pin_low = 0; + boot_pin_high = 0; + for (uint32_t i = 0; i < 1000; i++) { + if (gpio_read(2)) { + boot_pin_high++; + } else { + boot_pin_low++; + } } - // flash_erase(addr); - flash_read(0x1200, flash_rx_buffer, FLASH_WRITE_SIZE); - // flash_write(addr, data, FLASH_WRITE_SIZE); + if (boot_pin_high > boot_pin_low) { + while (1) { + uart_state_machine(UART_ID_0); + } + } else { + jump_to_address(0x00100000); + } + return 0; } diff --git a/tools/flash_download.py b/tools/flash_download.py new file mode 100644 index 0000000..543f4f1 --- /dev/null +++ b/tools/flash_download.py @@ -0,0 +1,262 @@ +import os +import sys +import time +import argparse +import serial +import struct +from pathlib import Path + +class FlashDownloader: + def __init__(self): + self.port = None + self.baudrate = 2000000 # 默认2Mbps + self.filename = None + self.address = 0x00010000 # 默认地址 + self.check = False + self.ser = None + + def parse_arguments(self): + parser = argparse.ArgumentParser(description='UART Flash Downloader') + parser.add_argument('--COMX', required=True, help='COM port (e.g. COM3)') + parser.add_argument('--BAUDRATE', type=int, help='Baud rate (default: 2000000)') + parser.add_argument('--file', help='File to download') + parser.add_argument('--ADDR', type=lambda x: int(x, 0), help='Download address (hex/dec)') + parser.add_argument('--CHECK', action='store_true', help='Enable readback verification') + + args = parser.parse_args() + + self.port = args.COMX + if args.BAUDRATE: + self.baudrate = args.BAUDRATE + if args.file: + self.filename = args.file + else: + # 默认文件名为当前目录/build/当前文件夹名.bin + current_dir = Path.cwd().name + self.filename = f"build/{current_dir}.bin" + if args.ADDR: + self.address = args.ADDR + self.check = args.CHECK + + def connect_serial(self): + try: + self.ser = serial.Serial( + port=self.port, + baudrate=self.baudrate, + bytesize=serial.EIGHTBITS, + parity=serial.PARITY_NONE, + stopbits=serial.STOPBITS_ONE, + timeout=1 + ) + print(f"Connected to {self.port} at {self.baudrate} baud") + + # 初始化信号状态 + self.ser.rts = True # 初始拉低 + self.ser.dtr = False # 初始拉高 + except serial.SerialException as e: + print(f"Error opening serial port: {e}") + sys.exit(1) + + def control_signals_pre_erase(self): + """擦除前的信号控制""" + print("Setting pre-erase signals: RTS low, DTR high...") + self.ser.rts = True # RTS拉低 + self.ser.dtr = False # DTR拉高 + time.sleep(0.1) # 等待100ms + self.ser.rts = False # RTS拉高 + time.sleep(0.1) # 等待100ms + print("Pre-erase signals set") + + def control_signals_post_download(self): + """下载完成后的信号控制""" + print("Setting post-download signals: RTS low, DTR low...") + self.ser.rts = True # RTS拉低 + self.ser.dtr = True # DTR拉低 + time.sleep(0.1) # 等待100ms + self.ser.rts = False # RTS拉高 + print("Post-download signals set") + + def calculate_checksum(self, data): + return sum(data) & 0xFFFFFFFF + + def send_command(self, cmd, cmd_inv, data_length, data_length_inv, checksum, *args): + # 构建命令头 + header = struct.pack('>HHHH', cmd, cmd_inv, data_length, data_length_inv) + + # 构建数据部分 + data_parts = [] + for arg in args: + if isinstance(arg, int): + if arg <= 0xFF: + data_parts.append(struct.pack('>B', arg)) + elif arg <= 0xFFFF: + data_parts.append(struct.pack('>H', arg)) + else: + data_parts.append(struct.pack('>I', arg)) + elif isinstance(arg, bytes): + data_parts.append(arg) + + data = b''.join(data_parts) + + # 构建完整命令 + full_cmd = header + struct.pack('>I', checksum) + data + + # 发送命令 + self.ser.write(full_cmd) + + # 等待响应 + response = self.ser.read(4) + return response == b'OKAY' + + def erase_flash(self, addr, length): + print(f"Erasing {length} bytes at 0x{addr:08X}...") + cmd = 0x0001 + cmd_inv = 0xFFFE + data_length = 0x000C + data_length_inv = 0xFFF3 + + # 准备数据 + addr_bytes = struct.pack('>I', addr) + length_bytes = struct.pack('>I', length) + data = addr_bytes + length_bytes + checksum = self.calculate_checksum(data) + + if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes): + print("Erase failed!") + return False + print("Erase successful") + return True + + def write_flash(self, addr, data): + chunk_size = 256 + total_length = len(data) + offset = 0 + + while offset < total_length: + chunk = data[offset:offset+chunk_size] + actual_length = len(chunk) + progress = (offset / total_length) * 100 + + print(f"Writing {actual_length} bytes at 0x{addr+offset:08X}... [{progress:.1f}%]") + + cmd = 0x0002 + cmd_inv = 0xFFFD + data_length = 0x000C + actual_length + data_length_inv = 0xFFFF - data_length + + # 准备数据 + addr_bytes = struct.pack('>I', addr + offset) + length_bytes = struct.pack('>I', actual_length) + header_data = addr_bytes + length_bytes + checksum_data = header_data + chunk + checksum = self.calculate_checksum(checksum_data) + + if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes, chunk): + print(f"Write failed at offset 0x{offset:08X}!") + return False + + offset += actual_length + + print("Write completed successfully [100.0%]") + return True + + def read_flash(self, addr, length): + chunk_size = 256 + remaining = length + offset = 0 + read_data = bytearray() + + while remaining > 0: + actual_length = min(chunk_size, remaining) + progress = (offset / length) * 100 + + print(f"Reading {actual_length} bytes from 0x{addr+offset:08X}... [{progress:.1f}%]") + + cmd = 0x0003 + cmd_inv = 0xFFFC + data_length = 0x000C + data_length_inv = 0xFFF3 + + # 准备数据 + addr_bytes = struct.pack('>I', addr + offset) + length_bytes = struct.pack('>I', actual_length) + data = addr_bytes + length_bytes + checksum = self.calculate_checksum(data) + + # 发送命令 + if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes): + print(f"Read failed at offset 0x{offset:08X}!") + return False + + # 读取数据 + chunk = self.ser.read(actual_length) + if len(chunk) != actual_length: + print(f"Read incomplete at offset 0x{offset:08X}!") + return None + + read_data.extend(chunk) + offset += actual_length + remaining -= actual_length + + print("Read completed successfully [100.0%]") + return read_data + + def verify_flash(self, addr, original_data): + print("Starting verification...") + read_data = self.read_flash(addr, len(original_data)) + + if read_data is None: + print("Verification failed - could not read data") + return False + + if len(read_data) != len(original_data): + print(f"Verification failed - length mismatch (expected {len(original_data)}, got {len(read_data)})") + return False + + for i in range(len(original_data)): + if read_data[i] != original_data[i]: + print(f"Verification failed at offset 0x{i:08X} (expected 0x{original_data[i]:02X}, got 0x{read_data[i]:02X})") + return False + + print("Verification successful") + return True + + def run(self): + self.parse_arguments() + self.connect_serial() + + # 读取文件 + try: + with open(self.filename, 'rb') as f: + file_data = f.read() + except IOError as e: + print(f"Error opening file: {e}") + sys.exit(1) + + print(f"Preparing to download {len(file_data)} bytes from {self.filename} to 0x{self.address:08X}") + + # 擦除前的信号控制 + self.control_signals_pre_erase() + + # 擦除Flash (擦除足够的空间) + erase_size = ((len(file_data) + 4095) // 4096) * 4096 # 假设擦除块大小为4KB + if not self.erase_flash(self.address, erase_size): + sys.exit(1) + + # 写入数据 + if not self.write_flash(self.address, file_data): + sys.exit(1) + + # 校验 + if self.check: + if not self.verify_flash(self.address, file_data): + sys.exit(1) + + print("Flash download completed successfully") + + # 下载完成后的信号控制 + self.control_signals_post_download() + +if __name__ == "__main__": + downloader = FlashDownloader() + downloader.run()