实现一个自定义的protoc插件

我们使用protobuf+grpc技术栈来开发微服务时,会要使用相关protoc插件来生成相关代码。有时可能会需要自定义一些插件,本文就来实现一个自定义的protoc插件。

新旧接口的说明

以前开发protoc插件时,需要实现generator接口(github.com/golang/protobuf/protoc-gen-go/generator),现在网上有不少稍老一些资料也是这样介绍的。但是实际上,这个接口已经被废弃了,现在要开发插件,应该使用的是"google.golang.org/protobuf/compiler/protogen"包。我们开发插件时,必须要开发一个如下签名的函数:

go 复制代码
func(*Plugin) error

代码实现

废话不多说,这里直接上代码
main.go

go 复制代码
package main

import (
	"google.golang.org/protobuf/compiler/protogen"
	"strconv"
	"strings"
)

func myPlugin(p *protogen.Plugin) error {
	// 插件的代码在这里实现
	for _, f := range p.Files {
		if !f.Generate {
			continue
		}

		generateFile(p, f)
	}
	return nil
}

func generateFile(p *protogen.Plugin, f *protogen.File) {
	g := p.NewGeneratedFile(f.GeneratedFilenamePrefix+".pb.myplugin.go", f.GoImportPath)
	g.P("// Code generated by protoc-gen-myplugin. DO NOT EDIT.")
	g.P()
	g.P("package ", f.GoPackageName)

	g.P()

	g.P("import (")
	g.P(" \"errors\"")
	g.P(" \"strings\"")
	g.P(" \"strconv\"")
	g.P(" \"regexp\"")
	g.P(")")
	g.P()

	g.P(validatorTpl)

	for _, service := range f.Services {
		for _, method := range service.Methods {
			// 生成一个validator,对方法入参中的每个字段进行校验
			g.P("// Validate 参数校验")
			g.P("func (x *", method.Input.GoIdent.GoName, ") Validate() error {")

			for i, field := range method.Input.Fields {
				g.P("validator" + strconv.Itoa(i) + " := NewValidator(`" + strings.TrimSpace(field.Comments.Leading.String()) + "`)")

				g.P("if err := validator" + strconv.Itoa(i) + ".Validate(x." + field.GoName + ");err!= nil {")
				g.P("  return err")
				g.P("}")
				g.P()
			}
			g.P("  return nil")

			g.P("}")
			g.P()
		}
	}

	g.P()
}

const validatorTpl = `type Validator struct {
	fieldDesc     string // 字段描述,用于提示
	fieldLengthLt int    // 字段长度最小值
	fieldLengthGt int    // 字段长度最大值
	fieldType 	  string // 特殊字段类型,如mobile,使用内置的方法进行校验
}

// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
	v := &Validator{}
	// 过滤前缀和后面的空格
	fieldComment = strings.TrimPrefix(fieldComment, "// ")
	fieldComment = strings.TrimSpace(fieldComment)

	fields := strings.Split(fieldComment, " ")

	// 如果字段描述中包含了must,则表示该字段必填
	for _, field := range fields {
		columns := strings.Split(field, ":")

		if len(columns) != 2 {
			continue
		}

		switch columns[0] {
		case "type":
			v.fieldType = columns[1]
		case "desc":
			v.fieldDesc = columns[1]
		case "length":
			lengths := strings.Split(columns[1], "-")
			if len(lengths) != 2 {
				continue
			}

			lt, _ := strconv.ParseInt(lengths[0], 10, 64)
			gt, _ := strconv.ParseInt(lengths[1], 10, 64)
			v.fieldLengthGt = int(gt)
			v.fieldLengthLt = int(lt)
		}

		if columns[0] == "desc" {
			v.fieldDesc = columns[1]
			continue
		}
	}

	return v
}

// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
	// 判断字段类型
	switch fieldValue.(type) {
	case string:
		if v.fieldType == "mobile" {
			return v.validateMobile(fieldValue.(string))
		}
		return v.validateStringLength(fieldValue.(string))
	case int64, uint64, int32, uint32:
		// todo
	}

	return nil
}

func (v *Validator) validateStringLength(fieldValue string) error {

	if len([]rune(fieldValue)) > v.fieldLengthGt {
		return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
	}

	if len([]rune(fieldValue)) < v.fieldLengthLt {
		return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
	}

	return nil
}

const mobileReg = "^1[3456789]\\d{9}$"

func (v *Validator) validateMobile(fieldValue string) error {
	regM := mobileReg
	pattern := regexp.MustCompile(regM)
	if !pattern.MatchString(fieldValue) {
		return errors.New(v.fieldDesc + "格式不正确")
	}

	return nil
}
`

