三十的博客

Go 简单网络请求封装

发布时间
阅读量 加载中...

背景

在实践 Go 语言开发过程中,我曾构建过一个 天气通知公众号项目 。该项目需要集成第三方天气 API,促使我封装了一个具备重试机制和错误处理的 HTTP 客户端。尽管目前对 Go 的掌握仍在持续精进,但通过 AI 辅助优化后的实现方案已具备较好的健壮性。现将此方案整理成文,既为沉淀技术方案便于复用,也为后续将其封装为独立库并开源至 GitHub 做准备。

功能介绍

该 HTTP 客户端封装了以下功能:

目录结构

httputil/
├── client.go
├── client_test.go
├── config.go
├── errors.go
├── interceptor/
│ ├── logging.go
│ ├── metrics.go
│ └── test_log.go
├── request.go
├── response.go
├── retry.go
└── types.go

代码实现

1. 定义对外接口

go
// httputil/types.go

package httputil

import "context"

// Service 对外暴露的HTTP服务接口
type Service interface {
	Do(ctx context.Context, req *Request) (*Response, error)
	Get(ctx context.Context, url string, opts ...RequestOption) (*Response, error)
	Post(ctx context.Context, url string, body interface{}, opts ...RequestOption) (*Response, error)
}

type ConfigurableService interface {
	Service
	AddInterceptor(i Interceptor)
}

// Interceptor 拦截器接口
type Interceptor interface {
	Before(ctx context.Context, req *Request) error
	After(ctx context.Context, req *Request, resp *Response, err error)
}

type RequestOption func(*Request)

2. 基本配置结构体

go
// httputil/config.go

package httputil

import "time"

// 配置结构

type Config struct {
	Timeout       time.Duration // 请求超时时间
	MaxRetries    int           // 最大重试次数
	RetryWaitTime time.Duration // 重试等待时间
	BaseURL       string        // 基础URL
}

func DefaultConfig() *Config {
	return &Config{
		Timeout:       10 * time.Second,
		MaxRetries:    3,
		RetryWaitTime: 1 * time.Second,
	}
}

3. 定义客户端

go
// httputil/client.go

package httputil

import (
	"context"
	"fmt"
	"net/http"
	"strings"
	"sync"
	"time"
)

// 主服务接口和实现

// httpServiceImpl 实现私有化
type httpServiceImpl struct {
	client         *http.Client  // 底层HTTP客户端
	config         *Config       // 配置
	interceptors   []Interceptor // 拦截器链
	mu             sync.RWMutex  // 读写锁
	defaultHeaders map[string]string
}

// New 创建服务实例(推荐使用)
func New(cfg *Config) ConfigurableService {
	if cfg == nil {
		cfg = DefaultConfig()
	}

	return &httpServiceImpl{
		client: &http.Client{
			Timeout: cfg.Timeout,
			Transport: &http.Transport{
				MaxIdleConns:        100,              // 最大空闲连接数(100)
				MaxIdleConnsPerHost: 10,               // 每个主机最大空闲连接数(10)
				IdleConnTimeout:     30 * time.Second, // 空闲连接超时时间(30秒)
			},
		},
		config: cfg,
	}
}

// AddInterceptor 线程安全添加拦截器
func (s *httpServiceImpl) AddInterceptor(i Interceptor) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.interceptors = append(s.interceptors, i)
}

// Get 快捷方法
func (s *httpServiceImpl) Get(ctx context.Context, url string, opts ...RequestOption) (*Response, error) {
	req := NewRequest(http.MethodGet, url)
	for _, opt := range opts {
		opt(req)
	}
	return s.Do(ctx, req)
}

// Post 快捷方法
func (s *httpServiceImpl) Post(ctx context.Context, url string, body interface{}, opts ...RequestOption) (*Response, error) {
	req := NewRequest(http.MethodPost, url).SetBody(body)
	for _, opt := range opts {
		opt(req)
	}
	return s.Do(ctx, req)
}

