diff --git a/db.py b/db.py index 2e6b1ac..d3d55c8 100644 --- a/db.py +++ b/db.py @@ -22,6 +22,7 @@ import aiomysql import config from async_db import AsyncMysqlDB +from async_sqlite_db import AsyncSqliteDB from tools import utils from var import db_conn_pool_var, media_crawler_db_var @@ -33,11 +34,11 @@ async def init_mediacrawler_db(): """ pool = await aiomysql.create_pool( - host=config.RELATION_DB_HOST, - port=config.RELATION_DB_PORT, - user=config.RELATION_DB_USER, - password=config.RELATION_DB_PWD, - db=config.RELATION_DB_NAME, + host=config.MYSQL_DB_HOST, + port=config.MYSQL_DB_PORT, + user=config.MYSQL_DB_USER, + password=config.MYSQL_DB_PWD, + db=config.MYSQL_DB_NAME, autocommit=True, ) async_db_obj = AsyncMysqlDB(pool) @@ -47,6 +48,18 @@ async def init_mediacrawler_db(): media_crawler_db_var.set(async_db_obj) +async def init_sqlite_db(): + """ + 初始化SQLite数据库对象,并将该对象塞给media_crawler_db_var上下文变量 + Returns: + + """ + async_db_obj = AsyncSqliteDB(config.SQLITE_DB_PATH) + + # 将SQLite数据库对象放到上下文变量中 + media_crawler_db_var.set(async_db_obj) + + async def init_db(): """ 初始化db连接池 @@ -54,37 +67,124 @@ async def init_db(): """ utils.logger.info("[init_db] start init mediacrawler db connect object") - await init_mediacrawler_db() - utils.logger.info("[init_db] end init mediacrawler db connect object") + if config.SAVE_DATA_OPTION == "sqlite": + await init_sqlite_db() + utils.logger.info("[init_db] end init sqlite db connect object") + else: + await init_mediacrawler_db() + utils.logger.info("[init_db] end init mysql db connect object") async def close(): """ - 关闭连接池 + 关闭数据库连接 Returns: """ - utils.logger.info("[close] close mediacrawler db pool") - db_pool: aiomysql.Pool = db_conn_pool_var.get() - if db_pool is not None: - db_pool.close() + utils.logger.info("[close] close mediacrawler db connection") + if config.SAVE_DATA_OPTION == "sqlite": + # SQLite数据库连接会在AsyncSqliteDB对象销毁时自动关闭 + utils.logger.info("[close] sqlite db connection will be closed automatically") + else: + # MySQL连接池关闭 + db_pool: aiomysql.Pool = db_conn_pool_var.get() + if db_pool is not None: + db_pool.close() + utils.logger.info("[close] mysql db pool closed") -async def init_table_schema(): +async def init_table_schema(db_type: str = None): """ 用来初始化数据库表结构,请在第一次需要创建表结构的时候使用,多次执行该函数会将已有的表以及数据全部删除 + Args: + db_type: 数据库类型,可选值为 'sqlite' 或 'mysql',如果不指定则使用配置文件中的设置 Returns: """ - utils.logger.info("[init_table_schema] begin init mysql table schema ...") - await init_mediacrawler_db() - async_db_obj: AsyncMysqlDB = media_crawler_db_var.get() - async with aiofiles.open("schema/tables.sql", mode="r", encoding="utf-8") as f: - schema_sql = await f.read() - await async_db_obj.execute(schema_sql) - utils.logger.info("[init_table_schema] mediacrawler table schema init successful") - await close() + # 如果没有指定数据库类型,则使用配置文件中的设置 + if db_type is None: + db_type = config.SAVE_DATA_OPTION + + if db_type == "sqlite": + utils.logger.info("[init_table_schema] begin init sqlite table schema ...") + await init_sqlite_db() + async_db_obj: AsyncSqliteDB = media_crawler_db_var.get() + async with aiofiles.open("schema/sqlite_tables.sql", mode="r", encoding="utf-8") as f: + schema_sql = await f.read() + await async_db_obj.executescript(schema_sql) + utils.logger.info("[init_table_schema] sqlite table schema init successful") + elif db_type == "mysql": + utils.logger.info("[init_table_schema] begin init mysql table schema ...") + await init_mediacrawler_db() + async_db_obj: AsyncMysqlDB = media_crawler_db_var.get() + async with aiofiles.open("schema/tables.sql", mode="r", encoding="utf-8") as f: + schema_sql = await f.read() + await async_db_obj.execute(schema_sql) + utils.logger.info("[init_table_schema] mysql table schema init successful") + await close() + else: + utils.logger.error(f"[init_table_schema] 不支持的数据库类型: {db_type}") + raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql") + + +def show_database_options(): + """ + 显示支持的数据库选项 + """ + print("\n=== MediaCrawler 数据库初始化工具 ===") + print("支持的数据库类型:") + print("1. sqlite - SQLite 数据库 (轻量级,无需额外配置)") + print("2. mysql - MySQL 数据库 (需要配置数据库连接信息)") + print("3. config - 使用配置文件中的设置") + print("4. exit - 退出程序") + print("="*50) + + +def get_user_choice(): + """ + 获取用户选择的数据库类型 + Returns: + str: 用户选择的数据库类型 + """ + while True: + choice = input("请输入数据库类型 (sqlite/mysql/config/exit): ").strip().lower() + + if choice in ['sqlite', 'mysql', 'config', 'exit']: + return choice + else: + print("❌ 无效的选择,请输入: sqlite, mysql, config 或 exit") + + +async def main(): + """ + 主函数,处理用户交互和数据库初始化 + """ + try: + show_database_options() + + while True: + choice = get_user_choice() + + if choice == 'exit': + print("👋 程序已退出") + break + elif choice == 'config': + print(f"📋 使用配置文件中的设置: {config.SAVE_DATA_OPTION}") + await init_table_schema() + print("✅ 数据库表结构初始化完成!") + break + else: + print(f"🚀 开始初始化 {choice.upper()} 数据库...") + await init_table_schema(choice) + print("✅ 数据库表结构初始化完成!") + break + + except KeyboardInterrupt: + print("\n\n⚠️ 用户中断操作") + except Exception as e: + print(f"\n❌ 初始化失败: {str(e)}") + utils.logger.error(f"[main] 数据库初始化失败: {str(e)}") if __name__ == '__main__': - asyncio.get_event_loop().run_until_complete(init_table_schema()) + asyncio.get_event_loop().run_until_complete(main()) diff --git a/main.py b/main.py index 7292701..707f02c 100644 --- a/main.py +++ b/main.py @@ -50,13 +50,13 @@ async def main(): await cmd_arg.parse_cmd() # init db - if config.SAVE_DATA_OPTION == "db": + if config.SAVE_DATA_OPTION in ["db", "sqlite"]: await db.init_db() crawler = CrawlerFactory.create_crawler(platform=config.PLATFORM) await crawler.start() - if config.SAVE_DATA_OPTION == "db": + if config.SAVE_DATA_OPTION in ["db", "sqlite"]: await db.close()