func main() {
	protogen.Options{}.Run(myPlugin)
}

在上面,我们实现了一个简单的校验器,可以对方法入参中的每个字段进行校验。可以在字段的注释中添加校验内容:

  • type: 字段类型,这里指特殊类型,如mobile(手机号),使用内置的方法进行校验
  • desc: 字段描述,用于提示 ,如 {desc} 不能为空
  • length: 字段长度,如 1-10,表示字段长度在1-10之间

可以看到,我们定义了一个myPlugin函数(注意函数类型),并在main方法中将其传入protogen.Options{}.Run方法中。

编译插件

为了方便,我们写一个简单的Makefile文件来完成编译,安装操作:

makefile 复制代码
install:
	go build -o protoc-gen-myplugin ./
	mv protoc-gen-myplugin /Users/gq/go/bin/

执行make install命令,将会生成protoc-gen-myplugin文件,并将其移动到$GOPATH/bin目录下。

使用插件

我们再另外新建一个项目bufdemo,来测试一下生成的插件。 注意:我们这里使用buf来调用相关插件,并生成go代码。

编写proto文件

新建一个user.proto文件
user.proto

proto 复制代码
syntax = "proto3";


package pb;

option go_package = "bufdemo/pb";

service UserService {
  // 添加用户
  rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {}
}


message CreateUserRequest {
  // desc:姓名 length:2-20
  string name = 1;
  // desc:手机号码 type:mobile
  string mobile = 2;
}

message CreateUserResponse {
  uint32 id = 2;
}

修改buf.gen.yml

yaml 复制代码
version: v1
plugins:
  - plugin: go
    out: pb
    opt:
      - paths=source_relative
  - plugin: go-grpc
    out: pb
    opt:
      - paths=source_relative
  - plugin: myplugin
    out: pb
    opt:
      - paths=source_relative

生成代码

执行buf generate命令,将会生成相关代码:

文件内容为:
user.pb.myplugin.go ```go // Code generated by protoc-gen-myplugin. DO NOT EDIT.

package pb

import ( "errors" "strings" "strconv" "regexp" )

type Validator struct { fieldDesc string // 字段描述,用于提示 fieldLengthLt int // 字段长度最小值 fieldLengthGt int // 字段长度最大值 fieldType string // 特殊字段类型,如mobile,使用内置的方法进行校验 }

// NewValidator 初始化 func NewValidator(fieldComment string) *Validator { v := &Validator{} // 过滤前缀和后面的空格 fieldComment = strings.TrimPrefix(fieldComment, "// ") fieldComment = strings.TrimSpace(fieldComment)

go 复制代码
fields := strings.Split(fieldComment, " ")

// 如果字段描述中包含了must,则表示该字段必填
for _, field := range fields {
	columns := strings.Split(field, ":")

	if len(columns) != 2 {
		continue
	}

	switch columns[0] {
	case "type":
		v.fieldType = columns[1]
	case "desc":
		v.fieldDesc = columns[1]
	case "length":
		lengths := strings.Split(columns[1], "-")
		if len(lengths) != 2 {
			continue
		}

		lt, _ := strconv.ParseInt(lengths[0], 10, 64)
		gt, _ := strconv.ParseInt(lengths[1], 10, 64)
		v.fieldLengthGt = int(gt)
		v.fieldLengthLt = int(lt)
	}

	if columns[0] == "desc" {
		v.fieldDesc = columns[1]
		continue
	}
}

return v

}

// Validate 校验 func (v *Validator) Validate(fieldValue interface{}) error { // 判断字段类型 switch fieldValue.(type) { case string: if v.fieldType == "mobile" { return v.validateMobile(fieldValue.(string)) } return v.validateStringLength(fieldValue.(string)) case int64, uint64, int32, uint32: // todo }

go 复制代码
return nil

}

func (v *Validator) validateStringLength(fieldValue string) error {

go 复制代码
if len([]rune(fieldValue)) > v.fieldLengthGt {
	return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
}

if len([]rune(fieldValue)) < v.fieldLengthLt {
	return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
}

return nil

}

const mobileReg = "^1[3456789]\d{9}$"

func (v *Validator) validateMobile(fieldValue string) error { regM := mobileReg pattern := regexp.MustCompile(regM) if !pattern.MatchString(fieldValue) { return errors.New(v.fieldDesc + "格式不正确") }

go 复制代码
return nil

}

// Validate 参数校验 func (x *CreateUserRequest) Validate() error { validator0 := NewValidator(// desc:姓名 length:2-20) if err := validator0.Validate(x.Name); err != nil { return err }

go 复制代码
validator1 := NewValidator(`// desc:手机号码 type:mobile`)
if err := validator1.Validate(x.Mobile); err != nil {
	return err
}