func (s *httpServiceImpl) Put(ctx context.Context, url string, body interface{}, opts ...RequestOption) (*Response, error) {
	req := NewRequest(http.MethodPut, url).SetBody(body)
	for _, opt := range opts {
		opt(req)
	}
	return s.Do(ctx, req)
}

func (s *httpServiceImpl) Delete(ctx context.Context, url string, opts ...RequestOption) (*Response, error) {
	req := NewRequest(http.MethodDelete, url)
	for _, opt := range opts {
		opt(req)
	}
	return s.Do(ctx, req)
}

// Do 执行请求(核心方法)
func (s *httpServiceImpl) Do(ctx context.Context, req *Request) (*Response, error) {
	// 如果配置了BaseURL且请求URL不是完整URL,则拼接
	if s.config.BaseURL != "" && !strings.HasPrefix(req.URL, "http") {
		req.URL = strings.TrimRight(s.config.BaseURL, "/") + "/" + strings.TrimLeft(req.URL, "/")
	}

	// 设置请求超时
	if s.config.Timeout > 0 {
		var cancel context.CancelFunc
		ctx, cancel = context.WithTimeout(ctx, s.config.Timeout)
		defer cancel()
	}

	// 添加默认请求头
	for k, v := range s.defaultHeaders {
		if _, ok := req.Headers[k]; !ok {
			req.Headers[k] = v
		}
	}

	// 执行拦截器前置处理
	s.mu.RLock()
	interceptors := make([]Interceptor, len(s.interceptors))
	copy(interceptors, s.interceptors)
	s.mu.RUnlock()

	// 执行拦截器前置处理
	for _, i := range interceptors {
		if err := i.Before(ctx, req); err != nil {
			// 即使有错误也要执行After拦截器
			for _, i := range interceptors {
				i.After(ctx, req, nil, err)
			}
			return nil, err
		}
	}

	// 定义最终要返回的resp和err
	var finalResp *Response
	var finalErr error

	// 确保After拦截器被执行
	defer func() {
		for _, i := range interceptors {
			i.After(ctx, req, finalResp, finalErr)
		}
	}()

	if s.config.MaxRetries <= 0 {
		finalResp, finalErr = s.doRequest(ctx, req)
		return finalResp, finalErr
	}

	// 带重试的执行
	finalResp, finalErr = s.retryOperation(ctx, req, s.doRequest)
	return finalResp, finalErr
}

// doRequest 执行实际的HTTP请求
func (s *httpServiceImpl) doRequest(ctx context.Context, req *Request) (*Response, error) {
	// 转换为标准http.Request
	httpReq, err := req.ToHTTPRequest(ctx)
	if err != nil {
		return nil, err
	}

	// 执行请求
	httpResp, err := s.client.Do(httpReq)
	if err != nil {
		return nil, err
	}

	// 转换为我们的Response格式
	resp, err := NewResponse(httpResp)
	if err != nil {
		return nil, err
	}

	// 检查响应状态码
	if resp.StatusCode >= http.StatusBadRequest {
		return resp, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
	}

	return resp, nil
}

4. 请求封装

go
// httputil/request.go

package httputil

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"net/http"
	"strings"
)

// 请求封装

func WithHeader(key, value string) RequestOption {
	return func(r *Request) {
		r.SetHeader(key, value)
	}
}

// 请求封装
type Request struct {
	Method  string            // HTTP 方法
	URL     string            // 请求URL
	Headers map[string]string // 请求头
	Query   map[string]string // 查询参数
	Body    interface{}       // 请求体
}

func NewRequest(method, path string) *Request {
	return &Request{
		Method:  strings.ToUpper(method),
		URL:     path,
		Headers: make(map[string]string),
		Query:   make(map[string]string),
	}
}

func (r *Request) SetHeader(key, value string) {
	r.Headers[key] = value
}

