336 lines
12 KiB
Python
336 lines
12 KiB
Python
import sqlite3
|
||
import xml.etree.ElementTree as ET
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Tuple, Optional, Dict, Any
|
||
import re
|
||
import chardet
|
||
import json
|
||
import os
|
||
|
||
|
||
@dataclass
|
||
class TableConfig:
|
||
db_name: str # 数据库文件名
|
||
enabled: bool # 是否启用该配置
|
||
table_name: str
|
||
data_source: str # 数据来源 uft30 uf20 dts
|
||
xml_file_path: str
|
||
xml_root: str
|
||
table_create_sql: str
|
||
table_insert_sql: str
|
||
xml_para_name: List[str] = field(default_factory=list)
|
||
para_name: List[str] = field(default_factory=list)
|
||
xml_childs: List[str] = field(default_factory=list)
|
||
xml_children: str = ""
|
||
xml_sub_children: str = ""
|
||
xml_get_type: str = '0' # 0:元素文本, 1:属性值, 2:子元素的属性
|
||
data: List[Tuple] = field(default_factory=list)
|
||
|
||
|
||
class XMLtoSQLiteImporter:
|
||
def __init__(self):
|
||
self.conn_cache = {} # 数据库连接缓存
|
||
|
||
def get_db_connection(self, db_name: str) -> sqlite3.Connection:
|
||
"""获取数据库连接(带缓存)"""
|
||
if db_name not in self.conn_cache:
|
||
# 确保数据库目录存在
|
||
db_dir = os.path.dirname(db_name)
|
||
if db_dir and not os.path.exists(db_dir):
|
||
os.makedirs(db_dir)
|
||
|
||
# 创建新的数据库连接
|
||
self.conn_cache[db_name] = sqlite3.connect(db_name)
|
||
self.conn_cache[db_name].execute("PRAGMA journal_mode = WAL;")
|
||
return self.conn_cache[db_name]
|
||
|
||
def close_db_connections(self):
|
||
"""关闭所有数据库连接"""
|
||
for conn in self.conn_cache.values():
|
||
conn.close()
|
||
self.conn_cache = {}
|
||
|
||
def create_table(self, conn: sqlite3.Connection, config: TableConfig):
|
||
"""创建数据库表"""
|
||
with conn:
|
||
conn.execute(config.table_create_sql)
|
||
print(f"已创建表: {config.table_name}")
|
||
|
||
def insert_data(self, conn: sqlite3.Connection, config: TableConfig):
|
||
"""插入数据到数据库"""
|
||
if not config.data:
|
||
print(f"警告: 表 {config.table_name} 无数据可导入")
|
||
return
|
||
|
||
with conn:
|
||
conn.executemany(config.table_insert_sql, config.data)
|
||
print(f"成功导入 {len(config.data)} 条记录到表 {config.table_name}")
|
||
|
||
def get_element_text(self, element, tag: str) -> str:
|
||
"""获取子元素文本内容"""
|
||
child = element.find(tag)
|
||
return child.text.strip() if child is not None and child.text else ""
|
||
|
||
def get_child_node(self, root, path: List[str]) -> Optional[ET.Element]:
|
||
"""递归获取嵌套XML节点"""
|
||
node = root
|
||
for tag in path:
|
||
node = node.find(tag)
|
||
if node is None:
|
||
print(f"警告: 找不到XML节点: {tag}")
|
||
return None
|
||
return node
|
||
|
||
def convert_xml_encoding(self, file_path: str) -> str:
|
||
"""检测XML文件编码并转换为UTF-8字符串"""
|
||
# 检查文件是否存在
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"XML文件未找到: {file_path}")
|
||
|
||
# 以二进制模式读取文件
|
||
with open(file_path, 'rb') as f:
|
||
raw_data = f.read()
|
||
|
||
# 检测文件编码
|
||
result = chardet.detect(raw_data)
|
||
encoding = result['encoding'] if result['encoding'] else 'utf-8'
|
||
|
||
# 尝试解码
|
||
try:
|
||
content = raw_data.decode(encoding)
|
||
except UnicodeDecodeError:
|
||
# 如果检测的编码无效,尝试常见的中文编码
|
||
try:
|
||
content = raw_data.decode('gbk')
|
||
except:
|
||
# 最终兜底方案
|
||
content = raw_data.decode('utf-8', errors='ignore')
|
||
|
||
# 替换XML声明中的编码为UTF-8
|
||
xml_declaration = re.search(r'<\?xml.*?\?>', content, re.DOTALL)
|
||
|
||
if xml_declaration:
|
||
# 替换编码属性为UTF-8
|
||
decl = xml_declaration.group(0)
|
||
decl = re.sub(r'encoding\s*=\s*["\'][^"\']*["\']', 'encoding="UTF-8"', decl)
|
||
content = decl + content[xml_declaration.end():]
|
||
else:
|
||
# 如果没有声明,添加UTF-8声明
|
||
content = '<?xml version="1.0" encoding="UTF-8"?>' + content
|
||
|
||
return content
|
||
|
||
def parse_xml(self, config: TableConfig, uf20_path='D:\\Sources\\经纪业务运营平台V21\\', uft30_path='F:\\sesCode\\') -> bool:
|
||
"""解析XML文件并提取数据"""
|
||
try:
|
||
# 使用编码转换函数处理XML文件
|
||
xml_file_path = config.xml_file_path
|
||
print('xmlfilepath11', xml_file_path)
|
||
print('data_source:', config.data_source)
|
||
if config.data_source == 'uf20':
|
||
xml_file_path = uf20_path + xml_file_path
|
||
elif config.data_source == 'uft30':
|
||
xml_file_path = uft30_path + xml_file_path
|
||
print('xmlfilepath', xml_file_path)
|
||
xml_content = self.convert_xml_encoding(xml_file_path)
|
||
|
||
# 解析处理后的XML内容
|
||
root = ET.fromstring(xml_content)
|
||
|
||
# 检查根节点是否正确
|
||
if root.tag != config.xml_root:
|
||
print(f"错误: XML根节点不匹配! 期望: {config.xml_root}, 实际: {root.tag}")
|
||
return False
|
||
|
||
# 获取起始节点
|
||
start_node = root
|
||
if config.xml_childs:
|
||
start_node = self.get_child_node(root, config.xml_childs)
|
||
if start_node is None:
|
||
print(f"错误: 找不到XML路径: {config.xml_childs}")
|
||
return False
|
||
|
||
# 模式0: 读取子元素的文本内容
|
||
if config.xml_get_type == '0':
|
||
for node in start_node.findall(config.xml_children):
|
||
values = []
|
||
for param in config.xml_para_name:
|
||
values.append(self.get_element_text(node, param))
|
||
config.data.append(tuple(values))
|
||
|
||
# 模式1: 读取元素的属性值
|
||
elif config.xml_get_type == '1':
|
||
for node in start_node.findall(config.xml_children):
|
||
values = []
|
||
for param in config.xml_para_name:
|
||
# 使用get方法获取属性值,没有则返回空字符串
|
||
values.append(node.get(param, "").strip())
|
||
config.data.append(tuple(values))
|
||
|
||
# 模式2: 多层嵌套结构 (parent -> children -> sub_children)
|
||
elif config.xml_get_type == '2':
|
||
if not config.xml_sub_children:
|
||
print("错误: 模式2需要设置xml_sub_children参数")
|
||
return False
|
||
|
||
for parent_node in start_node.findall(config.xml_children):
|
||
for node in parent_node.findall(config.xml_sub_children):
|
||
values = []
|
||
for param in config.xml_para_name:
|
||
# 使用get方法获取属性值
|
||
values.append(node.get(param, "").strip())
|
||
config.data.append(tuple(values))
|
||
|
||
print(f"从XML解析出 {len(config.data)} 条记录")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"处理XML时出错({config.xml_file_path}): {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def import_config(self, config: TableConfig, uf20_path='D:\\Sources\\经纪业务运营平台V21\\', uft30_path='F:\\sesCode\\') -> bool:
|
||
"""执行单个配置的导入流程"""
|
||
if not config.enabled:
|
||
print(f"跳过: {config.table_name} (配置已禁用)")
|
||
return False
|
||
|
||
try:
|
||
# 获取数据库连接
|
||
conn = self.get_db_connection(config.db_name)
|
||
|
||
# 创建表
|
||
self.create_table(conn, config)
|
||
|
||
# 解析XML
|
||
if not self.parse_xml(config, uf20_path, uft30_path):
|
||
return False
|
||
|
||
# 插入数据
|
||
self.insert_data(conn, config)
|
||
return True
|
||
|
||
except sqlite3.Error as e:
|
||
print(f"数据库错误({config.table_name}): {str(e)}")
|
||
return False
|
||
|
||
|
||
def load_configs(config_file: str) -> List[TableConfig]:
|
||
"""从JSON文件加载配置"""
|
||
if not os.path.exists(config_file):
|
||
raise FileNotFoundError(f"配置文件未找到: {config_file}")
|
||
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
configs = []
|
||
for item in config_data:
|
||
# 设置默认值
|
||
if 'db_name' not in item:
|
||
item['db_name'] = "default.db"
|
||
if 'enabled' not in item:
|
||
item['enabled'] = True
|
||
|
||
# 创建配置对象
|
||
config = TableConfig(
|
||
db_name=item['db_name'],
|
||
enabled=item['enabled'],
|
||
table_name=item['table_name'],
|
||
xml_file_path=item['xml_file_path'],
|
||
xml_root=item['xml_root'],
|
||
table_create_sql=item['table_create_sql'],
|
||
table_insert_sql=item['table_insert_sql'],
|
||
xml_para_name=item.get('xml_para_name', []),
|
||
para_name=item.get('para_name', []),
|
||
xml_childs=item.get('xml_childs', []),
|
||
xml_children=item.get('xml_children', ''),
|
||
xml_sub_children=item.get('xml_sub_children', ''),
|
||
xml_get_type=item.get('xml_get_type', '0'),
|
||
data_source=item.get('data_source', '')
|
||
)
|
||
|
||
# 确保列名和参数名数量一致
|
||
if config.xml_para_name and config.para_name:
|
||
if len(config.xml_para_name) != len(config.para_name):
|
||
print(f"警告: 表 {config.table_name} 的 xml_para_name 和 para_name 长度不一致")
|
||
|
||
configs.append(config)
|
||
|
||
print(f"已加载 {len(configs)} 个配置")
|
||
return configs
|
||
|
||
import json
|
||
|
||
def update_table_enabled(json_file_path, table_name, enable_value):
|
||
"""
|
||
更新 JSON 配置文件中指定 table_name 的 enabled 状态
|
||
|
||
:param json_file_path: JSON 文件路径
|
||
:param table_name: 要修改的表名
|
||
:param enable_value: 要设置的 enabled 值 (True/False)
|
||
:return: True 表示成功找到并修改;False 表示未找到目标表
|
||
"""
|
||
try:
|
||
# 1. 读取 JSON 文件
|
||
with open(json_file_path, 'r', encoding='utf-8') as f:
|
||
config_data = json.load(f)
|
||
|
||
# 2. 查找目标配置项并更新 enabled
|
||
found = False
|
||
for item in config_data:
|
||
if item.get("table_name") == table_name:
|
||
item["enabled"] = enable_value
|
||
found = True
|
||
break
|
||
|
||
if not found:
|
||
print(f"未找到 table_name 为 '{table_name}' 的配置项")
|
||
return False
|
||
|
||
# 3. 写回文件
|
||
with open(json_file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(config_data, f, ensure_ascii=False, indent=4)
|
||
|
||
print(f"成功将 {table_name} 的 enabled 设置为 {enable_value}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"操作失败:{e}")
|
||
return False
|
||
|
||
def import_xml_from_config(uf20_path='D:\\Sources\\经纪业务运营平台V21\\', uft30_path='F:\\sesCode\\'):
|
||
# 配置文件路径
|
||
CONFIG_FILE = "import_config.json"
|
||
|
||
# 创建导入器
|
||
importer = XMLtoSQLiteImporter()
|
||
|
||
try:
|
||
# 加载配置
|
||
configs = load_configs(CONFIG_FILE)
|
||
|
||
# 遍历所有配置并导入
|
||
for config in configs:
|
||
if config.enabled == False:
|
||
continue
|
||
print(f"\n== 开始导入表: {config.table_name} (数据库: {config.db_name}) ==")
|
||
success = importer.import_config(config, uf20_path, uft30_path)
|
||
status = "成功" if success else "失败"
|
||
print(f"== 表 {config.table_name} 导入{status} ==")
|
||
|
||
# 重置数据准备下一次导入
|
||
config.data = []
|
||
|
||
except Exception as e:
|
||
print(f"导入过程中发生错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
# 关闭所有数据库连接
|
||
importer.close_db_connections()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import_xml_from_config()
|