Logo
Published on

3.1.FastAPI 接口认证

Authors
  • avatar
    Name
    xiaobai
    Twitter

1.学习目标

  1. 本节我们学习如何使用FastAPI来对我们开发好的接口做权限认证功能;
  2. 一般情况下接口开发完毕之后,大部分接口都是需要用户经过登录之后,通过用户的登录凭证token才能够请求我们的接口,也就是说,当我们将模型服务器接口开发完毕之后,其实需要有一套完善的接口认证过程,这里我们使用比较流行的Oauth2以及JWT来实现接口认证,保证我们的接口安全;
  3. Oauth2:是一种开放授权协议,也可以理解为是一种授权框架,用来定义客户端如何获取令牌,客户端可以通过这个令牌来获取服务器部分资源(通常是接口)的访问权限;
  4. JWT:全名Json Web Token,是一种令牌传输的格式,可以用作Oauth2中的访问令牌;不是所有Oauth2都使用JWT来实现令牌,但是JWT通常都是用来作为Oauth2的令牌使用;
  5. 总结:Auth2处理授权流程,JWT作为令牌格式传递用户信息和权限;

官方原文:https://fastapi.tiangolo.com/zh/tutorial/security/

2.学习内容

这里我们会从零开始,逐步实现Oauth2与JWT实现接口认证功能,主要有以下步骤:

  1. 创建用户表;
  2. 现用户登录注册、登录;
  3. 使用登录得到的token来请求需要认证的接口;

3.创建用户表

create or replace table plain.pl_user
(
    id            varchar(50) default uuid()              not null comment '编号' primary key,
    created_at    datetime    default current_timestamp() null comment '创建时间',
    created_by    varchar(50)                             null comment '创建人id',
    updated_at    datetime    default current_timestamp() null comment '最后更新时间',
    updated_by    varchar(50)                             null comment '最后更新人id',
    username      varchar(50)                             not null comment '用户名',
    full_name     varchar(20)                             not null comment '用户昵称',
    email         varchar(50)                             not null comment '用户邮箱',
    valid         varchar(1)                              not null comment '用户是否已经激活',
    hash_password varchar(200)                            not null comment '用户密码的哈希值',
    constraint pl_user_id_uindex
        unique (id)
)
    collate = utf8mb4_unicode_ci;

字段说明:

  1. username:用户登录的用户名
  2. full_name:用户昵称,也就是平时显示的用户名称
  3. email:邮箱地址,由于注册之后要发送邮件激活账号的话需要实名认证备案域名以及部署https证书,所以这里就没有做邮箱校验的功能,注册成功之后会返回激活账号的url访问地址来激活账号;
  4. valid:字段标识用户账号是否已经激活;
  5. hash_password:用户密码的哈希值,数据库中不会存放用户明文密码;

4.安装依赖

执行如下命令,安装如下三个依赖:

poetry add passlib[bcrypt] pyjwt python-multipart

或者往“pyproject.toml”文件中的依赖清单中增加如下内容,再安装依赖;

passlib = {extras = ["bcrypt"], version = "^1.7.4"}
pyjwt = "^2.10.1"
python-multipart = "^0.0.20"

依赖说明:

  1. passlib:用来处理密码哈希的python包,支持很多哈希算法及其配套工具,本次课件中使用的是“Bcrypt”;
  2. pyjwt:用来生成以及校验JWT令牌;
  3. python-multipart:用于解析 multipart/form-data 类型的请求体,Oauth2规范要求使用密码流时,客户端必须以表单数据的形式发送username以及password,这两个字段名不可变更;

5.环境变量配置

这里我们需要在“.env”以及“.env.exmaple”文件中增加如下环境变量:

  1. SERVER_DOMAIN:后端服务部署的域名,用来用户注册之后,生成激活账号的访问地址;
  2. JWT_SECRET_KEY:JWT秘钥;
  3. JWT_ALGORITHM:JWT加密使用的算法;
  4. JWT_ACCESS_TOKEN_EXPIRE_MINUTES:JWT访问令牌默认有效时间(分钟);
  5. JWT_GLOBAL_ENABLE:用来控制是否开启全局的接口认证;
  6. JWT_WHITE_LIST:开启全局接口认证的情况下,哪些接口属于白名单不需要认证;

