[feat] add auto_flash

This commit is contained in:
zhji 2025-05-05 21:42:38 +08:00
parent cf55c5cc2a
commit 30fe7c4377
3 changed files with 293 additions and 31 deletions

View File

@ -15,8 +15,6 @@ PHDRS
{
boot2_pre PT_LOAD FLAGS(5); /* R + X */
text PT_LOAD FLAGS(5); /* R + X */
rodata PT_LOAD FLAGS(5); /* R + W */
data PT_LOAD FLAGS(6); /* R + W */
bss PT_LOAD FLAGS(6); /* R + W */
}
@ -40,22 +38,10 @@ SECTIONS
. = ALIGN(4);
_boot2_copy_self_start_addr = .;
*(.text*)
. = ALIGN(4);
} >RAM AT > FLASH :text
.rodata :
{
. = ALIGN(4);
*(.rodata*)
. = ALIGN(4);
} >RAM AT > FLASH :rodata
.data :
{
. = ALIGN(4);
*(.data*)
. = ALIGN(4);
} >RAM AT > FLASH :data
} >RAM AT > FLASH :text
_boot2_copy_self_end_addr = .;

View File

@ -62,8 +62,6 @@ void __attribute__((section(".text.boot2_pre"))) boot2_copy_self(void)
uint8_t uart_tx_buffer[512];
uint8_t uart_rx_buffer[512];
uint8_t flash_tx_buffer[512];
uint8_t flash_rx_buffer[512];
struct uart_cfg_s uart_cfg = {
.baudrate = 2 * 1000 * 1000,
@ -181,7 +179,7 @@ void uart_state_machine(uint8_t id)
uint16_t code, code_inv, length, length_inv;
if (uart_get_flags(id) & UART_FLAG_RXFE) {
if (timer_count_read() - time_fifo_empty < 1000) {
if (timer_count_read() - time_fifo_empty < 100) {
return;
}
time_fifo_empty = timer_count_read();
@ -202,15 +200,25 @@ void uart_state_machine(uint8_t id)
} else {
uart_rx_buffer[uart_rx_length++] = uart0_hw->dr & 0xFF;
time_fifo_empty = timer_count_read();
if (uart_rx_length >= sizeof(flash_rx_buffer)) {
if (uart_rx_length >= sizeof(uart_rx_buffer)) {
uart_rx_length = 0;
return;
}
}
}
__attribute__((naked)) void jump_to_address(uint32_t address) {
__asm volatile (
"bx %0\n"
:
: "r" (address)
);
}
int main(void)
{
uint32_t boot_pin_low, boot_pin_high;
clock_ref_set_src(CLOCK_REF_SRC_XOSC_GLITCHLESS);
clock_sys_set_src(CLOCK_SYS_SRC_REF_GLITCHLESS);
/* refdiv >= 5MHz, VCO=[750:1600]MHz, fbdiv=[16:320], postdiv=[1:7] */
@ -223,20 +231,26 @@ int main(void)
gpio_init_simple(0, GPIO_FUNC_UART, DISABLE, ENABLE);
gpio_init_simple(1, GPIO_FUNC_UART, DISABLE, ENABLE);
uart_init(UART_ID_0, &uart_cfg);
*(volatile uint32_t *)(WATCHDOG_BASE + 0x2C) = ((1 << 9) | (12 << 0));
timer_start();
printf("boot2 system clock = 60MHz\r\n");
printf("boot2 peripheral clock = 120MHz\r\n");
while (1) {
// int c = uart_get_char(UART_ID_0);
// if (c >= 0) {
// uart_put_char(UART_ID_0, c);
// }
uart_state_machine(UART_ID_0);
gpio_init_simple(2, GPIO_FUNC_SIO, ENABLE, DISABLE);
boot_pin_low = 0;
boot_pin_high = 0;
for (uint32_t i = 0; i < 1000; i++) {
if (gpio_read(2)) {
boot_pin_high++;
} else {
boot_pin_low++;
}
}
// flash_erase(addr);
flash_read(0x1200, flash_rx_buffer, FLASH_WRITE_SIZE);
// flash_write(addr, data, FLASH_WRITE_SIZE);
if (boot_pin_high > boot_pin_low) {
while (1) {
uart_state_machine(UART_ID_0);
}
} else {
jump_to_address(0x00100000);
}
return 0;
}

262
tools/flash_download.py Normal file
View File

@ -0,0 +1,262 @@
import os
import sys
import time
import argparse
import serial
import struct
from pathlib import Path
class FlashDownloader:
def __init__(self):
self.port = None
self.baudrate = 2000000 # 默认2Mbps
self.filename = None
self.address = 0x00010000 # 默认地址
self.check = False
self.ser = None
def parse_arguments(self):
parser = argparse.ArgumentParser(description='UART Flash Downloader')
parser.add_argument('--COMX', required=True, help='COM port (e.g. COM3)')
parser.add_argument('--BAUDRATE', type=int, help='Baud rate (default: 2000000)')
parser.add_argument('--file', help='File to download')
parser.add_argument('--ADDR', type=lambda x: int(x, 0), help='Download address (hex/dec)')
parser.add_argument('--CHECK', action='store_true', help='Enable readback verification')
args = parser.parse_args()
self.port = args.COMX
if args.BAUDRATE:
self.baudrate = args.BAUDRATE
if args.file:
self.filename = args.file
else:
# 默认文件名为当前目录/build/当前文件夹名.bin
current_dir = Path.cwd().name
self.filename = f"build/{current_dir}.bin"
if args.ADDR:
self.address = args.ADDR
self.check = args.CHECK
def connect_serial(self):
try:
self.ser = serial.Serial(
port=self.port,
baudrate=self.baudrate,
bytesize=serial.EIGHTBITS,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
timeout=1
)
print(f"Connected to {self.port} at {self.baudrate} baud")
# 初始化信号状态
self.ser.rts = True # 初始拉低
self.ser.dtr = False # 初始拉高
except serial.SerialException as e:
print(f"Error opening serial port: {e}")
sys.exit(1)
def control_signals_pre_erase(self):
"""擦除前的信号控制"""
print("Setting pre-erase signals: RTS low, DTR high...")
self.ser.rts = True # RTS拉低
self.ser.dtr = False # DTR拉高
time.sleep(0.1) # 等待100ms
self.ser.rts = False # RTS拉高
time.sleep(0.1) # 等待100ms
print("Pre-erase signals set")
def control_signals_post_download(self):
"""下载完成后的信号控制"""
print("Setting post-download signals: RTS low, DTR low...")
self.ser.rts = True # RTS拉低
self.ser.dtr = True # DTR拉低
time.sleep(0.1) # 等待100ms
self.ser.rts = False # RTS拉高
print("Post-download signals set")
def calculate_checksum(self, data):
return sum(data) & 0xFFFFFFFF
def send_command(self, cmd, cmd_inv, data_length, data_length_inv, checksum, *args):
# 构建命令头
header = struct.pack('>HHHH', cmd, cmd_inv, data_length, data_length_inv)
# 构建数据部分
data_parts = []
for arg in args:
if isinstance(arg, int):
if arg <= 0xFF:
data_parts.append(struct.pack('>B', arg))
elif arg <= 0xFFFF:
data_parts.append(struct.pack('>H', arg))
else:
data_parts.append(struct.pack('>I', arg))
elif isinstance(arg, bytes):
data_parts.append(arg)
data = b''.join(data_parts)
# 构建完整命令
full_cmd = header + struct.pack('>I', checksum) + data
# 发送命令
self.ser.write(full_cmd)
# 等待响应
response = self.ser.read(4)
return response == b'OKAY'
def erase_flash(self, addr, length):
print(f"Erasing {length} bytes at 0x{addr:08X}...")
cmd = 0x0001
cmd_inv = 0xFFFE
data_length = 0x000C
data_length_inv = 0xFFF3
# 准备数据
addr_bytes = struct.pack('>I', addr)
length_bytes = struct.pack('>I', length)
data = addr_bytes + length_bytes
checksum = self.calculate_checksum(data)
if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes):
print("Erase failed!")
return False
print("Erase successful")
return True
def write_flash(self, addr, data):
chunk_size = 256
total_length = len(data)
offset = 0
while offset < total_length:
chunk = data[offset:offset+chunk_size]
actual_length = len(chunk)
progress = (offset / total_length) * 100
print(f"Writing {actual_length} bytes at 0x{addr+offset:08X}... [{progress:.1f}%]")
cmd = 0x0002
cmd_inv = 0xFFFD
data_length = 0x000C + actual_length
data_length_inv = 0xFFFF - data_length
# 准备数据
addr_bytes = struct.pack('>I', addr + offset)
length_bytes = struct.pack('>I', actual_length)
header_data = addr_bytes + length_bytes
checksum_data = header_data + chunk
checksum = self.calculate_checksum(checksum_data)
if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes, chunk):
print(f"Write failed at offset 0x{offset:08X}!")
return False
offset += actual_length
print("Write completed successfully [100.0%]")
return True
def read_flash(self, addr, length):
chunk_size = 256
remaining = length
offset = 0
read_data = bytearray()
while remaining > 0:
actual_length = min(chunk_size, remaining)
progress = (offset / length) * 100
print(f"Reading {actual_length} bytes from 0x{addr+offset:08X}... [{progress:.1f}%]")
cmd = 0x0003
cmd_inv = 0xFFFC
data_length = 0x000C
data_length_inv = 0xFFF3
# 准备数据
addr_bytes = struct.pack('>I', addr + offset)
length_bytes = struct.pack('>I', actual_length)
data = addr_bytes + length_bytes
checksum = self.calculate_checksum(data)
# 发送命令
if not self.send_command(cmd, cmd_inv, data_length, data_length_inv, checksum, addr_bytes, length_bytes):
print(f"Read failed at offset 0x{offset:08X}!")
return False
# 读取数据
chunk = self.ser.read(actual_length)
if len(chunk) != actual_length:
print(f"Read incomplete at offset 0x{offset:08X}!")
return None
read_data.extend(chunk)
offset += actual_length
remaining -= actual_length
print("Read completed successfully [100.0%]")
return read_data
def verify_flash(self, addr, original_data):
print("Starting verification...")
read_data = self.read_flash(addr, len(original_data))
if read_data is None:
print("Verification failed - could not read data")
return False
if len(read_data) != len(original_data):
print(f"Verification failed - length mismatch (expected {len(original_data)}, got {len(read_data)})")
return False
for i in range(len(original_data)):
if read_data[i] != original_data[i]:
print(f"Verification failed at offset 0x{i:08X} (expected 0x{original_data[i]:02X}, got 0x{read_data[i]:02X})")
return False
print("Verification successful")
return True
def run(self):
self.parse_arguments()
self.connect_serial()
# 读取文件
try:
with open(self.filename, 'rb') as f:
file_data = f.read()
except IOError as e:
print(f"Error opening file: {e}")
sys.exit(1)
print(f"Preparing to download {len(file_data)} bytes from {self.filename} to 0x{self.address:08X}")
# 擦除前的信号控制
self.control_signals_pre_erase()
# 擦除Flash (擦除足够的空间)
erase_size = ((len(file_data) + 4095) // 4096) * 4096 # 假设擦除块大小为4KB
if not self.erase_flash(self.address, erase_size):
sys.exit(1)
# 写入数据
if not self.write_flash(self.address, file_data):
sys.exit(1)
# 校验
if self.check:
if not self.verify_flash(self.address, file_data):
sys.exit(1)
print("Flash download completed successfully")
# 下载完成后的信号控制
self.control_signals_post_download()
if __name__ == "__main__":
downloader = FlashDownloader()
downloader.run()