func (r *Request) SetQueryParam(key, value string) {
	r.Query[key] = value
}

func (r *Request) SetBody(body interface{}) *Request {
	r.Body = body
	return r
}

// ToHTTPRequest 方法将 Request 转换为标准 http.Request
// 查询参数拼接
// 请求体处理(支持多种类型)
// 自动设置 JSON 内容类型头
func (r *Request) ToHTTPRequest(ctx context.Context) (*http.Request, error) {
	// 构建URL查询参数
	url := r.URL
	if len(r.Query) > 0 {
		query := make([]string, 0, len(r.Query))
		for k, v := range r.Query {
			query = append(query, k+"="+v)
		}
		url += "?" + strings.Join(query, "&")
	}

	var body io.Reader
	if r.Body != nil {
		switch v := r.Body.(type) {
		case []byte:
			body = bytes.NewBuffer(v)
		case string:
			body = bytes.NewBufferString(v)
		case io.Reader:
			body = v
		default:
			data, err := json.Marshal(v)
			if err != nil {
				return nil, err
			}
			body = bytes.NewBuffer(data)
			r.SetHeader("Content-Type", "application/json")
		}
	}

	req, err := http.NewRequestWithContext(ctx, r.Method, url, body)
	if err != nil {
		return nil, err
	}

	for k, v := range r.Headers {
		req.Header.Set(k, v)
	}

	return req, nil
}

5. 重试机制

go
// httputil/retry.go

package httputil

import (
	"context"
	"errors"
	"log"
	"math"
	"math/rand/v2"
	"net"
	"net/http"
	"strings"
	"time"
)

// 重试策略

// shouldRetry 判断是否应该重试请求
// 上下文是否已取消
// 是否是网络错误
// 特定 HTTP 状态码(5xx, 429等)
// 方法的幂等性(GET/HEAD/PUT等可重试)
func shouldRetry(err error, method string, resp *Response) bool {
	// 如果有响应,获取状态码
	statusCode := http.StatusInternalServerError
	if resp != nil {
		statusCode = resp.StatusCode
	}

	// 上下文取消不重试
	if errors.Is(err, context.Canceled) {
		return false
	}

	// 2. 网络错误可以重试
	if isNetworkError(err) {
		return true
	}

	// 3. 以下HTTP状态码可以重试
	retryableStatusCodes := map[int]bool{
		http.StatusRequestTimeout:      true,
		http.StatusTooManyRequests:     true,
		http.StatusInternalServerError: true,
		http.StatusBadGateway:          true,
		http.StatusServiceUnavailable:  true,
		http.StatusGatewayTimeout:      true,
	}

	if _, ok := retryableStatusCodes[statusCode]; ok {
		return true
	}

	// 4. 只对幂等方法重试
	switch method {
	case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodPut:
		return true
	case http.MethodDelete:
		// 删除操作谨慎重试
		return statusCode == http.StatusConflict
	default:
		return false
	}
}

// retryOperation 带重试的执行函数
func (s *httpServiceImpl) retryOperation(
	ctx context.Context,
	req *Request,
	operation func(context.Context, *Request) (*Response, error),
) (*Response, error) {
	var lastErr error
	var lastResp *Response

	for attempt := 0; attempt <= s.config.MaxRetries; attempt++ {
		// 检查上下文是否已取消
		if ctx.Err() != nil {
			return nil, ErrContextCanceled
		}

		// 执行请求
		lastResp, lastErr = operation(ctx, req)
		if lastErr == nil && lastResp.IsSuccess() {
			return lastResp, nil
		}

		if !shouldRetry(lastErr, req.Method, lastResp) || attempt >= s.config.MaxRetries {
			return lastResp, lastErr
		}

		// 等待重试间隔
		select {
		case <-time.After(s.calculateRetryDelay(attempt)):
		case <-ctx.Done():
			lastErr = ErrContextCanceled
			return lastResp, lastErr
		}
	}

	return lastResp, lastErr
}

