fastAPI教程:进阶操作

FastAPI

七、进阶操作

中间件与CORS

文档:https://fastapi.tiangolo.com/zh/advanced/middleware/

你可以向 FastAPI 应用添加中间件.

"中间件"是一个装饰器函数(也叫hook function,钩子函数 ),它会在fastAPI接受到客户端的每个请求操作之前响应操作以后自动执行。图示:

FastAPI中间件的写法

要创建中间件你可以在函数的顶部使用装饰器 @app.middleware("http").

中间件参数接收如下参数:

request.

  • 一个函数call_next,它将接收request,作为参数.
    • 这个函数将 request 传递给相应的 路径操作.
    • 然后它将返回由相应的路径操作 生成的 response.
  • 然后你可以在返回 response 前进一步修改它.

代码:

python 复制代码
import requests
import uvicorn
from fastapi import FastAPI, Request, Response


app = FastAPI()

# 创建一个中间件
@app.middleware('http') # 参数表示中间件的类型,如果值为'http',则表示当前中间件会在请求之前与响应以后自动执行
async def middleware1(request: Request, call_next):
    """
    中间件固定会有2个参数
    :param request: 本次客户端的http请求对象
    :param call_next: 本次客户端请求得人url地址绑定的接口函数
    :return:
    """
    # 执行请求之前的代码[]
    print("middleware1: 视图执行之前的自动执行的代码,例如:权限判断,读取缓存数据,访问日志,登录状态的判断等等")
    response: Response = await call_next(request) # api函数执行
    print('middleware1: 视图执行以后的自动执行的代码,例如:把数据缓存到数据库,操作日志,实现跨域共享')
    # 中间件必须有返回结果,返回的结果必须是Response对象
    return response # response将来返回给客户端

# 一个项目中可以存在0到多个中间件
@app.middleware('http')
async def middleware2(request: Request, call_next):
    """
    中间件固定会有2个参数
    :param request: 本次客户端的http请求对象
    :param call_next: 本次客户端请求得人url地址绑定的接口函数[当存在多个中间件函数时,call_next表示下一个中间件函数]
    """
    print("middleware2: 视图执行之前的自动执行的代码,例如:权限判断,读取缓存数据,访问日志,登录状态的判断等等")
    response: Response = await call_next(request)  # 进入下一个中间件
    print('middleware2: 视图执行以后的自动执行的代码,例如:把数据缓存到数据库,操作日志,实现跨域共享')
    return response  # response将来返回给客户端

# 测试接口
@app.get('/web')
async def web():
    print("视图执行了....")
    return {'title': '来自服务端的数据'}

@app.get('/music')
async def music(title: str):
    response = requests.get(f'http://msearchcdn.kugou.com/api/v3/search/song?keyword={title}')
    return response.json()

if __name__ == '__main__':
    uvicorn.run('main:app', host='0.0.0.0', port=8000, reload=True)
使用自定义中间实现跨域

main.py,代码:

python 复制代码
import requests
import uvicorn
from fastapi import FastAPI, Request, Response


app = FastAPI()

# FastAPI实现CORS跨域
@app.middleware('http')
async def cors_middleware(request: Request, call_next) -> Response:
    """
    CORS跨域支持中间件
    :param request: 本次客户端的HTTP请求对象
    :param call_next: 下一个调用的中间件,如果没有中间件,则调用API视图函数
    :return Response: HTTP响应对象
    """

    response: Response = await call_next(request)

    # 设置CORS响应保温
    response.headers['Access-Control-Allow-Origin'] = '*' # 允许任意客户端访问
    # response.headers['Access-Control-Allow-Origin'] = 'http://localhost:3000' # 仅允许指定客户端访问

    return response

# 测试接口
@app.get('/web')
async def web():
    print("视图执行了....")
    return {'title': '来自服务端的数据'}

@app.get('/music')
async def music(title: str):
    response = requests.get(f'http://msearchcdn.kugou.com/api/v3/search/song?keyword={title}')
    return response.json()

