from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, BigInteger from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.future import select from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, relationship from config import DATABASE_URL Base = declarative_base() class User(Base): __tablename__ = 'users' user_id = Column(BigInteger, primary_key=True, unique=True, index=True) # уникальный ID пользователя из aiogram first_name = Column(String) last_name = Column(String) user_name = Column(String) test_queries = Column(Integer) # временные промпты current_model = Column(String) subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan") class Subscription(Base): __tablename__ = 'subscriptions' id = Column(Integer, primary_key=True) user_id = Column(BigInteger, ForeignKey('users.user_id'), nullable=False) purchased_model = Column(String, nullable=False) purchased_tokens = Column(Integer) user = relationship("User", back_populates="subscriptions") engine = create_async_engine(DATABASE_URL) async_session = sessionmaker(bind=engine, class_=AsyncSession) async def init_db() -> None: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def get_all_users(): async with engine.connect() as connection: result = await connection.execute(select(User)) users = result.fetchall() # Получаем всех пользователей return users async def add_user_to_db(user_id: int, first_name: str, last_name: str, user_name: str) -> None: async with async_session() as session: async with session.begin(): user = User( user_id=user_id, first_name=first_name, last_name=last_name, user_name=user_name, test_queries=10, # количество временных промптов current_model="" ) session.add(user) await session.commit() return async def user_exists(user_id: int) -> bool: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() return user is not None async def decrease_test_queries(user_id: int) -> None: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: user.test_queries -= 1 if user.test_queries > 0 else 0 await session.commit() else: return None async def get_temp_prompts(user_id: int): async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: return user.test_queries else: return 0 async def set_curren_model(user_id: int, name_of_model: str) -> None: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: user.current_model = name_of_model await session.commit() else: return None async def add_tokens_to_user(user_id: int, name_of_model: str, queries: int) -> None: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: # Проверяем, есть ли подписка на модель subscription_result = await session.execute( select(Subscription).where( Subscription.user_id == user_id, Subscription.purchased_model == name_of_model ) ) subscription = subscription_result.scalar() if subscription: # Подписка существует, обновляем количество токенов subscription.purchased_tokens += queries # Предполагается, что поле tokens существует в модели Subscription await session.commit() # Сохраняем изменения в базе данных else: # Подписка не существует, создаем новую new_subscription = Subscription( user_id=user_id, purchased_model=name_of_model, purchased_tokens=queries ) session.add(new_subscription) await session.commit() # Сохраняем изменения в базе данных else: # Пользователь не найден, можно обработать эту ситуацию print(f"Пользователь с ID {user_id} не найден.") async def get_current_model(user_id: int) -> str: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() return user.current_model if user else "" async def get_current_model_tokens(user_id: int, name_of_model: str) -> int: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: subscription_result = await session.execute( select(Subscription).where(Subscription.user_id == user_id, Subscription.purchased_model == name_of_model) ) subscription = subscription_result.scalar() if subscription: return subscription.purchased_tokens else: return 0 else: return 0 async def decrease_current_model_tokens(user_id: int, name_of_model: str) -> None: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: subscription_result = await session.execute( select(Subscription).where(Subscription.user_id == user_id, Subscription.purchased_model == name_of_model) ) subscription = subscription_result.scalar() if subscription: subscription.purchased_tokens -= 1 if subscription.purchased_tokens > 0 else 0 await session.commit() return else: return None else: print('No user found') return None async def remoove_current_model(user_id: int) -> None: async with async_session() as session: async with session.begin(): result = await session.execute(select(User).where(User.user_id == user_id)) user = result.scalar() if user: user.current_model = "" await session.commit() return return