diff --git a/app/DataBase/__init__.py b/app/DataBase/__init__.py index 5c44b10..9a42506 100644 --- a/app/DataBase/__init__.py +++ b/app/DataBase/__init__.py @@ -13,6 +13,7 @@ from .media_msg import MediaMsg from .misc import Misc from .msg import Msg from .msg import MsgType +from .db_pool import db_pool, close_db_pool misc_db = Misc() msg_db = Msg() @@ -22,14 +23,18 @@ media_msg_db = MediaMsg() def close_db(): + """关闭所有数据库连接""" misc_db.close() msg_db.close() micro_msg_db.close() hard_link_db.close() media_msg_db.close() + # 关闭数据库连接池 + close_db_pool() def init_db(): + """初始化所有数据库连接""" misc_db.init_database() msg_db.init_database() micro_msg_db.init_database() @@ -37,4 +42,4 @@ def init_db(): media_msg_db.init_database() -__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db"] +__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db", "db_pool"] diff --git a/app/DataBase/db_pool.py b/app/DataBase/db_pool.py new file mode 100644 index 0000000..faf81ac --- /dev/null +++ b/app/DataBase/db_pool.py @@ -0,0 +1,256 @@ +import os +import sqlite3 +import threading +import queue +import time +from typing import Dict, Optional, List, Tuple + +class DatabaseConnectionPool: + """ + SQLite数据库连接池,用于管理多个数据库连接,减少连接创建和销毁的开销 + """ + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + with cls._lock: + if cls._instance is None: + cls._instance = super(DatabaseConnectionPool, cls).__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, max_connections=5, timeout=5): + # 保证只初始化一次 + if self._initialized: + return + + self._initialized = True + self.max_connections = max_connections + self.timeout = timeout + self.pools: Dict[str, queue.Queue] = {} + self.in_use: Dict[str, Dict[sqlite3.Connection, threading.Thread]] = {} + self.connection_locks: Dict[str, threading.Lock] = {} + + def _create_connection(self, db_path: str) -> sqlite3.Connection: + """创建一个新的数据库连接""" + if not os.path.exists(db_path): + raise FileNotFoundError(f"数据库文件不存在: {db_path}") + + conn = sqlite3.connect(db_path, check_same_thread=False) + # 开启外键约束 + conn.execute('PRAGMA foreign_keys = ON') + # 启用写入确认,提高安全性 + conn.execute('PRAGMA synchronous = NORMAL') + # 提高写入性能 + conn.execute('PRAGMA journal_mode = WAL') + # 设置页缓存 + conn.execute('PRAGMA cache_size = 10000') + return conn + + def _get_pool(self, db_path: str) -> queue.Queue: + """获取或创建指定数据库的连接池""" + if db_path not in self.pools: + with self._lock: + if db_path not in self.pools: + self.pools[db_path] = queue.Queue(maxsize=self.max_connections) + self.in_use[db_path] = {} + self.connection_locks[db_path] = threading.Lock() + + # 预创建连接 + for _ in range(min(2, self.max_connections)): + try: + conn = self._create_connection(db_path) + self.pools[db_path].put(conn) + except Exception as e: + print(f"预创建连接失败: {e}") + + return self.pools[db_path] + + def get_connection(self, db_path: str) -> sqlite3.Connection: + """ + 从连接池获取一个数据库连接 + + Args: + db_path: 数据库文件路径 + + Returns: + sqlite3.Connection: 数据库连接对象 + + Raises: + TimeoutError: 超时未获取到连接 + """ + pool = self._get_pool(db_path) + + # 尝试从池中获取连接 + try: + conn = pool.get(block=True, timeout=self.timeout) + except queue.Empty: + # 如果池已满但仍在使用的连接数小于最大连接数,则创建新连接 + with self.connection_locks[db_path]: + if len(self.in_use[db_path]) < self.max_connections: + conn = self._create_connection(db_path) + else: + raise TimeoutError(f"无法获取数据库连接,连接池已满: {db_path}") + + # 记录连接使用情况 + with self.connection_locks[db_path]: + self.in_use[db_path][conn] = threading.current_thread() + + return conn + + def release_connection(self, db_path: str, conn: sqlite3.Connection): + """ + 释放数据库连接回连接池 + + Args: + db_path: 数据库文件路径 + conn: 要释放的连接 + """ + if db_path not in self.pools: + conn.close() + return + + with self.connection_locks[db_path]: + if conn in self.in_use[db_path]: + del self.in_use[db_path][conn] + try: + # 将连接放回池中 + self.pools[db_path].put(conn, block=False) + except queue.Full: + # 如果池已满,关闭多余的连接 + conn.close() + + def execute_batch(self, db_path: str, sql: str, params_list: List[tuple], commit=True) -> List[Optional[Tuple]]: + """ + 执行批量SQL操作,适用于多次执行相同SQL语句的情况 + + Args: + db_path: 数据库文件路径 + sql: SQL语句 + params_list: 参数列表,每个元素是一个参数元组 + commit: 是否自动提交事务 + + Returns: + list: 执行结果列表 + """ + conn = None + results = [] + + try: + conn = self.get_connection(db_path) + cursor = conn.cursor() + + # 启动事务 + if commit: + conn.execute("BEGIN TRANSACTION") + + # 批量执行 + for params in params_list: + cursor.execute(sql, params) + if cursor.description: # 如果有返回数据 + results.append(cursor.fetchall()) + else: + results.append(None) + + # 提交事务 + if commit: + conn.commit() + + return results + except Exception as e: + if conn and commit: + conn.rollback() + raise e + finally: + if conn: + self.release_connection(db_path, conn) + + def execute_query(self, db_path: str, sql: str, params=None) -> List[Tuple]: + """ + 执行查询SQL语句 + + Args: + db_path: 数据库文件路径 + sql: SQL查询语句 + params: 查询参数 + + Returns: + list: 查询结果列表 + """ + conn = None + try: + conn = self.get_connection(db_path) + cursor = conn.cursor() + + if params: + cursor.execute(sql, params) + else: + cursor.execute(sql) + + return cursor.fetchall() + finally: + if conn: + self.release_connection(db_path, conn) + + def execute_update(self, db_path: str, sql: str, params=None) -> int: + """ + 执行更新SQL语句 + + Args: + db_path: 数据库文件路径 + sql: SQL更新语句 + params: 更新参数 + + Returns: + int: 受影响的行数 + """ + conn = None + try: + conn = self.get_connection(db_path) + cursor = conn.cursor() + + if params: + cursor.execute(sql, params) + else: + cursor.execute(sql) + + conn.commit() + return cursor.rowcount + except Exception as e: + if conn: + conn.rollback() + raise e + finally: + if conn: + self.release_connection(db_path, conn) + + def close_all(self): + """关闭所有连接池中的连接""" + with self._lock: + for db_path, pool in self.pools.items(): + # 关闭所有未使用的连接 + while not pool.empty(): + try: + conn = pool.get(block=False) + conn.close() + except queue.Empty: + break + + # 关闭所有正在使用的连接 + with self.connection_locks[db_path]: + for conn in list(self.in_use[db_path].keys()): + try: + conn.close() + except: + pass + self.in_use[db_path].clear() + + # 清空池 + self.pools.clear() + +# 全局连接池实例 +db_pool = DatabaseConnectionPool() + +def close_db_pool(): + """关闭数据库连接池中的所有连接""" + db_pool.close_all() \ No newline at end of file diff --git a/app/DataBase/msg.py b/app/DataBase/msg.py index b56e671..22d60bb 100644 --- a/app/DataBase/msg.py +++ b/app/DataBase/msg.py @@ -5,8 +5,9 @@ import threading import traceback from collections import defaultdict from datetime import datetime, date -from typing import Tuple +from typing import Tuple, List, Optional, Dict, Any +from app.DataBase.db_pool import db_pool from app.log import logger from app.util.compress_content import parser_reply from app.util.protocbuf.msg_pb2 import MessageBytesExtra @@ -140,8 +141,6 @@ class MsgType: class Msg: def __init__(self): - self.DB = None - self.cursor = None self.open_flag = False self.init_database() @@ -151,9 +150,6 @@ class Msg: if path: db_path = path if os.path.exists(db_path): - self.DB = sqlite3.connect(db_path, check_same_thread=False) - # '''创建游标''' - self.cursor = self.DB.cursor() self.open_flag = True if lock.locked(): lock.release() @@ -200,48 +196,137 @@ class Msg: a[10]: BytesExtra, a[11]: CompressContent, a[12]: DisplayContent, - a[13]: 联系人的类(如果是群聊就有,不是的话没有这个字段) """ if not self.open_flag: - return None - if time_range: - start_time, end_time = convert_to_timestamp(time_range) - sql = f''' - select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,CompressContent,DisplayContent - from MSG - where StrTalker=? - {'AND CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''} - order by CreateTime + return [] + + begin_time, end_time = convert_to_timestamp(time_range) + + sql = ''' + SELECT + localId, + TalkerId, + Type, + SubType, + IsSender, + CreateTime, + Status, + StrContent, + strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')), + MsgSvrID, + BytesExtra, + CompressContent, + DisplayContent + FROM MSG + WHERE StrTalker = ? + AND (? = 0 OR CreateTime >= ?) + AND (? = 0 OR CreateTime <= ?) + ORDER BY CreateTime DESC ''' + + params = (username_, begin_time, begin_time, end_time, end_time) + try: - lock.acquire(True) - self.cursor.execute(sql, [username_]) - result = self.cursor.fetchall() - finally: - lock.release() - return parser_chatroom_message(result) if username_.__contains__('@chatroom') else result - # result.sort(key=lambda x: x[5]) - # return self.add_sender(result) + results = db_pool.execute_query(db_path, sql, params) + + # 处理群聊信息 + if results and username_.startswith('chatroom'): + results = parser_chatroom_message(results) + + return results + except Exception as e: + logger.error(f"获取聊天记录失败: {e}\n{traceback.format_exc()}") + return [] + + def batch_insert_messages(self, messages_data: List[Dict[str, Any]]) -> bool: + """ + 批量插入消息数据 + + Args: + messages_data: 消息数据列表,每个字典包含一条消息的所有字段 + + Returns: + bool: 是否成功插入 + """ + if not self.open_flag or not messages_data: + return False + + # 构建插入SQL + sql = ''' + INSERT INTO MSG ( + MsgId, TalkerId, Type, SubType, IsSender, CreateTime, + Status, StrContent, StrTalker, MsgSvrID, BytesExtra, + CompressContent, DisplayContent + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''' + + # 准备参数列表 + params_list = [ + ( + msg.get('MsgId', ''), + msg.get('TalkerId', ''), + msg.get('Type', 0), + msg.get('SubType', 0), + msg.get('IsSender', 0), + msg.get('CreateTime', int(time.time())), + msg.get('Status', 0), + msg.get('StrContent', ''), + msg.get('StrTalker', ''), + msg.get('MsgSvrID', ''), + msg.get('BytesExtra', None), + msg.get('CompressContent', None), + msg.get('DisplayContent', None) + ) + for msg in messages_data + ] + + try: + db_pool.execute_batch(db_path, sql, params_list) + return True + except Exception as e: + logger.error(f"批量插入消息失败: {e}\n{traceback.format_exc()}") + return False def get_messages_all(self, time_range=None): - if time_range: - start_time, end_time = convert_to_timestamp(time_range) - sql = f''' - select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,StrTalker,Reserved1,CompressContent - from MSG - {'WHERE CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''} - order by CreateTime - ''' + """ + 获取所有聊天记录 + @param time_range: + @return: + """ if not self.open_flag: - return None + return [] + + begin_time, end_time = convert_to_timestamp(time_range) + + sql = ''' + SELECT + localId, + TalkerId, + Type, + SubType, + IsSender, + CreateTime, + Status, + StrContent, + strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')), + MsgSvrID, + BytesExtra, + CompressContent, + DisplayContent, + StrTalker + FROM MSG + WHERE (? = 0 OR CreateTime >= ?) + AND (? = 0 OR CreateTime <= ?) + ORDER BY CreateTime DESC + ''' + + params = (begin_time, begin_time, end_time, end_time) + try: - lock.acquire(True) - self.cursor.execute(sql) - result = self.cursor.fetchall() - finally: - lock.release() - result.sort(key=lambda x: x[5]) - return result + return db_pool.execute_query(db_path, sql, params) + except Exception as e: + logger.error(f"获取所有聊天记录失败: {e}\n{traceback.format_exc()}") + return [] def get_messages_group_by_day( self, @@ -865,13 +950,8 @@ class Msg: return sum_type_1 + sum_type_49 def close(self): - if self.open_flag: - try: - lock.acquire(True) - self.open_flag = False - self.DB.close() - finally: - lock.release() + """关闭数据库连接,不再需要显式关闭,由连接池管理""" + self.open_flag = False def __del__(self): self.close()