// calculateRetryDelay 计算重试延迟(指数退避)
func (s *httpServiceImpl) calculateRetryDelay(attempt int) time.Duration {

	// 基础等待时间
	baseDelay := s.config.RetryWaitTime

	// 指数退避: 每次重试等待时间 = 基础等待时间 * 2^attempt
	delay := time.Duration(float64(baseDelay) * math.Pow(2, float64(attempt)))

	jitter := 0.1 * rand.Float64() // 添加10%的随机波动
	delay = time.Duration(float64(delay) * (1 + jitter))

	if attempt > 0 {
		log.Printf("请求重试中,第%d次重试,等待时间: %v", attempt, delay)
	}
	// 设置最大延迟不超过5秒
	maxDelay := 5 * time.Second
	if delay > maxDelay {
		delay = maxDelay
	}

	return delay
}

// isNetworkError 判断是否是网络错误
func isNetworkError(err error) bool {
	if err == nil {
		return false
	}

	// 标准库网络错误
	if _, ok := err.(net.Error); ok {
		return true
	}

	// DNS错误
	if _, ok := err.(*net.DNSError); ok {
		return true
	}

	// 超时错误
	if strings.Contains(err.Error(), "timeout") {
		return true
	}

	// TLS握手错误
	if strings.Contains(err.Error(), "handshake") {
		return true
	}

	return false
}

6. 响应封装

go
// httputil/response.go

package httputil

import (
	"encoding/json"
	"io"
	"net/http"
)

// 响应封装

type Response struct {
	StatusCode int               // 状态码
	Headers    map[string]string // 响应头
	Body       []byte            // 响应体
	Raw        *http.Response    // 原始响应
}

func NewResponse(resp *http.Response) (*Response, error) {
	body, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}
	resp.Body.Close()

	headers := make(map[string]string)
	for k, v := range resp.Header {
		if len(v) > 0 {
			headers[k] = v[0]
		}
	}

	return &Response{
		StatusCode: resp.StatusCode,
		Headers:    headers,
		Body:       body,
		Raw:        resp,
	}, nil
}

// JSON 解析 JSON 响应
func (r *Response) JSON(v interface{}) error {
	return json.Unmarshal(r.Body, v)
}

// String 获取字符串形式响应
func (r *Response) String() string {
	return string(r.Body)
}

// Bytes 获取字节形式响应
func (r *Response) Bytes() []byte {
	return r.Body
}

// 增强响应处理方法
func (r *Response) DecodeJSON(v interface{}) error {
	if r == nil || len(r.Body) == 0 {
		return ErrEmptyResponse
	}
	return json.Unmarshal(r.Body, v)
}

// IsSuccess 判断是否成功响应(2xx)
func (r *Response) IsSuccess() bool {
	return r.StatusCode >= 200 && r.StatusCode < 300
}

7. 自定义错误

go
// httputil/errors.go

package httputil

import "errors"

// 自定义错误

var (
	ErrMaxRetriesExceeded = errors.New("超出最大重试次数")
	ErrInvalidRequest     = errors.New("无效请求")
	ErrContextCanceled    = errors.New("上下文已取消")
	ErrEmptyResponse      = errors.New("空的请求结果")
)

拦截器示例

1. 日志拦截器

go
// interceptor/logging.go

package interceptor

import (
	"context"
	"log"

	"gitee.com/iswleii/wechat-weather/pkg/httputil"
	"go.uber.org/zap"
)

// 记录请求和响应日志

type LoggingInterceptor struct{}

func NewLogging() httputil.Interceptor {
	return &LoggingInterceptor{}
}

func (l *LoggingInterceptor) Before(ctx context.Context, req *httputil.Request) error {
	log.Printf("Request: %s %s", req.Method, req.URL)
	return nil
}

