Go 断点续传指南
什么是断点续传
断点续传(Resumable Transfer)是一种在网络传输中断后,能够从上次中断的位置继续传输的技术。这项技术广泛应用于大文件传输、下载管理器和云存储服务中。
断点续传实现原理
断点续传的实现基于以下几个关键技术点:
- 文件分块:将大文件分割为多个小块(chunks)
- 进度记录:记录已成功传输的块信息
- 校验机制:确保传输数据的完整性
- 续传请求:从最后成功的位置重新请求数据
实战讲解
案例是基于将大文件移动到别的位置,中途可以通过命令行停止来模拟断点行为。
案例可能不够完美,但是可以帮助您理解断点续传的实现原理和代码设计思想。
请确保您有 Go 的基础文件目录操作 知识
简易版本
go
package main
import (
"fmt"
"io"
"os"
"strconv"
)
const chunkSize = 1024 // 1KB 的块大小
func main() {
srcFile := "src.md" // 源文件
dstFile := "copy.md" // 目标文件
offsetFile := "offset.txt" // 记录传输进度的文件
// 获取源文件信息
srcInfo, err := os.Stat(srcFile)
if err != nil {
fmt.Println("无法获取源文件信息:", err)
return
}
srcSize := srcInfo.Size()
// 打开源文件
src, err := os.Open(srcFile)
if err != nil {
fmt.Println("无法打开源文件:", err)
return
}
defer src.Close()
// 初始化偏移量
var offset int64 = 0
// 检查是否有之前的传输进度
if _, err := os.Stat(offsetFile); err == nil {
// Deprecated: As of Go 1.16, this function simply calls [os.ReadFile].
//data, err := ioutil.ReadFile(offsetFile)
data, err := os.ReadFile(offsetFile)
if err == nil {
// 将读取到的文件内容转换为十进制数字
offset, _ = strconv.ParseInt(string(data), 10, 64)
fmt.Printf("从断点继续传输,已传输 %d 字节\n", offset)
}
}
// 打开目标文件,如果存在则追加,否则创建
dst, err := os.OpenFile(dstFile, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
fmt.Println("无法打开目标文件:", err)
return
}
defer dst.Close()
// 定位到上次传输的位置
_, err = src.Seek(offset, io.SeekStart)
if err != nil {
fmt.Println("定位文件位置错误:", err)
return
}
_, err = dst.Seek(offset, io.SeekStart)
if err != nil {
fmt.Println("定位文件位置错误:", err)
return
}
// 开始传输
buf := make([]byte, chunkSize)
for {
n, err := src.Read(buf)
if err != nil && err != io.EOF {
fmt.Println("读取文件错误:", err)
return
}
if n == 0 {
break
}
_, err = dst.Write(buf[:n])
if err != nil {
fmt.Println("写入文件错误:", err)
return
}
offset += int64(n)
fmt.Printf("\r已传输 %d/%d 字节 (%.2f%%)", offset, srcSize, float64(offset)*100/float64(srcSize))
// 更新传输进度
// Deprecated: As of Go 1.16, this function simply calls [os.WriteFile].
//err = ioutil.WriteFile(offsetFile, []byte(strconv.FormatInt(offset, 10)), 0644)
err = os.WriteFile(offsetFile, []byte(strconv.FormatInt(offset, 10)), 0644)
if err != nil {
fmt.Println("保存传输进度错误:", err)
return
}
}
// 传输完成,删除进度文件
_ = os.Remove(offsetFile)
fmt.Println("文件传输完成")
}
进阶版本
go
package main
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"log"
"os"
"path/filepath"
"time"
)
// PersistedState 需要持久化的传输状态
type PersistedState struct {
Downloaded int64 `json:"downloaded"` // 已传输字节数
TotalSize int64 `json:"totalSize"` // 文件总大小
TempFileName string `json:"tempFileName"` // 临时文件名(不含路径)
Checksum string `json:"checksum"` // 最终校验和
HashState []byte `json:"hashState"` // 哈希计算器状态
HashSize int64 `json:"hashSize"` // 已计算哈希的数据量
HashAlgorithm string `json:"hashAlgorithm"` // 使用的哈希算法
}
// ResumeTransfer 断点续传结构体
type ResumeTransfer struct {
// 配置参数
SourcePath string // 源文件路径
DestPath string // 目标文件路径
TempDir string // 临时目录
ChunkSize int64 // 分片大小(字节)
HashAlgorithm string // 哈希算法(md5/sha1/sha256)
// 运行时状态
statusFile string // 状态文件完整路径
tempFilePath string // 临时文件完整路径
status PersistedState // 持久化状态
hash hash.Hash // 哈希计算器
}
// NewResumeTransfer 创建新的断点续传实例
func NewResumeTransfer(source, dest string) *ResumeTransfer {
return &ResumeTransfer{
SourcePath: source,
TempDir: ".",
DestPath: dest,
ChunkSize: 1 * 1024 * 1024, // 默认1MB
HashAlgorithm: "md5", // 默认使用MD5
}
}
// SetHashAlgorithm 设置哈希算法
func (rt *ResumeTransfer) SetHashAlgorithm(algorithm string) error {
switch algorithm {
case "md5", "sha1", "sha256":
rt.HashAlgorithm = algorithm
return nil
default:
return fmt.Errorf("不支持的哈希算法: %s", algorithm)
}
}
// newHash 创建新的哈希计算器
func (rt *ResumeTransfer) newHash() hash.Hash {
switch rt.HashAlgorithm {
case "sha1":
return sha1.New()
case "sha256":
return sha256.New()
default:
return md5.New()
}
}
// Start 开始或恢复传输
func (rt *ResumeTransfer) Start() error {
// 1. 初始化工作区
if err := rt.initWorkspace(); err != nil {
return fmt.Errorf("初始化工作区失败: %v", err)
}
// 2. 获取源文件信息
srcInfo, err := os.Stat(rt.SourcePath)
if err != nil {
return fmt.Errorf("获取源文件信息失败: %v", err)
}
rt.status.TotalSize = srcInfo.Size()
rt.status.HashAlgorithm = rt.HashAlgorithm
// 3. 尝试恢复传输
if err := rt.tryResume(); err != nil {
log.Printf("恢复失败: %v, 将重新开始传输", err)
rt.cleanup() // 清理可能损坏的状态
if err := rt.initWorkspace(); err != nil {
return err
}
rt.status.Downloaded = 0
rt.hash = rt.newHash() // 重置哈希计算器
}
// 4. 打开源文件和目标文件
srcFile, err := os.Open(rt.SourcePath)
if err != nil {
return fmt.Errorf("打开源文件失败: %v", err)
}
defer srcFile.Close()
destFile, err := os.OpenFile(rt.tempFilePath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("打开目标文件失败: %v", err)
}
// 5. 定位到上次传输的位置
if _, err := srcFile.Seek(rt.status.Downloaded, io.SeekStart); err != nil {
return fmt.Errorf("源文件定位失败: %v", err)
}
if _, err := destFile.Seek(rt.status.Downloaded, io.SeekStart); err != nil {
return fmt.Errorf("目标文件定位失败: %v", err)
}
// 6. 初始化哈希计算器
if rt.hash == nil {
rt.hash = rt.newHash()
}
// 7. 开始分片传输
buf := make([]byte, rt.ChunkSize)
startTime := time.Now()
for rt.status.Downloaded < rt.status.TotalSize {
// 计算当前分片大小
remaining := rt.status.TotalSize - rt.status.Downloaded
chunkSize := min(rt.ChunkSize, remaining)
// 读取分片数据
n, err := io.ReadFull(srcFile, buf[:chunkSize])
if err != nil && err != io.EOF {
return fmt.Errorf("读取分片失败: %v", err)
}
// 写入目标文件
if _, err := destFile.Write(buf[:n]); err != nil {
return fmt.Errorf("写入分片失败: %v", err)
}
// 更新哈希计算
if _, err := rt.hash.Write(buf[:n]); err != nil {
return fmt.Errorf("更新哈希失败: %v", err)
}
// 更新已传输大小
rt.status.Downloaded += int64(n)
rt.status.HashSize += int64(n)
// 保存状态(包括哈希状态)
if err := rt.saveStatus(); err != nil {
return fmt.Errorf("保存状态失败: %v", err)
}
// 确保数据写入磁盘
if err := destFile.Sync(); err != nil {
return fmt.Errorf("同步文件失败: %v", err)
}
// 打印进度
progress := float64(rt.status.Downloaded) / float64(rt.status.TotalSize) * 100
speed := float64(rt.status.Downloaded) / time.Since(startTime).Seconds() / 1024 / 1024
fmt.Printf("\r进度: %.2f%%, 速度: %.2f MB/s", progress, speed)
}
// 8. 完成传输
_ = destFile.Close()
if err := rt.finishTransfer(); err != nil {
return fmt.Errorf("完成传输失败: %v", err)
}
fmt.Printf("\n传输完成! 总耗时: %v\n", time.Since(startTime))
return nil
}
// initWorkspace 初始化工作区
func (rt *ResumeTransfer) initWorkspace() error {
// 创建临时目录
if rt.TempDir == "" {
tempDir, err := os.MkdirTemp("", "transfer-*")
if err != nil {
return err
}
rt.TempDir = tempDir
} else {
if err := os.MkdirAll(rt.TempDir, 0755); err != nil {
return err
}
}
// 设置状态文件和临时文件路径
rt.statusFile = filepath.Join(rt.TempDir, "status.json")
tempFileName := filepath.Base(rt.DestPath) + ".tmp"
rt.tempFilePath = filepath.Join(rt.TempDir, tempFileName)
rt.status.TempFileName = tempFileName // 只存储文件名
return nil
}
// tryResume 尝试恢复传输
func (rt *ResumeTransfer) tryResume() error {
// 检查状态文件是否存在
if _, err := os.Stat(rt.statusFile); os.IsNotExist(err) {
return nil // 全新传输
}
// 加载状态
data, err := os.ReadFile(rt.statusFile)
if err != nil {
return err
}
if err := json.Unmarshal(data, &rt.status); err != nil {
return err
}
// 重建临时文件路径
rt.tempFilePath = filepath.Join(rt.TempDir, rt.status.TempFileName)
// 验证临时文件
fileInfo, err := os.Stat(rt.tempFilePath)
if os.IsNotExist(err) {
return errors.New("临时文件不存在")
} else if err != nil {
return err
}
// 验证文件大小
if fileInfo.Size() != rt.status.Downloaded {
return fmt.Errorf("文件大小不匹配: 实际 %d, 记录 %d", fileInfo.Size(), rt.status.Downloaded)
}
// 验证哈希状态
if rt.status.HashSize != rt.status.Downloaded {
return fmt.Errorf("哈希状态不匹配: 哈希计算量 %d, 已传输量 %d",
rt.status.HashSize, rt.status.Downloaded)
}
// 恢复哈希计算器状态
if len(rt.status.HashState) > 0 {
rt.hash = rt.newHash()
if unmarshaler, ok := rt.hash.(encoding.BinaryUnmarshaler); ok {
if err := unmarshaler.UnmarshalBinary(rt.status.HashState); err != nil {
return fmt.Errorf("恢复哈希状态失败: %v", err)
}
}
} else {
rt.hash = rt.newHash()
}
return nil
}
// saveStatus 保存传输状态
func (rt *ResumeTransfer) saveStatus() error {
// 备份原状态
oldState := rt.status.HashState
// 保存哈希计算器状态
var err error
if rt.hash != nil {
if marshaler, ok := rt.hash.(encoding.BinaryMarshaler); ok {
rt.status.HashState, err = marshaler.MarshalBinary()
if err != nil {
rt.status.HashState = oldState
return err
}
}
}
data, err := json.Marshal(rt.status)
if err != nil {
rt.status.HashState = oldState
return err
}
if err := os.WriteFile(rt.statusFile, data, 0644); err != nil {
rt.status.HashState = oldState
return err
}
return nil
}
// finishTransfer 完成传输
func (rt *ResumeTransfer) finishTransfer() error {
// 计算最终校验和
rt.status.Checksum = hex.EncodeToString(rt.hash.Sum(nil))
// 保存最终状态
if err := rt.saveStatus(); err != nil {
return err
}
// 重命名临时文件
if err := os.Rename(rt.tempFilePath, rt.DestPath); err != nil {
return err
}
// 验证目标文件
if err := rt.verifyFile(); err != nil {
return fmt.Errorf("文件验证失败: %v", err)
}
// 清理状态文件
_ = os.Remove(rt.statusFile)
return nil
}
// verifyFile 验证文件完整性
func (rt *ResumeTransfer) verifyFile() error {
file, err := os.Open(rt.DestPath)
if err != nil {
return err
}
defer file.Close()
hash := rt.newHash()
if _, err := io.Copy(hash, file); err != nil {
return err
}
actualChecksum := hex.EncodeToString(hash.Sum(nil))
if rt.status.Checksum != "" && actualChecksum != rt.status.Checksum {
return fmt.Errorf("校验和不匹配: 期望 %s, 实际 %s", rt.status.Checksum, actualChecksum)
}
return nil
}
// cleanup 清理临时文件
func (rt *ResumeTransfer) cleanup() {
_ = os.Remove(rt.statusFile)
_ = os.Remove(rt.tempFilePath)
}
func min(a, b int64) int64 {
if a < b {
return a
}
return b
}
func main() {
if len(os.Args) < 3 {
fmt.Println("用法: ./resume <源文件> <目标文件> [哈希算法]")
fmt.Println("支持的哈希算法: md5 (默认), sha1, sha256")
return
}
transfer := NewResumeTransfer(os.Args[1], os.Args[2])
// 设置哈希算法(如果指定)
if len(os.Args) > 3 {
if err := transfer.SetHashAlgorithm(os.Args[3]); err != nil {
log.Fatalf("错误: %v", err)
}
}
if err := transfer.Start(); err != nil {
log.Fatalf("传输失败: %v", err)
}
}