return nil

}

erlang 复制代码
</details>

### 编写服务端文件
新建server目录,并在server目录下新建一个server.go文件,内容如下:
<details>
  <summary>server.go</summary>
  
```go
package main

import (
	"bufdemo/pb"
	"context"
	"google.golang.org/genproto/googleapis/rpc/code"
	"google.golang.org/grpc"
	codes "google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"log"
	"net"
)

type UserServiceImpl struct {
	pb.UnimplementedUserServiceServer
}

func (u *UserServiceImpl) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
	resp := &pb.CreateUserResponse{}

	if err := request.Validate(); err != nil {
		return nil, status.Error(codes.Code(code.Code_INVALID_ARGUMENT), err.Error())
	}

	resp.Id = 1
	return resp, nil
}

func main() {
	lis, err := net.Listen("tcp", ":8091")

	if err != nil {
		log.Fatalf("failed to listen:%v", err)
	}

	s := grpc.NewServer()

	pb.RegisterUserServiceServer(s, &UserServiceImpl{})

	if err = s.Serve(lis); err != nil {
		log.Fatalf("failed to serve:%v", err)
	}
}

在上面的方法中,我们调用了Validate方法,对入参进行了校验。

运行服务端:

shell 复制代码
go run server/server.go

使用postman测试

我们这里使用postman来测试一下:

  • 测试1:姓名长度不正确:
  • 测试2:手机号码格式不正确:

可以看到,我们的校验器已经生效了。

注:此文原载于本人个人网站,链接地址

本文由mdnice多平台发布

相关推荐
烛阴4 小时前
bignumber.js深度解析:驾驭任意精度计算的终极武器
前端·javascript·后端
服务端技术栈4 小时前
电商营销系统中的幂等性设计:从抽奖积分发放谈起
后端
你的人类朋友4 小时前
✍️Node.js CMS框架概述:Directus与Strapi详解
javascript·后端·node.js
面朝大海,春不暖,花不开5 小时前
自定义Spring Boot Starter的全面指南
java·spring boot·后端
钡铼技术ARM工业边缘计算机5 小时前
【成本降40%·性能翻倍】RK3588边缘控制器在安防联动系统的升级路径
后端
CryptoPP6 小时前
使用WebSocket实时获取印度股票数据源(无调用次数限制)实战
后端·python·websocket·网络协议·区块链
白宇横流学长6 小时前
基于SpringBoot实现的大创管理系统设计与实现【源码+文档】
java·spring boot·后端
草捏子6 小时前
状态机设计:比if-else优雅100倍的设计
后端
考虑考虑8 小时前
Springboot3.5.x结构化日志新属性
spring boot·后端·spring
涡能增压发动积8 小时前
一起来学 Langgraph [第三节]
后端