func (l *LoggingInterceptor) After(ctx context.Context, req *httputil.Request, resp *httputil.Response, err error) {
	if err != nil {
		zap.L().Error("api query err", zap.Any("query", req), zap.Any("resp", resp), zap.Error(err))
	} else {
		zap.L().Info("api query response data", zap.Any("data", resp))
	}
}

2. 耗时拦截器

go
// interceptor/metrics.go

package interceptor

import (
	"context"
	"log"
	"time"

	"gitee.com/iswleii/wechat-weather/pkg/httputil"
)

// 记录请求耗时指标

// MetricsInterceptor 指标监控拦截器
type MetricsInterceptor struct {
	startTime time.Time
}

func (m *MetricsInterceptor) Before(ctx context.Context, req *httputil.Request) error {
	m.startTime = time.Now()
	return nil
}

func (m *MetricsInterceptor) After(ctx context.Context, req *httputil.Request, resp *httputil.Response, err error) {
	duration := time.Since(m.startTime)
	// 这里可以上报指标到监控系统
	log.Printf("[METRIC] %s %s 耗时: %v", req.Method, req.URL, duration)
}

测试类实例

go
package httputil_test

import (
	"context"
	"encoding/json"
	"net"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"gitee.com/iswleii/wechat-weather/pkg/httputil"
	"gitee.com/iswleii/wechat-weather/pkg/httputil/interceptor"
)

// 测试数据结构
type testData struct {
	Name string `json:"name"`
	Age  int    `json:"age"`
}

// TestGetRequest 测试GET请求
func TestGetRequest(t *testing.T) {
	// 创建测试服务器
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.Method != http.MethodGet {
			t.Errorf("Expected GET request, got %s", r.Method)
		}

		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		json.NewEncoder(w).Encode(testData{Name: "Test", Age: 30})
	}))
	defer ts.Close()

	// 创建客户端
	client := httputil.New(&httputil.Config{
		Timeout:    10 * time.Second,
		MaxRetries: 2,
	})

	// 执行GET请求
	resp, err := client.Get(context.Background(), ts.URL)
	if err != nil {
		t.Fatalf("Get request failed: %v", err)
	}

	// 验证响应
	if !resp.IsSuccess() {
		t.Errorf("Expected success status, got %d", resp.StatusCode)
	}

	// 解析响应体
	var data testData
	if err := resp.DecodeJSON(&data); err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if data.Name != "Test" || data.Age != 30 {
		t.Errorf("Unexpected response data: %+v", data)
	}
}

// TestPostRequest 测试POST请求
func TestPostRequest(t *testing.T) {
	// 创建测试服务器
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if r.Method != http.MethodPost {
			t.Errorf("Expected POST request, got %s", r.Method)
		}

		// 解析请求体
		var reqData testData
		if err := json.NewDecoder(r.Body).Decode(&reqData); err != nil {
			t.Errorf("Failed to decode request body: %v", err)
		}

		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusCreated)
		json.NewEncoder(w).Encode(reqData)
	}))
	defer ts.Close()

	// 创建客户端
	client := httputil.New(&httputil.Config{
		Timeout:    10 * time.Second,
		MaxRetries: 2,
	})

	// 准备测试数据
	requestData := testData{Name: "Alice", Age: 25}

	// 执行POST请求
	resp, err := client.Post(context.Background(), ts.URL, requestData)
	if err != nil {
		t.Fatalf("Post request failed: %v", err)
	}

	// 验证响应
	if resp.StatusCode != http.StatusCreated {
		t.Errorf("Expected status 201, got %d", resp.StatusCode)
	}

	// 解析响应体
	var responseData testData
	if err := resp.DecodeJSON(&responseData); err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if responseData != requestData {
		t.Errorf("Expected %+v, got %+v", requestData, responseData)
	}
}

