Source code for tests.database.base

"""
Base database test class module.

Classes:
  BaseTestDatabaseTestCase
"""

import inspect
import os
import unittest

from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker, Session
from starlette.testclient import TestClient

import models
from api.endpoints import app
from database import db
from models.chat import Base


[docs] class BaseTestDatabaseTestCase(unittest.TestCase): """ Base class for database test cases. """
[docs] @classmethod def setUpClass(cls): cls.engine = create_engine( "sqlite:///:memory:", echo=os.getenv("DEBUG", False) == "True", connect_args={"check_same_thread": False}, ) cls.__session = sessionmaker(autocommit=False, autoflush=False, bind=cls.engine) Base.metadata.create_all(bind=cls.engine) cls.Session = cls.__session() # Override the `get_db` dependency in FastAPI def override_get_db(): """ Overrides the get_db function to use the in-memory SQLite database for testing. """ DATABASE = cls.Session try: yield DATABASE finally: DATABASE.close() # Apply the override to the FastAPI app app.dependency_overrides[db.get_db] = override_get_db cls.client = TestClient(app)
[docs] def setUp(self): self.session: Session = self.Session # Guarantee all Tables of models are available in the test database. for _, table in inspect.getmembers(models, inspect.isclass): try: self.session.query(table).count() except OperationalError: Base.metadata.create_all(bind=self.engine)
[docs] @classmethod def tearDownClass(cls): Base.metadata.drop_all(cls.engine) cls.engine.dispose()