此时环境变量文件中的完整内容如下所示:

DB_HOST=xxx.xxx.xxx.xxx 								                # 数据库连接ip地址
DB_PORT=xxx 											                      # 数据库连接端口
DB_USERNAME=xxx                                         # 数据库连接用户名
DB_PASSWORD=xxx                                         # 数据库连接密码
DB_DATABASE=xxx 										                    # 数据库连接的数据库名

LLM_KEY_LOCAL=123
LLM_KEY_HUOSHAN=a0311f2a-ba85-4428-b158-xxxxxxxxxxxx 	  # 火山引擎模型服务平台key
LLM_KEY_BAILIAN=sk-51d13ba8ea044d128c66dxxxxxxxxxxxx 	  # 阿里云百炼模型服务平台key
LLM_KEY_DEEPSEEK=sk-a89d0ff9421a43fca5f0xxxxxxxxxxxx 	  # Deepseek模型服务平台key

SERVER_PORT = 7002                                      # 服务启动端口
SERVER_DOMAIN = http://127.0.0.1                        # 后端服务部署的域名

JWT_SECRET_KEY = a2d41c19c766490a2bea4be9902ea9f60e44ca3d9ef82e8b803f77a4f9deb247 # JWT秘钥
JWT_ALGORITHM = HS256                                   # JWT加密算法
JWT_ACCESS_TOKEN_EXPIRE_MINUTES = 30                    # JWT访问令牌默认有效时间(分钟
JWT_GLOBAL_ENABLE = false                               # 是否开启全局JWT认证
JWT_WHITE_LIST = ["/token", "/login", "/registry", "/verify", "/async_delay"] # JWT认证白名单接口,不需要认证的接口

同样的,用来加载环境变量配置的文件“app/config/env.py”内容如下所示:

from typing import List

from pydantic_settings import BaseSettings
from pydantic import Field
from dotenv import load_dotenv


class Settings(BaseSettings):
  db_host: str = Field(..., env="DB_HOST")
  db_port: str = Field(..., env="DB_PORT")
  db_username: str = Field(..., env="DB_USERNAME")
  db_password: str = Field(..., env="DB_PASSWORD")
  db_database: str = Field(..., env="DB_DATABASE")

  llm_key_local: str = Field(..., env="LLM_KEY_LOCAL")
  llm_key_huoshan: str = Field(..., env="LLM_KEY_HUOSHAN")
  llm_key_bailian: str = Field(..., env="LLM_KEY_BAILIAN")
  llm_key_deepseek: str = Field(..., env="LLM_KEY_DEEPSEEK")

  server_port: str = Field(..., env="SERVER_PORT")
  server_domain: str = Field(..., env="SERVER_DOMAIN")

  jwt_secret_key: str = Field(..., env="JWT_SECRET_KEY")
  jwt_algorithm: str = Field(..., env="JWT_ALGORITHM")
  jwt_access_token_expire_minutes: int = Field(..., env="JWT_ACCESS_TOKEN_EXPIRE_MINUTES")
  jwt_global_enable: bool = Field(..., env="JWT_GLOBAL_ENABLE")
  jwt_white_list: List[str] = Field(..., env="JWT_WHITE_LIST")

  class Config:
    env_file = ".env"
    env_file_encoding = "utf-8"


load_dotenv(".env")  # 先加载 .env 文件
env = Settings()

6.加密工具类

新增文件“app/utils/CryptUtils.py”,内容如下所示,主要是基于jwt封装了用来加密解密token信息、生成密码哈希值以及校验明文密码登相关工具函数;

from datetime import timedelta, datetime, timezone

import jwt
from passlib.context import CryptContext

from app.config.env import env

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


