- Published on
2.1.FastAPI 增删改查
- Authors

- Name
- xiaobai
1.概念介绍
- 这里我们将介绍如何在FastAPI中连接mysql数据库对表进行增删改查;
- 首先我们需要安装两个包:sqlmodel以及asyncmy;
- asyncmy:是一个纯Python实现的异步mysql客户端连接库,专为asyncio设计;能够支持绝大部分mysql的功能特性,比如存储过程,事务处理等等;
- sqlmodel:是一个结合了sqlalchemy和pydantic的Python库,是一个现代化的ORM框架;可以用面向对象的方式来实现数据增删改查;
- sqlalchemy:一个比较古老的数据库操作工具,最早出现在2005年,提供比较全面的数据库访问能力;
- sqlmodel与sqlalchemy的区别在于,sqlmodel在继承sqlalchemy所有功能的基础上,全面接入pydantic类型定义,内置pydantic数据验证功能,以及自动处理数据转换;在fastapi中,sqlmodel可以与fastapi无缝集成,但是sqlalchemy则需要手动处理pydantic类型转换;
- 我们执行如下命令,安装这两个包:
poetry add asyncmy sqlmodel
2.数据库连接工具
我们创建一个文件“app/utils/db_utils.py”,这个文件专门存放数据库连接相关的变量以及工具函数,代码如下
# 异步引擎 (用于生产环境)
from typing import Annotated
from fastapi import Depends
from sqlalchemy import text, AsyncAdaptedQueuePool
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.orm import sessionmaker
from sqlmodel import create_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from app.config.env import env
# 构建MySQL连接URL,拼接数据库连接所需的各项参数
DATABASE_URL = f"mysql+asyncmy://{env.db_username}:{env.db_password}@{env.db_host}:{env.db_port}/{env.db_database}?charset=utf8mb4"
# 创建异步引擎实例,用于异步操作数据库
async_engine = AsyncEngine(create_engine(
DATABASE_URL,
poolclass=AsyncAdaptedQueuePool, # 使用异步适配的队列池
pool_size=5, # 连接池保持的连接数
max_overflow=10, # 允许超过pool_size的最大连接数
pool_timeout=30, # 获取连接的超时时间(秒)
pool_recycle=3600, # 连接回收时间(秒)
echo=True, # 启用SQL语句日志输出,便于开发调试
future=True, # 启用SQLAlchemy 2.0风格的未来模式API
))
# 创建异步会话工厂,用于生成数据库会话对象
async_session = sessionmaker(
bind=async_engine, # 绑定到之前创建的异步引擎
class_=AsyncSession, # 指定会话类为AsyncSession,支持异步操作
expire_on_commit=False # 提交后不使对象过期,避免自动刷新
)
# 定义获取异步会话的依赖函数,供FastAPI接口使用
async def get_async_session() -> AsyncSession:
# 使用异步上下文管理器创建会话,确保会话正确关闭
async with async_session() as session:
# 生成会话实例,供接口处理函数使用
yield session
# 创建带依赖的注解类型,简化接口函数中会话参数的声明
AsyncSessionDep = Annotated[AsyncSession, Depends(get_async_session)]
# 定义检查数据库连接的异步函数
async def check_database_connection():
"""检查MySQL连接是否可用"""
try:
# 使用异步引擎创建连接并开启事务
async with async_engine.begin() as conn:
# 执行简单查询语句检查连接是否正常
await conn.execute(text("SELECT 1"))
# 打印连接成功信息及连接URL
print("✅ Database connection successful:", DATABASE_URL)
except Exception as e:
# 打印连接失败信息及错误详情
print(f"❌ Database connection failed: {e}")
# 重新抛出异常,让上层处理
raise
# 返回异步引擎实例
return async_engine
这个代码中我们做了如下几个事情:
- 定义了数据库连接驱动的URL地址,也就是“DATABASE_URL”;
- 创建了异步引擎实例“async_engine”,这个实例引擎会自动管理连接池,处理每个连接的sql调用,控制sql日志输出等等;
- 创建了异步会话工厂“async_session”,用来创建异步会话,后续的sql执行都是通过这个异步会话来执行;
- 定义了注解类型“AsyncSessionDep”,用于在fastapi中自动注入得到这个异步会话对象;
- 定义了一个叫做“check_database_connection”的一个函数,用于启动服务时,立即执行sql语句以判断数据库连接是否正常;
3.启动时检查数据库连接
我们需要修改我们之前设置的lifespan函数,在应用启动的时候测试数据库连接,代码如下所示:
@asynccontextmanager
async def lifespan(app: FastAPI):
# 1. yield之前:启动服务的准备工作
print("lifespan 启动服务:加载模型、连接数据库...", type(app))
# (比如加载AI模型、连接数据库等)
async_engine = await check_database_connection()
# 2. yield:暂停,让服务开始处理请求
yield # 分界线:上面是启动,下面是关闭
# 3. yield之后:关闭服务的收拾工作
print("lifespan 关闭服务:断开数据库、释放模型...", type(app))
# (比如断开连接、释放内存等)
await async_engine.dispose()
4.测试数据库连接
我们这里新增两个接口:
- query_llm_user_list:查询数据库中llm_user这张表的所有数据,没有参数;
- query_llm_user:查询数据库中llm_user这张表指定username的数据,有一个查询参数username;
代码如下所示:
@app.get("/query_llm_user_list")
async def query_llm_user_list(session: AsyncSessionDep):
result = await session.execute(text("select * from llm_user"))
return [dict(row._mapping) for row in result]
@app.get("/query_llm_user")
async def query_llm_user(username: str, session: AsyncSessionDep):
result = await session.execute(text("select * from llm_user where username = :username"), {"username": username})
list = [dict(row._mapping) for row in result]
return {"result": list[0]} if len(list) == 1 else {"result": None}
然后我们使用接口调试工具请求这个“query_llm_user_list”接口,结果应该如下所示能够看到所有的数据:

当调用接口“query_llm_user“接口,结果如下所示,指定username可以查到对应用户信息:

并且如果采用sql注入的方式来设置查询参数,那么会查不到数据:

