go
复制代码
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)
type APIClient struct {
Url string
Timeout int
Method string
Headers map[string]string
Retry int
Body string
}
func (c *APIClient) DoRequest(ctx context.Context) (*APIResponse, error) {
var lastError error
//重试
for i := 0; i < c.Retry; i++ {
// 如果不是第一次尝试,等待一段时间再重试
if i > 0 {
time.Sleep(time.Duration(c.Retry) * time.Millisecond)
log.Info().Msgf("重试第 %d 次...", i)
}
//post请求
if c.Method == "POST" {
if c.Body == "" {
//打印错误日志
log.Error().Msg("body不能为空")
return nil, fmt.Errorf("body不能为空")
}
payload := bytes.NewBuffer([]byte(c.Body))
req, err := http.NewRequestWithContext(ctx, c.Method, c.Url, payload)
if err != nil {
lastError = err
continue
}
if c.Headers != nil {
for k, v := range c.Headers {
req.Header.Set(k, v)
}
}
client := http.DefaultClient
resp, err := client.Do(req)
if err != nil {
lastError = err
continue
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
lastError = err
continue
}
return &APIResponse{
Status: resp.StatusCode,
Body: body,
Error: nil,
}, nil
}
//get请求
if c.Method == "GET" {
//get请求添加body
if c.Body != "" {
c.Url = c.Url + "?" + c.Body
}
req, err := http.NewRequestWithContext(ctx, c.Method, c.Url, nil)
if err != nil {
lastError = err
continue
}
if c.Headers != nil {
for k, v := range c.Headers {
req.Header.Set(k, v)
}
}
client := http.DefaultClient
resp, err := client.Do(req)
if err != nil {
lastError = err
continue
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
lastError = err
continue
}
return &APIResponse{
Status: resp.StatusCode,
Body: body,
Error: nil,
}, nil
}
}
//如果所有重试都失败了,返回错误
log.Error().Msgf("达到最大重试次数 %d,最后一次错误:%v", c.Retry, lastError)
return nil, fmt.Errorf("达到最大重试次数 %d,最后一次错误:%v", c.Retry, lastError)
}
type APIResponse struct {
Status int
Body []byte
Error error
}
// 解析body
func (c *APIResponse) ParseBody() (map[string]interface{}, error) {
var result map[string]interface{}
err := json.Unmarshal(c.Body, &result)
if err != nil {
return nil, err
}
return result, nil
}
func main() {
var inputurl, method, inputheaders, inputbodys, level string
var timeout, retry int
var outputConsole bool
rootCmd := &cobra.Command{
Use: "apiclient",
PreRun: func(cmd *cobra.Command, args []string) {
if inputurl == "" {
cmd.Help()
//醒目打印
log.Error().Msg("url不能为空,请重新输入")
//退出不报错
os.Exit(0)
}
},
Run: func(cmd *cobra.Command, args []string) {
//初始化日志
logFile, err := os.OpenFile("apiclient.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
log.Error().Msgf("打开日志文件失败: %v", err)
}
defer logFile.Close()
initLogger(level, outputConsole, logFile)
client := APIClient{
Url: inputurl,
Timeout: timeout,
Method: method,
Retry: retry,
Body: inputbodys,
}
if inputheaders != "" {
client.Headers = parseHeaders(inputheaders)
}
//设置超时
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
response, err := client.DoRequest(ctx)
if err != nil {
log.Error().Msgf("请求失败: %v", err)
os.Exit(1)
}
result, err := response.ParseBody()
if err != nil {
log.Error().Msgf("json格式化body失败: %v", err)
} else {
log.Info().Msgf("json格式化body成功: %v", result)
}
},
}
rootCmd.Flags().StringVarP(&inputurl, "url", "u", "", "api地址")
rootCmd.Flags().IntVarP(&timeout, "timeout", "t", 10, "超时时间,单位秒")
rootCmd.Flags().StringVarP(&method, "method", "m", "GET", "请求方法")
rootCmd.Flags().StringVarP(&inputheaders, "headers", "H", "", "请输入headers,格式为key:value,key:value")
rootCmd.Flags().IntVarP(&retry, "retry", "r", 3, "重试次数")
rootCmd.Flags().StringVarP(&inputbodys, "body", "b", "", "请输入body,格式为key=value&key=value")
rootCmd.Flags().StringVarP(&level, "level", "l", "info", "日志级别")
rootCmd.Flags().BoolVarP(&outputConsole, "console", "c", true, "是否输出到控制台")
if err := rootCmd.Execute(); err != nil {
log.Error().Msgf("执行命令失败: %v", err)
}
}
func parseHeaders(inputheaders string) map[string]string {
headers := make(map[string]string)
// 将inputheaders按逗号分割成多个header
headerPairs := strings.Split(inputheaders, ",")
for _, header := range headerPairs {
// 将header按冒号分割成key和value
parts := strings.Split(header, ":")
if len(parts) == 2 {
headers[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
return headers
}
func initLogger(level string, outputConsole bool, logFile *os.File) {
//设置全局日志级别
switch level {
case "info":
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case "warn":
zerolog.SetGlobalLevel(zerolog.WarnLevel)
case "error":
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
default:
zerolog.SetGlobalLevel(zerolog.InfoLevel)
}
//设置时间格式
zerolog.TimeFieldFormat = "2006-01-02 15:04:05"
// 定义颜色代码
const (
ColorReset = "\033[0m"
ColorRed = "\033[31m" // Error
ColorYellow = "\033[33m" // Warn
ColorGreen = "\033[32m" // Info
ColorBlue = "\033[34m" // Debug
)
// 自定义输出格式
writer := zerolog.ConsoleWriter{
Out: os.Stdout, // 终端输出
TimeFormat: "2006-01-02 15:04:05",
FormatLevel: func(i interface{}) string {
level := i.(string)
switch level {
case "info":
return ColorGreen + "[INFO]" + ColorReset
case "warn":
return ColorYellow + "[WARN]" + ColorReset
case "error":
return ColorRed + "[ERROR]" + ColorReset
case "debug":
return ColorBlue + "[DEBUG]" + ColorReset
default:
return "[" + level + "]"
}
},
FormatMessage: func(i interface{}) string {
return "- " + i.(string)
},
FormatTimestamp: func(i interface{}) string {
return "[" + i.(string) + "]"
},
}
//设置日志同时输出到文件和终端
multi := zerolog.MultiLevelWriter(writer, logFile)
//设置日志输出
if outputConsole {
log.Logger = zerolog.New(multi).With().Timestamp().Logger()
} else {
log.Logger = zerolog.New(logFile).With().Timestamp().Logger()
}
}