class CryptUtils:
  @staticmethod
  def get_password_hash(password):
    """
    对密码进行哈希加密处理

    Args:
        password (str): 需要加密的原始密码字符串

    Returns:
        str: 经过哈希加密后的密码字符串
    """
    return pwd_context.hash(password)

  @staticmethod
  def verify_password(plain_password, hashed_password):
    """
    验证密码是否正确

    通过比较明文密码和哈希密码来验证密码是否匹配

    参数:
        plain_password (str): 明文密码
        hashed_password (str): 哈希后的密码

    返回:
        bool: 密码验证结果,True表示密码正确,False表示密码错误
    """
    return pwd_context.verify(plain_password, hashed_password)

  @staticmethod
  def create_access_token(username: str, expires_delta: timedelta | None = None):
    data: dict = {"sub": username}
    to_encode = data.copy()
    if expires_delta:
      expire = datetime.now(timezone.utc) + expires_delta
    else:
      expire = datetime.now(timezone.utc) + timedelta(minutes=env.jwt_access_token_expire_minutes)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, env.jwt_secret_key, algorithm=env.jwt_algorithm)
    return encoded_jwt

  @staticmethod
  def get_username_from_token(token: str):
    payload = jwt.decode(token, env.jwt_secret_key, algorithms=[env.jwt_algorithm])
    username = payload.get("sub")
    return username

7.认证具体实现

7.1.准备文件

新增文件“app/routes/user/add_user_route.py”,我们将用户登录注册相关的路由都放在这个文件中;

7.2.类型定义

在这个“add_user_route.py”文件中,我们增加如下类型定义:

class UserValidStatus(str, Enum):
  Y = "Y"
  N = "N"


class PublicUser(BasicModel):
  username: str = Field(..., description="用户名")
  full_name: str = Field(..., description="用户昵称")
  email: str = Field(..., description="用户邮箱地址")
  valid: UserValidStatus = Field(default=UserValidStatus.N, description="用户账号是否已经激活生效")


class UserModel(PublicUser, table=True):
  __tablename__ = "pl_user"
  hash_password: str = Field(..., description="经过哈希转换的密码")


class RegistryUser(PublicUser):
  password: str = Field(..., description="明文密码")


class Token(BaseModel):
  access_token: str
  token_type: str
  1. UserValidStatus:具体的枚举类型,用来
  2. PublicUser:接口返回用户信息类,主要是排除了哈希密码字段值;
  3. UserModel:用来对数据库表进行增删改查的model类;
  4. RegistryUser:注册用户信息时的用户信息;
  5. Token:登录之后返回的token信息;

7.3.工具函数

在这个“add_user_route.py”文件中,我们增加如下工具函数:

  1. get_user_by_username:通过username查询得到user信息;
  2. authenticate_user:校验用户名密码,校验成功则返回用户信息,否则返回False;
  3. get_current_user:通过注入的token信息以及session来查询返回用户信息;
async def get_user_by_username(username: str, session: AsyncSessionDep):
  query = (
    select(UserModel)
    .where(UserModel.username == username)
    .where(UserModel.valid == UserValidStatus.Y)
  )
  result = await session.execute(query)
  user_model: UserModel | None = result.scalars().first()
  return user_model


async def authenticate_user(
  session: AsyncSessionDep,
  username: str,
  password: str,
):
  user_model = await get_user_by_username(username, session)
  if not user_model:
    return False
  if not CryptUtils.verify_password(password, user_model.hash_password):
    return False
  public_user = PublicUser(**user_model.model_dump())
  return public_user


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


unauthorized_exception = HTTPException(
  status_code=status.HTTP_401_UNAUTHORIZED,
  detail="The token is invalid or has expired",
  headers={"WWW-Authenticate": "Bearer"},
)