if __name__ == '__main__':
    uvicorn.run('main:app', host='0.0.0.0', port=8000, reload=True)
使用FastAPI提供的CORS中间件实现跨域

FastAPI提供了一些内置中间件,我们可以根据官方文档的说明,直接配置使用,就不需要自己编写代码。main.py,代码:

python 复制代码
import requests
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# 注册中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000"],
    allow_methods=['*'],
    allow_headers=['*'],
)

# 测试接口
@app.get('/web')
async def web():
    print("视图执行了....")
    return {'title': '来自服务端的数据'}

@app.get('/music')
async def music(title: str):
    response = requests.get(f'http://msearchcdn.kugou.com/api/v3/search/song?keyword={title}')
    return response.json()

if __name__ == '__main__':
    uvicorn.run('main:app', host='0.0.0.0', port=8000, reload=True)

依赖注入

依赖注入,是指在编程中,为了保证功能的使用,先导入或声明所需以来,如子函数、数据库连接等。但又不同于装饰器。

fastapi编写接口通常用函数式编程,没有像django似的视图编程的优势,所以就没有封装、继承、多态等,于是就依靠依赖注入来实现

优势是提高代码复用率:

  • 共享数据库连接
  • 增强安全、认证和角色管理
  • 响应数据注入,可以在原来的响应数据基础上再做出更改,如抽出公用函数,在公用函数中对数据处理
注入依赖

函数作为依赖:定义公共函数(相当于创建依赖),如下,意思是返回的数据是字典形式,依赖于公共函数。可以在async def中调用def依赖,也可以在def中导入async def依赖

python 复制代码
async def common_parameters(q: Optional[str] = None, page: int = 1, limit: int = 100):
    return {"q": q, "page": page, "limit": limit}

@app05.get("/dependency01")
async def dependency01(commons: dict = Depends(common_parameters)):
    return commons

类作为依赖,有3种写法。下面实现的是更新数据

python 复制代码
fake_items_db = [{"item_name": "Foo"}, {"item_name": "Bar"}, {"item_name": "Baz"}]


class CommonQueryParams:
    def __init__(self, q: Optional[str] = None, page: int = 1, limit: int = 100):
        self.q = q
        self.page = page
        self.limit = limit


@app.get("/classes_as_dependencies")
# async def classes_as_dependencies(commons: CommonQueryParams = Depends(CommonQueryParams)):
# async def classes_as_dependencies(commons: CommonQueryParams = Depends()):
async def classes_as_dependencies(commons=Depends(CommonQueryParams)):
    response = {}
    if commons.q:
        response.update({"q": commons.q})
    items = fake_items_db[commons.page: commons.page + commons.limit]
    response.update({"items": items})
    return response

JWT

基本概念

在用户登录后,我们需要在不同请求之间记录用户的登录状态,常用方式一般有三种:Cookie,Session和Token。

