|
|
马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
引言
FastAPI作为现代Python Web框架的新星,凭借其出色的性能、直观的API设计和丰富的功能集,迅速赢得了开发者的青睐。它基于Starlette(用于Web路由)和Pydantic(用于数据验证),提供了异步支持、自动API文档生成、类型注解验证等强大功能。本文将深入探索FastAPI的高级特性,从异步处理到依赖注入,全面解析这一现代Web框架的强大功能与实战应用技巧,帮助开发者提升开发效率,构建高性能API。
FastAPI基础回顾
在深入高级特性之前,让我们简要回顾FastAPI的基础知识,为后续内容奠定基础。
FastAPI是一个用于构建API的现代、快速(高性能)的Web框架,基于Python 3.6+的类型提示。以下是一个简单的FastAPI应用示例:
- from fastapi import FastAPI
- app = FastAPI()
- @app.get("/")
- async def read_root():
- return {"Hello": "World"}
- @app.get("/items/{item_id}")
- async def read_item(item_id: int, q: str = None):
- return {"item_id": item_id, "q": q}
复制代码
这个简单示例展示了FastAPI的基本路由定义和路径参数处理。现在,让我们深入探讨FastAPI的高级特性。
异步处理详解
异步编程基础
FastAPI的一个核心优势是其对异步编程的天然支持。异步编程允许应用在等待I/O操作(如数据库查询、网络请求)完成时处理其他请求,从而显著提高并发性能。
在Python中,异步编程主要通过async和await关键字实现。以下是一个简单的异步函数示例:
- import asyncio
- import time
- async def say_after(delay, what_to_say):
- await asyncio.sleep(delay)
- print(what_to_say)
- async def main():
- print(f"started at {time.strftime('%X')}")
- await say_after(1, "Hello")
- await say_after(2, "World")
- print(f"finished at {time.strftime('%X')}")
- asyncio.run(main())
复制代码
FastAPI中的异步路由
在FastAPI中,我们可以轻松定义异步路由处理函数:
- from fastapi import FastAPI
- import asyncio
- app = FastAPI()
- @app.get("/async-endpoint")
- async def async_endpoint():
- # 模拟耗时操作
- await asyncio.sleep(1)
- return {"message": "This is an async endpoint"}
复制代码
同步与异步路由的选择
FastAPI允许开发者根据需要选择同步或异步路由处理函数:
- from fastapi import FastAPI
- import time
- app = FastAPI()
- @app.get("/sync-endpoint")
- def sync_endpoint():
- # 模拟耗时操作
- time.sleep(1)
- return {"message": "This is a sync endpoint"}
- @app.get("/async-endpoint")
- async def async_endpoint():
- # 模拟耗时操作
- await asyncio.sleep(1)
- return {"message": "This is an async endpoint"}
复制代码
注意:对于CPU密集型任务,使用同步函数可能更合适;对于I/O密集型任务,异步函数通常能提供更好的性能。
异步数据库操作
FastAPI与异步数据库驱动(如asyncpg、aiomysql)结合使用,可以大幅提高数据库操作的性能:
- from fastapi import FastAPI, HTTPException
- from databases import Database
- import asyncio
- app = FastAPI()
- # 数据库连接配置
- DATABASE_URL = "postgresql://user:password@localhost/dbname"
- database = Database(DATABASE_URL)
- @app.on_event("startup")
- async def database_connect():
- await database.connect()
- @app.on_event("shutdown")
- async def database_disconnect():
- await database.disconnect()
- @app.get("/users/{user_id}")
- async def read_user(user_id: int):
- query = "SELECT * FROM users WHERE id = :user_id"
- user = await database.fetch_one(query, {"user_id": user_id})
- if not user:
- raise HTTPException(status_code=404, detail="User not found")
- return user
复制代码
异步HTTP客户端
在FastAPI应用中,我们经常需要与其他服务进行通信。使用异步HTTP客户端(如httpx)可以提高性能:
- from fastapi import FastAPI
- import httpx
- app = FastAPI()
- @app.get("/external-data")
- async def get_external_data():
- async with httpx.AsyncClient() as client:
- response = await client.get("https://api.example.com/data")
- return response.json()
复制代码
背景任务
FastAPI支持后台任务,允许在返回响应后继续运行操作:
- from fastapi import FastAPI, BackgroundTasks
- import time
- app = FastAPI()
- def write_notification(email: str, message: str = ""):
- # 模拟发送邮件的耗时操作
- time.sleep(5)
- with open("log.txt", mode="w") as email_file:
- content = f"notification for {email}: {message}"
- email_file.write(content)
- @app.post("/send-notification/{email}")
- async def send_notification(email: str, background_tasks: BackgroundTasks):
- background_tasks.add_task(write_notification, email, message="Some notification message")
- return {"message": "Notification sent in the background"}
复制代码
异步生成器与流式响应
FastAPI支持使用异步生成器创建流式响应,适用于实时数据传输:
- from fastapi import FastAPI
- import asyncio
- app = FastAPI()
- async def generate_numbers():
- for i in range(10):
- yield i
- await asyncio.sleep(0.5)
- @app.get("/stream-numbers")
- async def stream_numbers():
- return generate_numbers()
复制代码
WebSocket支持
FastAPI原生支持WebSocket,便于构建实时应用:
- from fastapi import FastAPI, WebSocket
- import asyncio
- app = FastAPI()
- @app.websocket("/ws")
- async def websocket_endpoint(websocket: WebSocket):
- await websocket.accept()
- try:
- while True:
- data = await websocket.receive_text()
- await websocket.send_text(f"Message text was: {data}")
- except Exception as e:
- await websocket.close()
复制代码
依赖注入系统深入解析
依赖注入基础
依赖注入是FastAPI的核心特性之一,它允许我们以声明式方式定义组件的依赖关系,使代码更加模块化、可测试和可维护。
以下是一个简单的依赖注入示例:
- from fastapi import FastAPI, Depends
- app = FastAPI()
- # 依赖函数
- def common_parameters(q: str = None, skip: int = 0, limit: int = 100):
- return {"q": q, "skip": skip, "limit": limit}
- @app.get("/users/")
- async def read_users(commons: dict = Depends(common_parameters)):
- return commons
复制代码
类作为依赖
FastAPI不仅支持函数作为依赖,还支持类作为依赖:
- from fastapi import FastAPI, Depends
- from typing import Optional
- app = FastAPI()
- class CommonQueryParams:
- def __init__(self, q: Optional[str] = None, skip: int = 0, limit: int = 100):
- self.q = q
- self.skip = skip
- self.limit = limit
- @app.get("/users/")
- async def read_users(commons: CommonQueryParams = Depends()):
- return commons
复制代码
带子依赖的依赖
FastAPI支持创建多层次的依赖结构:
- from fastapi import FastAPI, Depends
- app = FastAPI()
- def query_extractor(q: str = None):
- return q
- def query_or_cookie_extractor(
- q: str = Depends(query_extractor), last_query: str = None
- ):
- if not q:
- return last_query
- return q
- @app.get("/items/")
- async def read_query(query_or_default: str = Depends(query_or_cookie_extractor)):
- return {"q_or_cookie": query_or_default}
复制代码
路径操作装饰器中的依赖
我们可以在路径操作装饰器中直接使用依赖,而不必在路径操作函数中包含依赖参数:
- from fastapi import FastAPI, Depends, Header, HTTPException
- app = FastAPI()
- async def verify_token(x_token: str = Header(...)):
- if x_token != "fake-super-secret-token":
- raise HTTPException(status_code=400, detail="X-Token header invalid")
- return x_token
- async def verify_key(x_key: str = Header(...)):
- if x_key != "fake-super-secret-key":
- raise HTTPException(status_code=400, detail="X-Key header invalid")
- return x_key
- @app.get("/items/", dependencies=[Depends(verify_token), Depends(verify_key)])
- async def read_items():
- return [{"item": "Foo"}, {"item": "Bar"}]
复制代码
使用yield的依赖
FastAPI支持使用yield创建需要额外设置和清理的依赖:
- from fastapi import FastAPI, Depends
- import time
- app = FastAPI()
- async def get_db():
- db = DBSession()
- try:
- yield db
- finally:
- db.close()
- @app.get("/users/")
- async def read_users(db = Depends(get_db)):
- users = db.query(Users).all()
- return users
复制代码
上下文管理器作为依赖
我们可以使用Python的上下文管理器作为依赖:
- from fastapi import FastAPI, Depends
- from contextlib import asynccontextmanager
- app = FastAPI()
- @asynccontextmanager
- async def lifespan(app):
- # 启动时的代码
- print("Starting up...")
- yield
- # 关闭时的代码
- print("Shutting down...")
- app.router.lifespan_context = lifespan
- @app.get("/")
- async def root():
- return {"message": "Hello World"}
复制代码
依赖覆盖
在测试中,我们可能需要覆盖某些依赖,FastAPI提供了app.dependency_overrides来实现这一点:
- from fastapi import FastAPI, Depends
- from fastapi.testclient import TestClient
- app = FastAPI()
- async def get_token():
- return "normal_token"
- @app.get("/token")
- async def read_token(token: str = Depends(get_token)):
- return {"token": token}
- # 测试代码
- async def override_get_token():
- return "test_token"
- app.dependency_overrides[get_token] = override_get_token
- client = TestClient(app)
- def test_read_token():
- response = client.get("/token")
- assert response.status_code == 200
- assert response.json() == {"token": "test_token"}
复制代码
依赖缓存
FastAPI默认会在单个请求中缓存依赖的结果:
- from fastapi import FastAPI, Depends
- app = FastAPI()
- async def reusable_dependency():
- print("This function is called only once per request!")
- return {"result": "reusable"}
- @app.get("/endpoint1")
- async def endpoint1(reusable: dict = Depends(reusable_dependency)):
- return reusable
- @app.get("/endpoint2")
- async def endpoint2(reusable: dict = Depends(reusable_dependency, use_cache=False)):
- # use_cache=False 强制每次调用依赖
- return reusable
复制代码
高级路由与中间件
路由分组
随着应用规模的增长,我们可以使用APIRouter来组织路由:
- from fastapi import FastAPI, APIRouter
- app = FastAPI()
- # 创建路由器
- users_router = APIRouter(prefix="/users", tags=["users"])
- items_router = APIRouter(prefix="/items", tags=["items"])
- # 用户相关路由
- @users_router.get("/")
- async def read_users():
- return [{"username": "Rick"}, {"username": "Morty"}]
- @users_router.get("/{user_id}")
- async def read_user(user_id: int):
- return {"user_id": user_id}
- # 物品相关路由
- @items_router.get("/")
- async def read_items():
- return [{"item_id": 1}, {"item_id": 2}]
- # 将路由器添加到应用
- app.include_router(users_router)
- app.include_router(items_router)
复制代码
路由类
FastAPI支持使用类来组织路由处理函数:
- from fastapi import FastAPI, APIRouter, Depends
- app = FastAPI()
- class UserRoutes:
- def __init__(self, router: APIRouter):
- self.router = router
-
- def register_routes(self):
- @self.router.get("/users")
- async def get_users():
- return {"users": []}
-
- @self.router.get("/users/{user_id}")
- async def get_user(user_id: int):
- return {"user_id": user_id}
- router = APIRouter()
- user_routes = UserRoutes(router)
- user_routes.register_routes()
- app.include_router(router)
复制代码
中间件
中间件是在每个请求被特定路径操作处理之前和之后运行的函数:
- from fastapi import FastAPI, Request
- import time
- app = FastAPI()
- @app.middleware("http")
- async def add_process_time_header(request: Request, call_next):
- start_time = time.time()
- response = await call_next(request)
- process_time = time.time() - start_time
- response.headers["X-Process-Time"] = str(process_time)
- return response
复制代码
CORS中间件
处理跨域资源共享(CORS)是Web应用中的常见需求:
- from fastapi import FastAPI
- from fastapi.middleware.cors import CORSMiddleware
- app = FastAPI()
- # 配置CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], # 在生产环境中应该指定具体的源
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
复制代码
自定义中间件
我们可以创建更复杂的自定义中间件:
- from fastapi import FastAPI, Request, HTTPException
- import jwt
- app = FastAPI()
- @app.middleware("http")
- async def auth_middleware(request: Request, call_next):
- # 获取token
- token = request.headers.get("Authorization")
-
- # 检查是否是公开路径
- if request.url.path in ["/login", "/docs", "/openapi.json"]:
- return await call_next(request)
-
- # 验证token
- if not token:
- raise HTTPException(status_code=401, detail="Not authenticated")
-
- try:
- payload = jwt.decode(token, "your-secret-key", algorithms=["HS256"])
- request.state.user_id = payload["sub"]
- except jwt.PyJWTError:
- raise HTTPException(status_code=401, detail="Invalid authentication credentials")
-
- return await call_next(request)
复制代码
路由级中间件
FastAPI允许我们在特定路由上应用中间件:
- from fastapi import FastAPI, Request
- from fastapi.routing import APIRoute
- app = FastAPI()
- async def log_route_info(request: Request):
- print(f"Request to {request.url.path} received")
- # 创建自定义路由类
- class LoggedRoute(APIRoute):
- def get_route_handler(self):
- original_route_handler = super().get_route_handler()
-
- async def custom_route_handler(request: Request) -> Response:
- await log_route_info(request)
- return await original_route_handler(request)
-
- return custom_route_handler
- # 应用自定义路由
- app.router.route_class = LoggedRoute
- @app.get("/")
- async def read_root():
- return {"message": "Hello World"}
复制代码
数据验证与序列化
Pydantic模型深入
FastAPI使用Pydantic进行数据验证和序列化。让我们深入探讨Pydantic的高级功能:
- from fastapi import FastAPI
- from pydantic import BaseModel, Field, validator
- from typing import List, Optional
- from datetime import datetime
- from enum import Enum
- app = FastAPI()
- class UserRole(str, Enum):
- ADMIN = "admin"
- USER = "user"
- GUEST = "guest"
- class UserBase(BaseModel):
- username: str = Field(..., min_length=3, max_length=20)
- email: str = Field(..., regex=r'^[^@]+@[^@]+\.[^@]+$')
- full_name: Optional[str] = None
- role: UserRole = UserRole.USER
- class UserCreate(UserBase):
- password: str = Field(..., min_length=8)
-
- @validator('password')
- def validate_password(cls, v):
- if not any(c.isupper() for c in v):
- raise ValueError('Password must contain at least one uppercase letter')
- if not any(c.isdigit() for c in v):
- raise ValueError('Password must contain at least one digit')
- return v
- class User(UserBase):
- id: int
- created_at: datetime
-
- class Config:
- orm_mode = True # 允许从ORM对象创建模型实例
- @app.post("/users/", response_model=User)
- async def create_user(user: UserCreate):
- # 在实际应用中,这里会保存到数据库
- # 为了示例,我们返回一个模拟的用户
- return {
- "id": 1,
- "created_at": datetime.now(),
- **user.dict()
- }
复制代码
复杂模型验证
Pydantic支持复杂的验证逻辑:
- from pydantic import BaseModel, validator, root_validator
- from typing import List, Optional
- class OrderItem(BaseModel):
- product_id: int
- quantity: int
- unit_price: float
-
- @validator('quantity')
- def quantity_positive(cls, v):
- if v <= 0:
- raise ValueError('Quantity must be positive')
- return v
-
- @validator('unit_price')
- def price_positive(cls, v):
- if v <= 0:
- raise ValueError('Unit price must be positive')
- return v
- class Order(BaseModel):
- items: List[OrderItem]
- total_amount: float
-
- @root_validator
- def validate_total(cls, values):
- items = values.get('items', [])
- total = sum(item.quantity * item.unit_price for item in items)
- total_amount = values.get('total_amount')
-
- if abs(total - total_amount) > 0.01: # 考虑浮点精度
- raise ValueError(f'Total amount should be {total}')
-
- return values
复制代码
自定义验证器
我们可以创建自定义验证器来处理复杂的验证逻辑:
- from pydantic import BaseModel, validator
- from typing import List
- import re
- class User(BaseModel):
- username: str
- password: str
-
- @validator('username')
- def username_alphanumeric(cls, v):
- if not re.match(r'^[a-zA-Z0-9_]+$', v):
- raise ValueError('Username must be alphanumeric')
- return v
-
- @validator('password')
- def password_strength(cls, v):
- if len(v) < 8:
- raise ValueError('Password must be at least 8 characters')
- if not any(c.isupper() for c in v):
- raise ValueError('Password must contain an uppercase letter')
- if not any(c.islower() for c in v):
- raise ValueError('Password must contain a lowercase letter')
- if not any(c.isdigit() for c in v):
- raise ValueError('Password must contain a digit')
- return v
复制代码
数据转换
Pydantic不仅验证数据,还可以在验证过程中转换数据:
- from pydantic import BaseModel, validator
- from datetime import datetime
- from typing import Optional
- class Event(BaseModel):
- name: str
- start_date: str
- end_date: Optional[str] = None
-
- @validator('start_date', 'end_date', pre=True)
- def parse_date(cls, v):
- if isinstance(v, str):
- try:
- # 尝试解析不同格式的日期
- for fmt in ('%Y-%m-%d', '%d/%m/%Y', '%m/%d/%Y'):
- try:
- return datetime.strptime(v, fmt).date()
- except ValueError:
- pass
- raise ValueError('Invalid date format')
- except Exception as e:
- raise ValueError(f'Could not parse date: {e}')
- return v
复制代码
响应模型
FastAPI允许我们为响应定义精确的模型:
- from fastapi import FastAPI
- from pydantic import BaseModel, EmailStr
- from typing import List, Optional
- app = FastAPI()
- class UserBase(BaseModel):
- username: str
- email: EmailStr
- full_name: Optional[str] = None
- class UserCreate(UserBase):
- password: str
- class User(UserBase):
- id: int
- is_active: bool
-
- class Config:
- orm_mode = True
- class ItemBase(BaseModel):
- title: str
- description: Optional[str] = None
- class ItemCreate(ItemBase):
- pass
- class Item(ItemBase):
- id: int
- owner_id: int
-
- class Config:
- orm_mode = True
- @app.post("/users/", response_model=User)
- async def create_user(user: UserCreate):
- # 实际应用中会保存到数据库
- return {"id": 1, "is_active": True, **user.dict()}
- @app.get("/users/{user_id}", response_model=User)
- async def read_user(user_id: int):
- # 实际应用中会从数据库获取
- return {
- "id": user_id,
- "username": "example",
- "email": "example@example.com",
- "full_name": "Example User",
- "is_active": True
- }
复制代码
响应模型与字段排除
我们可以控制响应中包含哪些字段:
- from fastapi import FastAPI
- from pydantic import BaseModel, EmailStr
- app = FastAPI()
- class UserIn(BaseModel):
- username: str
- password: str
- email: EmailStr
- full_name: Optional[str] = None
- class UserOut(BaseModel):
- username: str
- email: EmailStr
- full_name: Optional[str] = None
- @app.post("/user/", response_model=UserOut)
- async def create_user(user: UserIn):
- # 在实际应用中,这里会保存到数据库
- # 为了示例,我们直接返回用户数据
- return user
复制代码
响应模型与Union
我们可以定义多个可能的响应模型:
- from fastapi import FastAPI
- from pydantic import BaseModel
- from typing import Union
- app = FastAPI()
- class BaseItem(BaseModel):
- description: str
- type: str
- class CarItem(BaseItem):
- type = "car"
- model: str
- class PlaneItem(BaseItem):
- type = "plane"
- size: int
- items = {
- "item1": {"description": "All my friends drive a low rider", "type": "car", "model": "Lowrider"},
- "item2": {"description": "Music is my aeroplane", "type": "plane", "size": 5},
- }
- @app.get("/items/{item_id}", response_model=Union[CarItem, PlaneItem])
- async def read_item(item_id: str):
- return items[item_id]
复制代码
响应模型列表
我们可以返回模型列表:
- from fastapi import FastAPI
- from pydantic import BaseModel
- app = FastAPI()
- class Item(BaseModel):
- name: str
- description: str
- items = [
- {"name": "Foo", "description": "There comes my hero"},
- {"name": "Red", "description": "It's my aeroplane"},
- ]
- @app.get("/items/", response_model=list[Item])
- async def read_items():
- return items
复制代码
安全与认证
API密钥认证
API密钥是最简单的认证方式之一:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import APIKeyHeader
- app = FastAPI()
- API_KEY_NAME = "X-API-KEY"
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
- async def get_api_key(api_key: str = Depends(api_key_header)):
- if api_key != "my-secret-api-key":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid API Key",
- )
- return api_key
- @app.get("/protected-route")
- async def protected_route(api_key: str = Depends(get_api_key)):
- return {"message": "You have access to this protected route"}
复制代码
OAuth2密码流
OAuth2密码流是Web应用中常用的认证方式:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from pydantic import BaseModel
- from typing import Optional
- app = FastAPI()
- # 模拟用户数据库
- fake_users_db = {
- "johndoe": {
- "username": "johndoe",
- "full_name": "John Doe",
- "email": "johndoe@example.com",
- "hashed_password": "fakehashedsecret",
- "disabled": False,
- }
- }
- # 模拟JWT令牌
- fake_tokens_db = {}
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- class User(BaseModel):
- username: str
- email: Optional[str] = None
- full_name: Optional[str] = None
- disabled: Optional[bool] = None
- class UserInDB(User):
- hashed_password: str
- def get_user(db, username: str):
- if username in db:
- user_dict = db[username]
- return UserInDB(**user_dict)
- def fake_decode_token(token):
- # 这里应该解码JWT令牌
- # 为了示例,我们简化处理
- user = get_user(fake_users_db, token)
- return user
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- user = fake_decode_token(token)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid authentication credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- return user
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if current_user.disabled:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- @app.post("/token")
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
- # 在实际应用中,这里应该验证用户名和密码
- # 为了示例,我们简化处理
- if form_data.username not in fake_users_db:
- raise HTTPException(status_code=400, detail="Incorrect username or password")
-
- # 生成并存储令牌
- token = form_data.username # 在实际应用中,这里应该生成JWT
- fake_tokens_db[token] = {"sub": form_data.username}
-
- return {"access_token": token, "token_type": "bearer"}
- @app.get("/users/me")
- async def read_users_me(current_user: User = Depends(get_current_active_user)):
- return current_user
复制代码
JWT认证
JSON Web Token (JWT) 是一种现代认证方式:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from pydantic import BaseModel
- from typing import Optional
- from datetime import datetime, timedelta
- import jwt
- from jwt import PyJWTError
- from passlib.context import CryptContext
- app = FastAPI()
- # 密码上下文
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- # JWT设置
- SECRET_KEY = "your-secret-key"
- ALGORITHM = "HS256"
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
- # 模拟用户数据库
- fake_users_db = {
- "johndoe": {
- "username": "johndoe",
- "full_name": "John Doe",
- "email": "johndoe@example.com",
- "hashed_password": pwd_context.hash("secret"),
- "disabled": False,
- }
- }
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- class Token(BaseModel):
- access_token: str
- token_type: str
- class TokenData(BaseModel):
- username: Optional[str] = None
- class User(BaseModel):
- username: str
- email: Optional[str] = None
- full_name: Optional[str] = None
- disabled: Optional[bool] = None
- class UserInDB(User):
- hashed_password: str
- def verify_password(plain_password, hashed_password):
- return pwd_context.verify(plain_password, hashed_password)
- def get_password_hash(password):
- return pwd_context.hash(password)
- def get_user(db, username: str):
- if username in db:
- user_dict = db[username]
- return UserInDB(**user_dict)
- def authenticate_user(fake_db, username: str, password: str):
- user = get_user(fake_db, username)
- if not user:
- return False
- if not verify_password(password, user.hashed_password):
- return False
- return user
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- token_data = TokenData(username=username)
- except PyJWTError:
- raise credentials_exception
- user = get_user(fake_users_db, username=token_data.username)
- if user is None:
- raise credentials_exception
- return user
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if current_user.disabled:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- @app.post("/token", response_model=Token)
- async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
- user = authenticate_user(fake_users_db, form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
- )
- return {"access_token": access_token, "token_type": "bearer"}
- @app.get("/users/me", response_model=User)
- async def read_users_me(current_user: User = Depends(get_current_active_user)):
- return current_user
复制代码
基于角色的访问控制
基于角色的访问控制(RBAC)是管理复杂权限的常用方法:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from pydantic import BaseModel
- from typing import List, Optional
- from enum import Enum
- app = FastAPI()
- class Role(str, Enum):
- ADMIN = "admin"
- MANAGER = "manager"
- USER = "user"
- GUEST = "guest"
- class User(BaseModel):
- username: str
- roles: List[Role]
- disabled: bool = False
- # 模拟用户数据库
- fake_users_db = {
- "admin": {
- "username": "admin",
- "roles": [Role.ADMIN],
- "disabled": False,
- },
- "manager": {
- "username": "manager",
- "roles": [Role.MANAGER],
- "disabled": False,
- },
- "user": {
- "username": "user",
- "roles": [Role.USER],
- "disabled": False,
- }
- }
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- def get_user(db, username: str):
- if username in db:
- user_dict = db[username]
- return User(**user_dict)
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- # 在实际应用中,这里应该解码JWT令牌
- # 为了示例,我们简化处理
- user = get_user(fake_users_db, token)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid authentication credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- return user
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if current_user.disabled:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- def require_role(required_role: Role):
- def role_checker(current_user: User = Depends(get_current_active_user)):
- if required_role not in current_user.roles and Role.ADMIN not in current_user.roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Insufficient permissions",
- )
- return current_user
- return role_checker
- @app.get("/admin-only")
- async def admin_only_route(user: User = Depends(require_role(Role.ADMIN))):
- return {"message": "Welcome, admin!"}
- @app.get("/manager-or-above")
- async def manager_route(user: User = Depends(require_role(Role.MANAGER))):
- return {"message": "Welcome, manager!"}
- @app.get("/user-or-above")
- async def user_route(user: User = Depends(require_role(Role.USER))):
- return {"message": "Welcome, user!"}
复制代码
安全中间件
FastAPI提供了一些安全相关的中间件:
- from fastapi import FastAPI
- from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
- from fastapi.middleware.trustedhost import TrustedHostMiddleware
- app = FastAPI()
- # 强制HTTPS
- app.add_middleware(HTTPSRedirectMiddleware)
- # 限制可信任的主机
- app.add_middleware(
- TrustedHostMiddleware,
- allowed_hosts=["example.com", "*.example.com"]
- )
复制代码
安全头
我们可以添加安全相关的HTTP头:
- from fastapi import FastAPI, Response
- from fastapi.middleware.gzip import GZipMiddleware
- app = FastAPI()
- # 添加GZip压缩
- app.add_middleware(GZipMiddleware, minimum_size=1000)
- @app.middleware("http")
- async def add_security_headers(request, call_next):
- response = await call_next(request)
-
- # 添加安全相关的HTTP头
- response.headers["X-Content-Type-Options"] = "nosniff"
- response.headers["X-Frame-Options"] = "DENY"
- response.headers["X-XSS-Protection"] = "1; mode=block"
- response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
- response.headers["Permissions-Policy"] = "geolocation=(), microphone=()"
- response.headers["Content-Security-Policy"] = "default-src 'self'"
-
- return response
复制代码
测试策略
使用TestClient进行测试
FastAPI提供了TestClient用于测试API:
- from fastapi import FastAPI
- from fastapi.testclient import TestClient
- app = FastAPI()
- @app.get("/")
- async def read_root():
- return {"message": "Hello World"}
- client = TestClient(app)
- def test_read_root():
- response = client.get("/")
- assert response.status_code == 200
- assert response.json() == {"message": "Hello World"}
复制代码
测试异步端点
测试异步端点与同步端点类似:
- from fastapi import FastAPI
- from fastapi.testclient import TestClient
- import asyncio
- app = FastAPI()
- @app.get("/async-endpoint")
- async def async_endpoint():
- await asyncio.sleep(0.1) # 模拟异步操作
- return {"message": "Async response"}
- client = TestClient(app)
- def test_async_endpoint():
- response = client.get("/async-endpoint")
- assert response.status_code == 200
- assert response.json() == {"message": "Async response"}
复制代码
测试带依赖的端点
测试带依赖的端点可能需要覆盖依赖:
- from fastapi import FastAPI, Depends
- from fastapi.testclient import TestClient
- app = FastAPI()
- def get_token():
- return "normal_token"
- @app.get("/protected-route")
- async def protected_route(token: str = Depends(get_token)):
- return {"token": token}
- client = TestClient(app)
- def test_protected_route():
- response = client.get("/protected-route")
- assert response.status_code == 200
- assert response.json() == {"token": "normal_token"}
- # 覆盖依赖
- def override_get_token():
- return "test_token"
- app.dependency_overrides[get_token] = override_get_token
- def test_protected_route_with_override():
- response = client.get("/protected-route")
- assert response.status_code == 200
- assert response.json() == {"token": "test_token"}
-
- # 清除覆盖
- app.dependency_overrides.clear()
复制代码
测试带认证的端点
测试需要认证的端点:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer
- from fastapi.testclient import TestClient
- app = FastAPI()
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- if token != "valid-token":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid authentication credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- return {"username": "john", "token": token}
- @app.get("/users/me")
- async def read_users_me(current_user: dict = Depends(get_current_user)):
- return current_user
- client = TestClient(app)
- def test_read_users_me_with_valid_token():
- response = client.get("/users/me", headers={"Authorization": "Bearer valid-token"})
- assert response.status_code == 200
- assert response.json() == {"username": "john", "token": "valid-token"}
- def test_read_users_me_with_invalid_token():
- response = client.get("/users/me", headers={"Authorization": "Bearer invalid-token"})
- assert response.status_code == 401
- assert response.json() == {"detail": "Invalid authentication credentials"}
复制代码
使用pytest进行高级测试
使用pytest进行更复杂的测试:
- import pytest
- from fastapi import FastAPI
- from fastapi.testclient import TestClient
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy import Column, Integer, String
- # 创建测试数据库
- SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
- engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
- TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- Base = declarative_base()
- class User(Base):
- __tablename__ = "users"
-
- id = Column(Integer, primary_key=True, index=True)
- username = Column(String, unique=True, index=True)
- email = Column(String, unique=True, index=True)
- hashed_password = Column(String)
- # 创建表
- Base.metadata.create_all(bind=engine)
- app = FastAPI()
- # 依赖覆盖
- def override_get_db():
- try:
- db = TestingSessionLocal()
- yield db
- finally:
- db.close()
- app.dependency_overrides[get_db] = override_get_db
- client = TestClient(app)
- @pytest.fixture(scope="module")
- def db():
- Base.metadata.create_all(bind=engine)
- yield
- Base.metadata.drop_all(bind=engine)
- def test_create_user(db):
- response = client.post(
- "/users/",
- json={"username": "testuser", "email": "test@example.com", "password": "testpassword"},
- )
- assert response.status_code == 200
- data = response.json()
- assert data["username"] == "testuser"
- assert data["email"] == "test@example.com"
- assert "id" in data
- def test_read_user(db):
- # 首先创建一个用户
- response = client.post(
- "/users/",
- json={"username": "testuser2", "email": "test2@example.com", "password": "testpassword"},
- )
- user_id = response.json()["id"]
-
- # 然后读取该用户
- response = client.get(f"/users/{user_id}")
- assert response.status_code == 200
- data = response.json()
- assert data["username"] == "testuser2"
- assert data["email"] == "test2@example.com"
- assert data["id"] == user_id
复制代码
性能测试
使用locust进行性能测试:
- from locust import HttpUser, task, between
- import json
- class FastAPIUser(HttpUser):
- wait_time = between(1, 2.5)
-
- @task
- def read_root(self):
- self.client.get("/")
-
- @task(3)
- def read_item(self):
- for item_id in range(1, 10):
- self.client.get(f"/items/{item_id}")
-
- @task(2)
- def create_item(self):
- self.client.post(
- "/items/",
- json={"name": f"Test Item {self.user_id}", "description": "A test item"}
- )
复制代码
性能优化
异步数据库操作
使用异步数据库驱动可以显著提高性能:
- from fastapi import FastAPI
- from databases import Database
- import sqlalchemy
- app = FastAPI()
- # 数据库配置
- DATABASE_URL = "postgresql://user:password@localhost/dbname"
- database = Database(DATABASE_URL)
- # SQLAlchemy模型
- metadata = sqlalchemy.MetaData()
- users = sqlalchemy.Table(
- "users",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("name", sqlalchemy.String),
- sqlalchemy.Column("email", sqlalchemy.String),
- )
- # 启动和关闭时连接和断开数据库
- @app.on_event("startup")
- async def startup():
- await database.connect()
- @app.on_event("shutdown")
- async def shutdown():
- await database.disconnect()
- @app.get("/users/{user_id}")
- async def read_user(user_id: int):
- query = users.select().where(users.c.id == user_id)
- return await database.fetch_one(query)
- @app.post("/users/")
- async def create_user(name: str, email: str):
- query = users.insert().values(name=name, email=email)
- last_record_id = await database.execute(query)
- return {"id": last_record_id, "name": name, "email": email}
复制代码
使用缓存
使用缓存可以减少数据库查询和提高响应速度:
- from fastapi import FastAPI, Depends
- from fastapi_cache import FastAPICache
- from fastapi_cache.backends.redis import RedisBackend
- from fastapi_cache.decorator import cache
- from redis import asyncio as aioredis
- import time
- app = FastAPI()
- @app.on_event("startup")
- async def startup():
- redis = aioredis.from_url("redis://localhost")
- FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
- @app.get("/expensive-operation")
- @cache(expire=60) # 缓存60秒
- async def expensive_operation():
- # 模拟耗时操作
- time.sleep(2)
- return {"result": "This is an expensive operation result"}
复制代码
连接池管理
使用连接池可以提高数据库操作的性能:
- from fastapi import FastAPI
- from databases import Database
- import sqlalchemy
- app = FastAPI()
- # 数据库配置,带有连接池参数
- DATABASE_URL = "postgresql://user:password@localhost/dbname"
- database = Database(DATABASE_URL, force_rollback=True, pool_size=10, max_overflow=20)
- # SQLAlchemy模型
- metadata = sqlalchemy.MetaData()
- users = sqlalchemy.Table(
- "users",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("name", sqlalchemy.String),
- sqlalchemy.Column("email", sqlalchemy.String),
- )
- @app.on_event("startup")
- async def startup():
- await database.connect()
- @app.on_event("shutdown")
- async def shutdown():
- await database.disconnect()
- @app.get("/users/{user_id}")
- async def read_user(user_id: int):
- query = users.select().where(users.c.id == user_id)
- return await database.fetch_one(query)
复制代码
响应压缩
启用响应压缩可以减少网络传输时间:
- from fastapi import FastAPI
- from fastapi.middleware.gzip import GZipMiddleware
- app = FastAPI()
- # 添加GZip中间件
- app.add_middleware(GZipMiddleware, minimum_size=1000)
- @app.get("/large-data")
- async def get_large_data():
- # 返回大量数据
- return {"data": "x" * 10000}
复制代码
静态文件服务
使用静态文件服务器可以减轻应用服务器的负担:
- from fastapi import FastAPI
- from fastapi.staticfiles import StaticFiles
- from fastapi.responses import HTMLResponse
- app = FastAPI()
- # 挂载静态文件目录
- app.mount("/static", StaticFiles(directory="static"), name="static")
- @app.get("/")
- async def main():
- return HTMLResponse(content="""
- <html>
- <head>
- <title>FastAPI Static Files</title>
- </head>
- <body>
- <h1>Hello, World!</h1>
- <img src="/static/logo.png" alt="Logo">
- </body>
- </html>
- """)
复制代码
使用CDN
将静态资源托管到CDN可以显著提高加载速度:
- from fastapi import FastAPI
- from fastapi.templating import Jinja2Templates
- from fastapi import Request
- app = FastAPI()
- templates = Jinja2Templates(directory="templates")
- @app.get("/", response_class=HTMLResponse)
- async def read_item(request: Request):
- return templates.TemplateResponse("index.html", {
- "request": request,
- "cdn_url": "https://your-cdn-url.com"
- })
复制代码
在模板文件中:
- <!DOCTYPE html>
- <html>
- <head>
- <title>FastAPI CDN Example</title>
- <link rel="stylesheet" href="{{ cdn_url }}/css/style.css">
- </head>
- <body>
- <h1>Hello, World!</h1>
- <script src="{{ cdn_url }}/js/main.js"></script>
- </body>
- </html>
复制代码
数据库查询优化
优化数据库查询可以显著提高性能:
- from fastapi import FastAPI
- from databases import Database
- import sqlalchemy
- app = FastAPI()
- DATABASE_URL = "postgresql://user:password@localhost/dbname"
- database = Database(DATABASE_URL)
- metadata = sqlalchemy.MetaData()
- # 定义表
- users = sqlalchemy.Table(
- "users",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("name", sqlalchemy.String),
- sqlalchemy.Column("email", sqlalchemy.String),
- )
- posts = sqlalchemy.Table(
- "posts",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("title", sqlalchemy.String),
- sqlalchemy.Column("content", sqlalchemy.String),
- sqlalchemy.Column("user_id", sqlalchemy.Integer, sqlalchemy.ForeignKey("users.id")),
- )
- @app.on_event("startup")
- async def startup():
- await database.connect()
- @app.on_event("shutdown")
- async def shutdown():
- await database.disconnect()
- # 低效的查询 - N+1问题
- @app.get("/users-with-posts-inefficient")
- async def get_users_with_posts_inefficient():
- users_query = users.select()
- users_list = await database.fetch_all(users_query)
-
- result = []
- for user in users_list:
- posts_query = posts.select().where(posts.c.user_id == user["id"])
- user_posts = await database.fetch_all(posts_query)
- result.append({
- "user": user,
- "posts": user_posts
- })
-
- return result
- # 高效的查询 - 使用JOIN
- @app.get("/users-with-posts-efficient")
- async def get_users_with_posts_efficient():
- # 使用JOIN一次性获取所有数据
- query = sqlalchemy.select([
- users,
- posts
- ]).select_from(
- users.join(posts, users.c.id == posts.c.user_id)
- )
-
- results = await database.fetch_all(query)
-
- # 处理结果
- users_dict = {}
- for row in results:
- user_id = row["id"]
- if user_id not in users_dict:
- users_dict[user_id] = {
- "user": {
- "id": row["id"],
- "name": row["name"],
- "email": row["email"]
- },
- "posts": []
- }
-
- users_dict[user_id]["posts"].append({
- "id": row["id_1"], # 注意:这里可能需要根据实际查询调整
- "title": row["title"],
- "content": row["content"]
- })
-
- return list(users_dict.values())
复制代码
批量操作
使用批量操作可以提高数据库操作效率:
- from fastapi import FastAPI, HTTPException
- from databases import Database
- import sqlalchemy
- from pydantic import BaseModel
- from typing import List
- app = FastAPI()
- DATABASE_URL = "postgresql://user:password@localhost/dbname"
- database = Database(DATABASE_URL)
- metadata = sqlalchemy.MetaData()
- users = sqlalchemy.Table(
- "users",
- metadata,
- sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
- sqlalchemy.Column("name", sqlalchemy.String),
- sqlalchemy.Column("email", sqlalchemy.String),
- )
- @app.on_event("startup")
- async def startup():
- await database.connect()
- @app.on_event("shutdown")
- async def shutdown():
- await database.disconnect()
- class UserCreate(BaseModel):
- name: str
- email: str
- class User(BaseModel):
- id: int
- name: str
- email: str
-
- class Config:
- orm_mode = True
- # 单个创建用户 - 低效
- @app.post("/users/single", response_model=User)
- async def create_user(user: UserCreate):
- query = users.insert().values(name=user.name, email=user.email)
- user_id = await database.execute(query)
- return {**user.dict(), "id": user_id}
- # 批量创建用户 - 高效
- @app.post("/users/batch", response_model=List[User])
- async def create_users_batch(users_data: List[UserCreate]):
- # 准备批量插入的数据
- values = [{"name": user.name, "email": user.email} for user in users_data]
-
- # 执行批量插入
- query = users.insert()
- await database.execute_many(query, values)
-
- # 返回创建的用户(这里简化处理,实际应用中可能需要查询数据库)
- result = []
- for i, user in enumerate(users_data):
- result.append({**user.dict(), "id": i+1}) # 假设ID从1开始
-
- return result
复制代码
实战案例
构建RESTful API
让我们构建一个完整的博客API,展示FastAPI的各种高级特性:
- from fastapi import FastAPI, Depends, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from fastapi.middleware.cors import CORSMiddleware
- from pydantic import BaseModel, EmailStr, validator
- from typing import List, Optional
- from datetime import datetime, timedelta
- import jwt
- from passlib.context import CryptContext
- from databases import Database
- import sqlalchemy
- from sqlalchemy import Column, Integer, String, Text, ForeignKey, DateTime, Boolean
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.orm import relationship, sessionmaker
- from sqlalchemy.orm import Session
- import uvicorn
- app = FastAPI(title="Blog API", description="A complete blog API built with FastAPI")
- # 添加CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 数据库配置
- DATABASE_URL = "sqlite:///./blog.db"
- database = Database(DATABASE_URL)
- # SQLAlchemy配置
- Base = declarative_base()
- class User(Base):
- __tablename__ = "users"
-
- id = Column(Integer, primary_key=True, index=True)
- username = Column(String, unique=True, index=True)
- email = Column(String, unique=True, index=True)
- hashed_password = Column(String)
- is_active = Column(Boolean, default=True)
- posts = relationship("Post", back_populates="author")
- class Post(Base):
- __tablename__ = "posts"
-
- id = Column(Integer, primary_key=True, index=True)
- title = Column(String, index=True)
- content = Column(Text)
- author_id = Column(Integer, ForeignKey("users.id"))
- created_at = Column(DateTime, default=datetime.utcnow)
- updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
- author = relationship("User", back_populates="posts")
- # 创建数据库引擎
- engine = sqlalchemy.create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
- Base.metadata.create_all(bind=engine)
- # 密码上下文
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- # JWT设置
- SECRET_KEY = "your-secret-key-here"
- ALGORITHM = "HS256"
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
- # OAuth2设置
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- # Pydantic模型
- class UserBase(BaseModel):
- username: str
- email: EmailStr
- class UserCreate(UserBase):
- password: str
-
- @validator('password')
- def validate_password(cls, v):
- if len(v) < 8:
- raise ValueError('Password must be at least 8 characters')
- return v
- class UserResponse(UserBase):
- id: int
- is_active: bool
-
- class Config:
- orm_mode = True
- class Token(BaseModel):
- access_token: str
- token_type: str
- class TokenData(BaseModel):
- username: Optional[str] = None
- class PostBase(BaseModel):
- title: str
- content: str
- class PostCreate(PostBase):
- pass
- class PostResponse(PostBase):
- id: int
- author_id: int
- created_at: datetime
- updated_at: datetime
-
- class Config:
- orm_mode = True
- class PostWithAuthor(PostResponse):
- author: UserResponse
-
- class Config:
- orm_mode = True
- # 辅助函数
- def verify_password(plain_password, hashed_password):
- return pwd_context.verify(plain_password, hashed_password)
- def get_password_hash(password):
- return pwd_context.hash(password)
- def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- to_encode.update({"exp": expire})
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- return encoded_jwt
- # 依赖
- async def get_db():
- db = SessionLocal()
- try:
- yield db
- finally:
- db.close()
- async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- token_data = TokenData(username=username)
- except jwt.PyJWTError:
- raise credentials_exception
-
- user = db.query(User).filter(User.username == token_data.username).first()
- if user is None:
- raise credentials_exception
- return user
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if not current_user.is_active:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
- # 路由
- @app.post("/token", response_model=Token)
- async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
- user = db.query(User).filter(User.username == form_data.username).first()
- if not user or not verify_password(form_data.password, user.hashed_password):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
- )
- return {"access_token": access_token, "token_type": "bearer"}
- @app.post("/users/", response_model=UserResponse)
- async def create_user(user: UserCreate, db: Session = Depends(get_db)):
- db_user = db.query(User).filter(User.email == user.email).first()
- if db_user:
- raise HTTPException(status_code=400, detail="Email already registered")
-
- hashed_password = get_password_hash(user.password)
- db_user = User(
- username=user.username,
- email=user.email,
- hashed_password=hashed_password
- )
- db.add(db_user)
- db.commit()
- db.refresh(db_user)
- return db_user
- @app.get("/users/me", response_model=UserResponse)
- async def read_users_me(current_user: User = Depends(get_current_active_user)):
- return current_user
- @app.post("/posts/", response_model=PostResponse)
- async def create_post(
- post: PostCreate,
- current_user: User = Depends(get_current_active_user),
- db: Session = Depends(get_db)
- ):
- db_post = Post(
- title=post.title,
- content=post.content,
- author_id=current_user.id
- )
- db.add(db_post)
- db.commit()
- db.refresh(db_post)
- return db_post
- @app.get("/posts/", response_model=List[PostWithAuthor])
- async def read_posts(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
- posts = db.query(Post).offset(skip).limit(limit).all()
- return posts
- @app.get("/posts/{post_id}", response_model=PostWithAuthor)
- async def read_post(post_id: int, db: Session = Depends(get_db)):
- post = db.query(Post).filter(Post.id == post_id).first()
- if post is None:
- raise HTTPException(status_code=404, detail="Post not found")
- return post
- @app.put("/posts/{post_id}", response_model=PostResponse)
- async def update_post(
- post_id: int,
- post: PostCreate,
- current_user: User = Depends(get_current_active_user),
- db: Session = Depends(get_db)
- ):
- db_post = db.query(Post).filter(Post.id == post_id).first()
- if db_post is None:
- raise HTTPException(status_code=404, detail="Post not found")
-
- if db_post.author_id != current_user.id:
- raise HTTPException(status_code=403, detail="Not authorized to update this post")
-
- db_post.title = post.title
- db_post.content = post.content
- db_post.updated_at = datetime.utcnow()
-
- db.commit()
- db.refresh(db_post)
- return db_post
- @app.delete("/posts/{post_id}", status_code=status.HTTP_204_NO_CONTENT)
- async def delete_post(
- post_id: int,
- current_user: User = Depends(get_current_active_user),
- db: Session = Depends(get_db)
- ):
- db_post = db.query(Post).filter(Post.id == post_id).first()
- if db_post is None:
- raise HTTPException(status_code=404, detail="Post not found")
-
- if db_post.author_id != current_user.id:
- raise HTTPException(status_code=403, detail="Not authorized to delete this post")
-
- db.delete(db_post)
- db.commit()
- return None
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8000)
复制代码
构建GraphQL API
FastAPI也可以与GraphQL结合使用:
- from fastapi import FastAPI
- from fastapi import Request, Response
- from fastapi.middleware.cors import CORSMiddleware
- from databases import Database
- import sqlalchemy
- from sqlalchemy import Column, Integer, String, Text, ForeignKey, DateTime
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.orm import relationship, sessionmaker
- from datetime import datetime
- import strawberry
- from strawberry.fastapi import GraphQLRouter
- app = FastAPI(title="Blog GraphQL API")
- # 添加CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 数据库配置
- DATABASE_URL = "sqlite:///./blog_graphql.db"
- database = Database(DATABASE_URL)
- # SQLAlchemy配置
- Base = declarative_base()
- class User(Base):
- __tablename__ = "users"
-
- id = Column(Integer, primary_key=True, index=True)
- username = Column(String, unique=True, index=True)
- email = Column(String, unique=True, index=True)
- posts = relationship("Post", back_populates="author")
- class Post(Base):
- __tablename__ = "posts"
-
- id = Column(Integer, primary_key=True, index=True)
- title = Column(String, index=True)
- content = Column(Text)
- author_id = Column(Integer, ForeignKey("users.id"))
- created_at = Column(DateTime, default=datetime.utcnow)
- author = relationship("User", back_populates="posts")
- # 创建数据库引擎
- engine = sqlalchemy.create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
- Base.metadata.create_all(bind=engine)
- # Strawberry类型定义
- @strawberry.type
- class UserType:
- id: int
- username: str
- email: str
- @strawberry.type
- class PostType:
- id: int
- title: str
- content: str
- created_at: datetime
- author: UserType
- @strawberry.type
- class Query:
- @strawberry.field
- async def posts(self) -> list[PostType]:
- # 在实际应用中,这里应该从数据库获取数据
- return [
- PostType(
- id=1,
- title="First Post",
- content="This is the first post",
- created_at=datetime.now(),
- author=UserType(id=1, username="john", email="john@example.com")
- ),
- PostType(
- id=2,
- title="Second Post",
- content="This is the second post",
- created_at=datetime.now(),
- author=UserType(id=1, username="john", email="john@example.com")
- )
- ]
-
- @strawberry.field
- async def post(self, id: int) -> PostType:
- # 在实际应用中,这里应该从数据库获取数据
- if id == 1:
- return PostType(
- id=1,
- title="First Post",
- content="This is the first post",
- created_at=datetime.now(),
- author=UserType(id=1, username="john", email="john@example.com")
- )
- else:
- raise ValueError("Post not found")
- @strawberry.type
- class Mutation:
- @strawberry.mutation
- async def create_post(self, title: str, content: str, author_id: int) -> PostType:
- # 在实际应用中,这里应该保存到数据库
- return PostType(
- id=3,
- title=title,
- content=content,
- created_at=datetime.now(),
- author=UserType(id=author_id, username="john", email="john@example.com")
- )
- # 创建GraphQL schema
- schema = strawberry.Schema(query=Query, mutation=Mutation)
- # 创建GraphQL路由
- graphql_app = GraphQLRouter(schema)
- # 添加GraphQL路由
- app.include_router(graphql_app, prefix="/graphql")
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
复制代码
构建WebSocket聊天应用
使用FastAPI构建一个简单的WebSocket聊天应用:
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
- from typing import List
- import json
- from datetime import datetime
- app = FastAPI(title="Chat Application")
- class ConnectionManager:
- def __init__(self):
- self.active_connections: List[WebSocket] = []
- async def connect(self, websocket: WebSocket):
- await websocket.accept()
- self.active_connections.append(websocket)
- def disconnect(self, websocket: WebSocket):
- self.active_connections.remove(websocket)
- async def send_personal_message(self, message: str, websocket: WebSocket):
- await websocket.send_text(message)
- async def broadcast(self, message: str):
- for connection in self.active_connections:
- try:
- await connection.send_text(message)
- except:
- # 连接可能已经断开
- pass
- manager = ConnectionManager()
- @app.websocket("/ws/{client_id}")
- async def websocket_endpoint(websocket: WebSocket, client_id: int):
- await manager.connect(websocket)
- try:
- while True:
- data = await websocket.receive_text()
- message_data = json.loads(data)
-
- # 添加时间戳
- message_data["timestamp"] = datetime.now().isoformat()
-
- # 广播消息
- await manager.broadcast(json.dumps(message_data))
- except WebSocketDisconnect:
- manager.disconnect(websocket)
- await manager.broadcast(json.dumps({
- "sender": "System",
- "message": f"Client #{client_id} left the chat",
- "timestamp": datetime.now().isoformat()
- }))
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
复制代码
构建文件上传服务
使用FastAPI构建一个文件上传服务:
- from fastapi import FastAPI, File, UploadFile, HTTPException, status
- from fastapi.responses import HTMLResponse, FileResponse
- from fastapi.staticfiles import StaticFiles
- import os
- import uuid
- from typing import List
- import aiofiles
- app = FastAPI(title="File Upload Service")
- # 创建上传目录
- UPLOAD_DIR = "uploads"
- os.makedirs(UPLOAD_DIR, exist_ok=True)
- # 挂载静态文件目录
- app.mount("/static", StaticFiles(directory="uploads"), name="static")
- @app.get("/", response_class=HTMLResponse)
- async def main():
- return """
- <!DOCTYPE html>
- <html>
- <head>
- <title>File Upload</title>
- </head>
- <body>
- <h1>Upload a file</h1>
- <form action="/uploadfiles/" enctype="multipart/form-data" method="post">
- <input name="files" type="file" multiple>
- <input type="submit">
- </form>
- </body>
- </html>
- """
- @app.post("/uploadfiles/")
- async def create_upload_files(files: List[UploadFile] = File(...)):
- uploaded_files = []
-
- for file in files:
- # 生成唯一文件名
- file_extension = file.filename.split(".")[-1]
- unique_filename = f"{uuid.uuid4()}.{file_extension}"
- file_path = os.path.join(UPLOAD_DIR, unique_filename)
-
- # 异步保存文件
- async with aiofiles.open(file_path, 'wb') as f:
- content = await file.read()
- await f.write(content)
-
- uploaded_files.append({
- "original_filename": file.filename,
- "stored_filename": unique_filename,
- "content_type": file.content_type,
- "file_path": file_path
- })
-
- return {"uploaded_files": uploaded_files}
- @app.get("/files/{file_name}")
- async def download_file(file_name: str):
- file_path = os.path.join(UPLOAD_DIR, file_name)
-
- if not os.path.exists(file_path):
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="File not found"
- )
-
- return FileResponse(
- path=file_path,
- filename=file_name,
- media_type='application/octet-stream'
- )
- @app.get("/files/")
- async def list_files():
- files = []
- for file_name in os.listdir(UPLOAD_DIR):
- file_path = os.path.join(UPLOAD_DIR, file_name)
- if os.path.isfile(file_path):
- files.append({
- "filename": file_name,
- "size": os.path.getsize(file_path),
- "url": f"/static/{file_name}"
- })
-
- return {"files": files}
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
复制代码
总结与展望
FastAPI作为一个现代、高性能的Web框架,凭借其异步支持、依赖注入系统、自动API文档生成等特性,为Python开发者提供了构建高性能API的强大工具。通过本文的深入探索,我们了解了FastAPI的高级特性,包括异步处理、依赖注入、高级路由与中间件、数据验证与序列化、安全与认证、测试策略以及性能优化等方面的内容。
FastAPI的优势总结
1. 高性能:基于Starlette和Pydantic,FastAPI提供了与NodeJS和Go相当的性能。
2. 快速开发:自动API文档生成、类型提示和编辑器支持大大提高了开发效率。
3. 更少的Bug:通过类型提示和自动数据验证,减少了运行时错误。
4. 直观易用:简洁的API设计和丰富的文档使得学习和使用FastAPI变得容易。
5. 标准兼容:完全兼容OpenAPI和JSON Schema,便于与其他工具集成。
6. 异步支持:原生支持异步编程,提高了I/O密集型应用的性能。
7. 强大的依赖注入系统:使得代码更加模块化、可测试和可维护。
未来展望
FastAPI仍在快速发展中,未来可能会在以下方面继续改进:
1. 更强大的异步支持:随着Python异步生态的成熟,FastAPI可能会提供更多异步相关的功能。
2. 更丰富的中间件:可能会有更多内置中间件,简化常见功能的实现。
3. 更好的测试工具:可能会提供更强大的测试工具,简化API测试。
4. 更广泛的数据库集成:可能会提供更多数据库ORM的集成方案。
5. 更完善的安全特性:可能会提供更多安全相关的功能和最佳实践。
最佳实践建议
在使用FastAPI构建应用时,以下是一些最佳实践建议:
1. 充分利用类型提示:使用类型提示不仅可以帮助编辑器提供更好的支持,还可以利用Pydantic进行数据验证。
2. 合理使用异步:对于I/O密集型操作,使用异步可以显著提高性能;但对于CPU密集型操作,可能需要考虑使用任务队列。
3. 模块化组织代码:使用APIRouter将相关路由组织在一起,保持代码的模块化和可维护性。
4. 合理使用依赖注入:依赖注入是FastAPI的核心特性,合理使用可以使代码更加模块化和可测试。
5. 编写全面的测试:使用TestClient和pytest编写全面的测试,确保API的稳定性和可靠性。
6. 注重安全性:使用HTTPS、输入验证、认证和授权等安全措施保护API。
7. 监控和日志:实现适当的监控和日志记录,以便及时发现和解决问题。
8. 性能优化:使用缓存、连接池、批量操作等技术优化API性能。
FastAPI作为一个现代Web框架,为Python开发者提供了构建高性能API的强大工具。通过深入理解和应用其高级特性,我们可以构建出高效、可靠、易于维护的Web应用。随着FastAPI的不断发展和完善,它将在Python Web开发领域扮演越来越重要的角色。
版权声明
1、转载或引用本网站内容(深入探索FastAPI高级特性从异步处理到依赖注入全面解析现代Web框架的强大功能与实战应用技巧提升开发效率构建高性能API)须注明原网址及作者(威震华夏关云长),并标明本网站网址(https://pixtech.cc/)。
2、对于不当转载或引用本网站内容而引起的民事纷争、行政处理或其他损失,本网站不承担责任。
3、对不遵守本声明或其他违法、恶意使用本网站内容者,本网站保留追究其法律责任的权利。
本文地址: https://pixtech.cc/thread-41525-1-1.html
|
|