chatgpt_bot/db/db.py

202 lines
7.9 KiB
Python
Raw Permalink Normal View History

2024-10-09 20:50:46 +00:00
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, BigInteger
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, relationship
from config import DATABASE_URL
Base = declarative_base()
class User(Base):
__tablename__ = 'users'
user_id = Column(BigInteger, primary_key=True, unique=True, index=True) # уникальный ID пользователя из aiogram
first_name = Column(String)
last_name = Column(String)
user_name = Column(String)
test_queries = Column(Integer) # временные промпты
current_model = Column(String)
subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
class Subscription(Base):
__tablename__ = 'subscriptions'
id = Column(Integer, primary_key=True)
user_id = Column(BigInteger, ForeignKey('users.user_id'), nullable=False)
purchased_model = Column(String, nullable=False)
purchased_tokens = Column(Integer)
user = relationship("User", back_populates="subscriptions")
engine = create_async_engine(DATABASE_URL)
async_session = sessionmaker(bind=engine, class_=AsyncSession)
async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_all_users():
async with engine.connect() as connection:
result = await connection.execute(select(User))
users = result.fetchall() # Получаем всех пользователей
return users
async def add_user_to_db(user_id: int, first_name: str, last_name: str, user_name: str) -> None:
async with async_session() as session:
async with session.begin():
user = User(
user_id=user_id,
first_name=first_name,
last_name=last_name,
user_name=user_name,
test_queries=10, # количество временных промптов
current_model=""
)
session.add(user)
await session.commit()
return
async def user_exists(user_id: int) -> bool:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
return user is not None
async def decrease_test_queries(user_id: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.test_queries -= 1 if user.test_queries > 0 else 0
await session.commit()
else:
return None
async def get_temp_prompts(user_id: int):
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
return user.test_queries
else:
return 0
async def set_curren_model(user_id: int, name_of_model: str) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.current_model = name_of_model
await session.commit()
else:
return None
async def add_tokens_to_user(user_id: int, name_of_model: str, queries: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
# Проверяем, есть ли подписка на модель
subscription_result = await session.execute(
select(Subscription).where(
Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model
)
)
subscription = subscription_result.scalar()
if subscription:
# Подписка существует, обновляем количество токенов
subscription.purchased_tokens += queries # Предполагается, что поле tokens существует в модели Subscription
await session.commit() # Сохраняем изменения в базе данных
else:
# Подписка не существует, создаем новую
new_subscription = Subscription(
user_id=user_id,
purchased_model=name_of_model,
purchased_tokens=queries
)
session.add(new_subscription)
await session.commit() # Сохраняем изменения в базе данных
else:
# Пользователь не найден, можно обработать эту ситуацию
print(f"Пользователь с ID {user_id} не найден.")
async def get_current_model(user_id: int) -> str:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
return user.current_model if user else ""
async def get_current_model_tokens(user_id: int, name_of_model: str) -> int:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
subscription_result = await session.execute(
select(Subscription).where(Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model)
)
subscription = subscription_result.scalar()
if subscription:
return subscription.purchased_tokens
else:
return 0
else:
return 0
async def decrease_current_model_tokens(user_id: int, name_of_model: str) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
subscription_result = await session.execute(
select(Subscription).where(Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model)
)
subscription = subscription_result.scalar()
if subscription:
subscription.purchased_tokens -= 1 if subscription.purchased_tokens > 0 else 0
await session.commit()
return
else:
return None
else:
print('No user found')
return None
async def remoove_current_model(user_id: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.current_model = ""
await session.commit()
return
return