这里我们使用第三种Token令牌方式来实现认证鉴权,采用Json Web Token认证机制(简称:jwt)。

Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准((RFC 7519).该token被设计为紧凑且安全的,特别适用于分布式站点的单点登录(SSO)场景。JWT的声明一般被用来在身份提供者和服务提供者间传递被认证的用户身份信息,以便于从资源服务器获取资源,也可以增加一些额外的其它业务逻辑所必须的声明信息,该token也可直接被用于认证,也可被加密。

jwt官网:https://jwt.io/

jwt规范:https://datatracker.ietf.org/doc/html/draft-ietf-oauth-json-web-token

JWT的构成

JWT就一段由三段信息构成的字符串,将这三段信息文本用.拼接一起就构成的。就像这样:

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ

第一部分我们称它为头部(header),第二部分我们称其为载荷(payload, 类似于飞机上承载的物品),第三部分是签证(signature).

python 复制代码
jwtToken = f"{header}.{payload}.{signature}"

jwt的头部承载两部分信息:

  • typ: type的缩写,声明token的类型,值一般可以是 JWT Bear
  • alg: algorithm的缩写,声明token的第三方部分(签证)的加密算法,通常直接使用 HMAC SHA256

完整的头部就像下面这样的JSON:

json 复制代码
{
  "typ": "Bear",
  "alg": "HS256"
}

然后将头部进行base64.b64urlencode()编码,构成了第一部分头部。

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9

python代码实现过程:

python 复制代码
import base64,json
data = {
  'typ': 'JWT',
  'alg': 'HS256'
}

header = base64.b64encode(json.dumps(data).encode()).decode()

# 各个语言中都有base64加密解密的功能,所以我们jwt为了安全,需要配合第三段加密
payload

载荷(payload)就是jwt存放有效信息的部分。这个名字像是特指飞机上承载的货仓,这些有效信息包含三种不同类型的数据:

  • 标准声明
  • 公共声明
  • 私有声明

标准声明 (官方提出建议但不强制使用) :

  • iss: jwt签发者

  • sub: jwt所面向的用户

  • aud: 接收jwt的一方

  • exp: jwt的过期时间,这个过期时间必须要大于签发时间

  • nbf: 定义在什么时间之后,该jwt可以正常使用。

  • iat: jwt的签发时间

  • jti: jwt的唯一身份标识,主要用来作为一次性token,往往采用UUID字符串或随机字符串来充当。

    以上是JWT规范中提供的7个官方字段,开发者根据自己的业务进行选用。

公共声明:公共的声明可以添加任何的信息,一般添加用户的相关信息或其他业务需要的必要信息。但不建议添加敏感信息,因为该部分在客户端直接可以查看。

私有声明:私有声明是服务端和客户端所共同定义的声明,一般使用类似ace算法进行非对称加密和解密的,意味着该部分信息可以归类为明文信息。

定义一个payload,json格式的数据:

json 复制代码
{
  "sub": "1234567890",  // 时间戳
  "exp": "3422335555",  // 时间戳
  "name": "John Doe",
  "admin": true,
}

然后将其进行base64.b64encode() 编码,得到JWT的第二部分。

eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9

python代码实现过程:

python 复制代码
import base64,json
data = {
  "sub": "1234567890",
  "exp": "3422335555",
  "name": "John Doe",
  "admin": True,
  "info": "232323ssdgerere3335dssss"
}

payload = base64.b64encode(json.dumps(data).encode()).decode()

# 各个语言中都有base64编码和解码,所以我们jwt为了安全,需要配合第三段签证来进行加密保证jwt不会被人篡改。
signature

JWT的第三部分叫签证信息,主要用于辨真伪,防篡改。签证信息使用加密算法生成,公式:

python 复制代码
secret_key = "秘钥" # 只保存服务端,不能外泄
signature = SHA256(base64.b64encode(header) + "." +base64.b64encode(payload),secret_key) 

python代码实现过程:

python 复制代码
import sys, json, base64, time, hmac

if __name__ == '__main__':
    # 头部
    data = {'typ': 'JWT', 'alg': 'HS256'}
    header = base64.b64encode(json.dumps(data).encode()).decode()

    # 载荷
    data = {"sub": "1234567890", "exp": "3422335555", "name": "John Doe", "admin": True,
            "info": "232323ssdgerere3335dssss"}
    payload = base64.b64encode(json.dumps(data).encode()).decode()

    # 签证,生成jwt token 提供给客户端
    # from django.conf import settings
    # secret = settings.SECRET_KEY
    secret = 'django-insecure-(_+qtd5edmhm%2rdsg+qc3wi@s_k*3cbk-+k2gpg3@qx)z6r+p'
    sign = base64.b64encode(f"{header}.{payload}".encode())
    signature = base64.b64encode(hmac.digest(secret.encode(), sign, digest="sha256")).decode()
    jwt = f"f{header}.{payload}.{signature}"
    print(jwt)
    # 将这三部分用`.`连接成一个完整的字符串,构成了最终的jwt:
    # feyJ0eXAiOiAiSldUIiwgImFsZyI6ICJIUzI1NiJ9.eyJzdWIiOiAiMTIzNDU2Nzg5MCIsICJleHAiOiAiMzQyMjMzNTU1NSIsICJuYW1lIjogIkpvaG4gRG9lIiwgImFkbWluIjogdHJ1ZSwgImluZm8iOiAiMjMyMzIzc3NkZ2VyZXJlMzMzNWRzc3NzIn0=.3OnGXAx5wWA5AjxyewICSn5Hirz1tXfzxOc4tns4elM=

注意:

secret是保存在服务器端的,jwt的签发生成代码也是在服务器端的,secret就是用来进行jwt的签发和jwt的验证,

所以它应该是服务端的私钥,在任何场景下都不应该流露出去,而且应该在每次服务端更新维护后及时更新。

一旦第三方得知这个secret, 那就意味着他们绕过服务端伪造jwt了。

优缺点

优点:

  1. 实现分布式集群的单点登陆非常方便
  2. Token实际保存在客户端,所以我们可以分担服务端的存储压力。
  3. jwt不仅可用于认证,还可用于信息交换。善用JWT有助于减少服务器请求数据库的次数,jwt的构成非常简单,字节占用很小,所以它是非常便于传输的。

缺点:

  1. jwt保存在客户端,我们服务端只认jwt,不识别客户端。
    • 解决方案1. 设置客户端唯一登陆
    • 解决方案2. 绑定客户端的标记符和IP,机器码
  2. jwt可以设置过期时间,但是因为jwt保存在了客户端,所以对于过期时间不好调整,一旦签发不可控。
    • 解决方案1:设置短有效期,例如:30分、15分钟、10分钟之类。
    • 解决方案2:生成jwt的时候,提供给客户端之前先在内存(一般使用内存数据库redis,而不是变量)备份jwt,每次用户访问需要登录身份数据时,把token去内存中验证一样。
使用jwt实现认证流程

所谓的认证流程,实际上就是用户登录的过程。

python代码实现认证流程,代码:

python 复制代码
import sys, json, base64, time, hmac

if __name__ == '__main__':
    # 模拟客户端提交的token
    client_token = "feyJ0eXAiOiAiSldUIiwgImFsZyI6ICJIUzI1NiJ9.eyJzdWIiOiAiMTIzNDU2Nzg5MCIsICJleHAiOiAiMzQyMjMzNTU1NSIsICJuYW1lIjogIkpvaG4gRG9lIiwgImFkbWluIjogdHJ1ZSwgImluZm8iOiAiMjMyMzIzc3NkZ2VyZXJlMzMzNWRzc3NzIn0=.3OnGXAx5wWA5AjxyewICSn5Hirz1tXfzxOc4tns4elM="
    # 把客户端提交的token分割成三段:头部、载荷、签证
    header, payload, signature = client_token.split(".")

    # 验证是否过期了,先基于base64,接着使用json解码,提供载荷中的过期时间进行比较
    payload_data = json.loads(base64.b64decode(payload.encode()))

    exp = int(payload_data.get("exp", 0))
    now = int(time.time())
    if exp < now:
        print("token已经过期!")
        sys.exit()  # 退出程序,实际开发中,应该时响应代码给客户端,不会继续往下执行了。

    secret = "django-insecure-(_+qtd5edmhm%2rdsg+qc3wi@s_k*3cbk-+k2gpg3@qx)z6r+p"
    # 与生成token时一样的秘钥和数据,再次生成一个签证
    new_signature = hmac.digest(secret.encode(), sign, digest="sha256")
    # 拿客户端提交上面的token中的签证进行base64解码得到原始的签证
    signature = base64.b64decode(signature)
    # 通过compare_digest比较两者是否吻合
    if hmac.compare_digest(signature, new_signature):
        print("认证通过")
    else:
        print("认证失败,token被串改!")
基本使用

开发中除非找不到,否则我们可以直接使用第三方已经开源的模块来完成相关的功能。大部分要求使用第三方模块是必须star数量>150。

依赖库安装
bash 复制代码
# python-jose 用于生成和检验JWT令牌
pip install jwt
pip install python-jose
JWT基本使用

生成一个随机的密钥,用于对JWT令牌进行签名加密的。终端执行命令如下:

python 复制代码
openssl rand -hex 32
# eac77e4e9a9a767b792779132e84ea37b1f4c31bec56714607f617a3fbdfbd53

创建JWT需要的相关配置项,settings.py,代码:

python 复制代码
# 加密数据所使用的秘钥[盐值]
SECRET_KEY = "eac77e4e9a9a767b792779132e84ea37b1f4c31bec56714607f617a3fbdfbd53"
# 设定JWT令牌签名算法
ALGORITHM = "HS256"
# 设置令牌过期时间变量(单位:秒)
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 60

创建JWT工具类,utils.py,代码:

python 复制代码
from typing import Optional
from datetime import timedelta, datetime
import settings
from jose import jwt
import uuid


class JWT(object):
    JWTError = jwt.JWTError
    ExpiredSignatureError = jwt.ExpiredSignatureError
    def create_token(self, data: dict, expire_time: Optional[timedelta] = None):
        """
        生成Token
        :param data: 需要进行JWT令牌加密的用户信息(解密的时候会用到)
        :param expire_time: 令牌有效期,单位:秒
        :return: token
        """

        now_time = datetime.utcnow()
        if expire_time:
            expire = now_time + timedelta(seconds=expire_time)
        else:
            expire = now_time + timedelta(seconds=settings.ACCESS_TOKEN_EXPIRE_TIME)

        payload = {
            "exp": expire,
            "iat": now_time,
            "nbf": now_time,
            "jti": str(uuid.uuid4())
        }
        payload.update(data)

        token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
        return token

    def verify_token(self, token: str) -> dict:
        """
        验证token
        :param token: 客户端发送过来的token
        :return: 返回用户信息
        """
        payload = jwt.decode(token, settings.SECRET_KEY, algorithms=settings.ALGORITHM)
        return payload


if __name__ == '__main__':
    """密码加密与验证"""
    # hashing = Hashing()
    # hashed_pwd = hashing.hash("123456")
    # print(hashed_pwd) # 加密后要保存到数据库中的哈希串
    # # 把原密码和加密后的哈希串进行配对,验证通过则返回结果为True
    # ret = hashing.verify("123456", hashed_pwd)
    # print(ret)


    """JWT"""
    jwt_tool = JWT()

    try:
        # 正确使用
        token = jwt_tool.create_token({'username': 'admin', 'sex': True})
        print(token)
        data = jwt_tool.verify_token(token)
        print(data)

        # # 因为Token过期导致验证失败
        # token = jwt_tool.create_token({'username': 'admin', 'sex': True}, -300)
        # print(token)
        # data = jwt_tool.verify_token(token)
        # print(data)

        # # 因为Token被串改导致验证失败
        # token = jwt_tool.create_token({'username': 'admin', 'sex': True})
        # print(token)
        # data = jwt_tool.verify_token(token[:-1])
        # print(data)

    except (jwt_tool.ExpiredSignatureError, jwt_tool.JWTError) as e:
        print("验证失败,", e)

基于自定义中间件创建JWT中间件实现用户身份认证

python 复制代码
async def jwt_middleware(request: Request, call_next):
    try:
        token: str = request.headers["Authorization"].split()[1]
        payload = jwt_took.verify(token)
        id: str = payload.get("id")
        # 查询数据库,是否存在当前用户
        user = await models.User.filter(id=id).first(id)
        if user is None:
            raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials")
        request.user = user
    except (jwt_tool.ExpiredSignatureError, jwt_tool.JWTError):
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials")
    response = await call_next(request)
    return response

注册JWT中间件,代码:

python 复制代码
app.add_middleware(jwt_middleware)

Admin

fastapi-admin 是一个基于 FastAPI 和 TortoiseORM 以及 tabler UI框架的后台管理面板, 灵感来自Django admin。

官方文档:https://fastapi-admin-docs.long2ice.io/zh/

终端下执行如下命令安装:

bash 复制代码
pip install fastapi-admin

Admin插件添加到FastAPI

首先,你需要挂载admin app到已存在的FastAPI应用。

python 复制代码
from fastapi_admin.app import app as admin_app
from fastapi import FastAPI

app = FastAPI()
app.mount("/admin", admin_app)

需要在FastAPIstartup事件中进行初始化配置。

python 复制代码
from fastapi_admin.app import app as admin_app
from fastapi_admin.providers.login import UsernamePasswordProvider
import aioredis
from fastapi import FastAPI

login_provider = UsernamePasswordProvider(user_model=User, enable_captcha=True)

app = FastAPI()


@app.on_event("startup")
async def startup():
    redis = await aioredis.create_redis_pool("redis://localhost", encoding="utf8")
    admin_app.configure(
        logo_url="https://preview.tabler.io/static/logo-white.svg",
        login_logo_url="https://preview.tabler.io/static/logo.svg",
        template_folders=[os.path.join(BASE_DIR, "templates")],
        login_provider=login_provider,
        maintenance=False,
        redis=redis,
    )

websocket

FastAPI官方文档中关于websocket:https://fastapi.tiangolo.com/zh/advanced/websockets/

首先需要安装 WebSockets

bash 复制代码
pip install websockets
基于websocket实现客户端与服务端的实时通信

ws.py,服务端,代码:

python 复制代码
import uvicorn
from fastapi import FastAPI, WebSocket

# 创建App应用对象
app = FastAPI()


# 绑定路由,监听来自websocket协议的路由
@app.websocket("/api") # 访问这个接口需要的地址  ws://127.0.0.1:8088/api
async def api(websocket: WebSocket):
    # 等待客户端的websocket请求连接
    await websocket.accept()
    while True:
        # await websocket.receive_json() 接收客户端通过websocket协议上传过来的json数据
        # await websocket.receive_text() 接收客户端通过websocket协议上传过来的文本数据
        # await websocket.receive_bytes() 接收客户端通过websocket协议上传过来的二进制数据[文件/图片/语音]
        data = await websocket.receive_json()
        print("接收到来自客户端的数据:", data)
        # await websocket.send_text 使用websocket协议把文本数据发送给客户端
        # await websocket.send_json 使用websocket协议把json数据发送给客户端
        # await websocket.send_bytes 使用websocket协议把二进制数据发送给客户端[文件/图片/语音]
        await websocket.send_json(f"您方才问的是: {data['message']}")


if __name__ == '__main__':
    uvicorn.run('ws:app', host='0.0.0.0', port=8088)

ws.html,客户端,代码:

html 复制代码
<!DOCTYPE html>
<html>
    <head>
        <title>Chat</title>
        <meta charset="UTF-8" />
    </head>
    <body>
        <h1>WebSocket Chat</h1>
        <input type="text" name="message" autocomplete="off"/><button>发送</button>
        <!-- 通话列表 -->
        <ul id='history'>

        </ul>
        <script>
        const message_inp = document.querySelector('input[name=message]')
        const send_btn = document.querySelector('button');
        const history = document.querySelector('#history');
        // 原生websocket
        ws = new WebSocket('ws://127.0.0.1:8088/api')

        // 客户端主动发送数据
        send_btn.onclick = ()=>{
            // 使用websocket发送数据
            let username = '用户A';
            let message = message_inp.value;
            ws.send(JSON.stringify({
                username,
                message
            }))

            history.innerHTML += `<li style="text-align: left">${username}: ${message}</li>`
        }

        // 客户端监听服务端主动发送的数据
        ws.onmessage = (message)=>{
            console.log("接受到服务端推送的数据:", message.data)
            history.innerHTML += `<li style="text-align: right">${message.data} :ChatGPT</li>`
        }
        </script>
    </body>
</html>
基于websocket实现客户端与客户端之间的聊天室功能

ws_chat.py,服务端,代码:

python 复制代码
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()


class ConnectionManager:
    """
    客户端连接管理器[在服务端中保存所有的websocket客户端连接]
    """
    def __init__(self):
        # 这个列表保存的是连接到服务端的所有websocket客户端连接对象
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        """websocket连接到服务端"""
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        """客户端主动断开连接"""
        self.active_connections.remove(websocket)

    async def send_personal_message(self, message: str, websocket: WebSocket):
        """私聊[把信息发送给指定的websocket连接对象]"""
        await websocket.send_text(message)

    async def broadcast(self, message: str, current_connection):
        """广播[把信息发给所有的websocket连接对象]"""
        for connection in self.active_connections:
            if current_connection != connection:
                # 上面的判断的主要作用是:不要把数据发给本人
                await connection.send_text(message)


manager = ConnectionManager()


@app.websocket("/chat")
async def chat(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_json()
            await manager.send_personal_message(f"我: {data['message']}", websocket)
            await manager.broadcast(f"{data['username']}: {data}", websocket)
    except WebSocketDisconnect:
        # 当有客户端离开聊天室,通过广播的方式告知所有的客户端
        manager.disconnect(websocket)
        await manager.broadcast(f"用户{data['username']} 离开了聊天室", websocket)


if __name__ == '__main__':
    uvicorn.run('ws_chat:app', host='0.0.0.0', port=8123)

ws_chat.html,代码:

html 复制代码
<!DOCTYPE html>
<html>
    <head>
        <title>Chat</title>
        <meta charset="UTF-8" />
    </head>
    <body>
        <h1>聊天室[你叫<span id="user"></span></h1>
        <input type="text" name="message" autocomplete="off"/><button>发送</button>
        <!-- 通话列表 -->
        <ul id='history'></ul>
        <script>
        const message_inp = document.querySelector('input[name=message]')
        const send_btn = document.querySelector('button');
        const history = document.querySelector('#history');
        const user = document.querySelector('#user');
        const username = location.search.substring(1);
        user.innerHTML = decodeURI(username);
        // 原生websocket
        ws = new WebSocket(`ws://127.0.0.1:8123/chat`)

        // 客户端主动发送数据
        send_btn.onclick = ()=>{
            // 使用websocket发送数据
            let message = message_inp.value;
            ws.send(JSON.stringify({
                username,
                message
            }))

            history.innerHTML += `<li style="text-align: left">${username}: ${message}</li>`
        }

        // 客户端监听服务端主动发送的数据
        ws.onmessage = (message)=>{
            console.log("接受到服务端推送的数据:", message.data)
            history.innerHTML += `<li style="text-align: right">${message.data} :ChatGPT</li>`
        }
        </script>
    </body>
</html>
相关推荐
春末的南方城市5 分钟前
开源音乐分离器Audio Decomposition:可实现盲源音频分离,无需外部乐器分离库,从头开始制作。将音乐转换为五线谱的程序
人工智能·计算机视觉·aigc·音视频
Make_magic44 分钟前
Git学习教程(更新中)
大数据·人工智能·git·elasticsearch·计算机视觉
goomind1 小时前
深度学习模型评价指标介绍
人工智能·python·深度学习·计算机视觉
youcans_1 小时前
【微软报告:多模态基础模型】(2)视觉理解
人工智能·计算机视觉·大语言模型·多模态·视觉理解
金蝶软件小李1 小时前
基于深度学习的猫狗识别
图像处理·深度学习·计算机视觉
__基本操作__4 小时前
边缘提取函数 [OPENCV--2]
人工智能·opencv·计算机视觉
这是一个图像4 小时前
从opencv-python入门opencv--图像处理之图像滤波
图像处理·opencv·计算机视觉·中值滤波·高斯滤波·双边滤波·图像滤波
新手小白勇闯新世界13 小时前
深度学习知识点5-马尔可夫链
人工智能·深度学习·计算机视觉
LittroInno15 小时前
TofuAI处理BT1120时序视频要求
深度学习·计算机视觉·tofu
Seeklike15 小时前
OpenCV图像预处理
人工智能·opencv·计算机视觉