三十的博客

Go 断点续传指南

发布时间
阅读量 加载中...

什么是断点续传

断点续传(Resumable Transfer)是一种在网络传输中断后,能够从上次中断的位置继续传输的技术。这项技术广泛应用于大文件传输、下载管理器和云存储服务中。

断点续传实现原理

断点续传的实现基于以下几个关键技术点:

  1. 文件分块:将大文件分割为多个小块(chunks)
  2. 进度记录:记录已成功传输的块信息
  3. 校验机制:确保传输数据的完整性
  4. 续传请求:从最后成功的位置重新请求数据

实战讲解

案例是基于将大文件移动到别的位置,中途可以通过命令行停止来模拟断点行为。

案例可能不够完美,但是可以帮助您理解断点续传的实现原理和代码设计思想。

请确保您有 Go 的基础文件目录操作 知识

简易版本

go
package main

import (
	"fmt"
	"io"
	"os"
	"strconv"
)

const chunkSize = 1024 // 1KB 的块大小

func main() {
	srcFile := "src.md"        // 源文件
	dstFile := "copy.md"       // 目标文件
	offsetFile := "offset.txt" // 记录传输进度的文件

	// 获取源文件信息
	srcInfo, err := os.Stat(srcFile)
	if err != nil {
		fmt.Println("无法获取源文件信息:", err)
		return
	}
	srcSize := srcInfo.Size()

	// 打开源文件
	src, err := os.Open(srcFile)
	if err != nil {
		fmt.Println("无法打开源文件:", err)
		return
	}
	defer src.Close()

	// 初始化偏移量
	var offset int64 = 0

	// 检查是否有之前的传输进度
	if _, err := os.Stat(offsetFile); err == nil {
		// Deprecated: As of Go 1.16, this function simply calls [os.ReadFile].
		//data, err := ioutil.ReadFile(offsetFile)
		data, err := os.ReadFile(offsetFile)
		if err == nil {
			// 将读取到的文件内容转换为十进制数字
			offset, _ = strconv.ParseInt(string(data), 10, 64)
			fmt.Printf("从断点继续传输,已传输 %d 字节\n", offset)
		}
	}

	// 打开目标文件,如果存在则追加,否则创建
	dst, err := os.OpenFile(dstFile, os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		fmt.Println("无法打开目标文件:", err)
		return
	}
	defer dst.Close()

	// 定位到上次传输的位置
	_, err = src.Seek(offset, io.SeekStart)
	if err != nil {
		fmt.Println("定位文件位置错误:", err)
		return
	}
	_, err = dst.Seek(offset, io.SeekStart)
	if err != nil {
		fmt.Println("定位文件位置错误:", err)
		return
	}

	// 开始传输
	buf := make([]byte, chunkSize)
	for {
		n, err := src.Read(buf)
		if err != nil && err != io.EOF {
			fmt.Println("读取文件错误:", err)
			return
		}
		if n == 0 {
			break
		}

		_, err = dst.Write(buf[:n])
		if err != nil {
			fmt.Println("写入文件错误:", err)
			return
		}

		offset += int64(n)
		fmt.Printf("\r已传输 %d/%d 字节 (%.2f%%)", offset, srcSize, float64(offset)*100/float64(srcSize))

		// 更新传输进度
		// Deprecated: As of Go 1.16, this function simply calls [os.WriteFile].
		//err = ioutil.WriteFile(offsetFile, []byte(strconv.FormatInt(offset, 10)), 0644)

		err = os.WriteFile(offsetFile, []byte(strconv.FormatInt(offset, 10)), 0644)
		if err != nil {
			fmt.Println("保存传输进度错误:", err)
			return
		}
	}

	// 传输完成,删除进度文件
	_ = os.Remove(offsetFile)
	fmt.Println("文件传输完成")
}

进阶版本

go
package main

