202 lines
7.9 KiB
Python
202 lines
7.9 KiB
Python
|
||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, BigInteger
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.future import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
from sqlalchemy.orm import sessionmaker, relationship
|
||
from config import DATABASE_URL
|
||
|
||
Base = declarative_base()
|
||
class User(Base):
|
||
__tablename__ = 'users'
|
||
|
||
user_id = Column(BigInteger, primary_key=True, unique=True, index=True) # уникальный ID пользователя из aiogram
|
||
first_name = Column(String)
|
||
last_name = Column(String)
|
||
user_name = Column(String)
|
||
test_queries = Column(Integer) # временные промпты
|
||
current_model = Column(String)
|
||
|
||
subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
|
||
|
||
class Subscription(Base):
|
||
__tablename__ = 'subscriptions'
|
||
|
||
id = Column(Integer, primary_key=True)
|
||
user_id = Column(BigInteger, ForeignKey('users.user_id'), nullable=False)
|
||
purchased_model = Column(String, nullable=False)
|
||
purchased_tokens = Column(Integer)
|
||
|
||
user = relationship("User", back_populates="subscriptions")
|
||
|
||
engine = create_async_engine(DATABASE_URL)
|
||
async_session = sessionmaker(bind=engine, class_=AsyncSession)
|
||
|
||
async def init_db() -> None:
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async def get_all_users():
|
||
async with engine.connect() as connection:
|
||
result = await connection.execute(select(User))
|
||
users = result.fetchall() # Получаем всех пользователей
|
||
return users
|
||
|
||
|
||
async def add_user_to_db(user_id: int, first_name: str, last_name: str, user_name: str) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
user = User(
|
||
user_id=user_id,
|
||
first_name=first_name,
|
||
last_name=last_name,
|
||
user_name=user_name,
|
||
test_queries=10, # количество временных промптов
|
||
current_model=""
|
||
)
|
||
session.add(user)
|
||
await session.commit()
|
||
return
|
||
|
||
async def user_exists(user_id: int) -> bool:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
return user is not None
|
||
|
||
async def decrease_test_queries(user_id: int) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
user.test_queries -= 1 if user.test_queries > 0 else 0
|
||
await session.commit()
|
||
else:
|
||
return None
|
||
|
||
|
||
async def get_temp_prompts(user_id: int):
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
return user.test_queries
|
||
else:
|
||
return 0
|
||
|
||
|
||
async def set_curren_model(user_id: int, name_of_model: str) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
user.current_model = name_of_model
|
||
await session.commit()
|
||
else:
|
||
return None
|
||
|
||
|
||
async def add_tokens_to_user(user_id: int, name_of_model: str, queries: int) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
# Проверяем, есть ли подписка на модель
|
||
subscription_result = await session.execute(
|
||
select(Subscription).where(
|
||
Subscription.user_id == user_id,
|
||
Subscription.purchased_model == name_of_model
|
||
)
|
||
)
|
||
subscription = subscription_result.scalar()
|
||
|
||
if subscription:
|
||
# Подписка существует, обновляем количество токенов
|
||
subscription.purchased_tokens += queries # Предполагается, что поле tokens существует в модели Subscription
|
||
await session.commit() # Сохраняем изменения в базе данных
|
||
else:
|
||
# Подписка не существует, создаем новую
|
||
new_subscription = Subscription(
|
||
user_id=user_id,
|
||
purchased_model=name_of_model,
|
||
purchased_tokens=queries
|
||
)
|
||
session.add(new_subscription)
|
||
await session.commit() # Сохраняем изменения в базе данных
|
||
else:
|
||
# Пользователь не найден, можно обработать эту ситуацию
|
||
print(f"Пользователь с ID {user_id} не найден.")
|
||
|
||
|
||
async def get_current_model(user_id: int) -> str:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
return user.current_model if user else ""
|
||
|
||
async def get_current_model_tokens(user_id: int, name_of_model: str) -> int:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
subscription_result = await session.execute(
|
||
select(Subscription).where(Subscription.user_id == user_id,
|
||
Subscription.purchased_model == name_of_model)
|
||
)
|
||
subscription = subscription_result.scalar()
|
||
|
||
if subscription:
|
||
return subscription.purchased_tokens
|
||
else:
|
||
return 0
|
||
else:
|
||
return 0
|
||
|
||
async def decrease_current_model_tokens(user_id: int, name_of_model: str) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
subscription_result = await session.execute(
|
||
select(Subscription).where(Subscription.user_id == user_id,
|
||
Subscription.purchased_model == name_of_model)
|
||
)
|
||
subscription = subscription_result.scalar()
|
||
|
||
if subscription:
|
||
subscription.purchased_tokens -= 1 if subscription.purchased_tokens > 0 else 0
|
||
await session.commit()
|
||
return
|
||
else:
|
||
return None
|
||
else:
|
||
print('No user found')
|
||
return None
|
||
|
||
async def remoove_current_model(user_id: int) -> None:
|
||
async with async_session() as session:
|
||
async with session.begin():
|
||
result = await session.execute(select(User).where(User.user_id == user_id))
|
||
user = result.scalar()
|
||
|
||
if user:
|
||
user.current_model = ""
|
||
await session.commit()
|
||
return
|
||
return
|