5.面向对象操作数据库
这里我们介绍如何使用sqlmodel提供的面向对象的方式来操作数据库,不需要手写sql,首先我们准备一个叫做“app/models/BasicModel.py”的文件,代码如下所示:
- 这个文件的作用对应我们表中的几个标准字段:id、created_at、created_by、updated_at、updated_by
- 我们后面定义的model类,只需要继承这个基类,就不需要写上这几个标准字段了;同时这个基类还配置了Config类,作用是将类中“date、datetime”类型的字段,在转化成json时转化成对应格式的日期时间字符串返回给前端;
- 注意的是这里继承的是sqlmodel的SQLModel,而不是pydantic的BaseModel,但是SQLModel本身也是继承了pydantic的BaseModel;
from datetime import datetime, timezone, timedelta, date
from pydantic import model_validator
from sqlmodel import SQLModel, Field
# 定义北京时区(UTC+8)
beijing_timezone = timezone(timedelta(hours=8))
# 定义获取当前北京时区时间的匿名函数,用于默认值生成
current_datetime = lambda: datetime.now(beijing_timezone)
# 定义基础模型类,所有其他模型类的父类,包含通用字段和配置
class BasicModel(SQLModel):
# 唯一标识字段,主键,默认为None(通常由系统生成),描述为“唯一标识,编号”
id: str = Field(default=None, primary_key=True, description="唯一标识,编号")
# 创建时间字段,默认值为当前北京时区时间,描述为“创建时间”
created_at: datetime = Field(default_factory=current_datetime, description="创建时间")
# 更新时间字段,默认值为当前北京时区时间,描述为“更新时间”
updated_at: datetime = Field(default_factory=current_datetime, description="更新时间")
# 创建人ID字段,默认为None,描述为“创建人id”
created_by: str = Field(default=None, description="创建人id")
# 更新人ID字段,默认为None,描述为“更新人id”
updated_by: str = Field(default=None, description="更新人id")
# 模型配置类,用于设置JSON序列化等配置
class Config:
# 定义datetime和date类型的JSON编码器,将其格式化为指定字符串
json_encoders = {
# 若为datetime类型,格式化为“年-月-日 时:分:秒”,若为None则保持None
datetime: lambda dt: dt.strftime("%Y-%m-%d %H:%M:%S") if dt is not None else None,
# 若为date类型,格式化为“年-月-日”,若为None则保持None
date: lambda dt: dt.strftime("%Y-%m-%d") if dt is not None else None
}
# 定义模型验证器,在数据解析前(mode='before')执行,用于处理字符串格式的日期时间
@model_validator(mode='before')
def parse_string_datetimes(cls, data: dict) -> dict:
# 处理datetime类型字段:将字符串格式的日期时间转换为datetime对象
datetime_fields = {
k: datetime.strptime(v, "%Y-%m-%d %H:%M:%S") # 使用strptime解析字符串为datetime
for k, v in data.items() # 遍历输入数据的键值对
if isinstance(v, str) # 只处理值为字符串的项
and k in cls.model_fields # 键必须是模型中定义的字段
and cls.model_fields[k].annotation is datetime # 字段的注解类型是datetime
}
# 处理date类型字段:将字符串格式的日期转换为date对象(通过datetime解析后取date部分)
date_fields = {
k: datetime.strptime(v, "%Y-%m-%d").date() # 使用strptime解析字符串为datetime后取date
for k, v in data.items() # 遍历输入数据的键值对
if isinstance(v, str) # 只处理值为字符串的项
and k in cls.model_fields # 键必须是模型中定义的字段
and cls.model_fields[k].annotation is date # 字段的注解类型是date
}
# 打印转换后的datetime字段,用于调试
# print("datetime_fields", datetime_fields)
# 合并原始数据、转换后的datetime字段和date字段,转换后的字段会覆盖原始数据中的对应键
result = {**data, **datetime_fields, **date_fields}
# 打印合并后的结果,用于调试
# print("result", result)
# 返回处理后的数据字典
return result
准备文件“app/models/LlmUser.py”,代码如下所示:
- 我们准备基于数据库中的表“llm_user”来进行测试数据库操作,这里我们创建对应的model类“LlmUser”
- “tablename”是设置这个类对应在数据库中的表名,如果不设置,按照类名“LlmUser”得到的默认表名是“llmuser”;
- 其余几个字段是我们额外设置的“llm_user”表中额外的字段;
from datetime import datetime
from sqlmodel import Field
from app.models.BasicModel import BasicModel
class LlmUser(BasicModel, table=True):
__tablename__ = "llm_user"
full_name: str = Field(default=None)
username: str = Field(default=None)
password: str = Field(default=None)
member_start: datetime = Field(default=None)
member_end: datetime = Field(default=None)
准备文件“app/utils/PageQueryParams.py”,代码如下所示:
- 这个文件继承pydantic的BaseModel,用来作为类型定义以及校验接收请求体参数,这个参数就是分页查询参数;
from pydantic import BaseModel, Field
class PageQueryParams(BaseModel):
page: int = Field(default=0, description="分页查询,表示第几页")
page_size: int = Field(default=10, description="分页查询,每页多少条数据")
all: bool = Field(default=False, description="不分页查询所有数据")
filters: dict = Field(default=None, description="筛选字典对象")
准备接口,如下所示,我们新增如下两个接口,用来对比刚刚手动执行sql的语法差异:
- llm_user_list:是查询所有数据的接口;
- llm_user:是根据username查询指定用户的接口;
@app.post("/llm_user_list")
async def query_llm_user_list(param: PageQueryParams, session: AsyncSessionDep):
result = await session.execute(select(LlmUser).offset(param.page * param.page_size).limit(param.page_size))
return result.scalars().all()
@app.get("/llm_user")
async def query_llm_user(username: str, session: AsyncSessionDep):
result = await session.execute(select(LlmUser).where(LlmUser.username == username))
return {"result": result.scalars().first()}
这里我们调用接口llm_user_list,得到的结果示例如下所示:

调用查询指定用户名的用户:

6.行记录ID生成
在新建数据时,新建的数据的主键ID的生成现在一般已经不用自增ID的方法了,取而代之的是一般用的Mysql数据自带的“uuid”方法,这里我们封装这个工具函数及其对应的接口,用来前后端新建数据时生成行记录的ID;比如如下sql语句,执行之后就可以生成一个id:
select uuid() as _1