import (
	"crypto/md5"
	"crypto/sha1"
	"crypto/sha256"
	"encoding"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"hash"
	"io"
	"log"
	"os"
	"path/filepath"
	"time"
)

// PersistedState 需要持久化的传输状态
type PersistedState struct {
	Downloaded    int64  `json:"downloaded"`    // 已传输字节数
	TotalSize     int64  `json:"totalSize"`     // 文件总大小
	TempFileName  string `json:"tempFileName"`  // 临时文件名(不含路径)
	Checksum      string `json:"checksum"`      // 最终校验和
	HashState     []byte `json:"hashState"`     // 哈希计算器状态
	HashSize      int64  `json:"hashSize"`      // 已计算哈希的数据量
	HashAlgorithm string `json:"hashAlgorithm"` // 使用的哈希算法
}

// ResumeTransfer 断点续传结构体
type ResumeTransfer struct {
	// 配置参数
	SourcePath    string // 源文件路径
	DestPath      string // 目标文件路径
	TempDir       string // 临时目录
	ChunkSize     int64  // 分片大小(字节)
	HashAlgorithm string // 哈希算法(md5/sha1/sha256)

	// 运行时状态
	statusFile   string         // 状态文件完整路径
	tempFilePath string         // 临时文件完整路径
	status       PersistedState // 持久化状态
	hash         hash.Hash      // 哈希计算器
}

// NewResumeTransfer 创建新的断点续传实例
func NewResumeTransfer(source, dest string) *ResumeTransfer {
	return &ResumeTransfer{
		SourcePath:    source,
		TempDir:       ".",
		DestPath:      dest,
		ChunkSize:     1 * 1024 * 1024, // 默认1MB
		HashAlgorithm: "md5",           // 默认使用MD5
	}
}

// SetHashAlgorithm 设置哈希算法
func (rt *ResumeTransfer) SetHashAlgorithm(algorithm string) error {
	switch algorithm {
	case "md5", "sha1", "sha256":
		rt.HashAlgorithm = algorithm
		return nil
	default:
		return fmt.Errorf("不支持的哈希算法: %s", algorithm)
	}
}

// newHash 创建新的哈希计算器
func (rt *ResumeTransfer) newHash() hash.Hash {
	switch rt.HashAlgorithm {
	case "sha1":
		return sha1.New()
	case "sha256":
		return sha256.New()
	default:
		return md5.New()
	}
}

