diff --git a/__pycache__/dependencies.cpython-310.pyc b/__pycache__/dependencies.cpython-310.pyc index 6290831..1cdf2c5 100644 Binary files a/__pycache__/dependencies.cpython-310.pyc and b/__pycache__/dependencies.cpython-310.pyc differ diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc index afa68d0..737b915 100644 Binary files a/__pycache__/main.cpython-310.pyc and b/__pycache__/main.cpython-310.pyc differ diff --git a/__pycache__/test.cpython-310.pyc b/__pycache__/test.cpython-310.pyc new file mode 100644 index 0000000..84329ce Binary files /dev/null and b/__pycache__/test.cpython-310.pyc differ diff --git a/dependencies.py b/dependencies.py index b798ddd..30c16da 100644 --- a/dependencies.py +++ b/dependencies.py @@ -1,74 +1,72 @@ -from datetime import datetime, timedelta - -from fastapi import HTTPException, Depends, status +from datetime import datetime, timedelta, timezone from jose import JWTError, jwt from passlib.context import CryptContext +from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends,HTTPException,status +from internal.models import TokenData,UserInDB,User +from internal.database import execute_query -from internal.database import execute_query, create_connection - -# 设置密码上下文 -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -# 随机秘钥 -SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" -# JWT签名算法变量 +# openssl rand -hex 32 +SECRET_KEY = "e86c54c19962d562dab09081e5a6ce0c8ef49ac9a49cdb7135aa670707bbc894" ALGORITHM = "HS256" -# 设置令牌过期时间 ACCESS_TOKEN_EXPIRE_MINUTES = 30 -# 校验密码方法 -def verify_password(plain_password, hashed_password): - return pwd_context.verify(plain_password, hashed_password) - -# 将密码进行hash -def get_password_hash(password): - return pwd_context.hash(password) - -# 认证用户 -def authenticate_user(username: str, password: str): - query = "SELECT * FROM users WHERE username = %s" - user_data = execute_query(query, (username,)) - if not user_data: - return None - stored_password = user_data["password"] - if not verify_password(password, stored_password): - return None - return user_data +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # 创建访问令牌 -def create_access_token(data: dict, expires_delta: timedelta = None): +def create_access_token(data: dict, expires_delta: timedelta): to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -# 获取当前用户 -async def get_current_user(token: str = Depends(create_connection)): - # 首先定义一个 HTTPException 用于处理认证失败的情况 +async def get_current_user(token: str = Depends(oauth2_scheme)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) - try: - # 解码 JWT 令牌,验证签名和有效期 payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception + token_data = TokenData(username=username) except JWTError: - # 如果 JWT 解码失败,则返回认证异常 raise credentials_exception - - # 根据用户名从数据库中获取用户数据 - user = authenticate_user(username) + user = get_user(username=token_data.username) if user is None: raise credentials_exception - - # 返回获取到的用户数据 + return user + +# 验证用户是否为活跃用户 +async def get_current_active_user(current_user: User = Depends(get_current_user)): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + +# 验证密码 +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + +# 获取密码哈希 +def get_password_hash(password): + return pwd_context.hash(password) + +# 从数据库获取信息 +def get_user(username: str): + query = "SELECT * FROM users WHERE username = %s" + result = execute_query(query, (username,), fetchall=False) + if result: + return UserInDB(**result) + +# 验证用户密码 +def authenticate_user(username: str, password: str): + user = get_user(username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False return user diff --git a/internal/__pycache__/__init__.cpython-310.pyc b/internal/__pycache__/__init__.cpython-310.pyc index 1522cfa..d2544a1 100644 Binary files a/internal/__pycache__/__init__.cpython-310.pyc and b/internal/__pycache__/__init__.cpython-310.pyc differ diff --git a/internal/__pycache__/database.cpython-310.pyc b/internal/__pycache__/database.cpython-310.pyc index bc94863..bfb96b5 100644 Binary files a/internal/__pycache__/database.cpython-310.pyc and b/internal/__pycache__/database.cpython-310.pyc differ diff --git a/internal/__pycache__/models.cpython-310.pyc b/internal/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..f39384e Binary files /dev/null and b/internal/__pycache__/models.cpython-310.pyc differ diff --git a/internal/__pycache__/schemas.cpython-310.pyc b/internal/__pycache__/schemas.cpython-310.pyc index 592a5dc..3ea6ffa 100644 Binary files a/internal/__pycache__/schemas.cpython-310.pyc and b/internal/__pycache__/schemas.cpython-310.pyc differ diff --git a/internal/database.py b/internal/database.py index 81983df..298451e 100644 --- a/internal/database.py +++ b/internal/database.py @@ -4,7 +4,7 @@ DB_CONFIG = { "host": "111.229.38.129", "user": "root", "password": "zl981023", - "database": "blogapi", + "database": "test", "charset": "utf8mb4", "cursorclass": pymysql.cursors.DictCursor, } diff --git a/internal/models.py b/internal/models.py new file mode 100644 index 0000000..fb7de62 --- /dev/null +++ b/internal/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + +# Token相关的模型 +class Token(BaseModel): + access_token: str + token_type: str + +class TokenData(BaseModel): + username: str = None + +# User相关的模型 +class User(BaseModel): + username: str + email: str = None + full_name: str = None + disabled: bool = None + +class UserInDB(User): + hashed_password: str diff --git a/internal/schemas.py b/internal/schemas.py deleted file mode 100644 index 5804c64..0000000 --- a/internal/schemas.py +++ /dev/null @@ -1,11 +0,0 @@ -from pydantic import BaseModel - -class Token(BaseModel): - access_token: str - token_type: str - -class TokenData(BaseModel): - username: str | None=None - -class User(BaseModel): - username: str \ No newline at end of file diff --git a/main.py b/main.py index e1a51fa..344049e 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,16 @@ -from fastapi import FastAPI,HTTPException +from datetime import timedelta +from fastapi.security import OAuth2PasswordRequestForm +from fastapi import Depends, FastAPI, HTTPException, status from dependencies import * -from internal.schemas import * -# 初始化 FastAPI 应用 -app = FastAPI() +from internal.models import Token +app=FastAPI() -# 数据库连接参数 - -# 定义登录接口 +# 用户登录 @app.post("/token", response_model=Token) -async def login_for_access_token(username: str, password: str): - user = authenticate_user(username,password) +async def login_for_access_token( + form_data: OAuth2PasswordRequestForm = Depends(), +) -> Token: + user = authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -20,5 +21,29 @@ async def login_for_access_token(username: str, password: str): access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) - return Token(access_token=access_token, token_type="bearer") + return {"access_token": access_token, "token_type": "bearer"} + +# 注册新用户 +@app.post("/register/", response_model=UserInDB) +async def register_user(user: UserInDB): + # 检查用户名是否已经存在 + existing_user = get_user(user.username) + if existing_user: + raise HTTPException(status_code=400, detail="Username already registered") + + # 创建新用户并保存到数据库 + hashed_password = get_password_hash(user.hashed_password) + insert_query = "INSERT INTO users (username, email, full_name, hashed_password) VALUES (%s, %s, %s, %s)" + user_data = (user.username, user.email, user.full_name, hashed_password) + execute_query(insert_query, user_data) + + # 返回创建的用户信息 + return user + +@app.get("/users/me/", response_model=User) +async def read_users_me(current_user: User = Depends(get_current_active_user)): + return current_user +@app.get("/users/me/items/") +async def read_own_items(current_user: User = Depends(get_current_active_user)): + return [{"item_id": "Foo", "owner": current_user.username}]