概要
实现一个HTTP服务端框架既简单不简单,Go的gin框架为什么那样设计API,hertz有在gin的基础上做了什么改进,这篇带你入门。
分析
Hello world
go
package main
import (
"context"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)
func main() {
// server.Default() creates a Hertz with recovery middleware.
// If you need a pure hertz, you can use server.New()
h := server.Default()
h.GET("/hello", func(ctx context.Context, c *app.RequestContext) {
c.String(consts.StatusOK, "Hello hertz!")
})
v1 := h.Group("/v1")
{
// loginEndpoint is a handler func
v1.GET("/get", func(ctx context.Context, c *app.RequestContext) {
c.String(consts.StatusOK, "get")
})
v1.POST("/post", func(ctx context.Context, c *app.RequestContext) {
c.String(consts.StatusOK, "post")
})
}
h.Spin()
}
API设计
引擎和路由
- route.Engine 是server的核心,即引擎,它会作为一个成员被到处引用
- route.RouterGroup 是路由的核心,包含了basepath和已注册的handlers,为什么会有已注册的handlers?因为RouterGroup支持middleware,middleware的本质就是handler
go
func Default(opts ...config.Option) *Hertz
type Hertz struct {
*route.Engine
signalWaiter func(err chan error) error
}
type Engine struct {
noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used
// engine name
Name string
serverName atomic.Value
// Options for route and protocol server
options *config.Options
// route
RouterGroup
trees MethodTrees
maxParams uint16
allNoMethod app.HandlersChain
allNoRoute app.HandlersChain
noRoute app.HandlersChain
noMethod app.HandlersChain
// For render HTML
delims render.Delims
funcMap template.FuncMap
htmlRender render.HTMLRender
// NoHijackConnPool will control whether invite pool to acquire/release the hijackConn or not.
// If it is difficult to guarantee that hijackConn will not be closed repeatedly, set it to true.
NoHijackConnPool bool
hijackConnPool sync.Pool
// KeepHijackedConns is an opt-in disable of connection
// close by hertz after connections' HijackHandler returns.
// This allows to save goroutines, e.g. when hertz used to upgrade
// http connections to WS and connection goes to another handler,
// which will close it when needed.
KeepHijackedConns bool
// underlying transport
transport network.Transporter
// trace
tracerCtl tracer.Controller
enableTrace bool
// protocol layer management
protocolSuite *suite.Config
protocolServers map[string]protocol.Server
protocolStreamServers map[string]protocol.StreamServer
// RequestContext pool
ctxPool sync.Pool
// Function to handle panics recovered from http handlers.
// It should be used to generate an error page and return the http error code
// 500 (Internal Server Error).
// The handler can be used to keep your server from crashing because of
// unrecovered panics.
PanicHandler app.HandlerFunc
// ContinueHandler is called after receiving the Expect 100 Continue Header
//
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1
// Using ContinueHandler a server can make decisioning on whether or not
// to read a potentially large request body based on the headers
//
// The default is to automatically read request bodies of Expect 100 Continue requests
// like they are normal requests
ContinueHandler func(header *protocol.RequestHeader) bool
// Indicates the engine status (Init/Running/Shutdown/Closed).
status uint32
// Hook functions get triggered sequentially when engine start
OnRun []CtxErrCallback
// Hook functions get triggered simultaneously when engine shutdown
OnShutdown []CtxCallback
// Custom Functions
clientIPFunc app.ClientIP
formValueFunc app.FormValueFunc
}
func (group *RouterGroup) Group(relativePath string, handlers ...app.HandlerFunc) *RouterGroup
type RouterGroup struct {
Handlers app.HandlersChain
basePath string
engine *Engine
root bool
}
请求上下文
- hertz的HandlerFunc提供了2个上下文,context.Context就是Go标准的context,用于调用Go生态中的接口,例如Gorm;
- RequestContext相当于gin的上下文,是生命周期为Request级别的上下文,封装了处理请求过程中需要的所有东西和方法。
go
type HandlersChain []HandlerFunc
type HandlerFunc func(c context.Context, ctx *RequestContext)
type RequestContext struct {
conn network.Conn
Request protocol.Request
Response protocol.Response
// Errors is a list of errors attached to all the handlers/middlewares who used this context.
Errors errors.ErrorChain
Params param.Params
handlers HandlersChain
fullPath string
index int8
HTMLRender render.HTMLRender
// This mutex protect Keys map.
mu sync.RWMutex
// Keys is a key/value pair exclusively for the context of each request.
Keys map[string]interface{}
hijackHandler HijackHandler
finishedMu sync.Mutex
// finished means the request end.
finished chan struct{}
// traceInfo defines the trace information.
traceInfo traceinfo.TraceInfo
// enableTrace defines whether enable trace.
enableTrace bool
// clientIPFunc get client ip by use custom function.
clientIPFunc ClientIP
// clientIPFunc get form value by use custom function.
formValueFunc FormValueFunc
}
API实现
handle
go
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers app.HandlersChain) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers)
group.engine.addRoute(httpMethod, absolutePath, handlers)
return group.returnObj()
}
func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) {
methodRouter := engine.trees.get(method)
if methodRouter == nil {
methodRouter = &router{method: method, root: &node{}, hasTsrHandler: make(map[string]bool)}
engine.trees = append(engine.trees, methodRouter)
}
methodRouter.addRoute(path, handlers)
}
func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) {
// ...
//
value := t[i].find(rPath, paramsPointer, unescape)
if value.handlers != nil {
ctx.SetHandlers(value.handlers)
ctx.SetFullPath(value.fullPath)
ctx.Next(c)
return
}
}
middleware
go
// example
func MyMiddleware1() app.HandlerFunc {
return func(ctx context.Context, c *app.RequestContext) {
// pre-handle
fmt.Println("pre-handle")
}
}
func MyMiddleware2() app.HandlerFunc {
return func(ctx context.Context, c *app.RequestContext) {
// pre-handle
fmt.Println("pre-handle")
c.Next(ctx) // call the next middleware(handler)
// post-handle
fmt.Println("post-handle")
}
}
func (group *RouterGroup) Use(middleware ...app.HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
// ctx.handlers 从前往后调用handler,handler是在开始处理前被engine注入的
// 由于RequestContext是单线程,所以外部调用几次Next都可以
func (ctx *RequestContext) Next(c context.Context) {
ctx.index++
for ctx.index < int8(len(ctx.handlers)) {
ctx.handlers[ctx.index](c, ctx)
ctx.index++
}
}
深入并发模式
eventLoop
hertz使用了netpoll包的eventLoop模式,类似于epoll;另一种实现是Go http原生实现
go
func (t *transport) serve() (err error) {
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
t.lock.Lock()
if t.listenConfig != nil {
t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
} else {
t.ln, err = net.Listen(t.network, t.addr)
}
t.lock.Unlock()
if err != nil {
return err
}
hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.ln.Addr().String())
for {
ctx := context.Background()
conn, err := t.ln.Accept()
var c network.Conn
if err != nil {
hlog.SystemLogger().Errorf("Error=%s", err.Error())
return err
}
if t.OnAccept != nil {
ctx = t.OnAccept(conn)
}
if t.tls != nil {
c = newTLSConn(tls.Server(conn, t.tls), t.readBufferSize)
} else {
c = newConn(conn, t.readBufferSize)
}
if t.OnConnect != nil {
ctx = t.OnConnect(ctx, c)
}
go t.handler(ctx, c)
}
}
func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck
if t.listenConfig != nil {
t.listener, err = t.listenConfig.Listen(context.Background(), t.network, t.addr)
} else {
t.listener, err = net.Listen(t.network, t.addr)
}
if err != nil {
panic("create netpoll listener fail: " + err.Error())
}
// Initialize custom option for EventLoop
opts := []netpoll.Option{
netpoll.WithIdleTimeout(t.keepAliveTimeout),
netpoll.WithOnPrepare(func(conn netpoll.Connection) context.Context {
conn.SetReadTimeout(t.readTimeout) // nolint:errcheck
if t.writeTimeout > 0 {
conn.SetWriteTimeout(t.writeTimeout)
}
if t.OnAccept != nil {
return t.OnAccept(newConn(conn))
}
return context.Background()
}),
}
if t.OnConnect != nil {
opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context {
return t.OnConnect(ctx, newConn(conn))
}))
}
// Create EventLoop
t.Lock()
t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
return onReq(ctx, newConn(connection))
}, opts...)
t.Unlock()
if err != nil {
panic("create netpoll event-loop fail")
}
// Start Server
hlog.SystemLogger().Infof("HTTP server listening on address=%s", t.listener.Addr().String())
t.RLock()
err = t.eventLoop.Serve(t.listener)
t.RUnlock()
if err != nil {
panic("netpoll server exit")
}
return nil
}
engine.Serve
go
func (engine *Engine) onData(c context.Context, conn interface{}) (err error) {
switch conn := conn.(type) {
case network.Conn:
err = engine.Serve(c, conn)
case network.StreamConn:
err = engine.ServeStream(c, conn)
}
return
}
func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) {
defer func() {
errProcess(conn, err)
}()
// H2C path
if engine.options.H2C {
// protocol sniffer
buf, _ := conn.Peek(len(bytestr.StrClientPreface))
if bytes.Equal(buf, bytestr.StrClientPreface) && engine.protocolServers[suite.HTTP2] != nil {
return engine.protocolServers[suite.HTTP2].Serve(c, conn)
}
hlog.SystemLogger().Warn("HTTP2 server is not loaded, request is going to fallback to HTTP1 server")
}
// ALPN path
if engine.options.ALPN && engine.options.TLS != nil {
proto, err1 := engine.getNextProto(conn)
if err1 != nil {
// The client closes the connection when handshake. So just ignore it.
if err1 == io.EOF {
return nil
}
if re, ok := err1.(tls.RecordHeaderError); ok && re.Conn != nil && utils.TLSRecordHeaderLooksLikeHTTP(re.RecordHeader) {
io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n")
re.Conn.Close()
return re
}
return err1
}
if server, ok := engine.protocolServers[proto]; ok {
return server.Serve(c, conn)
}
}
// HTTP1 path
err = engine.protocolServers[suite.HTTP1].Serve(c, conn)
return
}
server.Serve
- s.Core.ServeHTTP(cc, ctx) 进入用户代码
go
func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
var (
zr network.Reader
zw network.Writer
serverName []byte
isHTTP11 bool
connectionClose bool
continueReadingRequest = true
hijackHandler app.HijackHandler
// HTTP1 path
// 1. Get a request context
// 2. Prepare it
// 3. Process it
// 4. Reset and recycle
ctx = s.Core.GetCtxPool().Get().(*app.RequestContext)
traceCtl = s.Core.GetTracer()
eventsToTrigger *eventStack
// Use a new variable to hold the standard context to avoid modify the initial
// context.
cc = c
)
if s.EnableTrace {
eventsToTrigger = s.eventStackPool.Get().(*eventStack)
}
defer func() {
if s.EnableTrace {
if err != nil && !errors.Is(err, errs.ErrIdleTimeout) && !errors.Is(err, errs.ErrHijacked) {
ctx.GetTraceInfo().Stats().SetError(err)
}
// in case of error, we need to trigger all events
if eventsToTrigger != nil {
for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() {
last(ctx.GetTraceInfo(), err)
}
s.eventStackPool.Put(eventsToTrigger)
}
traceCtl.DoFinish(cc, ctx, err)
}
// Hijack may release and close the connection already
if zr != nil && !errors.Is(err, errs.ErrHijacked) {
zr.Release() //nolint:errcheck
zr = nil
}
ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
}()
ctx.HTMLRender = s.HTMLRender
ctx.SetConn(conn)
ctx.Request.SetIsTLS(s.TLS != nil)
ctx.SetEnableTrace(s.EnableTrace)
if !s.NoDefaultServerHeader {
serverName = s.ServerName
}
connRequestNum := uint64(0)
for {
connRequestNum++
if zr == nil {
zr = ctx.GetReader()
}
// If this is a keep-alive connection we want to try and read the first bytes
// within the idle time.
if connRequestNum > 1 {
ctx.GetConn().SetReadTimeout(s.IdleTimeout) //nolint:errcheck
_, err = zr.Peek(4)
// This is not the first request, and we haven't read a single byte
// of a new request yet. This means it's just a keep-alive connection
// closing down either because the remote closed it or because
// or a read timeout on our side. Either way just close the connection
// and don't return any error response.
if err != nil {
err = errIdleTimeout
return
}
// Reset the real read timeout for the coming request
ctx.GetConn().SetReadTimeout(s.ReadTimeout) //nolint:errcheck
}
if s.EnableTrace {
cc = traceCtl.DoStart(c, ctx)
internalStats.Record(ctx.GetTraceInfo(), stats.ReadHeaderStart, err)
eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
internalStats.Record(ti, stats.ReadHeaderFinish, err)
})
}
// Read Headers
if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil {
if s.EnableTrace {
// read header finished
if last := eventsToTrigger.pop(); last != nil {
last(ctx.GetTraceInfo(), err)
}
internalStats.Record(ctx.GetTraceInfo(), stats.ReadBodyStart, err)
eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
internalStats.Record(ti, stats.ReadBodyFinish, err)
})
}
// Read body
if s.StreamRequestBody {
err = req.ReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
} else {
err = req.ReadLimitBody(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
}
}
if s.EnableTrace {
if ctx.Request.Header.ContentLength() >= 0 {
ctx.GetTraceInfo().Stats().SetRecvSize(len(ctx.Request.Header.RawHeaders()) + ctx.Request.Header.ContentLength())
} else {
ctx.GetTraceInfo().Stats().SetRecvSize(0)
}
// read body finished
if last := eventsToTrigger.pop(); last != nil {
last(ctx.GetTraceInfo(), err)
}
}
if err != nil {
if errors.Is(err, errs.ErrNothingRead) {
return nil
}
if err == io.EOF {
return errUnexpectedEOF
}
writeErrorResponse(zw, ctx, serverName, err)
return
}
// 'Expect: 100-continue' request handling.
// See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details.
if ctx.Request.MayContinue() {
// Allow the ability to deny reading the incoming request body
if s.ContinueHandler != nil {
if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest {
ctx.SetStatusCode(consts.StatusExpectationFailed)
}
}
if continueReadingRequest {
zw = ctx.GetWriter()
// Send 'HTTP/1.1 100 Continue' response.
_, err = zw.WriteBinary(bytestr.StrResponseContinue)
if err != nil {
return
}
err = zw.Flush()
if err != nil {
return
}
// Read body.
if zr == nil {
zr = ctx.GetReader()
}
if s.StreamRequestBody {
err = req.ContinueReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
} else {
err = req.ContinueReadBody(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm)
}
if err != nil {
writeErrorResponse(zw, ctx, serverName, err)
return
}
}
}
connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
isHTTP11 = ctx.Request.Header.IsHTTP11()
if serverName != nil {
ctx.Response.Header.SetServerBytes(serverName)
}
if s.EnableTrace {
internalStats.Record(ctx.GetTraceInfo(), stats.ServerHandleStart, err)
eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
internalStats.Record(ti, stats.ServerHandleFinish, err)
})
}
// Handle the request
//
// NOTE: All middlewares and business handler will be executed in this. And at this point, the request has been parsed
// and the route has been matched.
s.Core.ServeHTTP(cc, ctx)
if s.EnableTrace {
// application layer handle finished
if last := eventsToTrigger.pop(); last != nil {
last(ctx.GetTraceInfo(), err)
}
}
// exit check
if !s.Core.IsRunning() {
connectionClose = true
}
if !ctx.IsGet() && ctx.IsHead() {
ctx.Response.SkipBody = true
}
hijackHandler = ctx.GetHijackHandler()
ctx.SetHijackHandler(nil)
connectionClose = connectionClose || ctx.Response.ConnectionClose()
if connectionClose {
ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrClose)
} else if !isHTTP11 {
ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrKeepAlive)
}
if zw == nil {
zw = ctx.GetWriter()
}
if s.EnableTrace {
internalStats.Record(ctx.GetTraceInfo(), stats.WriteStart, err)
eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) {
internalStats.Record(ti, stats.WriteFinish, err)
})
}
if err = writeResponse(ctx, zw); err != nil {
return
}
if s.EnableTrace {
if ctx.Response.Header.ContentLength() > 0 {
ctx.GetTraceInfo().Stats().SetSendSize(ctx.Response.Header.GetHeaderLength() + ctx.Response.Header.ContentLength())
} else {
ctx.GetTraceInfo().Stats().SetSendSize(0)
}
}
// Release the zeroCopyReader before flush to prevent data race
if zr != nil {
zr.Release() //nolint:errcheck
zr = nil
}
// Flush the response.
if err = zw.Flush(); err != nil {
return
}
if s.EnableTrace {
// write finished
if last := eventsToTrigger.pop(); last != nil {
last(ctx.GetTraceInfo(), err)
}
}
// Release request body stream
if ctx.Request.IsBodyStream() {
err = ext.ReleaseBodyStream(ctx.RequestBodyStream())
if err != nil {
return
}
}
if hijackHandler != nil {
// Hijacked conn process the timeout by itself
err = ctx.GetConn().SetReadTimeout(0)
if err != nil {
return
}
// Hijack and block the connection until the hijackHandler return
s.HijackConnHandle(ctx.GetConn(), hijackHandler)
err = errHijacked
return
}
// general case
if s.EnableTrace {
traceCtl.DoFinish(cc, ctx, err)
}
if connectionClose {
return errShortConnection
}
// Back to network layer to trigger.
// For now, only netpoll network mode has this feature.
if s.IdleTimeout == 0 {
return
}
ctx.ResetWithoutConn()
}
}