hertz源码入门: 如何设计一个好用的HTTP服务端框架

概要

实现一个HTTP服务端框架既简单不简单,Go的gin框架为什么那样设计API,hertz有在gin的基础上做了什么改进,这篇带你入门。

分析

Hello world

go 复制代码
package main

import (
	"context"

	"github.com/cloudwego/hertz/pkg/app"
	"github.com/cloudwego/hertz/pkg/app/server"
	"github.com/cloudwego/hertz/pkg/protocol/consts"
)

func main() {
	// server.Default() creates a Hertz with recovery middleware.
	// If you need a pure hertz, you can use server.New()
	h := server.Default()

	h.GET("/hello", func(ctx context.Context, c *app.RequestContext) {
		c.String(consts.StatusOK, "Hello hertz!")
	})
        
        v1 := h.Group("/v1")
	{
		// loginEndpoint is a handler func
		v1.GET("/get", func(ctx context.Context, c *app.RequestContext) {
			c.String(consts.StatusOK, "get")
		})
		v1.POST("/post", func(ctx context.Context, c *app.RequestContext) {
			c.String(consts.StatusOK, "post")
		})
	}

	h.Spin()
}

API设计

引擎和路由

  • route.Engine 是server的核心,即引擎,它会作为一个成员被到处引用
  • route.RouterGroup 是路由的核心,包含了basepath和已注册的handlers,为什么会有已注册的handlers?因为RouterGroup支持middleware,middleware的本质就是handler
go 复制代码
func Default(opts ...config.Option) *Hertz 
type Hertz struct {
	*route.Engine
	signalWaiter func(err chan error) error
}
type Engine struct {
	noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used

	// engine name
	Name       string
	serverName atomic.Value

	// Options for route and protocol server
	options *config.Options

	// route
	RouterGroup
	trees MethodTrees

	maxParams uint16

	allNoMethod app.HandlersChain
	allNoRoute  app.HandlersChain
	noRoute     app.HandlersChain
	noMethod    app.HandlersChain

	// For render HTML
	delims     render.Delims
	funcMap    template.FuncMap
	htmlRender render.HTMLRender

	// NoHijackConnPool will control whether invite pool to acquire/release the hijackConn or not.
	// If it is difficult to guarantee that hijackConn will not be closed repeatedly, set it to true.
	NoHijackConnPool bool
	hijackConnPool   sync.Pool
	// KeepHijackedConns is an opt-in disable of connection
	// close by hertz after connections' HijackHandler returns.
	// This allows to save goroutines, e.g. when hertz used to upgrade
	// http connections to WS and connection goes to another handler,
	// which will close it when needed.
	KeepHijackedConns bool

	// underlying transport
	transport network.Transporter

	// trace
	tracerCtl   tracer.Controller
	enableTrace bool

	// protocol layer management
	protocolSuite         *suite.Config
	protocolServers       map[string]protocol.Server
	protocolStreamServers map[string]protocol.StreamServer

	// RequestContext pool
	ctxPool sync.Pool

	// Function to handle panics recovered from http handlers.
	// It should be used to generate an error page and return the http error code
	// 500 (Internal Server Error).
	// The handler can be used to keep your server from crashing because of
	// unrecovered panics.
	PanicHandler app.HandlerFunc

	// ContinueHandler is called after receiving the Expect 100 Continue Header
	//
	// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
	// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
	// Using ContinueHandler a server can make decisioning on whether or not
	// to read a potentially large request body based on the headers
	//
	// The default is to automatically read request bodies of Expect 100 Continue requests
	// like they are normal requests
	ContinueHandler func(header *protocol.RequestHeader) bool

	// Indicates the engine status (Init/Running/Shutdown/Closed).
	status uint32

	// Hook functions get triggered sequentially when engine start
	OnRun []CtxErrCallback

	// Hook functions get triggered simultaneously when engine shutdown
	OnShutdown []CtxCallback

	// Custom Functions
	clientIPFunc  app.ClientIP
	formValueFunc app.FormValueFunc
}
func (group *RouterGroup) Group(relativePath string, handlers ...app.HandlerFunc) *RouterGroup
type RouterGroup struct {
	Handlers app.HandlersChain
	basePath string
	engine   *Engine
	root     bool
}