// TestRetryLogic 测试重试逻辑
func TestRetryLogic(t *testing.T) {
	retryCount := 0
	maxRetries := 2

	// 创建测试服务器 - 前两次返回500错误,第三次成功
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		retryCount++
		if retryCount <= maxRetries {
			w.WriteHeader(http.StatusInternalServerError)
			return
		}

		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		json.NewEncoder(w).Encode(testData{Name: "Success", Age: 42})
	}))
	defer ts.Close()

	// 创建客户端,设置重试次数
	client := httputil.New(&httputil.Config{
		Timeout:    10 * time.Second,
		MaxRetries: maxRetries,
	})

	// 执行GET请求
	resp, err := client.Get(context.Background(), ts.URL)
	if err != nil {
		t.Fatalf("Get request failed after retries: %v", err)
	}

	// 验证重试确实发生了
	if retryCount != maxRetries+1 {
		t.Errorf("Expected %d retries, got %d", maxRetries, retryCount-1)
	}

	// 验证最终响应
	var data testData
	if err := resp.DecodeJSON(&data); err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if data.Name != "Success" {
		t.Errorf("Unexpected response data: %+v", data)
	}
}

// TestTimeout 测试超时逻辑
func TestTimeout(t *testing.T) {
	// 创建测试服务器 - 故意延迟响应
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		time.Sleep(200 * time.Millisecond) // 超过客户端超时设置
		w.WriteHeader(http.StatusOK)
	}))
	defer ts.Close()

	// 创建客户端,设置很短的超时
	client := httputil.New(&httputil.Config{
		Timeout:    100 * time.Millisecond,
		MaxRetries: 0,
	})

	// 执行请求并期望超时错误
	_, err := client.Get(context.Background(), ts.URL)
	if err == nil {
		t.Fatal("Expected timeout error, got nil")
	}

	if !isTimeoutError(err) {
		t.Errorf("Expected timeout error, got %v", err)
	}
}

// isTimeoutError 检查是否是超时错误
func isTimeoutError(err error) bool {
	if err == nil {
		return false
	}
	// 检查各种可能的超时错误表示
	if err.Error() == "context deadline exceeded" {
		return true
	}
	if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
		return true
	}
	return false
}

// TestInterceptors 测试拦截器功能
func TestInterceptors(t *testing.T) {
	// 创建测试服务器
	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		json.NewEncoder(w).Encode(testData{Name: "Test", Age: 30})
	}))
	defer ts.Close()

	// 创建客户端
	client := httputil.New(&httputil.Config{
		Timeout:    10 * time.Second,
		MaxRetries: 0,
	})

	// 创建测试拦截器
	var beforeCalled, afterCalled bool
	testInterceptor := &mockInterceptor{
		beforeFn: func(ctx context.Context, req *httputil.Request) error {
			beforeCalled = true
			return nil
		},
		afterFn: func(ctx context.Context, req *httputil.Request, resp *httputil.Response, err error) {
			afterCalled = true
		},
	}

	// 添加拦截器
	client.AddInterceptor(interceptor.NewLogging())
	client.AddInterceptor(testInterceptor)

	// 执行请求
	_, err := client.Get(context.Background(), ts.URL)
	if err != nil {
		t.Fatalf("Request failed: %v", err)
	}

	// 验证拦截器被调用
	if !beforeCalled || !afterCalled {
		t.Errorf("Interceptor not called properly: before=%v, after=%v", beforeCalled, afterCalled)
	}
}

// mockInterceptor 用于测试的模拟拦截器
type mockInterceptor struct {
	beforeFn func(ctx context.Context, req *httputil.Request) error
	afterFn  func(ctx context.Context, req *httputil.Request, resp *httputil.Response, err error)
}

func (m *mockInterceptor) Before(ctx context.Context, req *httputil.Request) error {
	if m.beforeFn != nil {
		return m.beforeFn(ctx, req)
	}
	return nil
}

func (m *mockInterceptor) After(ctx context.Context, req *httputil.Request, resp *httputil.Response, err error) {
	if m.afterFn != nil {
		m.afterFn(ctx, req, resp, err)
	}
}
#Http请求封装 #Http #Golang