// Start 开始或恢复传输
func (rt *ResumeTransfer) Start() error {
	// 1. 初始化工作区
	if err := rt.initWorkspace(); err != nil {
		return fmt.Errorf("初始化工作区失败: %v", err)
	}

	// 2. 获取源文件信息
	srcInfo, err := os.Stat(rt.SourcePath)
	if err != nil {
		return fmt.Errorf("获取源文件信息失败: %v", err)
	}
	rt.status.TotalSize = srcInfo.Size()
	rt.status.HashAlgorithm = rt.HashAlgorithm

	// 3. 尝试恢复传输
	if err := rt.tryResume(); err != nil {
		log.Printf("恢复失败: %v, 将重新开始传输", err)
		rt.cleanup() // 清理可能损坏的状态
		if err := rt.initWorkspace(); err != nil {
			return err
		}
		rt.status.Downloaded = 0
		rt.hash = rt.newHash() // 重置哈希计算器
	}

	// 4. 打开源文件和目标文件
	srcFile, err := os.Open(rt.SourcePath)
	if err != nil {
		return fmt.Errorf("打开源文件失败: %v", err)
	}
	defer srcFile.Close()

	destFile, err := os.OpenFile(rt.tempFilePath, os.O_WRONLY|os.O_CREATE, 0644)
	if err != nil {
		return fmt.Errorf("打开目标文件失败: %v", err)
	}

	// 5. 定位到上次传输的位置
	if _, err := srcFile.Seek(rt.status.Downloaded, io.SeekStart); err != nil {
		return fmt.Errorf("源文件定位失败: %v", err)
	}

	if _, err := destFile.Seek(rt.status.Downloaded, io.SeekStart); err != nil {
		return fmt.Errorf("目标文件定位失败: %v", err)
	}

	// 6. 初始化哈希计算器
	if rt.hash == nil {
		rt.hash = rt.newHash()
	}

	// 7. 开始分片传输
	buf := make([]byte, rt.ChunkSize)
	startTime := time.Now()

	for rt.status.Downloaded < rt.status.TotalSize {
		// 计算当前分片大小
		remaining := rt.status.TotalSize - rt.status.Downloaded
		chunkSize := min(rt.ChunkSize, remaining)

		// 读取分片数据
		n, err := io.ReadFull(srcFile, buf[:chunkSize])
		if err != nil && err != io.EOF {
			return fmt.Errorf("读取分片失败: %v", err)
		}

		// 写入目标文件
		if _, err := destFile.Write(buf[:n]); err != nil {
			return fmt.Errorf("写入分片失败: %v", err)
		}

		// 更新哈希计算
		if _, err := rt.hash.Write(buf[:n]); err != nil {
			return fmt.Errorf("更新哈希失败: %v", err)
		}

		// 更新已传输大小
		rt.status.Downloaded += int64(n)
		rt.status.HashSize += int64(n)

		// 保存状态(包括哈希状态)
		if err := rt.saveStatus(); err != nil {
			return fmt.Errorf("保存状态失败: %v", err)
		}

		// 确保数据写入磁盘
		if err := destFile.Sync(); err != nil {
			return fmt.Errorf("同步文件失败: %v", err)
		}

		// 打印进度
		progress := float64(rt.status.Downloaded) / float64(rt.status.TotalSize) * 100
		speed := float64(rt.status.Downloaded) / time.Since(startTime).Seconds() / 1024 / 1024
		fmt.Printf("\r进度: %.2f%%, 速度: %.2f MB/s", progress, speed)
	}

	// 8. 完成传输
	_ = destFile.Close()
	if err := rt.finishTransfer(); err != nil {
		return fmt.Errorf("完成传输失败: %v", err)
	}

	fmt.Printf("\n传输完成! 总耗时: %v\n", time.Since(startTime))
	return nil
}

// initWorkspace 初始化工作区
func (rt *ResumeTransfer) initWorkspace() error {
	// 创建临时目录
	if rt.TempDir == "" {
		tempDir, err := os.MkdirTemp("", "transfer-*")
		if err != nil {
			return err
		}
		rt.TempDir = tempDir
	} else {
		if err := os.MkdirAll(rt.TempDir, 0755); err != nil {
			return err
		}
	}

	// 设置状态文件和临时文件路径
	rt.statusFile = filepath.Join(rt.TempDir, "status.json")
	tempFileName := filepath.Base(rt.DestPath) + ".tmp"
	rt.tempFilePath = filepath.Join(rt.TempDir, tempFileName)
	rt.status.TempFileName = tempFileName // 只存储文件名

	return nil
}