请求上下文

  • hertz的HandlerFunc提供了2个上下文,context.Context就是Go标准的context,用于调用Go生态中的接口,例如Gorm;
  • RequestContext相当于gin的上下文,是生命周期为Request级别的上下文,封装了处理请求过程中需要的所有东西和方法。
go 复制代码
type HandlersChain []HandlerFunc
type HandlerFunc func(c context.Context, ctx *RequestContext)

type RequestContext struct {
	conn     network.Conn
	Request  protocol.Request
	Response protocol.Response

	// Errors is a list of errors attached to all the handlers/middlewares who used this context.
	Errors errors.ErrorChain

	Params     param.Params
	handlers   HandlersChain
	fullPath   string
	index      int8
	HTMLRender render.HTMLRender

	// This mutex protect Keys map.
	mu sync.RWMutex

	// Keys is a key/value pair exclusively for the context of each request.
	Keys map[string]interface{}

	hijackHandler HijackHandler

	finishedMu sync.Mutex

	// finished means the request end.
	finished chan struct{}

	// traceInfo defines the trace information.
	traceInfo traceinfo.TraceInfo

	// enableTrace defines whether enable trace.
	enableTrace bool

	// clientIPFunc get client ip by use custom function.
	clientIPFunc ClientIP

	// clientIPFunc get form value by use custom function.
	formValueFunc FormValueFunc
}

API实现

handle

go 复制代码
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers app.HandlersChain) IRoutes {
	absolutePath := group.calculateAbsolutePath(relativePath)
	handlers = group.combineHandlers(handlers)
	group.engine.addRoute(httpMethod, absolutePath, handlers)
	return group.returnObj()
}

func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) {
	methodRouter := engine.trees.get(method)
	if methodRouter == nil {
		methodRouter = &router{method: method, root: &node{}, hasTsrHandler: make(map[string]bool)}
		engine.trees = append(engine.trees, methodRouter)
	}
	methodRouter.addRoute(path, handlers)
}

func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) {
    // ...
    
    // 
    value := t[i].find(rPath, paramsPointer, unescape)

    if value.handlers != nil {
            ctx.SetHandlers(value.handlers)
            ctx.SetFullPath(value.fullPath)
            ctx.Next(c)
            return
    }
}

middleware

go 复制代码
    // example
func MyMiddleware1() app.HandlerFunc {
	return func(ctx context.Context, c *app.RequestContext) {
		// pre-handle
		fmt.Println("pre-handle")
	}
}

func MyMiddleware2() app.HandlerFunc {
	return func(ctx context.Context, c *app.RequestContext) {
		// pre-handle
		fmt.Println("pre-handle")

		c.Next(ctx) // call the next middleware(handler)
		// post-handle
		fmt.Println("post-handle")
	}
}

func (group *RouterGroup) Use(middleware ...app.HandlerFunc) IRoutes {
	group.Handlers = append(group.Handlers, middleware...)
	return group.returnObj()
}

// ctx.handlers 从前往后调用handler,handler是在开始处理前被engine注入的
// 由于RequestContext是单线程,所以外部调用几次Next都可以
func (ctx *RequestContext) Next(c context.Context) {
	ctx.index++
	for ctx.index < int8(len(ctx.handlers)) {
		ctx.handlers[ctx.index](c, ctx)
		ctx.index++
	}
}

深入并发模式

eventLoop

hertz使用了netpoll包的eventLoop模式,类似于epoll;另一种实现是Go http原生实现

go 复制代码
func (t *transport) serve() (err error) {
	network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
	t.lock.Lock()
	if t.listenConfig != nil {
		t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
	} else {
		t.ln, err = net.Listen(t.network, t.addr)
	}
	t.lock.Unlock()
	if err != nil {
		return err
	}
	hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.ln.Addr().String())
	for {
		ctx := context.Background()
		conn, err := t.ln.Accept()
		var c network.Conn
		if err != nil {
			hlog.SystemLogger().Errorf("Error=%s", err.Error())
			return err
		}

		if t.OnAccept != nil {
			ctx = t.OnAccept(conn)
		}

		if t.tls != nil {
			c = newTLSConn(tls.Server(conn, t.tls), t.readBufferSize)
		} else {
			c = newConn(conn, t.readBufferSize)
		}

		if t.OnConnect != nil {
			ctx = t.OnConnect(ctx, c)
		}
		go t.handler(ctx, c)
	}
}

