import sqlite3 from 'sqlite3'; import path from 'path'; import fs from 'fs'; import { v4 as uuidv4 } from 'uuid'; const DEFAULT_DB_DIR = path.join(process.cwd(), 'data'); const DEFAULT_DB_PATH = path.join(DEFAULT_DB_DIR, 'survey.db'); const DB_PATH = process.env.DB_PATH || DEFAULT_DB_PATH; const DB_DIR = path.dirname(DB_PATH); // 确保数据目录存在 if (!fs.existsSync(DB_DIR)) { fs.mkdirSync(DB_DIR, { recursive: true }); } // 创建数据库连接 export const db = new sqlite3.Database(DB_PATH, (err) => { if (err) { console.error('数据库连接失败:', err); } else { console.log('数据库连接成功'); } }); // 启用外键约束 db.run('PRAGMA foreign_keys = ON'); const exec = (sql: string): Promise => { return new Promise((resolve, reject) => { db.exec(sql, (err) => { if (err) reject(err); else resolve(); }); }); }; const tableExists = async (tableName: string): Promise => { const row = await get( "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [tableName] ); return Boolean(row); }; const columnExists = async (tableName: string, columnName: string): Promise => { const columns = await query(`PRAGMA table_info(${tableName})`); return columns.some((col: any) => col.name === columnName); }; const ensureColumn = async (tableName: string, columnDefSql: string, columnName: string) => { if (!(await columnExists(tableName, columnName))) { await exec(`ALTER TABLE ${tableName} ADD COLUMN ${columnDefSql}`); } }; const ensureTable = async (createTableSql: string) => { await exec(createTableSql); }; const ensureIndex = async (createIndexSql: string) => { await exec(createIndexSql); }; const migrateDatabase = async () => { await ensureColumn('users', `password TEXT NOT NULL DEFAULT ''`, 'password'); await ensureColumn('questions', `category TEXT NOT NULL DEFAULT '通用'`, 'category'); await run(`UPDATE questions SET category = '通用' WHERE category IS NULL OR category = ''`); await ensureTable(` CREATE TABLE IF NOT EXISTS question_categories ( id TEXT PRIMARY KEY, name TEXT UNIQUE NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); `); await run( `INSERT OR IGNORE INTO question_categories (id, name) VALUES ('default', '通用')` ); await ensureTable(` CREATE TABLE IF NOT EXISTS exam_subjects ( id TEXT PRIMARY KEY, name TEXT UNIQUE NOT NULL, type_ratios TEXT NOT NULL, category_ratios TEXT NOT NULL, total_score INTEGER NOT NULL, duration_minutes INTEGER NOT NULL DEFAULT 60, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); `); await ensureTable(` CREATE TABLE IF NOT EXISTS exam_tasks ( id TEXT PRIMARY KEY, name TEXT NOT NULL, subject_id TEXT NOT NULL, start_at DATETIME NOT NULL, end_at DATETIME NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (subject_id) REFERENCES exam_subjects(id) ); `); await ensureTable(` CREATE TABLE IF NOT EXISTS exam_task_users ( id TEXT PRIMARY KEY, task_id TEXT NOT NULL, user_id TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, UNIQUE(task_id, user_id), FOREIGN KEY (task_id) REFERENCES exam_tasks(id) ON DELETE CASCADE, FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ); `); if (await tableExists('quiz_records')) { await ensureColumn('quiz_records', `subject_id TEXT`, 'subject_id'); await ensureColumn('quiz_records', `task_id TEXT`, 'task_id'); } await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_questions_category ON questions(category);`); await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_exam_tasks_subject_id ON exam_tasks(subject_id);`); await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_exam_task_users_task_id ON exam_task_users(task_id);`); await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_exam_task_users_user_id ON exam_task_users(user_id);`); await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_quiz_records_subject_id ON quiz_records(subject_id);`); await ensureIndex(`CREATE INDEX IF NOT EXISTS idx_quiz_records_task_id ON quiz_records(task_id);`); // 1. 创建用户组表 await ensureTable(` CREATE TABLE IF NOT EXISTS user_groups ( id TEXT PRIMARY KEY, name TEXT UNIQUE NOT NULL, description TEXT, is_system BOOLEAN DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); `); // 2. 创建用户-用户组关联表 await ensureTable(` CREATE TABLE IF NOT EXISTS user_group_members ( group_id TEXT NOT NULL, user_id TEXT NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (group_id, user_id), FOREIGN KEY (group_id) REFERENCES user_groups(id) ON DELETE CASCADE, FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE ); `); // 3. 为考试任务表添加选择配置字段 await ensureColumn('exam_tasks', 'selection_config TEXT', 'selection_config'); // 4. 初始化"全体用户"组 const allUsersGroup = await get(`SELECT id FROM user_groups WHERE is_system = 1`); let allUsersGroupId = allUsersGroup?.id; if (!allUsersGroupId) { allUsersGroupId = uuidv4(); await run( `INSERT INTO user_groups (id, name, description, is_system) VALUES (?, ?, ?, ?)`, [allUsersGroupId, '全体用户', '包含系统所有用户的默认组', 1] ); console.log('已创建"全体用户"系统组'); } // 5. 将现有用户添加到"全体用户"组 if (allUsersGroupId) { // 找出尚未在全体用户组中的用户 const usersNotInGroup = await query(` SELECT id FROM users WHERE id NOT IN ( SELECT user_id FROM user_group_members WHERE group_id = ? ) `, [allUsersGroupId]); if (usersNotInGroup.length > 0) { const stmt = db.prepare(`INSERT INTO user_group_members (group_id, user_id) VALUES (?, ?)`); usersNotInGroup.forEach(user => { stmt.run(allUsersGroupId, user.id); }); stmt.finalize(); console.log(`已将 ${usersNotInGroup.length} 名现有用户添加到"全体用户"组`); } } }; // 数据库初始化函数 export const initDatabase = async () => { const initSQL = fs.readFileSync(path.join(process.cwd(), 'api', 'database', 'init.sql'), 'utf8'); const hasUsersTable = await tableExists('users'); if (!hasUsersTable) { await exec(initSQL); console.log('数据库初始化成功'); } else { console.log('数据库已初始化,准备执行迁移检查'); } await migrateDatabase(); }; // 数据库查询工具函数 export const query = (sql: string, params: any[] = []): Promise => { return new Promise((resolve, reject) => { db.all(sql, params, (err, rows) => { if (err) { reject(err); } else { resolve(rows); } }); }); }; // all函数是query函数的别名,用于向后兼容 export const all = query; export const run = (sql: string, params: any[] = []): Promise<{ id: string }> => { return new Promise((resolve, reject) => { db.run(sql, params, function(err) { if (err) { reject(err); } else { resolve({ id: this.lastID.toString() }); } }); }); }; export const get = (sql: string, params: any[] = []): Promise => { return new Promise((resolve, reject) => { db.get(sql, params, (err, row) => { if (err) { reject(err); } else { resolve(row); } }); }); };