我们使用protobuf+grpc技术栈来开发微服务时,会要使用相关protoc插件来生成相关代码。有时可能会需要自定义一些插件,本文就来实现一个自定义的protoc插件。
新旧接口的说明
以前开发protoc插件时,需要实现generator接口(github.com/golang/protobuf/protoc-gen-go/generator),现在网上有不少稍老一些资料也是这样介绍的。但是实际上,这个接口已经被废弃了,现在要开发插件,应该使用的是"google.golang.org/protobuf/compiler/protogen"包。我们开发插件时,必须要开发一个如下签名的函数:
go
func(*Plugin) error
代码实现
废话不多说,这里直接上代码
main.go
go
package main
import (
"google.golang.org/protobuf/compiler/protogen"
"strconv"
"strings"
)
func myPlugin(p *protogen.Plugin) error {
// 插件的代码在这里实现
for _, f := range p.Files {
if !f.Generate {
continue
}
generateFile(p, f)
}
return nil
}
func generateFile(p *protogen.Plugin, f *protogen.File) {
g := p.NewGeneratedFile(f.GeneratedFilenamePrefix+".pb.myplugin.go", f.GoImportPath)
g.P("// Code generated by protoc-gen-myplugin. DO NOT EDIT.")
g.P()
g.P("package ", f.GoPackageName)
g.P()
g.P("import (")
g.P(" \"errors\"")
g.P(" \"strings\"")
g.P(" \"strconv\"")
g.P(" \"regexp\"")
g.P(")")
g.P()
g.P(validatorTpl)
for _, service := range f.Services {
for _, method := range service.Methods {
// 生成一个validator,对方法入参中的每个字段进行校验
g.P("// Validate 参数校验")
g.P("func (x *", method.Input.GoIdent.GoName, ") Validate() error {")
for i, field := range method.Input.Fields {
g.P("validator" + strconv.Itoa(i) + " := NewValidator(`" + strings.TrimSpace(field.Comments.Leading.String()) + "`)")
g.P("if err := validator" + strconv.Itoa(i) + ".Validate(x." + field.GoName + ");err!= nil {")
g.P(" return err")
g.P("}")
g.P()
}
g.P(" return nil")
g.P("}")
g.P()
}
}
g.P()
}
const validatorTpl = `type Validator struct {
fieldDesc string // 字段描述,用于提示
fieldLengthLt int // 字段长度最小值
fieldLengthGt int // 字段长度最大值
fieldType string // 特殊字段类型,如mobile,使用内置的方法进行校验
}
// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
v := &Validator{}
// 过滤前缀和后面的空格
fieldComment = strings.TrimPrefix(fieldComment, "// ")
fieldComment = strings.TrimSpace(fieldComment)
fields := strings.Split(fieldComment, " ")
// 如果字段描述中包含了must,则表示该字段必填
for _, field := range fields {
columns := strings.Split(field, ":")
if len(columns) != 2 {
continue
}
switch columns[0] {
case "type":
v.fieldType = columns[1]
case "desc":
v.fieldDesc = columns[1]
case "length":
lengths := strings.Split(columns[1], "-")
if len(lengths) != 2 {
continue
}
lt, _ := strconv.ParseInt(lengths[0], 10, 64)
gt, _ := strconv.ParseInt(lengths[1], 10, 64)
v.fieldLengthGt = int(gt)
v.fieldLengthLt = int(lt)
}
if columns[0] == "desc" {
v.fieldDesc = columns[1]
continue
}
}
return v
}
// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
// 判断字段类型
switch fieldValue.(type) {
case string:
if v.fieldType == "mobile" {
return v.validateMobile(fieldValue.(string))
}
return v.validateStringLength(fieldValue.(string))
case int64, uint64, int32, uint32:
// todo
}
return nil
}
func (v *Validator) validateStringLength(fieldValue string) error {
if len([]rune(fieldValue)) > v.fieldLengthGt {
return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
}
if len([]rune(fieldValue)) < v.fieldLengthLt {
return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
}
return nil
}
const mobileReg = "^1[3456789]\\d{9}$"
func (v *Validator) validateMobile(fieldValue string) error {
regM := mobileReg
pattern := regexp.MustCompile(regM)
if !pattern.MatchString(fieldValue) {
return errors.New(v.fieldDesc + "格式不正确")
}
return nil
}
`
func main() {
protogen.Options{}.Run(myPlugin)
}
在上面,我们实现了一个简单的校验器,可以对方法入参中的每个字段进行校验。可以在字段的注释中添加校验内容:
- type: 字段类型,这里指特殊类型,如mobile(手机号),使用内置的方法进行校验
- desc: 字段描述,用于提示 ,如 {desc} 不能为空
- length: 字段长度,如 1-10,表示字段长度在1-10之间
可以看到,我们定义了一个myPlugin函数(注意函数类型),并在main方法中将其传入protogen.Options{}.Run方法中。
编译插件
为了方便,我们写一个简单的Makefile文件来完成编译,安装操作:
makefile
install:
go build -o protoc-gen-myplugin ./
mv protoc-gen-myplugin /Users/gq/go/bin/
执行make install命令,将会生成protoc-gen-myplugin文件,并将其移动到$GOPATH/bin目录下。
使用插件
我们再另外新建一个项目bufdemo,来测试一下生成的插件。 注意:我们这里使用buf来调用相关插件,并生成go代码。
编写proto文件
新建一个user.proto文件
user.proto
proto
syntax = "proto3";
package pb;
option go_package = "bufdemo/pb";
service UserService {
// 添加用户
rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {}
}
message CreateUserRequest {
// desc:姓名 length:2-20
string name = 1;
// desc:手机号码 type:mobile
string mobile = 2;
}
message CreateUserResponse {
uint32 id = 2;
}
修改buf.gen.yml
yaml
version: v1
plugins:
- plugin: go
out: pb
opt:
- paths=source_relative
- plugin: go-grpc
out: pb
opt:
- paths=source_relative
- plugin: myplugin
out: pb
opt:
- paths=source_relative
生成代码
执行buf generate命令,将会生成相关代码:
文件内容为:
user.pb.myplugin.go ```go // Code generated by protoc-gen-myplugin. DO NOT EDIT.
package pb
import ( "errors" "strings" "strconv" "regexp" )
type Validator struct { fieldDesc string // 字段描述,用于提示 fieldLengthLt int // 字段长度最小值 fieldLengthGt int // 字段长度最大值 fieldType string // 特殊字段类型,如mobile,使用内置的方法进行校验 }
// NewValidator 初始化 func NewValidator(fieldComment string) *Validator { v := &Validator{} // 过滤前缀和后面的空格 fieldComment = strings.TrimPrefix(fieldComment, "// ") fieldComment = strings.TrimSpace(fieldComment)
go
fields := strings.Split(fieldComment, " ")
// 如果字段描述中包含了must,则表示该字段必填
for _, field := range fields {
columns := strings.Split(field, ":")
if len(columns) != 2 {
continue
}
switch columns[0] {
case "type":
v.fieldType = columns[1]
case "desc":
v.fieldDesc = columns[1]
case "length":
lengths := strings.Split(columns[1], "-")
if len(lengths) != 2 {
continue
}
lt, _ := strconv.ParseInt(lengths[0], 10, 64)
gt, _ := strconv.ParseInt(lengths[1], 10, 64)
v.fieldLengthGt = int(gt)
v.fieldLengthLt = int(lt)
}
if columns[0] == "desc" {
v.fieldDesc = columns[1]
continue
}
}
return v
}
// Validate 校验 func (v *Validator) Validate(fieldValue interface{}) error { // 判断字段类型 switch fieldValue.(type) { case string: if v.fieldType == "mobile" { return v.validateMobile(fieldValue.(string)) } return v.validateStringLength(fieldValue.(string)) case int64, uint64, int32, uint32: // todo }
go
return nil
}
func (v *Validator) validateStringLength(fieldValue string) error {
go
if len([]rune(fieldValue)) > v.fieldLengthGt {
return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
}
if len([]rune(fieldValue)) < v.fieldLengthLt {
return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
}
return nil
}
const mobileReg = "^1[3456789]\d{9}$"
func (v *Validator) validateMobile(fieldValue string) error { regM := mobileReg pattern := regexp.MustCompile(regM) if !pattern.MatchString(fieldValue) { return errors.New(v.fieldDesc + "格式不正确") }
go
return nil
}
// Validate 参数校验 func (x *CreateUserRequest) Validate() error { validator0 := NewValidator(// desc:姓名 length:2-20
) if err := validator0.Validate(x.Name); err != nil { return err }
go
validator1 := NewValidator(`// desc:手机号码 type:mobile`)
if err := validator1.Validate(x.Mobile); err != nil {
return err
}
return nil
}
erlang
</details>
### 编写服务端文件
新建server目录,并在server目录下新建一个server.go文件,内容如下:
<details>
<summary>server.go</summary>
```go
package main
import (
"bufdemo/pb"
"context"
"google.golang.org/genproto/googleapis/rpc/code"
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"log"
"net"
)
type UserServiceImpl struct {
pb.UnimplementedUserServiceServer
}
func (u *UserServiceImpl) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
resp := &pb.CreateUserResponse{}
if err := request.Validate(); err != nil {
return nil, status.Error(codes.Code(code.Code_INVALID_ARGUMENT), err.Error())
}
resp.Id = 1
return resp, nil
}
func main() {
lis, err := net.Listen("tcp", ":8091")
if err != nil {
log.Fatalf("failed to listen:%v", err)
}
s := grpc.NewServer()
pb.RegisterUserServiceServer(s, &UserServiceImpl{})
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve:%v", err)
}
}
在上面的方法中,我们调用了Validate方法,对入参进行了校验。
运行服务端:
shell
go run server/server.go
使用postman测试
我们这里使用postman来测试一下:
- 测试1:姓名长度不正确:
- 测试2:手机号码格式不正确:

可以看到,我们的校验器已经生效了。
注:此文原载于本人个人网站,链接地址。
本文由mdnice多平台发布