func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
	network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
	if t.listenConfig != nil {
		t.listener, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
	} else {
		t.listener, err = net.Listen(t.network, t.addr)
	}

	if err != nil {
		panic("create netpoll listener fail: " + err.Error())
	}

	// Initialize custom option for EventLoop
	opts := []netpoll.Option{
		netpoll.WithIdleTimeout(t.keepAliveTimeout),
		netpoll.WithOnPrepare(func(conn netpoll.Connection) context.Context {
			conn.SetReadTimeout(t.readTimeout) // nolint:errcheck
			if t.writeTimeout > 0 {
				conn.SetWriteTimeout(t.writeTimeout)
			}
			if t.OnAccept != nil {
				return t.OnAccept(newConn(conn))
			}
			return context.Background()
		}),
	}

	if t.OnConnect != nil {
		opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context {
			return t.OnConnect(ctx, newConn(conn))
		}))
	}

	// Create EventLoop
	t.Lock()
	t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
		return onReq(ctx, newConn(connection))
	}, opts...)
	t.Unlock()
	if err != nil {
		panic("create netpoll event-loop fail")
	}

	// Start Server
	hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.listener.Addr().String())
	t.RLock()
	err = t.eventLoop.Serve(t.listener)
	t.RUnlock()
	if err != nil {
		panic("netpoll server exit")
	}

	return nil
}

engine.Serve

go 复制代码
func (engine *Engine) onData(c context.Context, conn interface{}) (err error) {
	switch conn := conn.(type) {
	case network.Conn:
		err = engine.Serve(c, conn)
	case network.StreamConn:
		err = engine.ServeStream(c, conn)
	}
	return
}

func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) {
	defer func() {
		errProcess(conn, err)
	}()

	// H2C path
	if engine.options.H2C {
		// protocol sniffer
		buf, _ := conn.Peek(len(bytestr.StrClientPreface))
		if bytes.Equal(buf, bytestr.StrClientPreface) && engine.protocolServers[suite.HTTP2] != nil {
			return engine.protocolServers[suite.HTTP2].Serve(c, conn)
		}
		hlog.SystemLogger().Warn("HTTP2 server is not loaded, request is going to fallback to HTTP1 server")
	}

	// ALPN path
	if engine.options.ALPN && engine.options.TLS != nil {
		proto, err1 := engine.getNextProto(conn)
		if err1 != nil {
			// The client closes the connection when handshake. So just ignore it.
			if err1 == io.EOF {
				return nil
			}
			if re, ok := err1.(tls.RecordHeaderError); ok && re.Conn != nil && utils.TLSRecordHeaderLooksLikeHTTP(re.RecordHeader) {
				io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n")
				re.Conn.Close()
				return re
			}
			return err1
		}
		if server, ok := engine.protocolServers[proto]; ok {
			return server.Serve(c, conn)
		}
	}

	// HTTP1 path
	err = engine.protocolServers[suite.HTTP1].Serve(c, conn)

	return
}

server.Serve

  • s.Core.ServeHTTP(cc, ctx) 进入用户代码