async def get_current_user(session: AsyncSessionDep, token: str = Depends(oauth2_scheme)):
  try:
    username = CryptUtils.get_username_from_token(token)
    if username is None:
      raise unauthorized_exception
  except InvalidTokenError:
    raise unauthorized_exception
  user_model = await get_user_by_username(username, session)
  if user_model is None:
    raise unauthorized_exception

  public_user = PublicUser(**user_model.model_dump())
  return public_user

7.4.注册接口

如下所示为注册接口的具体实现代码:

  1. 接受到注册的用户信息之后,先检查用户名以及邮箱是否已经存在;
  2. 然后将用户信息新增到用户表中;
  3. 生成激活用户账号的token,并且将这个token生成一个激活账号的url访问地址,访问这个url地址就会激活这个账号;
  @app.post("/registry")
  async def registry(registry_user: RegistryUser, session: AsyncSessionDep):
    """通过用户名、密码、邮箱、用户昵称注册账号"""
    hash_password = CryptUtils.get_password_hash(registry_user.password)
    username = registry_user.username
    # /*---------------------------------------用户名是否已经存在-------------------------------------------*/
    query = select(UserModel).where(UserModel.username == username)
    result = await session.execute(query)
    item_cls = result.scalars().first()

    if item_cls:
      return {"result": None, "error": f"用户名'{registry_user.username}'已经存在"}

    # /*---------------------------------------邮箱是否已经注册-------------------------------------------*/

    query = select(UserModel).where(UserModel.email == registry_user.email)
    result = await session.execute(query)
    item_cls = result.scalars().first()

    if item_cls:
      return {"result": None, "error": f"邮箱'{registry_user.email}'已经被注册"}

    # /*---------------------------------------注册用户信息-------------------------------------------*/

    user = UserModel(
      username=registry_user.username,
      full_name=registry_user.full_name,
      email=registry_user.email,
      hash_password=hash_password,
      valid='N'
    )
    user.id = await next_id(1)

    session.add(user)
    # 提交事务(保存到数据库)
    await session.commit()
    # 刷新实例,获取数据库生成的最新数据(如自动更新的时间字段)
    await session.refresh(user)

    public_user = PublicUser(**user.model_dump())

    # 验证用户账号的token 3年内有效
    verify_token = CryptUtils.create_access_token(data={"sub": user.username}, expires_delta=timedelta(days=365 * 3))

    return {
      # 返回用户信息
      "result": public_user,
      # 访问这个url地址就可以激活账号
      "valid_url": f"{env.server_domain}:{env.server_port}/verify?token={verify_token}"
    }

7.5.激活账号

如下所示,定义一个叫做“verify”的get接口,访问这个接口通过查询参数传递token信息,然后通过解析token中的username,最后使用这个username来激活账号;

  @app.get("/verify")
  async def verify(session: AsyncSessionDep, token: str = Query(..., description="激活用户账号的token")):
    """激活账号"""
    username = CryptUtils.get_username_from_token(token)
    if not username:
      return {"error": "token无效或者已经过期"}
    query = select(UserModel).where(UserModel.username == username)
    result = await session.execute(query)
    item_cls: UserModel = result.scalars().first()
    if item_cls.valid == UserValidStatus.N:
      print("激活账号:", username)
      item_cls.valid = UserValidStatus.Y
      session.add(item_cls)
      await session.commit()
      await session.refresh(item_cls)
      return {
        "result": f"账号'{username}'激活成功",
      }
    else:
      return {
        "result": f"账号'{username}'已经激活",
      }

7.6.登录接口

如下所示:使用依赖注入的方式得到表单参数中的username以及password,校验用户名密码成功之后,返回访问令牌token;

  @app.post("/login")
  @app.post("/token")
  async def login(
    session: AsyncSessionDep,
    form_data: OAuth2PasswordRequestForm = Depends(),
  ):
    user = await authenticate_user(session, form_data.username, form_data.password)
    if not user:
      raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Incorrect username or password",
        headers={"WWW-Authenticate": "Bearer"},
      )
    access_token_expires = timedelta(minutes=env.jwt_access_token_expire_minutes)
    access_token = CryptUtils.create_access_token(username=user.username, expires_delta=access_token_expires)
    return Token(access_token=access_token, token_type="bearer")

