0.0.1
This commit is contained in:
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
__pycache__
|
||||
.env
|
||||
.trae
|
||||
.idea
|
||||
.DS_Store
|
||||
*.baiduyun.*
|
||||
.vscode
|
||||
对比
|
||||
logs/sessions.json
|
||||
logs/sessions.log
|
||||
222.py
|
||||
333.py
|
||||
5
app/apis/__init__.py
Normal file
5
app/apis/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from .country import app as country_app
|
||||
|
||||
app = APIRouter()
|
||||
app.include_router(country_app, prefix='/country')
|
||||
9
app/apis/country/__init__.py
Normal file
9
app/apis/country/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
from .info.view import app as info_app
|
||||
from .food.view import app as food_app
|
||||
from .shop.view import app as shop_app
|
||||
|
||||
app = APIRouter()
|
||||
app.include_router(info_app, prefix='/info', tags=['信息'])
|
||||
app.include_router(food_app, prefix='/food', tags=['食物'])
|
||||
app.include_router(shop_app, prefix='/shop', tags=['商店'])
|
||||
66
app/apis/country/food/schema.py
Normal file
66
app/apis/country/food/schema.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from utils.time_tool import TimestampModel
|
||||
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""
|
||||
基础食物信息模型
|
||||
|
||||
仅包含食物名称
|
||||
"""
|
||||
name: str = Field(..., description='食物名称')
|
||||
|
||||
|
||||
class Create(Base):
|
||||
"""
|
||||
创建请求模型
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Update(BaseModel):
|
||||
"""
|
||||
更新请求模型,支持部分更新
|
||||
"""
|
||||
name: str | None = Field(None, description='食物名称')
|
||||
|
||||
|
||||
class Out(TimestampModel, Base):
|
||||
"""
|
||||
输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
id: UUID = Field(..., description='ID')
|
||||
|
||||
create_time: datetime = Field(..., description='创建时间')
|
||||
update_time: datetime = Field(..., description='更新时间')
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def create_time_cn(self) -> str:
|
||||
return self.create_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def update_time_cn(self) -> str:
|
||||
return self.update_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OutList(BaseModel):
|
||||
"""
|
||||
列表输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
count: int = Field(0, description='总数')
|
||||
num: int = Field(0, description='当前数量')
|
||||
items: List[Out] = Field([], description='列表数据')
|
||||
122
app/apis/country/food/view.py
Normal file
122
app/apis/country/food/view.py
Normal file
@@ -0,0 +1,122 @@
|
||||
|
||||
from fastapi import APIRouter, Query, Body, HTTPException
|
||||
from uuid import UUID
|
||||
from .schema import Create, Update, Out, OutList
|
||||
from ..models import Food
|
||||
from utils.decorators import handle_exceptions_unified
|
||||
from utils.time_tool import parse_time
|
||||
from utils.out_base import CommonOut
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
|
||||
# 创建食物
|
||||
@app.post("", response_model=Out, description='创建食物', summary='创建食物')
|
||||
@handle_exceptions_unified()
|
||||
async def post(item: Create = Body(..., description='创建数据')):
|
||||
"""
|
||||
创建食物记录
|
||||
"""
|
||||
res = await Food.create(**item.model_dump())
|
||||
if not res:
|
||||
raise HTTPException(status_code=400, detail='创建失败')
|
||||
return res
|
||||
|
||||
|
||||
# 查询食物
|
||||
@app.get("", response_model=OutList, description='获取食物', summary='获取食物')
|
||||
@handle_exceptions_unified()
|
||||
async def gets(
|
||||
id: UUID | None = Query(None, description='主键ID'),
|
||||
name: str | None = Query(None, description='食物名称'),
|
||||
order_by: str | None = Query('create_time', description='排序字段',
|
||||
regex='^(-)?(id|name|create_time|update_time)$'),
|
||||
res_count: bool = Query(False, description='是否返回总数'),
|
||||
create_time_start: str | int | None = Query(
|
||||
None, description='创建时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
create_time_end: str | int | None = Query(
|
||||
None, description='创建时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_start: str | int | None = Query(
|
||||
None, description='更新时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_end: str | int | None = Query(
|
||||
None, description='更新时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
page: int = Query(1, ge=1, description='页码'),
|
||||
limit: int = Query(10, ge=1, le=1000, description='每页数量'),
|
||||
):
|
||||
"""
|
||||
获取食物列表
|
||||
"""
|
||||
query = Food.all()
|
||||
if id:
|
||||
query = query.filter(id=id)
|
||||
if name:
|
||||
query = query.filter(name=name)
|
||||
if create_time_start:
|
||||
query = query.filter(create_time__gte=parse_time(create_time_start))
|
||||
if create_time_end:
|
||||
query = query.filter(create_time__lte=parse_time(
|
||||
create_time_end, is_end=True))
|
||||
if update_time_start:
|
||||
query = query.filter(update_time__gte=parse_time(update_time_start))
|
||||
if update_time_end:
|
||||
query = query.filter(update_time__lte=parse_time(
|
||||
update_time_end, is_end=True))
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(order_by)
|
||||
|
||||
if res_count:
|
||||
count = await query.count()
|
||||
else:
|
||||
count = -1
|
||||
offset = (page - 1) * limit # 计算偏移量
|
||||
query = query.limit(limit).offset(offset) # 应用分页
|
||||
|
||||
res = await query
|
||||
if not res:
|
||||
raise HTTPException(status_code=404, detail='食物不存在')
|
||||
num = len(res)
|
||||
return OutList(count=count, num=num, items=res)
|
||||
|
||||
|
||||
# 更新食物
|
||||
@app.put("", response_model=Out, description='更新食物', summary='更新食物')
|
||||
@handle_exceptions_unified()
|
||||
async def put(id: UUID = Query(..., description='主键ID'),
|
||||
item: Update = Body(..., description='更新数据'),
|
||||
):
|
||||
"""
|
||||
部分更新食物,只更新传入的非空字段
|
||||
"""
|
||||
# 检查食物是否存在
|
||||
secret = await Food.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='食物不存在')
|
||||
|
||||
# 获取要更新的字段(排除None值的字段)
|
||||
update_data = item.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail='没有要更新的字段')
|
||||
|
||||
# 更新食物字段
|
||||
await secret.update_from_dict(update_data)
|
||||
await secret.save()
|
||||
return secret
|
||||
|
||||
|
||||
# 删除食物
|
||||
|
||||
@app.delete("", response_model=CommonOut, description='删除食物', summary='删除食物')
|
||||
@handle_exceptions_unified()
|
||||
async def delete(id: UUID = Query(..., description='主键ID'),
|
||||
):
|
||||
"""删除食物"""
|
||||
secret = await Food.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='食物不存在')
|
||||
await secret.delete()
|
||||
# Tortoise ORM 单个实例的 delete() 方法返回 None,而不是删除的记录数
|
||||
# 删除成功时手动返回 1,如果有异常会被装饰器捕获
|
||||
return CommonOut(count=1)
|
||||
84
app/apis/country/info/schema.py
Normal file
84
app/apis/country/info/schema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from utils.time_tool import TimestampModel
|
||||
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""
|
||||
基础地址信息模型
|
||||
|
||||
包含地址相关的通用字段,供创建与输出模型复用
|
||||
"""
|
||||
firstname: str = Field(..., description='名')
|
||||
lastname: str = Field(..., description='姓')
|
||||
full_name: str = Field(..., description='全名')
|
||||
birthday: str = Field(..., description='生日')
|
||||
street_address: str = Field(..., description='街道地址')
|
||||
city: str = Field(..., description='城市')
|
||||
phone: str = Field(..., description='电话')
|
||||
zip_code: str = Field(..., description='邮编')
|
||||
state_fullname: str = Field(..., description='州全称')
|
||||
status: bool = Field(False, description='状态')
|
||||
|
||||
|
||||
class Create(Base):
|
||||
"""
|
||||
创建请求模型
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Update(BaseModel):
|
||||
"""
|
||||
更新请求模型,支持部分更新
|
||||
"""
|
||||
firstname: str | None = Field(None, description='名')
|
||||
lastname: str | None = Field(None, description='姓')
|
||||
full_name: str | None = Field(None, description='全名')
|
||||
birthday: str | None = Field(None, description='生日')
|
||||
street_address: str | None = Field(None, description='街道地址')
|
||||
city: str | None = Field(None, description='城市')
|
||||
phone: str | None = Field(None, description='电话')
|
||||
zip_code: str | None = Field(None, description='邮编')
|
||||
state_fullname: str | None = Field(None, description='州全称')
|
||||
status: bool | None = Field(None, description='状态')
|
||||
|
||||
|
||||
class Out(TimestampModel, Base):
|
||||
"""
|
||||
输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
id: UUID = Field(..., description='ID')
|
||||
|
||||
create_time: datetime = Field(..., description='创建时间')
|
||||
update_time: datetime = Field(..., description='更新时间')
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def create_time_cn(self) -> str:
|
||||
return self.create_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def update_time_cn(self) -> str:
|
||||
return self.update_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OutList(BaseModel):
|
||||
"""
|
||||
列表输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
count: int = Field(0, description='总数')
|
||||
num: int = Field(0, description='当前数量')
|
||||
items: List[Out] = Field([], description='列表数据')
|
||||
171
app/apis/country/info/view.py
Normal file
171
app/apis/country/info/view.py
Normal file
@@ -0,0 +1,171 @@
|
||||
|
||||
from fastapi import APIRouter, Query, Body, HTTPException
|
||||
import random
|
||||
from uuid import UUID
|
||||
from .schema import Create, Update, Out, OutList
|
||||
from ..models import Info
|
||||
from utils.decorators import handle_exceptions_unified
|
||||
from utils.time_tool import parse_time
|
||||
from utils.out_base import CommonOut
|
||||
from tortoise.transactions import in_transaction
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
|
||||
# 创建信息
|
||||
@app.post("", response_model=Out, description='创建信息', summary='创建信息')
|
||||
@handle_exceptions_unified()
|
||||
async def post(item: Create = Body(..., description='创建数据')):
|
||||
"""
|
||||
创建信息记录
|
||||
"""
|
||||
res = await Info.create(**item.model_dump())
|
||||
if not res:
|
||||
raise HTTPException(status_code=400, detail='创建失败')
|
||||
return res
|
||||
|
||||
|
||||
# 查询信息
|
||||
@app.get("", response_model=OutList, description='获取信息', summary='获取信息')
|
||||
@handle_exceptions_unified()
|
||||
async def gets(
|
||||
id: UUID | None = Query(None, description='主键ID'),
|
||||
firstname: str | None = Query(None, description='名'),
|
||||
lastname: str | None = Query(None, description='姓'),
|
||||
full_name: str | None = Query(None, description='全名'),
|
||||
birthday: str | None = Query(None, description='生日'),
|
||||
street_address: str | None = Query(None, description='街道地址'),
|
||||
city: str | None = Query(None, description='城市'),
|
||||
phone: str | None = Query(None, description='电话'),
|
||||
zip_code: str | None = Query(None, description='邮编'),
|
||||
state_fullname: str | None = Query(None, description='州全称'),
|
||||
status: bool | None = Query(None, description='状态'),
|
||||
order_by: str | None = Query('create_time', description='排序字段',
|
||||
regex='^(-)?(id|firstname|lastname|city|zip_code|create_time|update_time)$'),
|
||||
res_count: bool = Query(False, description='是否返回总数'),
|
||||
create_time_start: str | int | None = Query(
|
||||
None, description='创建时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
create_time_end: str | int | None = Query(
|
||||
None, description='创建时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_start: str | int | None = Query(
|
||||
None, description='更新时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_end: str | int | None = Query(
|
||||
None, description='更新时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
page: int = Query(1, ge=1, description='页码'),
|
||||
limit: int = Query(10, ge=1, le=1000, description='每页数量'),
|
||||
):
|
||||
"""
|
||||
获取信息列表
|
||||
"""
|
||||
query = Info.all()
|
||||
if id:
|
||||
query = query.filter(id=id)
|
||||
if firstname:
|
||||
query = query.filter(firstname=firstname)
|
||||
if lastname:
|
||||
query = query.filter(lastname=lastname)
|
||||
if full_name:
|
||||
query = query.filter(full_name=full_name)
|
||||
if birthday:
|
||||
query = query.filter(birthday=birthday)
|
||||
if street_address:
|
||||
query = query.filter(street_address=street_address)
|
||||
if city:
|
||||
query = query.filter(city=city)
|
||||
if phone:
|
||||
query = query.filter(phone=phone)
|
||||
if zip_code:
|
||||
query = query.filter(zip_code=zip_code)
|
||||
if state_fullname:
|
||||
query = query.filter(state_fullname=state_fullname)
|
||||
if status is not None:
|
||||
query = query.filter(status=status)
|
||||
if create_time_start:
|
||||
query = query.filter(create_time__gte=parse_time(create_time_start))
|
||||
if create_time_end:
|
||||
query = query.filter(create_time__lte=parse_time(
|
||||
create_time_end, is_end=True))
|
||||
if update_time_start:
|
||||
query = query.filter(update_time__gte=parse_time(update_time_start))
|
||||
if update_time_end:
|
||||
query = query.filter(update_time__lte=parse_time(
|
||||
update_time_end, is_end=True))
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(order_by)
|
||||
|
||||
if res_count:
|
||||
count = await query.count()
|
||||
else:
|
||||
count = -1
|
||||
offset = (page - 1) * limit # 计算偏移量
|
||||
query = query.limit(limit).offset(offset) # 应用分页
|
||||
|
||||
res = await query
|
||||
if not res:
|
||||
raise HTTPException(status_code=404, detail='信息不存在')
|
||||
num = len(res)
|
||||
return OutList(count=count, num=num, items=res)
|
||||
|
||||
|
||||
# 更新信息
|
||||
@app.put("", response_model=Out, description='更新信息', summary='更新信息')
|
||||
@handle_exceptions_unified()
|
||||
async def put(id: UUID = Query(..., description='主键ID'),
|
||||
item: Update = Body(..., description='更新数据'),
|
||||
):
|
||||
"""
|
||||
部分更新信息,只更新传入的非空字段
|
||||
"""
|
||||
# 检查信息是否存在
|
||||
secret = await Info.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='信息不存在')
|
||||
|
||||
# 获取要更新的字段(排除None值的字段)
|
||||
update_data = item.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail='没有要更新的字段')
|
||||
|
||||
# 更新信息字段
|
||||
await secret.update_from_dict(update_data)
|
||||
await secret.save()
|
||||
return secret
|
||||
|
||||
|
||||
# 删除信息
|
||||
|
||||
@app.delete("", response_model=CommonOut, description='删除信息', summary='删除信息')
|
||||
@handle_exceptions_unified()
|
||||
async def delete(id: UUID = Query(..., description='主键ID'),
|
||||
):
|
||||
"""删除信息"""
|
||||
secret = await Info.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='信息不存在')
|
||||
await secret.delete()
|
||||
# Tortoise ORM 单个实例的 delete() 方法返回 None,而不是删除的记录数
|
||||
# 删除成功时手动返回 1,如果有异常会被装饰器捕获
|
||||
return CommonOut(count=1)
|
||||
|
||||
|
||||
# 随机获取一条状态修改为True的记录
|
||||
@app.put("/one", response_model=Out, description='随机获取一条状态修改为True的记录', summary='随机获取一条状态修改为True的记录')
|
||||
@handle_exceptions_unified()
|
||||
async def random_update_status():
|
||||
"""
|
||||
随机获取一条状态为 False 的记录并在事务中更新为 True
|
||||
"""
|
||||
async with in_transaction() as conn:
|
||||
q = Info.filter(status=False).using_db(conn)
|
||||
current_running_count = await q.count()
|
||||
if current_running_count == 0:
|
||||
raise HTTPException(status_code=404, detail='没有状态为False的记录')
|
||||
pick_index = random.choice(range(current_running_count))
|
||||
item = await q.order_by('create_time').offset(pick_index).first()
|
||||
updated = await Info.filter(id=item.id, status=False).using_db(conn).update(status=True)
|
||||
if updated == 0:
|
||||
raise HTTPException(status_code=400, detail='并发冲突,未更新')
|
||||
return item
|
||||
110
app/apis/country/models.py
Normal file
110
app/apis/country/models.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import uuid
|
||||
from tortoise import fields
|
||||
from tortoise.models import Model
|
||||
|
||||
|
||||
class Shop(Model):
|
||||
"""
|
||||
店铺模型
|
||||
|
||||
字段:
|
||||
id (UUIDField): 主键,默认使用 UUID 生成
|
||||
province (CharField): 省份,最大长度 255
|
||||
city (CharField): 城市,最大长度 255
|
||||
street (CharField): 街道,最大长度 255
|
||||
shop_name (CharField): 店铺名称,最大长度 255
|
||||
shop_number (CharField): 店铺号码,最大长度 255, nullable 为 True
|
||||
"""
|
||||
id = fields.UUIDField(pk=True, default=uuid.uuid4, description="ID")
|
||||
province = fields.CharField(max_length=255, index=True, description="省份")
|
||||
city = fields.CharField(max_length=255, index=True, description="城市")
|
||||
street = fields.CharField(max_length=255, index=True, description="街道")
|
||||
shop_name = fields.CharField(max_length=255, index=True, description="店铺名称")
|
||||
shop_number = fields.CharField(max_length=255, null=True, description="店铺号码")
|
||||
create_time = fields.DatetimeField(auto_now_add=True, index=True, description='创建时间')
|
||||
update_time = fields.DatetimeField(auto_now=True, description='更新时间')
|
||||
|
||||
|
||||
class Meta:
|
||||
table = "shop"
|
||||
table_description = "店铺表"
|
||||
ordering = ["create_time"]
|
||||
indexes = [
|
||||
("province", "city", "street"),
|
||||
]
|
||||
def __repr__(self):
|
||||
return f"<Shop(id={self.id}, province={self.province}, city={self.city}, street={self.street}, shop_name={self.shop_name})>"
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
class Food(Model):
|
||||
"""
|
||||
食物模型
|
||||
|
||||
字段:
|
||||
id (UUIDField): 主键,默认使用 UUID 生成
|
||||
name (CharField): 食物名称,最大长度 255
|
||||
"""
|
||||
id = fields.UUIDField(pk=True, default=uuid.uuid4, description="ID")
|
||||
name = fields.CharField(max_length=255, index=True, description="食物名称")
|
||||
create_time = fields.DatetimeField(auto_now_add=True, index=True, description='创建时间')
|
||||
update_time = fields.DatetimeField(auto_now=True, description='更新时间')
|
||||
|
||||
|
||||
class Meta:
|
||||
table = "food"
|
||||
table_description = "食物表"
|
||||
ordering = ["create_time"]
|
||||
indexes = [
|
||||
("name",),
|
||||
]
|
||||
def __repr__(self):
|
||||
return f"<Food(id={self.id}, name={self.name})>"
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class Info(Model):
|
||||
"""
|
||||
信息模型
|
||||
|
||||
字段:
|
||||
id (UUIDField): 主键,默认使用 UUID 生成
|
||||
firstname (CharField): 名,最大长度 255
|
||||
lastname (CharField): 姓,最大长度 255
|
||||
full_name (CharField): 全名,最大长度 255
|
||||
birthday (CharField): 生日(原始字符串),最大长度 32
|
||||
street_address (CharField): 街道地址,最大长度 255
|
||||
city (CharField): 城市,最大长度 255
|
||||
phone (CharField): 电话,最大长度 64
|
||||
zip_code (CharField): 邮编,最大长度 20
|
||||
state_fullname (CharField): 州全称,最大长度 255
|
||||
"""
|
||||
id = fields.UUIDField(pk=True, default=uuid.uuid4, description="ID")
|
||||
firstname = fields.CharField(max_length=255, index=True, description="名")
|
||||
lastname = fields.CharField(max_length=255, index=True, description="姓")
|
||||
full_name = fields.CharField(max_length=255, index=True, description="全名")
|
||||
birthday = fields.CharField(max_length=32, description="生日")
|
||||
street_address = fields.CharField(max_length=255, index=True, description="街道地址")
|
||||
city = fields.CharField(max_length=255, index=True, description="城市")
|
||||
phone = fields.CharField(max_length=64, description="电话")
|
||||
zip_code = fields.CharField(max_length=20, index=True, description="邮编")
|
||||
state_fullname = fields.CharField(max_length=255, index=True, description="州全称")
|
||||
status = fields.BooleanField(default=False, description="状态")
|
||||
create_time = fields.DatetimeField(auto_now_add=True, index=True, description='创建时间')
|
||||
update_time = fields.DatetimeField(auto_now=True, description='更新时间')
|
||||
|
||||
|
||||
class Meta:
|
||||
table = "info"
|
||||
table_description = "信息表"
|
||||
ordering = ["create_time"]
|
||||
indexes = [
|
||||
("city", "zip_code", "state_fullname"),
|
||||
("firstname", "lastname"),
|
||||
]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Info(id={self.id}, firstname={self.firstname}, lastname={self.lastname}, full_name={self.full_name}, birthday={self.birthday}, street_address={self.street_address}, city={self.city}, phone={self.phone}, zip_code={self.zip_code}, state_fullname={self.state_fullname})>"
|
||||
|
||||
__str__ = __repr__
|
||||
74
app/apis/country/shop/schema.py
Normal file
74
app/apis/country/shop/schema.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from utils.time_tool import TimestampModel
|
||||
|
||||
CHINA_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""
|
||||
基础店铺信息模型
|
||||
|
||||
包含店铺相关的通用字段,供创建与输出模型复用
|
||||
"""
|
||||
province: str = Field(..., description='省份')
|
||||
city: str = Field(..., description='城市')
|
||||
street: str = Field(..., description='街道')
|
||||
shop_name: str = Field(..., description='店铺名称')
|
||||
shop_number: str | None = Field(None, description='店铺号码')
|
||||
|
||||
|
||||
class Create(Base):
|
||||
"""
|
||||
创建请求模型
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Update(BaseModel):
|
||||
"""
|
||||
更新请求模型,支持部分更新
|
||||
"""
|
||||
province: str | None = Field(None, description='省份')
|
||||
city: str | None = Field(None, description='城市')
|
||||
street: str | None = Field(None, description='街道')
|
||||
shop_name: str | None = Field(None, description='店铺名称')
|
||||
shop_number: str | None = Field(None, description='店铺号码')
|
||||
|
||||
|
||||
class Out(TimestampModel, Base):
|
||||
"""
|
||||
输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
id: UUID = Field(..., description='ID')
|
||||
|
||||
create_time: datetime = Field(..., description='创建时间')
|
||||
update_time: datetime = Field(..., description='更新时间')
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def create_time_cn(self) -> str:
|
||||
return self.create_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def update_time_cn(self) -> str:
|
||||
return self.update_time.astimezone(CHINA_TZ).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class OutList(BaseModel):
|
||||
"""
|
||||
列表输出模型
|
||||
"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
count: int = Field(0, description='总数')
|
||||
num: int = Field(0, description='当前数量')
|
||||
items: List[Out] = Field([], description='列表数据')
|
||||
134
app/apis/country/shop/view.py
Normal file
134
app/apis/country/shop/view.py
Normal file
@@ -0,0 +1,134 @@
|
||||
|
||||
from fastapi import APIRouter, Query, Body, HTTPException
|
||||
from uuid import UUID
|
||||
from .schema import Create, Update, Out, OutList
|
||||
from ..models import Shop
|
||||
from utils.decorators import handle_exceptions_unified
|
||||
from utils.time_tool import parse_time
|
||||
from utils.out_base import CommonOut
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
|
||||
# 创建店铺
|
||||
@app.post("", response_model=Out, description='创建店铺', summary='创建店铺')
|
||||
@handle_exceptions_unified()
|
||||
async def post(item: Create = Body(..., description='创建数据')):
|
||||
"""
|
||||
创建店铺记录
|
||||
"""
|
||||
res = await Shop.create(**item.model_dump())
|
||||
if not res:
|
||||
raise HTTPException(status_code=400, detail='创建失败')
|
||||
return res
|
||||
|
||||
|
||||
# 查询店铺
|
||||
@app.get("", response_model=OutList, description='获取店铺', summary='获取店铺')
|
||||
@handle_exceptions_unified()
|
||||
async def gets(
|
||||
id: UUID | None = Query(None, description='主键ID'),
|
||||
province: str | None = Query(None, description='省份'),
|
||||
city: str | None = Query(None, description='城市'),
|
||||
street: str | None = Query(None, description='街道'),
|
||||
shop_name: str | None = Query(None, description='店铺名称'),
|
||||
shop_number: str | None = Query(None, description='店铺号码'),
|
||||
order_by: str | None = Query('create_time', description='排序字段',
|
||||
regex='^(-)?(id|province|city|street|shop_name|create_time|update_time)$'),
|
||||
res_count: bool = Query(False, description='是否返回总数'),
|
||||
create_time_start: str | int | None = Query(
|
||||
None, description='创建时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
create_time_end: str | int | None = Query(
|
||||
None, description='创建时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_start: str | int | None = Query(
|
||||
None, description='更新时间开始 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
update_time_end: str | int | None = Query(
|
||||
None, description='更新时间结束 (支持 YYYY-MM-DD / YYYY-MM-DD HH:mm:ss / 13位时间戳)'),
|
||||
page: int = Query(1, ge=1, description='页码'),
|
||||
limit: int = Query(10, ge=1, le=1000, description='每页数量'),
|
||||
):
|
||||
"""
|
||||
获取店铺列表
|
||||
"""
|
||||
query = Shop.all()
|
||||
if id:
|
||||
query = query.filter(id=id)
|
||||
if province:
|
||||
query = query.filter(province=province)
|
||||
if city:
|
||||
query = query.filter(city=city)
|
||||
if street:
|
||||
query = query.filter(street=street)
|
||||
if shop_name:
|
||||
query = query.filter(shop_name=shop_name)
|
||||
if shop_number:
|
||||
query = query.filter(shop_number=shop_number)
|
||||
if create_time_start:
|
||||
query = query.filter(create_time__gte=parse_time(create_time_start))
|
||||
if create_time_end:
|
||||
query = query.filter(create_time__lte=parse_time(
|
||||
create_time_end, is_end=True))
|
||||
if update_time_start:
|
||||
query = query.filter(update_time__gte=parse_time(update_time_start))
|
||||
if update_time_end:
|
||||
query = query.filter(update_time__lte=parse_time(
|
||||
update_time_end, is_end=True))
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(order_by)
|
||||
|
||||
if res_count:
|
||||
count = await query.count()
|
||||
else:
|
||||
count = -1
|
||||
offset = (page - 1) * limit # 计算偏移量
|
||||
query = query.limit(limit).offset(offset) # 应用分页
|
||||
|
||||
res = await query
|
||||
if not res:
|
||||
raise HTTPException(status_code=404, detail='店铺不存在')
|
||||
num = len(res)
|
||||
return OutList(count=count, num=num, items=res)
|
||||
|
||||
|
||||
# 更新店铺
|
||||
@app.put("", response_model=Out, description='更新店铺', summary='更新店铺')
|
||||
@handle_exceptions_unified()
|
||||
async def put(id: UUID = Query(..., description='主键ID'),
|
||||
item: Update = Body(..., description='更新数据'),
|
||||
):
|
||||
"""
|
||||
部分更新店铺,只更新传入的非空字段
|
||||
"""
|
||||
# 检查店铺是否存在
|
||||
secret = await Shop.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='店铺不存在')
|
||||
|
||||
# 获取要更新的字段(排除None值的字段)
|
||||
update_data = item.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail='没有要更新的字段')
|
||||
|
||||
# 更新店铺字段
|
||||
await secret.update_from_dict(update_data)
|
||||
await secret.save()
|
||||
return secret
|
||||
|
||||
|
||||
# 删除店铺
|
||||
|
||||
@app.delete("", response_model=CommonOut, description='删除店铺', summary='删除店铺')
|
||||
@handle_exceptions_unified()
|
||||
async def delete(id: UUID = Query(..., description='主键ID'),
|
||||
):
|
||||
"""删除店铺"""
|
||||
secret = await Shop.get_or_none(id=id)
|
||||
if not secret:
|
||||
raise HTTPException(status_code=404, detail='店铺不存在')
|
||||
await secret.delete()
|
||||
# Tortoise ORM 单个实例的 delete() 方法返回 None,而不是删除的记录数
|
||||
# 删除成功时手动返回 1,如果有异常会被装饰器捕获
|
||||
return CommonOut(count=1)
|
||||
152
app/main.py
Normal file
152
app/main.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from fastapi import FastAPI
|
||||
from settings import TORTOISE_ORM
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from tortoise.contrib.fastapi import register_tortoise
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from tortoise import Tortoise
|
||||
from contextlib import asynccontextmanager
|
||||
from apis import app as main_router
|
||||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用生命周期管理函数
|
||||
|
||||
- 启动:注册定时任务并启动调度器
|
||||
- 关闭:优雅关闭调度器与数据库连接
|
||||
"""
|
||||
print('项目启动...')
|
||||
|
||||
# 初始化数据库连接(使用 Tortoise 直接初始化,确保路由与定时任务可用)
|
||||
try:
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
print('数据库初始化完成')
|
||||
except Exception as e:
|
||||
print(f'数据库初始化失败: {e}')
|
||||
|
||||
# 每30分钟保持一次数据库连接活跃
|
||||
scheduler.add_job(
|
||||
keep_db_connection_alive,
|
||||
IntervalTrigger(minutes=30),
|
||||
id='keep_db_alive',
|
||||
name='保持数据库连接',
|
||||
coalesce=True,
|
||||
misfire_grace_time=30,
|
||||
)
|
||||
|
||||
|
||||
scheduler.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
print('项目结束...')
|
||||
|
||||
# 关闭数据库连接
|
||||
print('关闭数据库连接...')
|
||||
try:
|
||||
await asyncio.wait_for(Tortoise.close_connections(), timeout=2)
|
||||
except asyncio.TimeoutError:
|
||||
print('关闭数据库连接超时')
|
||||
except Exception as e:
|
||||
print(f'关闭数据库连接出错: {e}')
|
||||
|
||||
# 关闭调度器
|
||||
print('关闭调度器...')
|
||||
try:
|
||||
if scheduler is not None and hasattr(scheduler, 'shutdown'):
|
||||
scheduler.shutdown(wait=False)
|
||||
except Exception as e:
|
||||
print(f'关闭调度器出错: {e}')
|
||||
|
||||
|
||||
|
||||
# 创建 FastAPI 应用实例
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 创建调度器实例
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
# 包含主路由
|
||||
app.include_router(main_router)
|
||||
|
||||
# 注意:使用自定义 lifespan 已在启动时手动初始化数据库。
|
||||
# 若改回默认事件机制,可重新启用 register_tortoise。
|
||||
|
||||
|
||||
async def keep_db_connection_alive():
|
||||
"""
|
||||
保持数据库连接活跃的函数
|
||||
定期执行简单查询以防止连接超时
|
||||
"""
|
||||
try:
|
||||
conn = Tortoise.get_connection("default")
|
||||
await conn.execute_query("SELECT 1")
|
||||
print("数据库连接检查成功")
|
||||
except Exception as e:
|
||||
print(f"数据库连接检查失败: {e}")
|
||||
|
||||
|
||||
def signal_handler():
|
||||
"""
|
||||
处理终止信号,确保资源正确释放
|
||||
"""
|
||||
|
||||
async def shutdown():
|
||||
print("收到终止信号,开始优雅关闭...")
|
||||
|
||||
# 关闭数据库连接
|
||||
print("关闭数据库连接...")
|
||||
try:
|
||||
await Tortoise.close_connections()
|
||||
except Exception as e:
|
||||
print(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
# 关闭调度器
|
||||
print("关闭调度器...")
|
||||
try:
|
||||
scheduler.shutdown()
|
||||
except Exception as e:
|
||||
print(f"关闭调度器时出错: {e}")
|
||||
|
||||
print("所有资源已关闭,程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(shutdown())
|
||||
# 给异步任务一些时间完成
|
||||
loop.run_until_complete(asyncio.sleep(2))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from uvicorn import run
|
||||
|
||||
# 注册信号处理
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
signal.signal(sig, lambda sig, frame: signal_handler())
|
||||
|
||||
run(
|
||||
'main:app',
|
||||
host='0.0.0.0',
|
||||
port=6060,
|
||||
reload=False,
|
||||
workers=1,
|
||||
# loop='uvloop',
|
||||
http='httptools',
|
||||
limit_concurrency=10000,
|
||||
backlog=4096,
|
||||
timeout_keep_alive=5
|
||||
)
|
||||
65
app/migrations/models/0_20251118164406_init.py
Normal file
65
app/migrations/models/0_20251118164406_init.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
CREATE TABLE IF NOT EXISTS `food` (
|
||||
`id` CHAR(36) NOT NULL PRIMARY KEY COMMENT 'ID',
|
||||
`name` VARCHAR(255) NOT NULL COMMENT '食物名称',
|
||||
`create_time` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
|
||||
`update_time` DATETIME(6) NOT NULL COMMENT '更新时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
|
||||
KEY `idx_food_name_b88f83` (`name`),
|
||||
KEY `idx_food_create__2db565` (`create_time`)
|
||||
) CHARACTER SET utf8mb4 COMMENT='食物表';
|
||||
CREATE TABLE IF NOT EXISTS `info` (
|
||||
`id` CHAR(36) NOT NULL PRIMARY KEY COMMENT 'ID',
|
||||
`firstname` VARCHAR(255) NOT NULL COMMENT '名',
|
||||
`lastname` VARCHAR(255) NOT NULL COMMENT '姓',
|
||||
`full_name` VARCHAR(255) NOT NULL COMMENT '全名',
|
||||
`birthday` VARCHAR(32) NOT NULL COMMENT '生日',
|
||||
`street_address` VARCHAR(255) NOT NULL COMMENT '街道地址',
|
||||
`city` VARCHAR(255) NOT NULL COMMENT '城市',
|
||||
`phone` VARCHAR(64) NOT NULL COMMENT '电话',
|
||||
`zip_code` VARCHAR(20) NOT NULL COMMENT '邮编',
|
||||
`state_fullname` VARCHAR(255) NOT NULL COMMENT '州全称',
|
||||
`status` BOOL NOT NULL COMMENT '状态' DEFAULT 0,
|
||||
`create_time` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
|
||||
`update_time` DATETIME(6) NOT NULL COMMENT '更新时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
|
||||
KEY `idx_info_firstna_11312f` (`firstname`),
|
||||
KEY `idx_info_lastnam_c1f2c2` (`lastname`),
|
||||
KEY `idx_info_full_na_bc9dc4` (`full_name`),
|
||||
KEY `idx_info_street__632b0d` (`street_address`),
|
||||
KEY `idx_info_city_7b94a7` (`city`),
|
||||
KEY `idx_info_zip_cod_7d259e` (`zip_code`),
|
||||
KEY `idx_info_state_f_58c986` (`state_fullname`),
|
||||
KEY `idx_info_create__3bea91` (`create_time`),
|
||||
KEY `idx_info_city_014fff` (`city`, `zip_code`, `state_fullname`),
|
||||
KEY `idx_info_firstna_8d37ca` (`firstname`, `lastname`)
|
||||
) CHARACTER SET utf8mb4 COMMENT='信息表';
|
||||
CREATE TABLE IF NOT EXISTS `shop` (
|
||||
`id` CHAR(36) NOT NULL PRIMARY KEY COMMENT 'ID',
|
||||
`province` VARCHAR(255) NOT NULL COMMENT '省份',
|
||||
`city` VARCHAR(255) NOT NULL COMMENT '城市',
|
||||
`street` VARCHAR(255) NOT NULL COMMENT '街道',
|
||||
`shop_name` VARCHAR(255) NOT NULL COMMENT '店铺名称',
|
||||
`shop_number` VARCHAR(255) COMMENT '店铺号码',
|
||||
`create_time` DATETIME(6) NOT NULL COMMENT '创建时间' DEFAULT CURRENT_TIMESTAMP(6),
|
||||
`update_time` DATETIME(6) NOT NULL COMMENT '更新时间' DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6),
|
||||
KEY `idx_shop_provinc_904758` (`province`),
|
||||
KEY `idx_shop_city_69d82f` (`city`),
|
||||
KEY `idx_shop_street_5aaa95` (`street`),
|
||||
KEY `idx_shop_shop_na_938b2f` (`shop_name`),
|
||||
KEY `idx_shop_create__e13964` (`create_time`),
|
||||
KEY `idx_shop_provinc_72e64a` (`province`, `city`, `street`)
|
||||
) CHARACTER SET utf8mb4 COMMENT='店铺表';
|
||||
CREATE TABLE IF NOT EXISTS `aerich` (
|
||||
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
||||
`version` VARCHAR(255) NOT NULL,
|
||||
`app` VARCHAR(100) NOT NULL,
|
||||
`content` JSON NOT NULL
|
||||
) CHARACTER SET utf8mb4;"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
return """
|
||||
"""
|
||||
4
app/pyproject.toml
Normal file
4
app/pyproject.toml
Normal file
@@ -0,0 +1,4 @@
|
||||
[tool.aerich]
|
||||
tortoise_orm = "settings.TORTOISE_ORM"
|
||||
location = "./migrations"
|
||||
src_folder = "./."
|
||||
34
app/settings.py
Normal file
34
app/settings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
TORTOISE_ORM = {
|
||||
'connections': {
|
||||
'default': {
|
||||
# 'engine': 'tortoise.backends.asyncpg', PostgreSQL
|
||||
'engine': 'tortoise.backends.mysql', # MySQL or Mariadb
|
||||
'credentials': {
|
||||
'host': '192.168.11.67',
|
||||
'port': 3306,
|
||||
'user': 'country',
|
||||
'password': 'sWdFeXMmAbHE5MXj',
|
||||
'database': 'country',
|
||||
'minsize': 10, # 最小连接数设为10,避免连接过多
|
||||
'maxsize': 30, # 最大连接数设为30,避免超出数据库限制
|
||||
'charset': 'utf8mb4',
|
||||
"echo": False,
|
||||
'pool_recycle': 3600, # 增加连接回收时间从300秒到3600秒(1小时)
|
||||
'connect_timeout': 10, # 连接超时时间
|
||||
}
|
||||
},
|
||||
},
|
||||
'apps': {
|
||||
'models': {
|
||||
# 仅注册实际存在的模型模块,移除不存在的 apis.project.models,避免 Aerich 初始化失败
|
||||
'models': [
|
||||
"apis.country.models",
|
||||
"aerich.models"
|
||||
],
|
||||
'default_connection': 'default',
|
||||
|
||||
}
|
||||
},
|
||||
'use_tz': False,
|
||||
'timezone': 'Asia/Shanghai'
|
||||
}
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
143
app/utils/browser_api.py
Normal file
143
app/utils/browser_api.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import datetime
|
||||
import asyncio
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from utils.decorators import handle_exceptions_unified
|
||||
|
||||
|
||||
class BrowserApi:
|
||||
"""
|
||||
浏览器接口
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.local_url = 'http://127.0.0.1:54345'
|
||||
self.headers = {'Content-Type': 'application/json'}
|
||||
# 使用异步 HTTP 客户端,启用连接池和超时设置
|
||||
self.client = httpx.AsyncClient(
|
||||
base_url=self.local_url,
|
||||
headers=self.headers,
|
||||
timeout=httpx.Timeout(30.0, connect=10.0), # 总超时30秒,连接超时10秒
|
||||
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), # 连接池配置
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口,关闭客户端"""
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self):
|
||||
"""关闭 HTTP 客户端"""
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
# 打开指纹浏览器
|
||||
@handle_exceptions_unified()
|
||||
async def open_browser(self, id: str, jc: int = 0):
|
||||
"""
|
||||
打开指纹浏览器(异步优化版本)
|
||||
:param jc: 计次
|
||||
:param id: 浏览器id
|
||||
:return:http, pid
|
||||
"""
|
||||
if jc > 3:
|
||||
return None, None
|
||||
url = '/browser/open'
|
||||
data = {
|
||||
'id': id
|
||||
}
|
||||
try:
|
||||
res = await self.client.post(url, json=data)
|
||||
res.raise_for_status() # 检查 HTTP 状态码
|
||||
res_data = res.json()
|
||||
logger.info(f'打开指纹浏览器: {res_data}')
|
||||
if not res_data.get('success'):
|
||||
logger.error(f'打开指纹浏览器失败: {res_data}')
|
||||
return await self.open_browser(id, jc + 1)
|
||||
data = res_data.get('data')
|
||||
http = data.get('http')
|
||||
pid = data.get('pid')
|
||||
logger.info(f'打开指纹浏览器成功: {http}, {pid}')
|
||||
return http, pid
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f'打开指纹浏览器超时: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f'打开指纹浏览器请求错误: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f'打开指纹浏览器异常: {e}')
|
||||
if jc < 3:
|
||||
return await self.open_browser(id, jc + 1)
|
||||
return None, None
|
||||
|
||||
# 关闭指纹浏览器
|
||||
@handle_exceptions_unified()
|
||||
async def close_browser(self, id: str, jc: int = 0):
|
||||
"""
|
||||
关闭指纹浏览器(异步优化版本)
|
||||
:param jc: 计次
|
||||
:param id: 浏览器id
|
||||
:return:
|
||||
"""
|
||||
if jc > 3:
|
||||
return None
|
||||
url = '/browser/close'
|
||||
data = {
|
||||
'id': id
|
||||
}
|
||||
try:
|
||||
res = await self.client.post(url, json=data)
|
||||
res.raise_for_status() # 检查 HTTP 状态码
|
||||
res_data = res.json()
|
||||
logger.info(f'关闭指纹浏览器: {res_data}')
|
||||
if not res_data.get('success'):
|
||||
msg = res_data.get('msg', '')
|
||||
# 如果浏览器正在打开中,等待后重试(不是真正的错误)
|
||||
if '正在打开中' in msg or 'opening' in msg.lower():
|
||||
if jc < 3:
|
||||
# 等待 1-3 秒后重试(根据重试次数递增等待时间)
|
||||
wait_time = (jc + 1) * 1.0 # 第1次重试等1秒,第2次等2秒,第3次等3秒
|
||||
logger.info(f'浏览器正在打开中,等待 {wait_time} 秒后重试关闭: browser_id={id}')
|
||||
await asyncio.sleep(wait_time)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
else:
|
||||
# 超过重试次数,记录警告但不作为错误
|
||||
logger.warning(f'关闭指纹浏览器失败(浏览器正在打开中,已重试3次): browser_id={id}')
|
||||
return None
|
||||
else:
|
||||
# 其他错误,记录为错误并重试
|
||||
logger.error(f'关闭指纹浏览器失败: {res_data}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(0.5) # 短暂等待后重试
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
logger.info(f'关闭指纹浏览器成功: browser_id={id}')
|
||||
return True
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f'关闭指纹浏览器超时: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f'关闭指纹浏览器请求错误: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f'关闭指纹浏览器异常: {e}')
|
||||
if jc < 3:
|
||||
await asyncio.sleep(1.0)
|
||||
return await self.close_browser(id, jc + 1)
|
||||
return None
|
||||
|
||||
browser_api = BrowserApi()
|
||||
165
app/utils/decorators.py
Normal file
165
app/utils/decorators.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
from typing import Callable, Any, Optional
|
||||
import logging
|
||||
import asyncio
|
||||
from tortoise.exceptions import OperationalError
|
||||
|
||||
# 获取日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_exceptions_unified(
|
||||
max_retries: int = 0,
|
||||
retry_delay: float = 1.0,
|
||||
status_code: int = 500,
|
||||
custom_message: Optional[str] = None,
|
||||
is_background_task: bool = False
|
||||
):
|
||||
"""
|
||||
统一的异常处理装饰器
|
||||
|
||||
集成了所有异常处理功能:数据库重试、自定义状态码、自定义消息、后台任务处理
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数,默认0(不重试)
|
||||
retry_delay: 重试间隔时间(秒),默认1秒
|
||||
status_code: HTTP状态码,默认500
|
||||
custom_message: 自定义错误消息前缀
|
||||
is_background_task: 是否为后台任务(不抛出HTTPException)
|
||||
|
||||
使用方法:
|
||||
# 基础异常处理
|
||||
@handle_exceptions_unified()
|
||||
async def basic_function(...):
|
||||
pass
|
||||
|
||||
# 带数据库重试
|
||||
@handle_exceptions_unified(max_retries=3, retry_delay=1.0)
|
||||
async def db_function(...):
|
||||
pass
|
||||
|
||||
# 自定义状态码和消息
|
||||
@handle_exceptions_unified(status_code=400, custom_message="参数错误")
|
||||
async def validation_function(...):
|
||||
pass
|
||||
|
||||
# 后台任务处理
|
||||
@handle_exceptions_unified(is_background_task=True)
|
||||
async def background_function(...):
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except HTTPException as e:
|
||||
# HTTPException 直接抛出,不重试
|
||||
if is_background_task:
|
||||
logger.error(f"后台任务 {func.__name__} HTTPException: {str(e)}")
|
||||
return False
|
||||
raise
|
||||
except OperationalError as e:
|
||||
last_exception = e
|
||||
error_msg = str(e).lower()
|
||||
|
||||
# 检查是否是连接相关的错误
|
||||
if any(keyword in error_msg for keyword in [
|
||||
'lost connection', 'connection', 'timeout',
|
||||
'server has gone away', 'broken pipe'
|
||||
]):
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 数据库连接错误 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
|
||||
)
|
||||
# 等待一段时间后重试,使用指数退避
|
||||
await asyncio.sleep(retry_delay * (2 ** attempt))
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"函数 {func.__name__} 数据库连接错误,已达到最大重试次数: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# 非连接错误,直接处理
|
||||
logger.error(f"函数 {func.__name__} 发生数据库错误: {str(e)}")
|
||||
if is_background_task:
|
||||
return False
|
||||
error_detail = f"{custom_message}: {str(e)}" if custom_message else f"数据库操作失败: {str(e)}"
|
||||
raise HTTPException(status_code=status_code, detail=error_detail)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 发生异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay * (2 ** attempt))
|
||||
continue
|
||||
else:
|
||||
logger.error(f"函数 {func.__name__} 发生异常: {str(e)}", exc_info=True)
|
||||
if is_background_task:
|
||||
return False
|
||||
break
|
||||
|
||||
# 所有重试都失败了,处理最后一个异常
|
||||
if is_background_task:
|
||||
return False
|
||||
|
||||
if isinstance(last_exception, OperationalError):
|
||||
error_detail = f"{custom_message}: 数据库连接失败: {str(last_exception)}" if custom_message else f"数据库连接失败: {str(last_exception)}"
|
||||
else:
|
||||
error_detail = f"{custom_message}: {str(last_exception)}" if custom_message else str(last_exception)
|
||||
|
||||
raise HTTPException(status_code=status_code, detail=error_detail)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# 向后兼容的别名函数
|
||||
def handle_exceptions_with_db_retry(max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""
|
||||
带数据库连接重试的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
|
||||
def handle_exceptions(func: Callable) -> Callable:
|
||||
"""
|
||||
基础异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified() 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified()(func)
|
||||
|
||||
|
||||
def handle_background_task_exceptions(func: Callable) -> Callable:
|
||||
"""
|
||||
后台任务异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(is_background_task=True)(func)
|
||||
|
||||
|
||||
def handle_exceptions_with_custom_message(message: str = "操作失败"):
|
||||
"""
|
||||
带自定义错误消息的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(custom_message=message)
|
||||
|
||||
|
||||
def handle_exceptions_with_status_code(status_code: int = 500, message: str = None):
|
||||
"""
|
||||
带自定义状态码和错误消息的异常处理装饰器(向后兼容)
|
||||
|
||||
这是 handle_exceptions_unified 的别名,保持向后兼容性
|
||||
"""
|
||||
return handle_exceptions_unified(status_code=status_code, custom_message=message)
|
||||
47
app/utils/exceptions.py
Normal file
47
app/utils/exceptions.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from fastapi import Request, status
|
||||
from fastapi.exceptions import HTTPException, RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from .logs import getLogger
|
||||
|
||||
logger = getLogger(os.environ.get('APP_NAME'))
|
||||
|
||||
|
||||
def global_http_exception_handler(request: Request, exc):
|
||||
"""
|
||||
全局HTTP请求处理异常
|
||||
:param request: HTTP请求对象
|
||||
:param exc: 本次发生的异常对象
|
||||
:return:
|
||||
"""
|
||||
|
||||
# 使用日志记录异常
|
||||
logger.error(f"发生异常:{exc.detail}")
|
||||
|
||||
# 直接返回JSONResponse,避免重新抛出异常导致循环
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
'err_msg': exc.detail,
|
||||
'status': False
|
||||
},
|
||||
headers=getattr(exc, 'headers', None)
|
||||
)
|
||||
|
||||
|
||||
def global_request_exception_handler(request: Request, exc):
|
||||
"""
|
||||
全局请求校验异常处理函数
|
||||
:param request: HTTP请求对象
|
||||
:param exc: 本次发生的异常对象
|
||||
:return:
|
||||
"""
|
||||
|
||||
# 直接返回JSONResponse,避免重新抛出异常
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
'err_msg': exc.errors()[0],
|
||||
'status': False
|
||||
}
|
||||
)
|
||||
218
app/utils/logs.py
Normal file
218
app/utils/logs.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
import os
|
||||
from logging import Logger
|
||||
from concurrent_log_handler import ConcurrentRotatingFileHandler
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
import gzip
|
||||
import shutil
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def getLogger(name: str = 'root') -> Logger:
|
||||
"""
|
||||
创建一个按2小时滚动、支持多进程安全、自动压缩日志的 Logger
|
||||
:param name: 日志器名称
|
||||
:return: 单例 Logger 对象
|
||||
"""
|
||||
logger: Logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
if not logger.handlers:
|
||||
# 控制台输出
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# 日志目录
|
||||
log_dir = "logs"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 日志文件路径
|
||||
log_file = os.path.join(log_dir, f"{name}.log")
|
||||
|
||||
# 文件处理器:每2小时滚动一次,保留7天,共84个文件,支持多进程写入
|
||||
file_handler = TimedRotatingFileHandler(
|
||||
filename=log_file,
|
||||
when='H',
|
||||
interval=2, # 每2小时切一次
|
||||
backupCount=84, # 保留7天 = 7 * 24 / 2 = 84个文件
|
||||
encoding='utf-8',
|
||||
delay=False,
|
||||
utc=False # 你也可以改成 True 表示按 UTC 时间切
|
||||
)
|
||||
|
||||
# 设置 Formatter - 简化格式,去掉路径信息
|
||||
formatter = logging.Formatter(
|
||||
fmt="【{name}】{levelname} {asctime} {message}",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
style="{"
|
||||
)
|
||||
console_formatter = logging.Formatter(
|
||||
fmt="{levelname} {asctime} {message}",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
style="{"
|
||||
)
|
||||
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 添加压缩功能(在第一次创建 logger 时执行一次)
|
||||
_compress_old_logs(log_dir, name)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def _compress_old_logs(log_dir: str, name: str):
|
||||
"""
|
||||
将旧日志压缩成 .gz 格式
|
||||
"""
|
||||
pattern = os.path.join(log_dir, f"{name}.log.*")
|
||||
for filepath in glob.glob(pattern):
|
||||
if filepath.endswith('.gz'):
|
||||
continue
|
||||
try:
|
||||
with open(filepath, 'rb') as f_in:
|
||||
with gzip.open(filepath + '.gz', 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
print(f"日志压缩失败: {filepath}, 原因: {e}")
|
||||
|
||||
|
||||
def compress_old_logs(log_dir: str = None, name: str = "root"):
|
||||
"""
|
||||
压缩旧的日志文件(公共接口)
|
||||
|
||||
Args:
|
||||
log_dir: 日志目录,如果不指定则使用默认目录
|
||||
name: 日志器名称
|
||||
"""
|
||||
if log_dir is None:
|
||||
log_dir = "logs"
|
||||
|
||||
_compress_old_logs(log_dir, name)
|
||||
|
||||
|
||||
def log_api_call(logger: Logger, user_id: str = None, endpoint: str = None, method: str = None, params: dict = None, response_status: int = None, client_ip: str = None):
|
||||
"""
|
||||
记录API调用信息,包含用户ID、接口路径、请求方法、参数、响应状态和来源IP
|
||||
|
||||
Args:
|
||||
logger: 日志器对象
|
||||
user_id: 用户ID
|
||||
endpoint: 接口路径
|
||||
method: 请求方法 (GET, POST, PUT, DELETE等)
|
||||
params: 请求参数
|
||||
response_status: 响应状态码
|
||||
client_ip: 客户端IP地址
|
||||
"""
|
||||
try:
|
||||
# 构建日志信息
|
||||
log_parts = []
|
||||
|
||||
if user_id:
|
||||
log_parts.append(f"用户={user_id}")
|
||||
|
||||
if client_ip:
|
||||
log_parts.append(f"IP={client_ip}")
|
||||
|
||||
if method and endpoint:
|
||||
log_parts.append(f"{method} {endpoint}")
|
||||
elif endpoint:
|
||||
log_parts.append(f"接口={endpoint}")
|
||||
|
||||
if params:
|
||||
# 过滤敏感信息
|
||||
safe_params = {k: v for k, v in params.items()
|
||||
if k.lower() not in ['password', 'token', 'secret', 'key']}
|
||||
if safe_params:
|
||||
log_parts.append(f"参数={safe_params}")
|
||||
|
||||
if response_status:
|
||||
log_parts.append(f"状态码={response_status}")
|
||||
|
||||
if log_parts:
|
||||
log_message = " ".join(log_parts)
|
||||
logger.info(log_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录API调用日志失败: {e}")
|
||||
|
||||
|
||||
def delete_old_compressed_logs(log_dir: str = None, days: int = 7):
|
||||
"""
|
||||
删除超过指定天数的压缩日志文件
|
||||
|
||||
Args:
|
||||
log_dir: 日志目录,如果不指定则使用默认目录
|
||||
days: 保留天数,默认7天
|
||||
"""
|
||||
try:
|
||||
if log_dir is None:
|
||||
log_dir = "logs"
|
||||
|
||||
log_path = Path(log_dir)
|
||||
if not log_path.exists():
|
||||
return
|
||||
|
||||
# 计算截止时间
|
||||
cutoff_time = datetime.now() - timedelta(days=days)
|
||||
|
||||
# 获取所有压缩日志文件
|
||||
gz_files = [f for f in log_path.iterdir()
|
||||
if f.is_file() and f.name.endswith('.log.gz')]
|
||||
|
||||
deleted_count = 0
|
||||
for gz_file in gz_files:
|
||||
# 获取文件修改时间
|
||||
file_mtime = datetime.fromtimestamp(gz_file.stat().st_mtime)
|
||||
|
||||
# 如果文件超过保留期限,删除它
|
||||
if file_mtime < cutoff_time:
|
||||
gz_file.unlink()
|
||||
print(f"删除旧压缩日志文件: {gz_file}")
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
print(f"总共删除了 {deleted_count} 个旧压缩日志文件")
|
||||
|
||||
except Exception as e:
|
||||
print(f"删除旧压缩日志文件失败: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger = getLogger('WebAPI')
|
||||
|
||||
# 基础日志测试
|
||||
logger.info("系统启动")
|
||||
logger.debug("调试信息")
|
||||
logger.warning("警告信息")
|
||||
logger.error("错误信息")
|
||||
|
||||
# API调用日志测试
|
||||
log_api_call(
|
||||
logger=logger,
|
||||
user_id="user123",
|
||||
endpoint="/api/users/info",
|
||||
method="GET",
|
||||
params={"id": 123, "fields": ["name", "email"]},
|
||||
response_status=200,
|
||||
client_ip="192.168.1.100"
|
||||
)
|
||||
|
||||
log_api_call(
|
||||
logger=logger,
|
||||
user_id="user456",
|
||||
endpoint="/api/users/login",
|
||||
method="POST",
|
||||
params={"username": "test", "password": "hidden"}, # password会被过滤
|
||||
response_status=401,
|
||||
client_ip="10.0.0.50"
|
||||
)
|
||||
|
||||
# 单例验证
|
||||
logger2 = getLogger('WebAPI')
|
||||
print(f"Logger单例验证: {id(logger) == id(logger2)}")
|
||||
8
app/utils/out_base.py
Normal file
8
app/utils/out_base.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CommonOut(BaseModel):
|
||||
"""操作结果详情模型"""
|
||||
code: int = Field(200, description='状态码')
|
||||
message: str = Field('成功', description='提示信息')
|
||||
count: int = Field(0, description='操作影响的记录数')
|
||||
96
app/utils/redis_tool.py
Normal file
96
app/utils/redis_tool.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import redis
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, host: str = 'localhost', port: int = 6379, password: str = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.password = password
|
||||
self.browser_client = None
|
||||
self.task_client = None
|
||||
self.cache_client = None
|
||||
self.ok_client = None
|
||||
self.init()
|
||||
|
||||
# 初始化
|
||||
def init(self):
|
||||
"""
|
||||
初始化Redis客户端
|
||||
:return:
|
||||
"""
|
||||
if self.browser_client is None:
|
||||
self.browser_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=0,
|
||||
decode_responses=True)
|
||||
|
||||
if self.task_client is None:
|
||||
self.task_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=1,
|
||||
decode_responses=True)
|
||||
|
||||
if self.cache_client is None:
|
||||
self.cache_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=2,
|
||||
decode_responses=True)
|
||||
|
||||
if self.ok_client is None:
|
||||
self.ok_client = redis.Redis(host=self.host, port=self.port, password=self.password, db=3,
|
||||
decode_responses=True)
|
||||
|
||||
logger.info("Redis连接已初始化")
|
||||
|
||||
# 关闭连接
|
||||
def close(self):
|
||||
self.browser_client.close()
|
||||
self.task_client.close()
|
||||
self.cache_client.close()
|
||||
self.ok_client.close()
|
||||
logger.info("Redis连接已关闭")
|
||||
|
||||
"""browser_client"""
|
||||
|
||||
# 写入浏览器信息
|
||||
async def set_browser(self, browser_id: str, data: dict):
|
||||
try:
|
||||
# 处理None值,将其转换为空字符串
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
processed_data[key] = ""
|
||||
else:
|
||||
processed_data[key] = value
|
||||
|
||||
self.browser_client.hset(browser_id, mapping=processed_data)
|
||||
logger.info(f"写入浏览器信息: {browser_id} - {processed_data}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"写入浏览器信息失败: {browser_id} - {e}")
|
||||
return False
|
||||
|
||||
# 获取浏览器信息
|
||||
async def get_browser(self, browser_id: str = None):
|
||||
try:
|
||||
if browser_id is None:
|
||||
# 获取全部数据
|
||||
data = self.browser_client.hgetall()
|
||||
else:
|
||||
data = self.browser_client.hgetall(browser_id)
|
||||
logger.info(f"获取浏览器信息: {browser_id} - {data}")
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"获取浏览器信息失败: {browser_id} - {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
host = '183.66.27.14'
|
||||
port = 50086
|
||||
password = 'redis_AdJsBP'
|
||||
redis_client = RedisClient(host, port, password)
|
||||
# await redis_client.set_browser('9eac7f95ca2d47359ace4083a566e119', {'status': 'online', 'current_task_id': None})
|
||||
await redis_client.get_browser('9eac7f95ca2d47359ace4083a566e119')
|
||||
# 关闭连接
|
||||
redis_client.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
177
app/utils/session_store.py
Normal file
177
app/utils/session_store.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
会话持久化存储(日志文件版 + 内存缓存)
|
||||
|
||||
优化方案:
|
||||
1. 使用日志文件记录(追加模式,性能好,不会因为文件变大而变慢)
|
||||
2. 在内存中保留最近的会话记录(用于快速查询)
|
||||
3. 定期清理过期的内存记录(保留最近1小时或最多1000条)
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str = 'logs/sessions.log', enable_log: bool = True, max_memory_records: int = 1000):
|
||||
"""
|
||||
初始化会话存储。
|
||||
|
||||
Args:
|
||||
file_path (str): 日志文件路径(默认 logs/sessions.log)
|
||||
enable_log (bool): 是否启用日志记录,False 则不记录到文件
|
||||
max_memory_records (int): 内存中保留的最大记录数,默认1000
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.enable_log = enable_log
|
||||
self.max_memory_records = max_memory_records
|
||||
self._lock = threading.Lock()
|
||||
# 内存中的会话记录 {pid: record}
|
||||
self._memory_cache: Dict[int, Dict[str, Any]] = {}
|
||||
# 记录创建时间,用于清理过期记录
|
||||
self._cache_timestamps: Dict[int, datetime] = {}
|
||||
|
||||
if enable_log:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
def _write_log(self, action: str, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
写入日志文件(追加模式,性能好)
|
||||
|
||||
Args:
|
||||
action (str): 操作类型(CREATE/UPDATE)
|
||||
record (Dict[str, Any]): 会话记录
|
||||
"""
|
||||
if not self.enable_log:
|
||||
return
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
log_line = json.dumps({
|
||||
'action': action,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'data': record
|
||||
}, ensure_ascii=False)
|
||||
with open(self.file_path, 'a', encoding='utf-8') as f:
|
||||
f.write(log_line + '\n')
|
||||
except Exception as e:
|
||||
# 静默处理日志写入错误,避免影响主流程
|
||||
logger.debug(f"写入会话日志失败: {e}")
|
||||
|
||||
def _cleanup_old_cache(self) -> None:
|
||||
"""
|
||||
清理过期的内存缓存记录
|
||||
- 保留最近1小时的记录
|
||||
- 最多保留 max_memory_records 条记录
|
||||
"""
|
||||
now = datetime.now()
|
||||
expire_time = now - timedelta(hours=1)
|
||||
|
||||
# 清理过期记录
|
||||
expired_pids = [
|
||||
pid for pid, timestamp in self._cache_timestamps.items()
|
||||
if timestamp < expire_time
|
||||
]
|
||||
for pid in expired_pids:
|
||||
self._memory_cache.pop(pid, None)
|
||||
self._cache_timestamps.pop(pid, None)
|
||||
|
||||
# 如果记录数仍然超过限制,删除最旧的记录
|
||||
if len(self._memory_cache) > self.max_memory_records:
|
||||
# 按时间戳排序,删除最旧的
|
||||
sorted_pids = sorted(
|
||||
self._cache_timestamps.items(),
|
||||
key=lambda x: x[1]
|
||||
)
|
||||
# 计算需要删除的数量
|
||||
to_remove = len(self._memory_cache) - self.max_memory_records
|
||||
for pid, _ in sorted_pids[:to_remove]:
|
||||
self._memory_cache.pop(pid, None)
|
||||
self._cache_timestamps.pop(pid, None)
|
||||
|
||||
def create_session(self, record: Dict[str, Any]) -> None:
|
||||
"""
|
||||
创建新会话记录。
|
||||
|
||||
Args:
|
||||
record (Dict[str, Any]): 会话信息字典
|
||||
"""
|
||||
record = dict(record)
|
||||
record.setdefault('created_at', datetime.now().isoformat())
|
||||
pid = record.get('pid')
|
||||
|
||||
if pid is not None:
|
||||
with self._lock:
|
||||
# 保存到内存缓存
|
||||
self._memory_cache[pid] = record
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
# 清理过期记录
|
||||
self._cleanup_old_cache()
|
||||
|
||||
# 写入日志文件(追加模式,性能好)
|
||||
self._write_log('CREATE', record)
|
||||
|
||||
def update_session(self, pid: int, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
按 PID 更新会话记录。
|
||||
|
||||
Args:
|
||||
pid (int): 进程ID
|
||||
updates (Dict[str, Any]): 更新字段字典
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 更新后的会话记录
|
||||
"""
|
||||
with self._lock:
|
||||
# 从内存缓存获取
|
||||
record = self._memory_cache.get(pid)
|
||||
if record:
|
||||
record.update(updates)
|
||||
record.setdefault('updated_at', datetime.now().isoformat())
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
else:
|
||||
# 如果内存中没有,创建一个新记录
|
||||
record = {'pid': pid}
|
||||
record.update(updates)
|
||||
record.setdefault('created_at', datetime.now().isoformat())
|
||||
record.setdefault('updated_at', datetime.now().isoformat())
|
||||
self._memory_cache[pid] = record
|
||||
self._cache_timestamps[pid] = datetime.now()
|
||||
|
||||
if record:
|
||||
# 写入日志文件
|
||||
self._write_log('UPDATE', record)
|
||||
|
||||
return record
|
||||
|
||||
def get_session_by_pid(self, pid: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
按 PID 查询会话记录(仅从内存缓存查询,性能好)
|
||||
|
||||
Args:
|
||||
pid (int): 进程ID
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 会话记录
|
||||
"""
|
||||
with self._lock:
|
||||
return self._memory_cache.get(pid)
|
||||
|
||||
def list_sessions(self, status: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出会话记录,可按状态过滤(仅从内存缓存查询)
|
||||
|
||||
Args:
|
||||
status (Optional[int]): 状态码过滤(如 100 运行中、200 已结束、500 失败)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 会话记录列表
|
||||
"""
|
||||
with self._lock:
|
||||
records = list(self._memory_cache.values())
|
||||
if status is None:
|
||||
return records
|
||||
return [r for r in records if r.get('status') == status]
|
||||
56
app/utils/time_tool.py
Normal file
56
app/utils/time_tool.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pydantic import BaseModel, field_serializer
|
||||
CN_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
|
||||
def now_cn() -> datetime:
|
||||
"""
|
||||
获取中国时区的当前时间
|
||||
返回带有中国时区信息的 datetime 对象
|
||||
"""
|
||||
return datetime.now(CN_TZ)
|
||||
|
||||
def parse_time(val: str | int, is_end: bool = False) -> datetime:
|
||||
"""
|
||||
将传入的字符串或时间戳解析为中国时区的 datetime,用于数据库查询时间比较。
|
||||
支持格式:
|
||||
- "YYYY-MM-DD"
|
||||
- "YYYY-MM-DD HH:mm:ss"
|
||||
- 10 位时间戳(秒)
|
||||
- 13 位时间戳(毫秒)
|
||||
"""
|
||||
dt_cn: datetime
|
||||
|
||||
if isinstance(val, int) or (isinstance(val, str) and val.isdigit()):
|
||||
ts = int(val)
|
||||
# 根据量级判断是秒还是毫秒
|
||||
if ts >= 10**12:
|
||||
dt_cn = datetime.fromtimestamp(ts / 1000, CN_TZ)
|
||||
else:
|
||||
dt_cn = datetime.fromtimestamp(ts, CN_TZ)
|
||||
else:
|
||||
try:
|
||||
dt_cn = datetime.strptime(val, "%Y-%m-%d").replace(tzinfo=CN_TZ)
|
||||
if is_end:
|
||||
dt_cn = dt_cn.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
try:
|
||||
dt_cn = datetime.strptime(val, "%Y-%m-%d %H:%M:%S").replace(tzinfo=CN_TZ)
|
||||
except ValueError:
|
||||
raise ValueError("时间格式错误,支持 'YYYY-MM-DD' 或 'YYYY-MM-DD HH:mm:ss' 或 10/13位时间戳")
|
||||
|
||||
# 与 ORM 配置保持一致(use_tz=False),返回本地时区的“朴素”时间
|
||||
return dt_cn.replace(tzinfo=None)
|
||||
|
||||
|
||||
# 自动把 datetime 序列化为 13位时间戳的基类
|
||||
class TimestampModel(BaseModel):
|
||||
"""自动把 datetime 序列化为 13位时间戳的基类"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@field_serializer("*", when_used="json", check_fields=False) # "*" 表示作用于所有字段
|
||||
def serialize_datetime(self, value):
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp()*1000) # 转成 13 位 int 时间戳
|
||||
return value
|
||||
Reference in New Issue
Block a user