ai gen
This commit is contained in:
207
storage/storage.go
Normal file
207
storage/storage.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.apinb.com/quant/collector/models"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Storage 数据库存储器
|
||||
type Storage struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewStorage 创建新的数据库连接
|
||||
func NewStorage(connStr string) (*Storage, error) {
|
||||
// 配置GORM日志
|
||||
newLogger := logger.New(
|
||||
log.New(log.Writer(), "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second,
|
||||
LogLevel: logger.Warn,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: true,
|
||||
},
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(connStr), &gorm.Config{
|
||||
Logger: newLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库连接失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层的sql.DB以设置连接池
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置连接池参数
|
||||
sqlDB.SetMaxOpenConns(25)
|
||||
sqlDB.SetMaxIdleConns(5)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
log.Println("数据库连接成功")
|
||||
return &Storage{db: db}, nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Storage) Close() error {
|
||||
if s.db != nil {
|
||||
sqlDB, err := s.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AutoMigrate 自动迁移数据库表结构
|
||||
func (s *Storage) AutoMigrate() error {
|
||||
log.Println("开始自动迁移数据库表结构...")
|
||||
|
||||
err := s.db.AutoMigrate(
|
||||
&models.AssetSnapshot{},
|
||||
&models.OrderRecord{},
|
||||
&models.PositionRecord{},
|
||||
&models.TickRecord{},
|
||||
&models.CollectionLog{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("自动迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("数据库表结构迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveStatus 保存完整状态数据(使用事务)
|
||||
func (s *Storage) SaveStatus(status *models.Status, dataHash string) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 保存资产快照
|
||||
asset := models.AssetSnapshot{
|
||||
AccountID: status.Data.Assets.AccountID,
|
||||
Cash: status.Data.Assets.Cash,
|
||||
FrozenCash: status.Data.Assets.FrozenCash,
|
||||
MarketValue: status.Data.Assets.MarketValue,
|
||||
Profit: status.Data.Assets.Profit,
|
||||
TotalAsset: status.Data.Assets.TotalAsset,
|
||||
DataHash: dataHash,
|
||||
CollectedAt: time.Now(),
|
||||
}
|
||||
if err := tx.Create(&asset).Error; err != nil {
|
||||
return fmt.Errorf("保存资产快照失败: %w", err)
|
||||
}
|
||||
|
||||
// 批量保存订单
|
||||
if len(status.Data.Orders) > 0 {
|
||||
orders := make([]models.OrderRecord, 0, len(status.Data.Orders))
|
||||
for _, order := range status.Data.Orders {
|
||||
orders = append(orders, models.OrderRecord{
|
||||
OrderID: order.OrderID,
|
||||
AccountID: status.Data.Assets.AccountID,
|
||||
StockCode: order.StockCode,
|
||||
Price: order.Price,
|
||||
Volume: order.Volume,
|
||||
TradedPrice: order.TradedPrice,
|
||||
TradedVolume: order.TradedVolume,
|
||||
OrderStatus: order.OrderStatus,
|
||||
OrderTime: order.OrderTime,
|
||||
OrderRemark: order.OrderRemark,
|
||||
DataHash: dataHash,
|
||||
CollectedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
if err := tx.CreateInBatches(orders, 100).Error; err != nil {
|
||||
return fmt.Errorf("保存订单失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存持仓
|
||||
if len(status.Data.Positions) > 0 {
|
||||
positions := make([]models.PositionRecord, 0, len(status.Data.Positions))
|
||||
for _, pos := range status.Data.Positions {
|
||||
positions = append(positions, models.PositionRecord{
|
||||
AccountID: status.Data.Assets.AccountID,
|
||||
Code: pos.Code,
|
||||
Volume: pos.Volume,
|
||||
CanUseVolume: pos.CanUseVolume,
|
||||
FrozenVolume: pos.FrozenVolume,
|
||||
AvgPrice: pos.AvgPrice,
|
||||
OpenPrice: pos.OpenPrice,
|
||||
CurrentPrice: pos.CurrentPrice,
|
||||
MarketValue: pos.MarketValue,
|
||||
Profit: pos.Profit,
|
||||
ProfitRate: pos.ProfitRate,
|
||||
MinProfitRate: pos.MinProfitRate,
|
||||
DataHash: dataHash,
|
||||
CollectedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
if err := tx.CreateInBatches(positions, 100).Error; err != nil {
|
||||
return fmt.Errorf("保存持仓失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存行情数据
|
||||
if len(status.Data.TickData) > 0 {
|
||||
ticks := make([]models.TickRecord, 0, len(status.Data.TickData))
|
||||
for code, tick := range status.Data.TickData {
|
||||
ticks = append(ticks, models.TickRecord{
|
||||
StockCode: code,
|
||||
LastPrice: tick.LastPrice,
|
||||
Open: tick.Open,
|
||||
High: tick.High,
|
||||
Low: tick.Low,
|
||||
LastClose: tick.LastClose,
|
||||
Volume: tick.Volume,
|
||||
Amount: tick.Amount,
|
||||
PVolume: tick.PVolume,
|
||||
BidPrices: tick.BidPrice,
|
||||
BidVolumes: tick.BidVol,
|
||||
AskPrices: tick.AskPrice,
|
||||
AskVolumes: tick.AskVol,
|
||||
Time: tick.Time,
|
||||
TimeTag: tick.TimeTag,
|
||||
StockStatus: tick.StockStatus,
|
||||
DataHash: dataHash,
|
||||
CollectedAt: time.Now(),
|
||||
})
|
||||
}
|
||||
if err := tx.CreateInBatches(ticks, 100).Error; err != nil {
|
||||
return fmt.Errorf("保存行情数据失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SaveCollectionLog 保存采集日志
|
||||
func (s *Storage) SaveCollectionLog(dataHash string, hasChanged bool, statusMessage string) error {
|
||||
log := models.CollectionLog{
|
||||
DataHash: dataHash,
|
||||
HasChanged: hasChanged,
|
||||
StatusMessage: statusMessage,
|
||||
CollectedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&log).Error; err != nil {
|
||||
return fmt.Errorf("保存采集日志失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取GORM DB实例(用于高级查询)
|
||||
func (s *Storage) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
Reference in New Issue
Block a user