7.7.接口认证

  • 如下所示,如果一个接口需要token认证,那么只需要定义参数注入当前用户信息即可;
  • 如果注入用户信息失败则说明校验不通过;
  • 如果大部分接口都需要token认证,那么请看下文中的中间件处理;
  @app.get("/users/me", response_model=PublicUser)
  async def read_users_me(current_user: PublicUser = Depends(get_current_user)):
    return current_user

7.8.完整代码

最后不要忘了将这个“add_user_route”函数在“server.py”中调用;

from datetime import timedelta
from enum import Enum

from fastapi import FastAPI, HTTPException
from fastapi.params import Query, Depends
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from jwt import InvalidTokenError
from pydantic import BaseModel
from sqlmodel import select, Field
from starlette import status

from app.config.env import env
from app.models.BasicModel import BasicModel
from app.utils.CryptUtils import CryptUtils
from app.utils.db_utils import AsyncSessionDep
from app.utils.next_id import next_id


class UserValidStatus(str, Enum):
  Y = "Y"
  N = "N"


class PublicUser(BasicModel):
  username: str = Field(..., description="用户名")
  full_name: str = Field(..., description="用户昵称")
  email: str = Field(..., description="用户邮箱地址")
  valid: UserValidStatus = Field(default=UserValidStatus.N, description="用户账号是否已经激活生效")


class UserModel(PublicUser, table=True):
  __tablename__ = "pl_user"
  hash_password: str = Field(..., description="经过哈希转换的密码")


class RegistryUser(PublicUser):
  password: str = Field(..., description="明文密码")


class Token(BaseModel):
  access_token: str
  token_type: str


def add_user_route(app: FastAPI):
  @app.post("/login")
  @app.post("/token")
  async def login(
    session: AsyncSessionDep,
    form_data: OAuth2PasswordRequestForm = Depends(),
  ):
    user = await authenticate_user(session, form_data.username, form_data.password)
    if not user:
      raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Incorrect username or password",
        headers={"WWW-Authenticate": "Bearer"},
      )
    access_token_expires = timedelta(minutes=env.jwt_access_token_expire_minutes)
    access_token = CryptUtils.create_access_token(username=user.username, expires_delta=access_token_expires)
    return Token(access_token=access_token, token_type="bearer")

  @app.post("/registry")
  async def registry(registry_user: RegistryUser, session: AsyncSessionDep):
    """通过用户名、密码、邮箱、用户昵称注册账号"""
    hash_password = CryptUtils.get_password_hash(registry_user.password)
    username = registry_user.username
    # /*---------------------------------------用户名是否已经存在-------------------------------------------*/
    query = select(UserModel).where(UserModel.username == username)
    result = await session.execute(query)
    item_cls = result.scalars().first()

    if item_cls:
      return {"result": None, "error": f"用户名'{registry_user.username}'已经存在"}

    # /*---------------------------------------邮箱是否已经注册-------------------------------------------*/

    query = select(UserModel).where(UserModel.email == registry_user.email)
    result = await session.execute(query)
    item_cls = result.scalars().first()

    if item_cls:
      return {"result": None, "error": f"邮箱'{registry_user.email}'已经被注册"}

    # /*---------------------------------------注册用户信息-------------------------------------------*/

    user = UserModel(
      username=registry_user.username,
      full_name=registry_user.full_name,
      email=registry_user.email,
      hash_password=hash_password,
      valid='N'
    )
    user.id = await next_id(1)

    session.add(user)
    # 提交事务(保存到数据库)
    await session.commit()
    # 刷新实例,获取数据库生成的最新数据(如自动更新的时间字段)
    await session.refresh(user)

    public_user = PublicUser(**user.model_dump())

    # 验证用户账号的token 3年内有效
    verify_token = CryptUtils.create_access_token(data={"sub": user.username}, expires_delta=timedelta(days=365 * 3))

    return {
      # 返回用户信息
      "result": public_user,
      # 访问这个url地址就可以激活账号
      "valid_url": f"{env.server_domain}:{env.server_port}/verify?token={verify_token}"
    }

  @app.get("/verify")
  async def verify(session: AsyncSessionDep, token: str = Query(..., description="激活用户账号的token")):
    """激活账号"""
    username = CryptUtils.get_username_from_token(token)
    if not username:
      return {"error": "token无效或者已经过期"}
    query = select(UserModel).where(UserModel.username == username)
    result = await session.execute(query)
    item_cls: UserModel = result.scalars().first()
    if item_cls.valid == UserValidStatus.N:
      print("激活账号:", username)
      item_cls.valid = UserValidStatus.Y
      session.add(item_cls)
      await session.commit()
      await session.refresh(item_cls)
      return {
        "result": f"账号'{username}'激活成功",
      }
    else:
      return {
        "result": f"账号'{username}'已经激活",
      }

  @app.get("/users/me", response_model=PublicUser)
  async def read_users_me(current_user: PublicUser = Depends(get_current_user)):
    return current_user