如下生成多个ID:
select uuid() as _1, uuid() as _2, uuid() as _3

新建一个文件“app/utils/next_id.py”,代码如下所示:
- next_id:用来生成id;
- add_route_for_next_id:用来注册一个名为“/next_id”的接口,调用这个接口会调用next_id函数生成id返回调用端;
from fastapi import FastAPI
from sqlalchemy import text
from app.utils.db_utils import async_session
async def next_id(num: int = 1):
# 使用异步上下文管理器创建会话,确保会话正确关闭
async with async_session() as session:
sql_string = ", ".join([f"uuid() as _{index}" for index in range(num)])
result = await session.execute(text(f"SELECT {sql_string}"))
# 获取第一行结果中的所有值
result = list(result.first() or [])
return result[0] if num == 1 else result
def add_route_for_next_id(app: FastAPI):
@app.get("/next_id")
async def _next_id(num: int = 1):
return {
"data": await next_id(num)
}
调用示例如下所示,只获取一个id时:

获取多个id时:

7.新建数据
如下所示,我们定义一个名为“/llm_user/insert”的一个接口,用来获取请求参数作为用户信息,然后新建一条用户记录到表中:
@app.post("/llm_user/insert")
async def insert_llm_user(user: LlmUser, session: AsyncSessionDep):
if user.id is None:
user.id = await next_id()
session.add(user)
await session.commit()
await session.refresh(user)
return {"result": user}
然后我们使用接口调试工具调用这个接口:

此时数据库表中,应该有对应的记录,并且字段值应该是一一对应的:

8.更新数据
如下所示,我们直接创建更新接口:
@app.post("/llm_user/update")
async def update_llm_user(user: dict, session: AsyncSessionDep):
# 先查询要更新的对象
update_user = (await session.exec(select(LlmUser).where(LlmUser.id == user.get('id')))).first()
if not update_user:
raise HTTPException(status_code=500, detail="Update row not found")
# 按字段更新需要的字段
for key, value in user.items():
setattr(update_user, key, value)
session.add(update_user)
await session.commit()
await session.refresh(update_user)
return {"result": update_user}
执行结果如下所示,这里我们仅更新“full_name”字段:

9.删除数据
如下所示,我们直接创建删除接口:
@app.post("/llm_user/delete")
async def delete_llm_user(user: dict, session: AsyncSessionDep):
# 先查询要更新的对象
delete_user = (await session.exec(select(LlmUser).where(LlmUser.id == user.get('id')))).first()
if not delete_user:
return {"result": False}
await session.delete(delete_user)
await session.commit()
return {"result": True}
如下所示,当我们调用这个删除接口,应该能够返回删除成功还是失败的标识:

