# mongo_viewset.py
from bson import ObjectId
from functools import wraps
from flasgger import swag_from
from flask import request, Blueprint
from typing import Type, Any, Dict, List, Optional
from .mongo_filters import FilterSet
from .mongo_models import MongoBaseModel
from .mongo_responses import custom_response
from .mongo_paginations import MongoPagination
from .mongo_decorators import handle_api_errors
from .mongo_swagger import generate_swagger_config
from .mongo_exceptions import NotFound, ValidationError
from .mongo_serializers import Serializer, ModelSerializer
[docs]
class MongoModelViewSet:
"""
DRF 风格的 ModelViewSet,提供 list, create, retrieve, update, partial_update, destroy。
子类需要配置:
- collection: PyMongo Collection 实例(必须)
- serializer_class: 序列化器类(可选,默认自动生成)
- filterset_class: 过滤器类(可选)
- pagination_class: 分页器类(默认 MongoPagination)
- default_sort_by / default_sort_order: 默认排序
"""
model_class = None # 模型类(推荐)
collection = None # 直接传入 Collection 对象(备选)
serializer_class: Type[Serializer] = None
filterset_class: Optional[Type[FilterSet]] = None
pagination_class: Optional[Type[MongoPagination]] = MongoPagination
default_sort_by: str = "_id"
default_sort_order: int = -1
[docs]
def __init__(self):
# 1. 获取模型实例
if self.model_class is not None:
self.model = self.model_class() # 实例化模型,模型内部已绑定 collection
elif self.collection is not None:
self.model = MongoBaseModel(self.collection)
else:
raise ValueError("Either model_class or collection must be provided")
# 2. 确定序列化器
self._serializer_class = self.serializer_class or self._get_default_serializer()
def _get_default_serializer(self):
"""如果没有指定 serializer_class,动态创建一个 ModelSerializer"""
# 简单实现:可以扫描 collection 的一条文档来推断字段(可选)
# 这里返回一个空的 ModelSerializer,实际使用时需要用户指定或增强
class AutoSerializer(ModelSerializer):
class Meta:
model = None # 需要用户显式定义或传入
return AutoSerializer
[docs]
def get_serializer(self, instance=None, data=None, partial=False):
"""获取序列化器实例"""
return self._serializer_class(instance=instance, data=data, partial=partial)
[docs]
def get_filter_dict(self) -> Dict[str, Any]:
"""从 request args 构建 MongoDB 过滤条件"""
if self.filterset_class:
return self.filterset_class(data=request.args.to_dict()).get_query_dict()
return {}
[docs]
def get_pipeline(self, filter_dict: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
子类可重写,返回聚合管道。
若返回非 None,则使用聚合管道查询(忽略排序和普通过滤)。
"""
return None
def _apply_pagination(self, cursor, total):
"""对游标应用分页,返回 (结果列表, 分页元数据)"""
if self.pagination_class:
paginator = self.pagination_class(request.args, total, request.endpoint)
cursor = cursor.skip(paginator.skip).limit(paginator.limit)
return list(cursor), paginator.get_meta()
return list(cursor), {}
def _handle_normal_list(self, filter_dict):
"""处理非聚合的列表查询"""
total = self.model.count_documents(filter_dict)
cursor = self.model.find_all(filter_dict, self.default_sort_by, self.default_sort_order)
results, meta = self._apply_pagination(cursor, total)
serializer = self.get_serializer()
serialized_data = [serializer.to_representation(doc) for doc in results]
return custom_response(data=serialized_data, total=total, **meta)
def _handle_aggregate_list(self, filter_dict, pipeline):
"""处理聚合管道的列表查询(使用 $facet 同时获取数据和总数)"""
# 构建 facet 管道
facet_pipeline = pipeline + [{
"$facet": {
"data": [],
"total": [{"$count": "count"}]
}
}]
agg_result = list(self.model.aggregate(facet_pipeline))
total = agg_result[0]['total'][0]['count'] if agg_result and agg_result[0]['total'] else 0
if self.pagination_class:
paginator = self.pagination_class(request.args, total, request.endpoint)
pipeline = pipeline + [{"$skip": paginator.skip}, {"$limit": paginator.limit}]
results = list(self.model.aggregate(pipeline))
meta = paginator.get_meta()
else:
results = list(self.model.aggregate(pipeline))
meta = {}
serializer = self.get_serializer()
serialized_data = [serializer.to_representation(doc) for doc in results]
return custom_response(data=serialized_data, total=total, **meta)
[docs]
@handle_api_errors
def list(self):
"""GET / 获取列表(支持过滤、分页、聚合)"""
filter_dict = self.get_filter_dict()
pipeline = self.get_pipeline(filter_dict)
if pipeline:
return self._handle_aggregate_list(filter_dict, pipeline)
return self._handle_normal_list(filter_dict)
[docs]
@handle_api_errors
def create(self):
"""POST / 创建新资源"""
data = request.json
if not data:
raise ValidationError("JSON body is required")
serializer = self.get_serializer(data=data)
if not serializer.is_valid(raise_exception=True):
raise ValidationError(serializer.errors)
validated_data = serializer.validated_data
result = self.model.insert_one(validated_data)
return custom_response(message="Created", code=201, id=str(result.inserted_id))
[docs]
@handle_api_errors
def retrieve(self, id):
"""GET /<id> 获取单个资源"""
if not ObjectId.is_valid(id):
raise ValidationError("Invalid ID format")
doc = self.model.find_one({"_id": ObjectId(id)})
if not doc:
raise NotFound("Resource not found")
serializer = self.get_serializer(instance=doc)
return custom_response(data=serializer.data())
[docs]
@handle_api_errors
def update(self, id):
"""PUT /<id> 全量替换资源"""
if not ObjectId.is_valid(id):
raise ValidationError("Invalid ID format")
data = request.json
if not data:
raise ValidationError("JSON body is required")
# 检查原文档是否存在
old_doc = self.model.find_one({"_id": ObjectId(id)})
if not old_doc:
raise NotFound("Resource not found")
serializer = self.get_serializer(data=data)
if not serializer.is_valid(raise_exception=True):
raise ValidationError(serializer.errors)
validated_data = serializer.validated_data
validated_data['_id'] = old_doc['_id'] # 保留原 _id
result = self.model.replace_by_id(id, validated_data)
if result.matched_count == 0:
raise NotFound("Resource not found")
return custom_response(message="Updated", modified_count=result.modified_count)
[docs]
@handle_api_errors
def partial_update(self, id):
"""PATCH /<id> 部分更新资源"""
if not ObjectId.is_valid(id):
raise ValidationError("Invalid ID format")
data = request.json
if not data:
raise ValidationError("JSON body is required")
old_doc = self.model.find_one({"_id": ObjectId(id)})
if not old_doc:
raise NotFound("Resource not found")
serializer = self.get_serializer(instance=old_doc, data=data, partial=True)
if not serializer.is_valid(raise_exception=True):
raise ValidationError(serializer.errors)
validated_data = serializer.validated_data
result = self.model.update_one_by_id(id, validated_data)
if result.matched_count == 0:
raise NotFound("Resource not found")
return custom_response(message="Partially Updated", modified_count=result.modified_count)
[docs]
@handle_api_errors
def destroy(self, id):
"""DELETE /<id> 删除资源"""
if not ObjectId.is_valid(id):
raise ValidationError("Invalid ID format")
result = self.model.delete_by_id(id)
if result.deleted_count == 0:
raise NotFound("Resource not found")
return custom_response(message="Deleted", code=204)
[docs]
@classmethod
def register_routes(cls, blueprint: Blueprint, url_prefix: str, actions: List[str] = None):
"""
工业级路由注册,自动绑定 URL 规则和 Swagger 文档。
"""
view_instance = cls()
allowed = actions or ['list', 'create', 'retrieve', 'update', 'partial_update', 'destroy']
all_specs = generate_swagger_config(cls)
mapping = [
('', 'list', ['GET'], 'list'),
('', 'create', ['POST'], 'create'),
('/<id>', 'retrieve', ['GET'], 'retrieve'),
('/<id>', 'update', ['PUT'], 'update'),
('/<id>', 'partial_update', ['PATCH'], 'partial_update'),
('/<id>', 'destroy', ['DELETE'], 'destroy'),
]
def make_view_func(method_name, method):
@wraps(method)
def view(*args, **kwargs):
return method(*args, **kwargs)
if method_name in all_specs:
return swag_from(all_specs[method_name])(view)
return view
for suffix, method_name, http_methods, name_suffix in mapping:
if method_name in allowed:
current_method = getattr(view_instance, method_name)
final_view = make_view_func(method_name, current_method)
blueprint.add_url_rule(
f'/{url_prefix}{suffix}',
view_func=final_view,
methods=http_methods,
endpoint=f'{url_prefix}_{name_suffix}'
)