async def get_user_by_username(username: str, session: AsyncSessionDep):
  query = (
    select(UserModel)
    .where(UserModel.username == username)
    .where(UserModel.valid == UserValidStatus.Y)
  )
  result = await session.execute(query)
  user_model: UserModel | None = result.scalars().first()
  return user_model


async def authenticate_user(
  session: AsyncSessionDep,
  username: str,
  password: str,
):
  user_model = await get_user_by_username(username, session)
  if not user_model:
    return False
  if not CryptUtils.verify_password(password, user_model.hash_password):
    return False
  public_user = PublicUser(**user_model.model_dump())
  return public_user


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

unauthorized_exception = HTTPException(
  status_code=status.HTTP_401_UNAUTHORIZED,
  detail="The token is invalid or has expired",
  headers={"WWW-Authenticate": "Bearer"},
)


async def get_current_user(session: AsyncSessionDep, token: str = Depends(oauth2_scheme)):
  try:
    username = CryptUtils.get_username_from_token(token)
    if username is None:
      raise unauthorized_exception
  except InvalidTokenError:
    raise unauthorized_exception
  user_model = await get_user_by_username(username, session)
  if user_model is None:
    raise unauthorized_exception

  public_user = PublicUser(**user_model.model_dump())
  return public_user

8.中间件

8.1.自定义中间件

  • 新增文件“app/middlewares/app_middlewares.py”,我们将所有新增的中间件都通过这个“add_middlewares”函数来注册;最后不要忘了在“server.py”文件中调用这个函数;
  • 如下所示,我们自定义了一个中间件,我们可以在这个中间件中,访问参数request来处理请求信息,比如请求头、请求体参数等等;可以通过执行异步函数call_next来执行下一步内容;
  • 这里这个示例我们统计接口的执行耗时,最后将这个耗时写入到响应头的“X-Process-Time”字段中;
  @app.middleware("http")
  async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = f"time: {process_time}s"
    return response

8.2.认证中间件

认证中间件处理的内容:

  1. 读取环境变量,判断是否需要全局开启接口认证;
  2. 全局开启接口认证的情况下,排除掉白名单中的接口;
  3. 接口认证处理token以及session,获取当前用户信息,将用户信息保存到request.state中;
  4. 如果认证失败,抛出“unauthorized_exception”异常;
  @app.middleware("http")
  async def add_oauth_middleware(request: Request, call_next):

    # 没有开启全局JWT认证
    if not env.jwt_global_enable:
      return await call_next(request)

    # 当前接口为JWT认证白名单接口
    if request.url.path in env.jwt_white_list:
      return await call_next(request)

    token: str | None = None
    auth_header = request.headers.get("Authorization")  # 从请求头获取Authorization
    if auth_header and auth_header.startswith("Bearer "):
      # 提取 Bearer 后的 token(格式:Bearer <token> → 取索引1的部分)
      token = auth_header.split(" ")[1].strip()

    if not token:
      raise unauthorized_exception

    try:
      session = async_session()
      public_user = await get_current_user(session, token)
      request.state.user = public_user
      request.state.token = token
    except InvalidTokenError:
      return unauthorized_exception

    response = await call_next(request)

    return response

