package storage import ( "fmt" "log" "time" "git.apinb.com/quant/collector/models" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // Storage 数据库存储器 type Storage struct { db *gorm.DB } // NewStorage 创建新的数据库连接 func NewStorage(connStr string) (*Storage, error) { // 配置GORM日志 newLogger := logger.New( log.New(log.Writer(), "\r\n", log.LstdFlags), logger.Config{ SlowThreshold: time.Second, LogLevel: logger.Warn, IgnoreRecordNotFoundError: true, Colorful: true, }, ) db, err := gorm.Open(postgres.Open(connStr), &gorm.Config{ Logger: newLogger, }) if err != nil { return nil, fmt.Errorf("打开数据库连接失败: %w", err) } // 获取底层的sql.DB以设置连接池 sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("获取数据库实例失败: %w", err) } // 设置连接池参数 sqlDB.SetMaxOpenConns(25) sqlDB.SetMaxIdleConns(5) sqlDB.SetConnMaxLifetime(5 * time.Minute) log.Println("数据库连接成功") return &Storage{db: db}, nil } // Close 关闭数据库连接 func (s *Storage) Close() error { if s.db != nil { sqlDB, err := s.db.DB() if err != nil { return err } return sqlDB.Close() } return nil } // AutoMigrate 自动迁移数据库表结构 func (s *Storage) AutoMigrate() error { log.Println("开始自动迁移数据库表结构...") err := s.db.AutoMigrate( &models.AssetSnapshot{}, &models.OrderRecord{}, &models.PositionRecord{}, &models.TickRecord{}, &models.CollectionLog{}, ) if err != nil { return fmt.Errorf("自动迁移失败: %w", err) } log.Println("数据库表结构迁移完成") return nil } // SaveStatus 保存完整状态数据(使用事务) func (s *Storage) SaveStatus(status *models.Status, dataHash string) error { return s.db.Transaction(func(tx *gorm.DB) error { // 保存资产快照 asset := models.AssetSnapshot{ AccountID: status.Data.Assets.AccountID, Cash: status.Data.Assets.Cash, FrozenCash: status.Data.Assets.FrozenCash, MarketValue: status.Data.Assets.MarketValue, Profit: status.Data.Assets.Profit, TotalAsset: status.Data.Assets.TotalAsset, DataHash: dataHash, CollectedAt: time.Now(), } if err := tx.Create(&asset).Error; err != nil { return fmt.Errorf("保存资产快照失败: %w", err) } // 批量保存订单 if len(status.Data.Orders) > 0 { orders := make([]models.OrderRecord, 0, len(status.Data.Orders)) for _, order := range status.Data.Orders { orders = append(orders, models.OrderRecord{ OrderID: order.OrderID, AccountID: status.Data.Assets.AccountID, StockCode: order.StockCode, Price: order.Price, Volume: order.Volume, TradedPrice: order.TradedPrice, TradedVolume: order.TradedVolume, OrderStatus: order.OrderStatus, OrderTime: order.OrderTime, OrderRemark: order.OrderRemark, DataHash: dataHash, CollectedAt: time.Now(), }) } if err := tx.CreateInBatches(orders, 100).Error; err != nil { return fmt.Errorf("保存订单失败: %w", err) } } // 批量保存持仓 if len(status.Data.Positions) > 0 { positions := make([]models.PositionRecord, 0, len(status.Data.Positions)) for _, pos := range status.Data.Positions { positions = append(positions, models.PositionRecord{ AccountID: status.Data.Assets.AccountID, Code: pos.Code, Volume: pos.Volume, CanUseVolume: pos.CanUseVolume, FrozenVolume: pos.FrozenVolume, AvgPrice: pos.AvgPrice, OpenPrice: pos.OpenPrice, CurrentPrice: pos.CurrentPrice, MarketValue: pos.MarketValue, Profit: pos.Profit, ProfitRate: pos.ProfitRate, MinProfitRate: pos.MinProfitRate, DataHash: dataHash, CollectedAt: time.Now(), }) } if err := tx.CreateInBatches(positions, 100).Error; err != nil { return fmt.Errorf("保存持仓失败: %w", err) } } // 批量保存行情数据 if len(status.Data.TickData) > 0 { ticks := make([]models.TickRecord, 0, len(status.Data.TickData)) for code, tick := range status.Data.TickData { ticks = append(ticks, models.TickRecord{ StockCode: code, LastPrice: tick.LastPrice, Open: tick.Open, High: tick.High, Low: tick.Low, LastClose: tick.LastClose, Volume: tick.Volume, Amount: tick.Amount, PVolume: tick.PVolume, BidPrices: tick.BidPrice, BidVolumes: tick.BidVol, AskPrices: tick.AskPrice, AskVolumes: tick.AskVol, Time: tick.Time, TimeTag: tick.TimeTag, StockStatus: tick.StockStatus, DataHash: dataHash, CollectedAt: time.Now(), }) } if err := tx.CreateInBatches(ticks, 100).Error; err != nil { return fmt.Errorf("保存行情数据失败: %w", err) } } return nil }) } // SaveCollectionLog 保存采集日志 func (s *Storage) SaveCollectionLog(dataHash string, hasChanged bool, statusMessage string) error { log := models.CollectionLog{ DataHash: dataHash, HasChanged: hasChanged, StatusMessage: statusMessage, CollectedAt: time.Now(), } if err := s.db.Create(&log).Error; err != nil { return fmt.Errorf("保存采集日志失败: %w", err) } return nil } // GetDB 获取GORM DB实例(用于高级查询) func (s *Storage) GetDB() *gorm.DB { return s.db }