10.通用接口:介绍
可以看到上面的代码,要给一张表生成对应的增删改查接口,其实大多数情况下写法都是大同小异的,这里为了能够省略掉平时开发过程中的时间,我们把这个封装成标准功能;
然后我们新建一个文件“app/models/LlmProduct.py”,代码如下所示:
- 刚刚我们基于“llm_user”来演示实现增删改查功能,这里我们演示基于“llm_product”表来实现增删改查功能;
- 这里这个LlmProduct,就是用来操作“llm_product”表的model类;
from datetime import datetime
from sqlmodel import Field
from app.models.BasicModel import BasicModel
class LlmProduct(BasicModel, table=True):
__tablename__ = "llm_product"
name: str = Field(default=None)
price: float = Field(default=None)
valid_start: datetime = Field(default=None)
valid_end: datetime = Field(default=None)
然后是我们的核心实现函数“add_module_route”,功能如下所示:
- 这个“create_model_service”函数会一次性新增“list”、“item”、“insert”、“update”、“delete”、“batch_insert”、“batch_update”、“batch_delete”8个接口;
- list接口除了可以分页查询之外,还支持通过filters来控制筛选参数;
- item接口的参数就是筛选参数;
- insert接口的参数这里我们也改成dict,为了与update、delete保持一致;
- update接口为了能够实现按需更新字段值,所以请求参数也是dict;
- delete接口参数为了保持统一,也是使用dict接收参数,但是为了某些业务场景下的方便使用,这个dict中可以仅含有id属性值;
- 其余的批量操作接口以此类推;
代码如下所示,这个代码只是将函数列了出来,并没有实现函数体的逻辑,后续的部分会逐一讲解并且实现完整:
def create_model_service(
app: FastAPI, # FastAPI实例,用来注册路由
path: str, # 路由前缀地址
Cls: Type[BasicModel], # model实体类
end_points: List[str] = None, # 生成的端点入口接口清单
):
# 定义模型服务类,封装模型相关的CRUD接口及业务逻辑
class ModelService:
# 支持的所有端点列表,包含常用的CRUD及批量操作
END_POINTS = ['list', 'item', 'insert', 'batch_insert', 'update', 'batch_update', 'delete', 'batch_delete']
def __init__(self):
# 做一些初始化动作,包括注册接口路由
pass
# 检查字典中的键是否为模型类的有效属性
# 参数:
# row_dict: 待检查的字典(通常为请求参数)
def check_invalid_keys(self, row_dict: dict):
pass
# 分页查询工具方法:执行带过滤和分页的查询
async def query_list(self, query_param: PageQueryParams, session: AsyncSessionDep):
pass
# 单条查询工具方法:根据条件查询单条记录
async def query_item(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"查询数据的字段筛选值,字段参考{Cls.__name__}")):
pass
# 单条插入工具方法:新增一条记录
async def item_insert(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")):
pass
# 批量插入工具方法:批量新增记录
async def batch_insert(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量插入的数据数组,字段参考{Cls.__name__}")):
pass
# 单条更新工具方法:更新一条记录
async def item_update(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"更新的数据,字段参考{Cls.__name__}")):
pass
# 批量更新工具方法:批量更新记录
async def batch_update(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量更新的数据数组,字段参考{Cls.__name__}")):
pass
# 单条删除工具方法:删除一条记录
async def item_delete(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"删除的数据,字段参考{Cls.__name__}")):
pass
# 批量删除工具方法:批量删除记录
async def batch_delete(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量删除的数据数组,字段参考{Cls.__name__}")):
pass
# 创建并返回ModelService实例
return ModelService()
11.通用接口:分页查询
如下代码所示,这个就是我们通用的分页查询接口,在支持通过page以及size设置页数以及页大小的前提下,还支持设置filters来对部分字段做简单筛选;
# 分页查询工具方法:执行带过滤和分页的查询
async def query_list(self, query_param: PageQueryParams, session: AsyncSessionDep):
if before_query_list is not None:
await before_query_list(query_param, session)
# 创建基础查询:查询当前模型类的所有记录
query = select(Cls)
# 若有过滤条件,验证并应用过滤
if query_param.filters:
self.check_invalid_keys(query_param.filters)
# 为每个过滤条件添加WHERE子句(字段=值)
for key, value in query_param.filters.items():
query = query.where(getattr(Cls, key) == value)
# 若不查询全部数据(即启用分页)
if query_param.all is False:
# 计算偏移量(跳过前N条),并查询比一页多1条的记录(用于判断是否有下一页)
query = query.offset(query_param.page * query_param.page_size).limit(query_param.page_size + 1)
# 执行查询并获取结果
result = await session.execute(query)
# 将查询结果转换为标量列表(模型实例列表)
query_cls_list: List[Any] = result.scalars().all()
# 打印查询结果类型和内容(调试用)
print("query_cls_list", type(query_cls_list), query_cls_list)
# 判断是否有下一页:若查询结果数量等于一页大小+1,则说明有下一页
has_next = len(query_cls_list) == query_param.page_size + 1
# 若有下一页,移除多查询的那一条记录
if has_next:
query_cls_list.pop()
if after_query_list is not None:
await after_query_list(query_cls_list, has_next, query_param, session)
# 返回处理后的结果列表和是否有下一页的标识
return (query_cls_list, has_next)
注册分页查询接口:
# 若启用"list"端点,注册列表查询接口
if 'list' in self.end_points:
# 列表查询接口:支持过滤和分页,响应模型为ListResponse
@self.router.post("/list", response_model=self.ListResponse)
async def _list(query_param: PageQueryParams, session: AsyncSessionDep):
# 调用query_list方法执行查询,获取数据列表和是否有下一页
query_cls_list, has_next = await self.query_list(query_param, session)
# 返回符合响应模型的结果
return {
"list": query_cls_list,
"has_next": has_next,
}
先看目前表中的所有数据:

分页查询:

分页查询带参数:

12.通用接口:单条查询
如下所示,单条查询的接口实现代码:
# 单条查询工具方法:根据条件查询单条记录
async def query_item(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"查询数据的字段筛选值,字段参考{Cls.__name__}")):
if before_query_item is not None:
await before_query_item(row_dict, session)
# 创建基础查询:查询当前模型类的所有记录
query = select(Cls)
# 验证查询条件中的键是否有效
self.check_invalid_keys(row_dict)
# 为每个条件添加WHERE子句(字段=值)
for key, value in row_dict.items():
query = query.where(getattr(Cls, key) == value)
# 执行查询
result = await session.execute(query)
# 返回第一条匹配的记录(若存在)
item_cls = result.scalars().first()
if after_query_item is not None:
await after_query_item(item_cls, row_dict, session)
return item_cls
注册单条查询接口:
# 若启用"item"端点,注册单条查询接口
if 'item' in self.end_points:
# 单条查询接口:根据条件查询单条记录,响应模型为ItemResponse
@self.router.post("/item", response_model=self.ItemResponse)
async def _item(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")
):
# 调用query_item方法查询单条记录并返回
return {"result": await self.query_item(session, row_dict)}
查询单条数据:

13.通用接口:单条新建
单条新建接口如下所示:
# 单条插入工具方法:新增一条记录
async def item_insert(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")):
if before_insert is not None:
await before_insert(row_dict, session)
# 若未提供id,自动生成唯一id
if row_dict.get("id") is None:
row_dict["id"] = await next_id()
try:
# 使用模型类验证数据并创建实例(校验字段类型和约束)
insert_cls = Cls.model_validate(row_dict)
except ValueError as e:
# 数据验证失败时,抛出HTTP 500异常并返回错误详情
raise HTTPException(status_code=500, detail=str(e))
# 将实例添加到数据库会话
session.add(insert_cls)
# 提交事务(保存到数据库)
await session.commit()
# 刷新实例,获取数据库生成的最新数据(如自动更新的时间字段)
await session.refresh(insert_cls)
if after_insert is not None:
await after_insert(insert_cls, row_dict, session)
# 返回插入的实例
return insert_cls
注册单条新建接口:
# 若启用"insert"端点,注册单条插入接口
if 'insert' in self.end_points:
# 单条插入接口:新增一条记录,响应模型为ItemResponse
@self.router.post("/insert", response_model=self.ItemResponse)
async def _insert(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")
):
# 调用item_insert方法执行插入并返回结果
return {"result": await self.item_insert(session, row_dict)}
新建单条数据:

14.通用接口:批量新建
批量新建接口代码如下所示:
# 批量插入工具方法:批量新增记录
async def batch_insert(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量插入的数据数组,字段参考{Cls.__name__}")):
if before_batch_insert is not None:
await before_batch_insert(row_dict_list, session)
# 筛选出没有id的记录(需要自动生成id)
row_dict_list_without_id = []
for row_dict in row_dict_list:
if row_dict.get("id") is None:
row_dict_list_without_id.append(row_dict)
# 若存在需要自动生成id的记录
if len(row_dict_list_without_id):
# 批量生成唯一id(数量等于需要生成id的记录数)
new_id_list = await next_id(len(row_dict_list_without_id))
# 为每条记录分配生成的id
for index, id in enumerate(new_id_list):
row_dict_list_without_id[index]["id"] = id
try:
# 验证所有记录并转换为模型实例列表
insert_cls_list = [Cls.model_validate(row_dict) for row_dict in row_dict_list]
except ValueError as e:
# 验证失败时抛出异常
raise HTTPException(status_code=500, detail=str(e))
# 将所有实例添加到会话
session.add_all(insert_cls_list)
# 提交事务,保存数据到数据库
await session.commit()
# 查询并返回所有插入的实例(刷新数据,确保获取最新状态)
refresh_cls_list = (await session.execute(select(Cls).where(Cls.id.in_([obj.id for obj in insert_cls_list])))).scalars().all()
if after_batch_insert is not None:
await after_batch_insert(refresh_cls_list, row_dict_list, session)
# 返回刷新后的实例列表
return refresh_cls_list
注册批量新建接口:
# 若启用"batch_insert"端点,注册批量插入接口
if 'batch_insert' in self.end_points:
# 批量插入接口:批量新增记录,响应模型为BatchResponse
@self.router.post("/batch_insert", response_model=self.BatchResponse)
async def _batch_insert(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量插入的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_insert方法执行批量插入并返回结果
return {"result": await self.batch_insert(session, row_dict_list)}

15.通用接口:单条更新
单条更新接口代码如下所示,支持按需更新字段
# 单条更新工具方法:更新一条记录
async def item_update(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"更新的数据,字段参考{Cls.__name__}")):
if before_update is not None:
await before_update(row_dict, session)
# 检查id是否存在(更新必须指定id)
if not row_dict.get('id'):
raise HTTPException(status_code=400, detail="ID不能为空")
# 根据id查询要更新的记录
update_cls = (await session.exec(select(Cls).where(Cls.id == row_dict.get('id')))).first()
if not update_cls:
# 若记录不存在,抛出异常
raise HTTPException(status_code=500, detail="Update row not found")
# 遍历更新字段:为记录的每个键设置新值
for key, value in row_dict.items():
setattr(update_cls, key, value)
# 将更新后的实例添加到会话
session.add(update_cls)
# 提交事务
await session.commit()
# 刷新实例,获取最新数据
await session.refresh(update_cls)
if after_update is not None:
await after_update(update_cls, row_dict, session)
# 返回更新后的实例
return update_cls
注册单条更新接口:
# 若启用"update"端点,注册单条更新接口
if 'update' in self.end_points:
# 单条更新接口:更新一条记录,响应模型为ItemResponse
@self.router.post("/update", response_model=self.ItemResponse)
async def _update(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"更新的数据,字段参考{Cls.__name__}")
):
# 调用item_update方法执行更新并返回结果
return {"result": await self.item_update(session, row_dict)}
更新单条数据:

16.通用接口:批量更新
批量更新接口代码:
# 批量更新工具方法:批量更新记录
async def batch_update(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量更新的数据数组,字段参考{Cls.__name__}")):
if before_batch_update is not None:
await before_batch_update(row_dict_list, session)
# 提取所有待更新记录的id
update_id_list = [row_dict['id'] for row_dict in row_dict_list]
# 根据id查询所有待更新的记录
update_cls_list = (await session.exec(select(Cls).where(Cls.id.in_(update_id_list)))).all()
# 若查询到的记录数量与待更新数量不一致,说明部分id不存在
if len(update_cls_list) != len(row_dict_list):
# 抛出异常并提示不存在的id
raise HTTPException(status_code=500, detail="Update row not found:" + json.dumps(row_dict_list, ensure_ascii=False))
# 创建id到更新数据的映射(便于快速查找)
id_2_row_dict = {row_dict["id"]: row_dict for row_dict in row_dict_list}
# 遍历每条查询到的记录,更新其字段
for update_cls in update_cls_list:
# 获取当前记录对应的更新数据(根据id)
row_dict = id_2_row_dict.get(update_cls.id, None)
# 遍历更新字段
for key, value in row_dict.items():
setattr(update_cls, key, value)
# 将所有更新后的实例添加到会话
session.add_all(update_cls_list)
# 提交事务
await session.commit()
# 查询并返回所有更新后的实例(刷新数据)
refresh_cls_list = (await session.execute(select(Cls).where(Cls.id.in_([obj.id for obj in update_cls_list])))).scalars().all()
if after_batch_update is not None:
await after_batch_update(refresh_cls_list, row_dict_list, session)
# 返回刷新后的实例列表
return refresh_cls_list
注册批量更新接口:
# 若启用"batch_update"端点,注册批量更新接口
if 'batch_update' in self.end_points:
# 批量更新接口:批量更新记录,响应模型为BatchResponse
@self.router.post("/batch_update", response_model=self.BatchResponse)
async def _batch_update(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量更新的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_update方法执行批量更新并返回结果
return {"result": await self.batch_update(session, row_dict_list)}
批量更新接口调用:

17.通用接口:单条删除
单条删除接口:
# 单条删除工具方法:删除一条记录
async def item_delete(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"删除的数据,字段参考{Cls.__name__}")):
if before_delete is not None:
await before_delete(row_dict, session)
# 根据id查询要删除的记录
delete_cls = (await session.exec(select(Cls).where(Cls.id == row_dict.get('id')))).first()
if not delete_cls:
# 若记录不存在,返回删除失败
return False
# 从会话中删除记录
await session.delete(delete_cls)
# 提交事务,执行删除
await session.commit()
if after_delete is not None:
await after_delete(delete_cls, row_dict, session)
# 返回删除成功
return True
注册单条删除接口:
# 若启用"delete"端点,注册单条删除接口
if 'delete' in self.end_points:
# 单条删除接口:删除一条记录,响应模型为DeleteResponse
@self.router.post("/delete", response_model=self.DeleteResponse)
async def _delete(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"删除的数据,字段参考{Cls.__name__}")
):
# 调用item_delete方法执行删除并返回结果
return {"result": await self.item_delete(session, row_dict)}
删除单条数据:

18.通用接口:批量删除
批量删除接口代码:
# 批量删除工具方法:批量删除记录
async def batch_delete(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量删除的数据数组,字段参考{Cls.__name__}")):
if before_batch_delete is not None:
await before_batch_delete(row_dict_list, session)
# 若待删除列表为空,返回失败
if not row_dict_list:
return False
# 提取所有待删除记录的id
row_id_list = [row_dict.get("id") for row_dict in row_dict_list]
# 根据id查询所有待删除的记录
delete_cls_list = (await session.exec(select(Cls).where(Cls.id.in_(row_id_list)))).all()
# 若查询到的记录数量与待删除数量不一致,说明部分id不存在
if len(delete_cls_list) != len(row_id_list):
# 抛出异常并提示不存在的id
raise HTTPException(status_code=500, detail="Delete row not found:" + json.dumps(row_id_list, ensure_ascii=False))
# 异步批量删除所有记录(使用gather并发执行删除操作)
await asyncio.gather(*[asyncio.create_task(session.delete(delete_cls)) for delete_cls in delete_cls_list])
# 提交事务,执行删除
await session.commit()
if after_batch_delete is not None:
await after_batch_delete(delete_cls_list, row_dict_list, session)
# 返回删除成功
return True
注册批量删除接口:
# 若启用"batch_delete"端点,注册批量删除接口
if 'batch_delete' in self.end_points:
# 批量删除接口:批量删除记录,响应模型为DeleteResponse
@self.router.post("/batch_delete", response_model=self.DeleteResponse)
async def _delete(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量删除的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_delete方法执行批量删除并返回结果
return {"result": await self.batch_delete(session, row_dict_list)}
批量删除接口调用:

19.通用接口:完整代码
import asyncio
import json
from typing import Type, List, Any
from fastapi import FastAPI, APIRouter, HTTPException, Body
from pydantic import create_model
from sqlmodel import select
from app.models.BasicModel import BasicModel
from app.utils.PageQueryParams import PageQueryParams
from app.utils.db_utils import AsyncSessionDep
from app.utils.next_id import next_id
def create_model_service(
#/*@formatter:off*/
app: FastAPI, # FastAPI实例,用来注册路由
path: str, # 路由前缀地址
Cls: Type[BasicModel], # model实体类
end_points: List[str] = None, # 生成的端点入口接口清单
before_query_list=None, # 分页查询前异步处理函数,参数:(query_param, session)
after_query_list=None, # 分页查询后异步处理函数,参数:(query_cls_list, has_next, query_param, session)
before_query_item=None, # 单条查询前异步处理函数,参数:(row_dict, session)
after_query_item=None, # 单条查询后异步处理函数,参数:(item_cls, row_dict, session)
before_insert=None, # 单条新建前异步处理函数,参数:(row_dict, session)
after_insert=None, # 单条新建后异步处理函数,参数:(insert_cls, row_dict, session)
before_update=None, # 单条更新前异步处理函数,参数:(row_dict, session)
after_update=None, # 单条更新后异步处理函数,参数:(update_cls, row_dict, session)
before_delete=None, # 单条删除前异步处理函数,参数:(row_dict, session)
after_delete=None, # 单条删除后异步处理函数,参数:(delete_cls, row_dict, session)
before_batch_insert=None, # 批量新建前异步处理函数,参数:(row_dict_list, session)
after_batch_insert=None, # 批量新建后异步处理函数,参数:(refresh_cls_list, row_dict_list, session)
before_batch_update=None, # 批量更新前异步处理函数,参数:(row_dict_list, session)
after_batch_update=None, # 批量更新后异步处理函数,参数:(refresh_cls_list, row_dict_list, session)
before_batch_delete=None, # 批量删除前异步处理函数,参数:(row_dict_list, session)
after_batch_delete=None, # 批量删除后异步处理函数,参数:(delete_cls_list, row_dict_list, session)
# /*@formatter:on*/
):
# 定义模型服务类,封装模型相关的CRUD接口及业务逻辑
class ModelService:
# 支持的所有端点列表,包含常用的CRUD及批量操作
END_POINTS = ['list', 'item', 'insert', 'batch_insert', 'update', 'batch_update', 'delete', 'batch_delete']
def __init__(self):
# 验证传入的模型类是否继承自BasicModel,确保基础字段存在
if not issubclass(Cls, BasicModel):
raise TypeError(f"{Cls.__name__} 必须继承自 BasicModel")
# 保存FastAPI应用实例
self.app = app
# 保存路由前缀
self.path = path
# 保存当前操作的模型类
self.Cls = Cls
# 确定启用的端点,默认为全部支持的端点
self.end_points = end_points or self.END_POINTS
# 动态创建分页查询的响应模型:包含数据列表和是否有下一页的标识
self.ListResponse = create_model(f"{Cls.__name__}ListResponse", list=(List[Cls], ...), has_next=(bool, ...))
# 动态创建单条查询的响应模型:包含单个模型实例
self.ItemResponse = create_model(f"{Cls.__name__}ItemResponse", result=(Cls, ...))
# 动态创建批量操作的响应模型:包含操作后的模型实例列表
self.BatchResponse = create_model(f"{Cls.__name__}BatchResponse", result=(List[Cls], ...))
# 动态创建批量删除的响应模型:包含删除操作是否成功的标识
self.DeleteResponse = create_model(f"{Cls.__name__}BatchResponse", result=(bool, ...))
# 创建APIRouter实例,设置路由前缀和标签(标签用于API文档分组)
self.router = APIRouter(prefix=path, tags=[path])
# 若启用"list"端点,注册列表查询接口
if 'list' in self.end_points:
# 列表查询接口:支持过滤和分页,响应模型为ListResponse
@self.router.post("/list", response_model=self.ListResponse)
async def _list(query_param: PageQueryParams, session: AsyncSessionDep):
# 调用query_list方法执行查询,获取数据列表和是否有下一页
query_cls_list, has_next = await self.query_list(query_param, session)
# 返回符合响应模型的结果
return {
"list": query_cls_list,
"has_next": has_next,
}
# 若启用"item"端点,注册单条查询接口
if 'item' in self.end_points:
# 单条查询接口:根据条件查询单条记录,响应模型为ItemResponse
@self.router.post("/item", response_model=self.ItemResponse)
async def _item(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")
):
# 调用query_item方法查询单条记录并返回
return {"result": await self.query_item(session, row_dict)}
# 若启用"insert"端点,注册单条插入接口
if 'insert' in self.end_points:
# 单条插入接口:新增一条记录,响应模型为ItemResponse
@self.router.post("/insert", response_model=self.ItemResponse)
async def _insert(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")
):
# 调用item_insert方法执行插入并返回结果
return {"result": await self.item_insert(session, row_dict)}
# 若启用"batch_insert"端点,注册批量插入接口
if 'batch_insert' in self.end_points:
# 批量插入接口:批量新增记录,响应模型为BatchResponse
@self.router.post("/batch_insert", response_model=self.BatchResponse)
async def _batch_insert(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量插入的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_insert方法执行批量插入并返回结果
return {"result": await self.batch_insert(session, row_dict_list)}
# 若启用"update"端点,注册单条更新接口
if 'update' in self.end_points:
# 单条更新接口:更新一条记录,响应模型为ItemResponse
@self.router.post("/update", response_model=self.ItemResponse)
async def _update(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"更新的数据,字段参考{Cls.__name__}")
):
# 调用item_update方法执行更新并返回结果
return {"result": await self.item_update(session, row_dict)}
# 若启用"batch_update"端点,注册批量更新接口
if 'batch_update' in self.end_points:
# 批量更新接口:批量更新记录,响应模型为BatchResponse
@self.router.post("/batch_update", response_model=self.BatchResponse)
async def _batch_update(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量更新的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_update方法执行批量更新并返回结果
return {"result": await self.batch_update(session, row_dict_list)}
# 若启用"delete"端点,注册单条删除接口
if 'delete' in self.end_points:
# 单条删除接口:删除一条记录,响应模型为DeleteResponse
@self.router.post("/delete", response_model=self.DeleteResponse)
async def _delete(
session: AsyncSessionDep,
row_dict: dict = Body(..., description=f"删除的数据,字段参考{Cls.__name__}")
):
# 调用item_delete方法执行删除并返回结果
return {"result": await self.item_delete(session, row_dict)}
# 若启用"batch_delete"端点,注册批量删除接口
if 'batch_delete' in self.end_points:
# 批量删除接口:批量删除记录,响应模型为DeleteResponse
@self.router.post("/batch_delete", response_model=self.DeleteResponse)
async def _delete(
session: AsyncSessionDep,
row_dict_list: List[dict] = Body(..., description=f"批量删除的数据数组,字段参考{Cls.__name__}")
):
# 调用batch_delete方法执行批量删除并返回结果
return {"result": await self.batch_delete(session, row_dict_list)}
# 若有启用的端点,将路由添加到FastAPI应用
if self.end_points:
app.include_router(self.router)
# 检查字典中的键是否为模型类的有效属性
# 参数:
# row_dict: 待检查的字典(通常为请求参数)
def check_invalid_keys(self, row_dict: dict):
# 筛选出所有不在模型类属性中的键(无效键)
invalid_keys = [key for key in row_dict.keys() if not hasattr(Cls, key)]
if invalid_keys:
# 若存在无效键,抛出HTTP 500异常,提示无效键和有效键列表
raise HTTPException(
status_code=500,
detail=f"Invalid filter keys: {invalid_keys}. Valid keys are: {Cls.__annotations__.keys()}"
)
# 分页查询工具方法:执行带过滤和分页的查询
async def query_list(self, query_param: PageQueryParams, session: AsyncSessionDep):
if before_query_list is not None:
await before_query_list(query_param, session)
# 创建基础查询:查询当前模型类的所有记录
query = select(Cls)
# 若有过滤条件,验证并应用过滤
if query_param.filters:
self.check_invalid_keys(query_param.filters)
# 为每个过滤条件添加WHERE子句(字段=值)
for key, value in query_param.filters.items():
query = query.where(getattr(Cls, key) == value)
# 若不查询全部数据(即启用分页)
if query_param.all is False:
# 计算偏移量(跳过前N条),并查询比一页多1条的记录(用于判断是否有下一页)
query = query.offset(query_param.page * query_param.page_size).limit(query_param.page_size + 1)
# 执行查询并获取结果
result = await session.execute(query)
# 将查询结果转换为标量列表(模型实例列表)
query_cls_list: List[Any] = result.scalars().all()
# 打印查询结果类型和内容(调试用)
print("query_cls_list", type(query_cls_list), query_cls_list)
# 判断是否有下一页:若查询结果数量等于一页大小+1,则说明有下一页
has_next = len(query_cls_list) == query_param.page_size + 1
# 若有下一页,移除多查询的那一条记录
if has_next:
query_cls_list.pop()
if after_query_list is not None:
await after_query_list(query_cls_list, has_next, query_param, session)
# 返回处理后的结果列表和是否有下一页的标识
return (query_cls_list, has_next)
# 单条查询工具方法:根据条件查询单条记录
async def query_item(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"查询数据的字段筛选值,字段参考{Cls.__name__}")):
if before_query_item is not None:
await before_query_item(row_dict, session)
# 创建基础查询:查询当前模型类的所有记录
query = select(Cls)
# 验证查询条件中的键是否有效
self.check_invalid_keys(row_dict)
# 为每个条件添加WHERE子句(字段=值)
for key, value in row_dict.items():
query = query.where(getattr(Cls, key) == value)
# 执行查询
result = await session.execute(query)
# 返回第一条匹配的记录(若存在)
item_cls = result.scalars().first()
if after_query_item is not None:
await after_query_item(item_cls, row_dict, session)
return item_cls
# 单条插入工具方法:新增一条记录
async def item_insert(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"插入的数据,字段参考{Cls.__name__}")):
if before_insert is not None:
await before_insert(row_dict, session)
# 若未提供id,自动生成唯一id
if row_dict.get("id") is None:
row_dict["id"] = await next_id()
try:
# 使用模型类验证数据并创建实例(校验字段类型和约束)
insert_cls = Cls.model_validate(row_dict)
except ValueError as e:
# 数据验证失败时,抛出HTTP 500异常并返回错误详情
raise HTTPException(status_code=500, detail=str(e))
# 将实例添加到数据库会话
session.add(insert_cls)
# 提交事务(保存到数据库)
await session.commit()
# 刷新实例,获取数据库生成的最新数据(如自动更新的时间字段)
await session.refresh(insert_cls)
if after_insert is not None:
await after_insert(insert_cls, row_dict, session)
# 返回插入的实例
return insert_cls
# 批量插入工具方法:批量新增记录
async def batch_insert(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量插入的数据数组,字段参考{Cls.__name__}")):
if before_batch_insert is not None:
await before_batch_insert(row_dict_list, session)
# 筛选出没有id的记录(需要自动生成id)
row_dict_list_without_id = []
for row_dict in row_dict_list:
if row_dict.get("id") is None:
row_dict_list_without_id.append(row_dict)
# 若存在需要自动生成id的记录
if len(row_dict_list_without_id):
# 批量生成唯一id(数量等于需要生成id的记录数)
new_id_list = await next_id(len(row_dict_list_without_id))
# 为每条记录分配生成的id
for index, id in enumerate(new_id_list):
row_dict_list_without_id[index]["id"] = id
try:
# 验证所有记录并转换为模型实例列表
insert_cls_list = [Cls.model_validate(row_dict) for row_dict in row_dict_list]
except ValueError as e:
# 验证失败时抛出异常
raise HTTPException(status_code=500, detail=str(e))
# 将所有实例添加到会话
session.add_all(insert_cls_list)
# 提交事务,保存数据到数据库
await session.commit()
# 查询并返回所有插入的实例(刷新数据,确保获取最新状态)
refresh_cls_list = (await session.execute(select(Cls).where(Cls.id.in_([obj.id for obj in insert_cls_list])))).scalars().all()
if after_batch_insert is not None:
await after_batch_insert(refresh_cls_list, row_dict_list, session)
# 返回刷新后的实例列表
return refresh_cls_list
# 单条更新工具方法:更新一条记录
async def item_update(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"更新的数据,字段参考{Cls.__name__}")):
if before_update is not None:
await before_update(row_dict, session)
# 检查id是否存在(更新必须指定id)
if not row_dict.get('id'):
raise HTTPException(status_code=400, detail="ID不能为空")
# 根据id查询要更新的记录
update_cls = (await session.exec(select(Cls).where(Cls.id == row_dict.get('id')))).first()
if not update_cls:
# 若记录不存在,抛出异常
raise HTTPException(status_code=500, detail="Update row not found")
# 遍历更新字段:为记录的每个键设置新值
for key, value in row_dict.items():
setattr(update_cls, key, value)
# 将更新后的实例添加到会话
session.add(update_cls)
# 提交事务
await session.commit()
# 刷新实例,获取最新数据
await session.refresh(update_cls)
if after_update is not None:
await after_update(update_cls, row_dict, session)
# 返回更新后的实例
return update_cls
# 批量更新工具方法:批量更新记录
async def batch_update(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量更新的数据数组,字段参考{Cls.__name__}")):
if before_batch_update is not None:
await before_batch_update(row_dict_list, session)
# 提取所有待更新记录的id
update_id_list = [row_dict['id'] for row_dict in row_dict_list]
# 根据id查询所有待更新的记录
update_cls_list = (await session.exec(select(Cls).where(Cls.id.in_(update_id_list)))).all()
# 若查询到的记录数量与待更新数量不一致,说明部分id不存在
if len(update_cls_list) != len(row_dict_list):
# 抛出异常并提示不存在的id
raise HTTPException(status_code=500, detail="Update row not found:" + json.dumps(row_dict_list, ensure_ascii=False))
# 创建id到更新数据的映射(便于快速查找)
id_2_row_dict = {row_dict["id"]: row_dict for row_dict in row_dict_list}
# 遍历每条查询到的记录,更新其字段
for update_cls in update_cls_list:
# 获取当前记录对应的更新数据(根据id)
row_dict = id_2_row_dict.get(update_cls.id, None)
# 遍历更新字段
for key, value in row_dict.items():
setattr(update_cls, key, value)
# 将所有更新后的实例添加到会话
session.add_all(update_cls_list)
# 提交事务
await session.commit()
# 查询并返回所有更新后的实例(刷新数据)
refresh_cls_list = (await session.execute(select(Cls).where(Cls.id.in_([obj.id for obj in update_cls_list])))).scalars().all()
if after_batch_update is not None:
await after_batch_update(refresh_cls_list, row_dict_list, session)
# 返回刷新后的实例列表
return refresh_cls_list
# 单条删除工具方法:删除一条记录
async def item_delete(self, session: AsyncSessionDep, row_dict: dict = Body(..., description=f"删除的数据,字段参考{Cls.__name__}")):
if before_delete is not None:
await before_delete(row_dict, session)
# 根据id查询要删除的记录
delete_cls = (await session.exec(select(Cls).where(Cls.id == row_dict.get('id')))).first()
if not delete_cls:
# 若记录不存在,返回删除失败
return False
# 从会话中删除记录
await session.delete(delete_cls)
# 提交事务,执行删除
await session.commit()
if after_delete is not None:
await after_delete(delete_cls, row_dict, session)
# 返回删除成功
return True
# 批量删除工具方法:批量删除记录
async def batch_delete(self, session: AsyncSessionDep, row_dict_list: List[dict] = Body(..., description=f"批量删除的数据数组,字段参考{Cls.__name__}")):
if before_batch_delete is not None:
await before_batch_delete(row_dict_list, session)
# 若待删除列表为空,返回失败
if not row_dict_list:
return False
# 提取所有待删除记录的id
row_id_list = [row_dict.get("id") for row_dict in row_dict_list]
# 根据id查询所有待删除的记录
delete_cls_list = (await session.exec(select(Cls).where(Cls.id.in_(row_id_list)))).all()
# 若查询到的记录数量与待删除数量不一致,说明部分id不存在
if len(delete_cls_list) != len(row_id_list):
# 抛出异常并提示不存在的id
raise HTTPException(status_code=500, detail="Delete row not found:" + json.dumps(row_id_list, ensure_ascii=False))
# 异步批量删除所有记录(使用gather并发执行删除操作)
await asyncio.gather(*[asyncio.create_task(session.delete(delete_cls)) for delete_cls in delete_cls_list])
# 提交事务,执行删除
await session.commit()
if after_batch_delete is not None:
await after_batch_delete(delete_cls_list, row_dict_list, session)
# 返回删除成功
return True
# 创建并返回ModelService实例
return ModelService()
