Source code for tests.database.test_database

"""
Module for tests for application database

Classes:
  TestGetDB
"""

from sqlalchemy import inspect
from sqlalchemy.orm import Session

from api.endpoints import app
from database import db
from models.chat import Base
from tests.database import base


[docs] class TestGetDB(base.BaseTestDatabaseTestCase): """ Test case class for testing get_db function """
[docs] def test_get_db_returns_session_object(self) -> None: """ Test that get_db returns session object :return: None """ override_get_db = app.dependency_overrides[db.get_db] db_generator = override_get_db() session = next(db_generator) self.assertTrue( isinstance(session, Session), f"The returned type is: {type(session)}" ) # Session needs to be properly closed. session.close()
[docs] def test_get_db_returns_session_object_without_any_missing_tables(self) -> None: """ Test that get_db returns session of database without any missing tables :return: None """ overridden_get_db = app.dependency_overrides[db.get_db] db_generator = overridden_get_db() session: Session = next(db_generator) inspector = inspect(session.get_bind()) existing_tables_in_db_returned_by_get_db = inspector.get_table_names() all_db_tables = Base.metadata.tables.keys() self.assertEqual( 0, # alembic_version table is not created in test database len( set(existing_tables_in_db_returned_by_get_db) .difference({"alembic_version"}) .difference(set(all_db_tables)) ), )