mirror of
https://github.com/NanmiCoder/MediaCrawler.git
synced 2026-02-24 17:10:48 +08:00
fix: 修复SQLite数据库初始化和关闭逻辑\n\n- 更新 main.py: 修复数据库初始化条件,支持sqlite选项\n- 更新 db.py: 添加SQLite数据库初始化和关闭支持
This commit is contained in:
144
db.py
144
db.py
@@ -22,6 +22,7 @@ import aiomysql
|
||||
|
||||
import config
|
||||
from async_db import AsyncMysqlDB
|
||||
from async_sqlite_db import AsyncSqliteDB
|
||||
from tools import utils
|
||||
from var import db_conn_pool_var, media_crawler_db_var
|
||||
|
||||
@@ -33,11 +34,11 @@ async def init_mediacrawler_db():
|
||||
|
||||
"""
|
||||
pool = await aiomysql.create_pool(
|
||||
host=config.RELATION_DB_HOST,
|
||||
port=config.RELATION_DB_PORT,
|
||||
user=config.RELATION_DB_USER,
|
||||
password=config.RELATION_DB_PWD,
|
||||
db=config.RELATION_DB_NAME,
|
||||
host=config.MYSQL_DB_HOST,
|
||||
port=config.MYSQL_DB_PORT,
|
||||
user=config.MYSQL_DB_USER,
|
||||
password=config.MYSQL_DB_PWD,
|
||||
db=config.MYSQL_DB_NAME,
|
||||
autocommit=True,
|
||||
)
|
||||
async_db_obj = AsyncMysqlDB(pool)
|
||||
@@ -47,6 +48,18 @@ async def init_mediacrawler_db():
|
||||
media_crawler_db_var.set(async_db_obj)
|
||||
|
||||
|
||||
async def init_sqlite_db():
|
||||
"""
|
||||
初始化SQLite数据库对象,并将该对象塞给media_crawler_db_var上下文变量
|
||||
Returns:
|
||||
|
||||
"""
|
||||
async_db_obj = AsyncSqliteDB(config.SQLITE_DB_PATH)
|
||||
|
||||
# 将SQLite数据库对象放到上下文变量中
|
||||
media_crawler_db_var.set(async_db_obj)
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""
|
||||
初始化db连接池
|
||||
@@ -54,37 +67,124 @@ async def init_db():
|
||||
|
||||
"""
|
||||
utils.logger.info("[init_db] start init mediacrawler db connect object")
|
||||
await init_mediacrawler_db()
|
||||
utils.logger.info("[init_db] end init mediacrawler db connect object")
|
||||
if config.SAVE_DATA_OPTION == "sqlite":
|
||||
await init_sqlite_db()
|
||||
utils.logger.info("[init_db] end init sqlite db connect object")
|
||||
else:
|
||||
await init_mediacrawler_db()
|
||||
utils.logger.info("[init_db] end init mysql db connect object")
|
||||
|
||||
|
||||
async def close():
|
||||
"""
|
||||
关闭连接池
|
||||
关闭数据库连接
|
||||
Returns:
|
||||
|
||||
"""
|
||||
utils.logger.info("[close] close mediacrawler db pool")
|
||||
db_pool: aiomysql.Pool = db_conn_pool_var.get()
|
||||
if db_pool is not None:
|
||||
db_pool.close()
|
||||
utils.logger.info("[close] close mediacrawler db connection")
|
||||
if config.SAVE_DATA_OPTION == "sqlite":
|
||||
# SQLite数据库连接会在AsyncSqliteDB对象销毁时自动关闭
|
||||
utils.logger.info("[close] sqlite db connection will be closed automatically")
|
||||
else:
|
||||
# MySQL连接池关闭
|
||||
db_pool: aiomysql.Pool = db_conn_pool_var.get()
|
||||
if db_pool is not None:
|
||||
db_pool.close()
|
||||
utils.logger.info("[close] mysql db pool closed")
|
||||
|
||||
|
||||
async def init_table_schema():
|
||||
async def init_table_schema(db_type: str = None):
|
||||
"""
|
||||
用来初始化数据库表结构,请在第一次需要创建表结构的时候使用,多次执行该函数会将已有的表以及数据全部删除
|
||||
Args:
|
||||
db_type: 数据库类型,可选值为 'sqlite' 或 'mysql',如果不指定则使用配置文件中的设置
|
||||
Returns:
|
||||
|
||||
"""
|
||||
utils.logger.info("[init_table_schema] begin init mysql table schema ...")
|
||||
await init_mediacrawler_db()
|
||||
async_db_obj: AsyncMysqlDB = media_crawler_db_var.get()
|
||||
async with aiofiles.open("schema/tables.sql", mode="r", encoding="utf-8") as f:
|
||||
schema_sql = await f.read()
|
||||
await async_db_obj.execute(schema_sql)
|
||||
utils.logger.info("[init_table_schema] mediacrawler table schema init successful")
|
||||
await close()
|
||||
# 如果没有指定数据库类型,则使用配置文件中的设置
|
||||
if db_type is None:
|
||||
db_type = config.SAVE_DATA_OPTION
|
||||
|
||||
if db_type == "sqlite":
|
||||
utils.logger.info("[init_table_schema] begin init sqlite table schema ...")
|
||||
await init_sqlite_db()
|
||||
async_db_obj: AsyncSqliteDB = media_crawler_db_var.get()
|
||||
async with aiofiles.open("schema/sqlite_tables.sql", mode="r", encoding="utf-8") as f:
|
||||
schema_sql = await f.read()
|
||||
await async_db_obj.executescript(schema_sql)
|
||||
utils.logger.info("[init_table_schema] sqlite table schema init successful")
|
||||
elif db_type == "mysql":
|
||||
utils.logger.info("[init_table_schema] begin init mysql table schema ...")
|
||||
await init_mediacrawler_db()
|
||||
async_db_obj: AsyncMysqlDB = media_crawler_db_var.get()
|
||||
async with aiofiles.open("schema/tables.sql", mode="r", encoding="utf-8") as f:
|
||||
schema_sql = await f.read()
|
||||
await async_db_obj.execute(schema_sql)
|
||||
utils.logger.info("[init_table_schema] mysql table schema init successful")
|
||||
await close()
|
||||
else:
|
||||
utils.logger.error(f"[init_table_schema] 不支持的数据库类型: {db_type}")
|
||||
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql")
|
||||
|
||||
|
||||
def show_database_options():
|
||||
"""
|
||||
显示支持的数据库选项
|
||||
"""
|
||||
print("\n=== MediaCrawler 数据库初始化工具 ===")
|
||||
print("支持的数据库类型:")
|
||||
print("1. sqlite - SQLite 数据库 (轻量级,无需额外配置)")
|
||||
print("2. mysql - MySQL 数据库 (需要配置数据库连接信息)")
|
||||
print("3. config - 使用配置文件中的设置")
|
||||
print("4. exit - 退出程序")
|
||||
print("="*50)
|
||||
|
||||
|
||||
def get_user_choice():
|
||||
"""
|
||||
获取用户选择的数据库类型
|
||||
Returns:
|
||||
str: 用户选择的数据库类型
|
||||
"""
|
||||
while True:
|
||||
choice = input("请输入数据库类型 (sqlite/mysql/config/exit): ").strip().lower()
|
||||
|
||||
if choice in ['sqlite', 'mysql', 'config', 'exit']:
|
||||
return choice
|
||||
else:
|
||||
print("❌ 无效的选择,请输入: sqlite, mysql, config 或 exit")
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
主函数,处理用户交互和数据库初始化
|
||||
"""
|
||||
try:
|
||||
show_database_options()
|
||||
|
||||
while True:
|
||||
choice = get_user_choice()
|
||||
|
||||
if choice == 'exit':
|
||||
print("👋 程序已退出")
|
||||
break
|
||||
elif choice == 'config':
|
||||
print(f"📋 使用配置文件中的设置: {config.SAVE_DATA_OPTION}")
|
||||
await init_table_schema()
|
||||
print("✅ 数据库表结构初始化完成!")
|
||||
break
|
||||
else:
|
||||
print(f"🚀 开始初始化 {choice.upper()} 数据库...")
|
||||
await init_table_schema(choice)
|
||||
print("✅ 数据库表结构初始化完成!")
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 用户中断操作")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 初始化失败: {str(e)}")
|
||||
utils.logger.error(f"[main] 数据库初始化失败: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.get_event_loop().run_until_complete(init_table_schema())
|
||||
asyncio.get_event_loop().run_until_complete(main())
|
||||
|
||||
4
main.py
4
main.py
@@ -50,13 +50,13 @@ async def main():
|
||||
await cmd_arg.parse_cmd()
|
||||
|
||||
# init db
|
||||
if config.SAVE_DATA_OPTION == "db":
|
||||
if config.SAVE_DATA_OPTION in ["db", "sqlite"]:
|
||||
await db.init_db()
|
||||
|
||||
crawler = CrawlerFactory.create_crawler(platform=config.PLATFORM)
|
||||
await crawler.start()
|
||||
|
||||
if config.SAVE_DATA_OPTION == "db":
|
||||
if config.SAVE_DATA_OPTION in ["db", "sqlite"]:
|
||||
await db.close()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user