chatgpt_bot/db/db.py

202 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, BigInteger
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, relationship
from config import DATABASE_URL
Base = declarative_base()
class User(Base):
__tablename__ = 'users'
user_id = Column(BigInteger, primary_key=True, unique=True, index=True) # уникальный ID пользователя из aiogram
first_name = Column(String)
last_name = Column(String)
user_name = Column(String)
test_queries = Column(Integer) # временные промпты
current_model = Column(String)
subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
class Subscription(Base):
__tablename__ = 'subscriptions'
id = Column(Integer, primary_key=True)
user_id = Column(BigInteger, ForeignKey('users.user_id'), nullable=False)
purchased_model = Column(String, nullable=False)
purchased_tokens = Column(Integer)
user = relationship("User", back_populates="subscriptions")
engine = create_async_engine(DATABASE_URL)
async_session = sessionmaker(bind=engine, class_=AsyncSession)
async def init_db() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_all_users():
async with engine.connect() as connection:
result = await connection.execute(select(User))
users = result.fetchall() # Получаем всех пользователей
return users
async def add_user_to_db(user_id: int, first_name: str, last_name: str, user_name: str) -> None:
async with async_session() as session:
async with session.begin():
user = User(
user_id=user_id,
first_name=first_name,
last_name=last_name,
user_name=user_name,
test_queries=10, # количество временных промптов
current_model=""
)
session.add(user)
await session.commit()
return
async def user_exists(user_id: int) -> bool:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
return user is not None
async def decrease_test_queries(user_id: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.test_queries -= 1 if user.test_queries > 0 else 0
await session.commit()
else:
return None
async def get_temp_prompts(user_id: int):
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
return user.test_queries
else:
return 0
async def set_curren_model(user_id: int, name_of_model: str) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.current_model = name_of_model
await session.commit()
else:
return None
async def add_tokens_to_user(user_id: int, name_of_model: str, queries: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
# Проверяем, есть ли подписка на модель
subscription_result = await session.execute(
select(Subscription).where(
Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model
)
)
subscription = subscription_result.scalar()
if subscription:
# Подписка существует, обновляем количество токенов
subscription.purchased_tokens += queries # Предполагается, что поле tokens существует в модели Subscription
await session.commit() # Сохраняем изменения в базе данных
else:
# Подписка не существует, создаем новую
new_subscription = Subscription(
user_id=user_id,
purchased_model=name_of_model,
purchased_tokens=queries
)
session.add(new_subscription)
await session.commit() # Сохраняем изменения в базе данных
else:
# Пользователь не найден, можно обработать эту ситуацию
print(f"Пользователь с ID {user_id} не найден.")
async def get_current_model(user_id: int) -> str:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
return user.current_model if user else ""
async def get_current_model_tokens(user_id: int, name_of_model: str) -> int:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
subscription_result = await session.execute(
select(Subscription).where(Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model)
)
subscription = subscription_result.scalar()
if subscription:
return subscription.purchased_tokens
else:
return 0
else:
return 0
async def decrease_current_model_tokens(user_id: int, name_of_model: str) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
subscription_result = await session.execute(
select(Subscription).where(Subscription.user_id == user_id,
Subscription.purchased_model == name_of_model)
)
subscription = subscription_result.scalar()
if subscription:
subscription.purchased_tokens -= 1 if subscription.purchased_tokens > 0 else 0
await session.commit()
return
else:
return None
else:
print('No user found')
return None
async def remoove_current_model(user_id: int) -> None:
async with async_session() as session:
async with session.begin():
result = await session.execute(select(User).where(User.user_id == user_id))
user = result.scalar()
if user:
user.current_model = ""
await session.commit()
return
return