8.3.认证异常处理中间件

前面的认证失败之后都是抛出异常,这里在中间件中捕获这个异常,当抛出的是401认证失败异常时,返回一个HttpResponse,状态吗为401,优化响应结果:

  @app.middleware("http")
  async def catch_authorized(request: Request, call_next):
    try:
      response = await call_next(request)
    except HTTPException as e:
      if e.status_code == status.HTTP_401_UNAUTHORIZED:
        return JSONResponse(content=e.detail, status_code=status.HTTP_401_UNAUTHORIZED)
      else:
        raise e
    return response

8.4.完整代码

import time

from fastapi import FastAPI, HTTPException
from jwt import InvalidTokenError
from starlette import status
from starlette.requests import Request
from starlette.responses import JSONResponse

from app.config.env import env
from app.routes.user.add_user_route import get_current_user, unauthorized_exception
from app.utils.db_utils import async_session


def add_middlewares(app: FastAPI):
  @app.middleware("http")
  async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = f"time: {process_time}s"
    return response

  @app.middleware("http")
  async def add_oauth_middleware(request: Request, call_next):

    # 没有开启全局JWT认证
    if not env.jwt_global_enable:
      return await call_next(request)

    # 当前接口为JWT认证白名单接口
    if request.url.path in env.jwt_white_list:
      return await call_next(request)

    token: str | None = None
    auth_header = request.headers.get("Authorization")  # 从请求头获取Authorization
    if auth_header and auth_header.startswith("Bearer "):
      # 提取 Bearer 后的 token(格式:Bearer <token> → 取索引1的部分)
      token = auth_header.split(" ")[1].strip()

    if not token:
      raise unauthorized_exception

    try:
      session = async_session()
      public_user = await get_current_user(session, token)
      request.state.user = public_user
      request.state.token = token
    except InvalidTokenError:
      return unauthorized_exception

    response = await call_next(request)

    return response

  @app.middleware("http")
  async def catch_authorized(request: Request, call_next):
    try:
      response = await call_next(request)
    except HTTPException as e:
      if e.status_code == status.HTTP_401_UNAUTHORIZED:
        return JSONResponse(content=e.detail, status_code=status.HTTP_401_UNAUTHORIZED)
      else:
        raise e
    return response

9.async_session

如下示例所示:

  1. async_session既能够使用“async with”调用,实现自动关闭连接,也能够使用“async_session()”创建连接,最后finally中手动关闭连接的一个原理简单示例说明;
  2. “async_session()”返回的是一个符合异步上下文管理器协议的对象,这个对象的"aenter“函数中最后return的是它自己对象实例,并且在“aexit”函数中自动调用了自身的关闭连接方法;
  3. 这样就实现了既可以通过“async with”来调用“async_session”函数,也可以手动执行这个函数获得session对象,手动处理连接关闭动作;
import asyncio


class async_session:

  async def execute(selfm, sql: str):
    await asyncio.sleep(0.2)
    print("执行sql语句:", sql)

  def close(self, name: str):
    print(f"{name}:关闭数据库连接")

  # 实现__enter__方法:对应生成器中yield之前的代码
  async def __aenter__(self):
    await asyncio.sleep(0.2)
    return self

  # 实现__exit__方法:对应生成器中yield之后的代码
  async def __aexit__(self, exc_type, exc_val, exc_tb):
    self.close("自动")


async def way_01():
  print("start")
  async with async_session() as session:
    await session.execute("hello")
  print("end")


async def way_02():
  print("start")
  session = async_session()
  try:
    await session.execute("world")
  finally:
    session.close("手动")
  print("end")


async def main():
  await way_01()
  await way_02()


asyncio.run(main())