// tryResume 尝试恢复传输
func (rt *ResumeTransfer) tryResume() error {
	// 检查状态文件是否存在
	if _, err := os.Stat(rt.statusFile); os.IsNotExist(err) {
		return nil // 全新传输
	}

	// 加载状态
	data, err := os.ReadFile(rt.statusFile)
	if err != nil {
		return err
	}

	if err := json.Unmarshal(data, &rt.status); err != nil {
		return err
	}

	// 重建临时文件路径
	rt.tempFilePath = filepath.Join(rt.TempDir, rt.status.TempFileName)

	// 验证临时文件
	fileInfo, err := os.Stat(rt.tempFilePath)
	if os.IsNotExist(err) {
		return errors.New("临时文件不存在")
	} else if err != nil {
		return err
	}

	// 验证文件大小
	if fileInfo.Size() != rt.status.Downloaded {
		return fmt.Errorf("文件大小不匹配: 实际 %d, 记录 %d", fileInfo.Size(), rt.status.Downloaded)
	}

	// 验证哈希状态
	if rt.status.HashSize != rt.status.Downloaded {
		return fmt.Errorf("哈希状态不匹配: 哈希计算量 %d, 已传输量 %d",
			rt.status.HashSize, rt.status.Downloaded)
	}

	// 恢复哈希计算器状态
	if len(rt.status.HashState) > 0 {
		rt.hash = rt.newHash()
		if unmarshaler, ok := rt.hash.(encoding.BinaryUnmarshaler); ok {
			if err := unmarshaler.UnmarshalBinary(rt.status.HashState); err != nil {
				return fmt.Errorf("恢复哈希状态失败: %v", err)
			}
		}
	} else {
		rt.hash = rt.newHash()
	}

	return nil
}

// saveStatus 保存传输状态
func (rt *ResumeTransfer) saveStatus() error {
	// 备份原状态
	oldState := rt.status.HashState

	// 保存哈希计算器状态
	var err error
	if rt.hash != nil {
		if marshaler, ok := rt.hash.(encoding.BinaryMarshaler); ok {
			rt.status.HashState, err = marshaler.MarshalBinary()
			if err != nil {
				rt.status.HashState = oldState
				return err
			}
		}
	}

	data, err := json.Marshal(rt.status)
	if err != nil {
		rt.status.HashState = oldState
		return err
	}

	if err := os.WriteFile(rt.statusFile, data, 0644); err != nil {
		rt.status.HashState = oldState
		return err
	}

	return nil
}

// finishTransfer 完成传输
func (rt *ResumeTransfer) finishTransfer() error {
	// 计算最终校验和
	rt.status.Checksum = hex.EncodeToString(rt.hash.Sum(nil))

	// 保存最终状态
	if err := rt.saveStatus(); err != nil {
		return err
	}

	// 重命名临时文件
	if err := os.Rename(rt.tempFilePath, rt.DestPath); err != nil {
		return err
	}

	// 验证目标文件
	if err := rt.verifyFile(); err != nil {
		return fmt.Errorf("文件验证失败: %v", err)
	}

	// 清理状态文件
	_ = os.Remove(rt.statusFile)

	return nil
}

// verifyFile 验证文件完整性
func (rt *ResumeTransfer) verifyFile() error {
	file, err := os.Open(rt.DestPath)
	if err != nil {
		return err
	}
	defer file.Close()

	hash := rt.newHash()
	if _, err := io.Copy(hash, file); err != nil {
		return err
	}

	actualChecksum := hex.EncodeToString(hash.Sum(nil))
	if rt.status.Checksum != "" && actualChecksum != rt.status.Checksum {
		return fmt.Errorf("校验和不匹配: 期望 %s, 实际 %s", rt.status.Checksum, actualChecksum)
	}

	return nil
}

// cleanup 清理临时文件
func (rt *ResumeTransfer) cleanup() {
	_ = os.Remove(rt.statusFile)
	_ = os.Remove(rt.tempFilePath)
}

func min(a, b int64) int64 {
	if a < b {
		return a
	}
	return b
}

func main() {
	if len(os.Args) < 3 {
		fmt.Println("用法: ./resume <源文件> <目标文件> [哈希算法]")
		fmt.Println("支持的哈希算法: md5 (默认), sha1, sha256")
		return
	}

	transfer := NewResumeTransfer(os.Args[1], os.Args[2])

	// 设置哈希算法(如果指定)
	if len(os.Args) > 3 {
		if err := transfer.SetHashAlgorithm(os.Args[3]); err != nil {
			log.Fatalf("错误: %v", err)
		}
	}

	if err := transfer.Start(); err != nil {
		log.Fatalf("传输失败: %v", err)
	}
}
#文件目录操作 #Golang