go 复制代码
func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
	var (
		zr network.Reader
		zw network.Writer

		serverName      []byte
		isHTTP11        bool
		connectionClose bool

		continueReadingRequest = true

		hijackHandler app.HijackHandler

		// HTTP1 path
		// 1. Get a request context
		// 2. Prepare it
		// 3. Process it
		// 4. Reset and recycle
		ctx = s.Core.GetCtxPool().Get().(*app.RequestContext)

		traceCtl        = s.Core.GetTracer()
		eventsToTrigger *eventStack

		// Use a new variable to hold the standard context to avoid modify the initial
		// context.
		cc = c
	)

	if s.EnableTrace {
		eventsToTrigger = s.eventStackPool.Get().(*eventStack)
	}

	defer func() {
		if s.EnableTrace {
			if err != nil && !errors.Is(err, errs.ErrIdleTimeout) && !errors.Is(err, errs.ErrHijacked) {
				ctx.GetTraceInfo().Stats().SetError(err)
			}
			// in case of error, we need to trigger all events
			if eventsToTrigger != nil {
				for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() {
					last(ctx.GetTraceInfo(), err)
				}
				s.eventStackPool.Put(eventsToTrigger)
			}

			traceCtl.DoFinish(cc, ctx, err)
		}

		// Hijack may release and close the connection already
		if zr != nil && !errors.Is(err, errs.ErrHijacked) {
			zr.Release() //nolint:errcheck
			zr = nil
		}
		ctx.Reset()
		s.Core.GetCtxPool().Put(ctx)
	}()

	ctx.HTMLRender = s.HTMLRender
	ctx.SetConn(conn)
	ctx.Request.SetIsTLS(s.TLS != nil)
	ctx.SetEnableTrace(s.EnableTrace)

	if !s.NoDefaultServerHeader {
		serverName = s.ServerName
	}

	connRequestNum := uint64(0)

	for {
		connRequestNum++

		if zr == nil {
			zr = ctx.GetReader()
		}

		// If this is a keep-alive connection we want to try and read the first bytes
		// within the idle time.
		if connRequestNum > 1 {
			ctx.GetConn().SetReadTimeout(s.IdleTimeout) //nolint:errcheck

			_, err = zr.Peek(4)
			// This is not the first request, and we haven't read a single byte
			// of a new request yet. This means it's just a keep-alive connection
			// closing down either because the remote closed it or because
			// or a read timeout on our side. Either way just close the connection
			// and don't return any error response.
			if err != nil {
				err = errIdleTimeout
				return
			}

			// Reset the real read timeout for the coming request
			ctx.GetConn().SetReadTimeout(s.ReadTimeout) //nolint:errcheck
		}

		if s.EnableTrace {
			cc = traceCtl.DoStart(c, ctx)
			internalStats.Record(ctx.GetTraceInfo(), stats.ReadHeaderStart, err)
			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
				internalStats.Record(ti, stats.ReadHeaderFinish, err)
			})
		}
		// Read Headers
		if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil {
			if s.EnableTrace {
				// read header finished
				if last := eventsToTrigger.pop(); last != nil {
					last(ctx.GetTraceInfo(), err)
				}
				internalStats.Record(ctx.GetTraceInfo(), stats.ReadBodyStart, err)
				eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
					internalStats.Record(ti, stats.ReadBodyFinish, err)
				})
			}
			// Read body
			if s.StreamRequestBody {
				err = req.ReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
			} else {
				err = req.ReadLimitBody(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
			}
		}

		if s.EnableTrace {
			if ctx.Request.Header.ContentLength() >= 0 {
				ctx.GetTraceInfo().Stats().SetRecvSize(len(ctx.Request.Header.RawHeaders()) + ctx.Request.Header.ContentLength())
			} else {
				ctx.GetTraceInfo().Stats().SetRecvSize(0)
			}
			// read body finished
			if last := eventsToTrigger.pop(); last != nil {
				last(ctx.GetTraceInfo(), err)
			}
		}

		if err != nil {
			if errors.Is(err, errs.ErrNothingRead) {
				return nil
			}

			if err == io.EOF {
				return errUnexpectedEOF
			}
			writeErrorResponse(zw, ctx, serverName, err)
			return
		}

		// 'Expect: 100-continue' request handling.
		// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
		if ctx.Request.MayContinue() {
			// Allow the ability to deny reading the incoming request body
			if s.ContinueHandler != nil {
				if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
					ctx.SetStatusCode(consts.StatusExpectationFailed)
				}
			}

			if continueReadingRequest {
				zw = ctx.GetWriter()
				// Send 'HTTP/1.1 100 Continue' response.
				_, err = zw.WriteBinary(bytestr.StrResponseContinue)
				if err != nil {
					return
				}
				err = zw.Flush()
				if err != nil {
					return
				}

				// Read body.
				if zr == nil {
					zr = ctx.GetReader()
				}
				if s.StreamRequestBody {
					err = req.ContinueReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
				} else {
					err = req.ContinueReadBody(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
				}
				if err != nil {
					writeErrorResponse(zw, ctx, serverName, err)
					return
				}
			}
		}

		connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
		isHTTP11 = ctx.Request.Header.IsHTTP11()

		if serverName != nil {
			ctx.Response.Header.SetServerBytes(serverName)
		}
		if s.EnableTrace {
			internalStats.Record(ctx.GetTraceInfo(), stats.ServerHandleStart, err)
			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
				internalStats.Record(ti, stats.ServerHandleFinish, err)
			})
		}
		// Handle the request
		//
		// NOTE: All middlewares and business handler will be executed in this. And at this point, the request has been parsed
		// and the route has been matched.
		s.Core.ServeHTTP(cc, ctx)
		if s.EnableTrace {
			// application layer handle finished
			if last := eventsToTrigger.pop(); last != nil {
				last(ctx.GetTraceInfo(), err)
			}
		}

		// exit check
		if !s.Core.IsRunning() {
			connectionClose = true
		}

		if !ctx.IsGet() && ctx.IsHead() {
			ctx.Response.SkipBody = true
		}

		hijackHandler = ctx.GetHijackHandler()
		ctx.SetHijackHandler(nil)

		connectionClose = connectionClose || ctx.Response.ConnectionClose()
		if connectionClose {
			ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrClose)
		} else if !isHTTP11 {
			ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrKeepAlive)
		}

		if zw == nil {
			zw = ctx.GetWriter()
		}
		if s.EnableTrace {
			internalStats.Record(ctx.GetTraceInfo(), stats.WriteStart, err)
			eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
				internalStats.Record(ti, stats.WriteFinish, err)
			})
		}
		if err = writeResponse(ctx, zw); err != nil {
			return
		}

		if s.EnableTrace {
			if ctx.Response.Header.ContentLength() > 0 {
				ctx.GetTraceInfo().Stats().SetSendSize(ctx.Response.Header.GetHeaderLength() + ctx.Response.Header.ContentLength())
			} else {
				ctx.GetTraceInfo().Stats().SetSendSize(0)
			}
		}

		// Release the zeroCopyReader before flush to prevent data race
		if zr != nil {
			zr.Release() //nolint:errcheck
			zr = nil
		}
		// Flush the response.
		if err = zw.Flush(); err != nil {
			return
		}
		if s.EnableTrace {
			// write finished
			if last := eventsToTrigger.pop(); last != nil {
				last(ctx.GetTraceInfo(), err)
			}
		}

		// Release request body stream
		if ctx.Request.IsBodyStream() {
			err = ext.ReleaseBodyStream(ctx.RequestBodyStream())
			if err != nil {
				return
			}
		}

		if hijackHandler != nil {
			// Hijacked conn process the timeout by itself
			err = ctx.GetConn().SetReadTimeout(0)
			if err != nil {
				return
			}

			// Hijack and block the connection until the hijackHandler return
			s.HijackConnHandle(ctx.GetConn(), hijackHandler)
			err = errHijacked
			return
		}

		// general case
		if s.EnableTrace {
			traceCtl.DoFinish(cc, ctx, err)
		}

		if connectionClose {
			return errShortConnection
		}
		// Back to network layer to trigger.
		// For now, only netpoll network mode has this feature.
		if s.IdleTimeout == 0 {
			return
		}

		ctx.ResetWithoutConn()
	}
}
相关推荐
Code哈哈笑1 小时前
【基于Spring Boot 的图书购买系统】深度讲解 用户注册的前后端交互,Mapper操作MySQL数据库进行用户持久化
数据库·spring boot·后端·mysql·mybatis·交互
Javatutouhouduan1 小时前
线上问题排查:JVM OOM问题如何排查和解决
java·jvm·数据库·后端·程序员·架构师·oom
多多*2 小时前
Spring之Bean的初始化 Bean的生命周期 全站式解析
java·开发语言·前端·数据库·后端·spring·servlet
Villiam_AY3 小时前
Go 后端中双 token 的实现模板
开发语言·后端·golang
拾贰_C7 小时前
【SpringBoot】MyBatisPlus(MP | 分页查询操作
java·spring boot·后端·spring·maven·apache·intellij-idea
就叫飞六吧12 小时前
Spring Security 集成指南:避免 CORS 跨域问题
java·后端·spring
冼紫菜13 小时前
[特殊字符]CentOS 7.6 安装 JDK 11(适配国内服务器环境)
java·linux·服务器·后端·centos
秋野酱15 小时前
Spring Boot 项目的计算机专业论文参考文献
java·spring boot·后端
香饽饽~、15 小时前
【第二篇】 初步解析Spring Boot
java·spring boot·后端