diff --git a/formula/abs.go b/formula/abs.go index 5df50e1cd1ae2a8316cba1e660dbd45c5a682cc8..081ca7dee10aea042c5a88105d54ad228b93e9c8 100644 --- a/formula/abs.go +++ b/formula/abs.go @@ -5,6 +5,7 @@ import ( "gitee.com/quant1x/pandas/stat" ) +// ABS 计算S的绝对值 func ABS(S pandas.Series) pandas.Series { s := S.DTypes() d := stat.Abs(s) diff --git a/generic_diff_test.go b/generic_diff_test.go index ee2955027324c234fcee8c042cd2db7d4d7f3b8a..5c4827bdc044e6e4d8ef16677800e49a4c9e2c73 100644 --- a/generic_diff_test.go +++ b/generic_diff_test.go @@ -11,7 +11,7 @@ func TestNDFrame_Diff(t *testing.T) { df := NewDataFrame(s1) fmt.Println(df) fmt.Println("------------------------------------------------------------") - N := 2 + N := 1 fmt.Println("固定的参数, N =", N) r1 := df.Col("x").Diff(N).Values() fmt.Println("序列化结果:", r1) diff --git a/generic_rolling.go b/generic_rolling.go index caf86d38c7c263c53821ca9862b53375ebcc851d..3442ba32cd55be303773abe66f5f0a5950c2a468 100644 --- a/generic_rolling.go +++ b/generic_rolling.go @@ -42,7 +42,7 @@ func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { window := int(N) start := i + 1 - window end := i + 1 - blocks = append(blocks, r.series.Subset(start, end, true)) + blocks = append(blocks, r.series.Subset(start, end, false)) } return diff --git a/go.mod b/go.mod index 16ed4adc754edc538fccfc64bfed3dc5fe2f6a24..fe8a64f810f55e277cbfeafd7ffc4a850a22dfcb 100644 --- a/go.mod +++ b/go.mod @@ -6,36 +6,31 @@ require ( gitee.com/quant1x/gotdx v1.1.2 github.com/chewxy/math32 v1.10.1 github.com/huandu/go-clone v1.4.1 - github.com/mattn/go-runewidth v0.0.14 - github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db github.com/mymmsc/gox v1.3.1 github.com/olekukonko/tablewriter v0.0.5 - github.com/schollz/progressbar/v3 v3.13.0 - github.com/stretchr/testify v1.8.1 + github.com/qianlnk/pgbar v0.0.0-20210208085217-8c19b9f2477e github.com/tealeg/xlsx v1.0.5 github.com/tealeg/xlsx/v3 v3.2.4 github.com/viterin/partial v1.0.0 github.com/viterin/vek v0.4.0 golang.org/x/exp v0.0.0-20220907003533-145caa8ea1d0 - golang.org/x/term v0.5.0 gonum.org/v1/gonum v0.12.0 google.golang.org/protobuf v1.28.1 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/frankban/quicktest v1.11.2 // indirect github.com/google/btree v1.0.0 // indirect github.com/google/go-cmp v0.5.8 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/kr/text v0.1.0 // indirect + github.com/mattn/go-runewidth v0.0.14 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/qianlnk/to v0.0.0-20191230085244-91e712717368 // indirect github.com/rivo/uniseg v0.4.3 // indirect github.com/rogpeppe/fastuuid v1.2.0 // indirect github.com/shabbyrobe/xmlwriter v0.0.0-20200208144257-9fca06d00ffa // indirect golang.org/x/sys v0.5.0 // indirect golang.org/x/text v0.5.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index dba14f1d9580d1407e72000a0ca4051e222afaf7..c111366e9dc9ada99126f8bca6d00d30d94e98e7 100644 --- a/go.sum +++ b/go.sum @@ -18,19 +18,15 @@ github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3 github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= github.com/huandu/go-clone v1.4.1 h1:QQYjiLadyxOvdwgZoH8f1xGkvvf4+Cm8be7fo9W2QQA= github.com/huandu/go-clone v1.4.1/go.mod h1:ReGivhG6op3GYr+UY3lS6mxjKp7MIGTknuU5TbTVaXE= -github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= -github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mymmsc/gox v1.3.1 h1:CM6bGBuf5+UK/af06Dv8U8becBlh6jyZ0RP2kEsYT84= @@ -44,24 +40,20 @@ github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug= github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qianlnk/pgbar v0.0.0-20210208085217-8c19b9f2477e h1:d2mxZa66Z5wS4zhRt8KiMmzIPLzZsMnYHR+uWXGz6m0= +github.com/qianlnk/pgbar v0.0.0-20210208085217-8c19b9f2477e/go.mod h1:4YWkn3EVkh8c1BDlVmw+Zh2QLhs+MbAg4xy4RqcKMsA= +github.com/qianlnk/to v0.0.0-20191230085244-91e712717368 h1:YWi/c6UOBSwKWfFwYd3B6kZW2GGjCFPKJxUE37IPyFQ= +github.com/qianlnk/to v0.0.0-20191230085244-91e712717368/go.mod h1:HYAQIJIdgW9cGr75BDsucQMgKREt00mECJHOskH5n5k= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.3 h1:utMvzDsuh3suAEnhH0RdHmoPbU648o6CvXxTx4SBMOw= github.com/rivo/uniseg v0.4.3/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/schollz/progressbar/v3 v3.13.0 h1:9TeeWRcjW2qd05I8Kf9knPkW4vLM/hYoa6z9ABvxje8= -github.com/schollz/progressbar/v3 v3.13.0/go.mod h1:ZBYnSuLAX2LU8P8UiKN/KgF2DY58AJC8yfVYLPC8Ly4= github.com/shabbyrobe/xmlwriter v0.0.0-20200208144257-9fca06d00ffa h1:2cO3RojjYl3hVTbEvJVqrMaFmORhL6O06qdW42toftk= github.com/shabbyrobe/xmlwriter v0.0.0-20200208144257-9fca06d00ffa/go.mod h1:Yjr3bdWaVWyME1kha7X0jsz3k2DgXNa1Pj3XGyUAbx8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tealeg/xlsx v1.0.5 h1:+f8oFmvY8Gw1iUXzPk+kz+4GpbDZPK1FhPiQRd+ypgE= github.com/tealeg/xlsx v1.0.5/go.mod h1:btRS8dz54TDnvKNosuAqxrM1QgN1udgk9O34bDCnORM= github.com/tealeg/xlsx/v3 v3.2.4 h1:QPuk5v1xEivxoEUFmqszqINF52ppWCMejEd11ju3180= @@ -72,13 +64,8 @@ github.com/viterin/vek v0.4.0 h1:P34BWVGd3pSZFma9SE+G1pTucMGtw9p79I+Hull/+Ao= github.com/viterin/vek v0.4.0/go.mod h1:hVXEX7pnI4acHRhtFhmuBapUxhQ3TetMEp68jjxExBs= golang.org/x/exp v0.0.0-20220907003533-145caa8ea1d0 h1:17k44ji3KFYG94XS5QEFC8pyuOlMh3IoR+vkmTZmJJs= golang.org/x/exp v0.0.0-20220907003533-145caa8ea1d0/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= @@ -95,6 +82,4 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/stat/align.go b/stat/align.go index 60f280d250d3158718139b4b3715310432ba202a..45d20ac4e02641b96d9168cb0e00cdaf9fbc52c7 100644 --- a/stat/align.go +++ b/stat/align.go @@ -1,7 +1,8 @@ package stat // Align Data alignment -// a 通常是默认值 +// +// a 通常是默认值 func Align[T BaseType](x []T, a T, dLen int) []T { d := []T{} xLen := len(x) diff --git a/stat/builtin.go b/stat/builtin.go index 3a49319d84708beb95123b70ac20e9349d4b8e6c..532e72bc251e28f6de30154ca85df2ec9b9cd6b7 100644 --- a/stat/builtin.go +++ b/stat/builtin.go @@ -18,19 +18,32 @@ var ( var ( // IgnoreParseExceptions 忽略解析异常 IgnoreParseExceptions bool = true + + Avx2Enabled = false // AVX2加速开关 ) // 初始化 avx2 // 可以参考另一个实现库 gonum.org/v1/gonum/stat func init() { // 开启加速选项 - vek.SetAcceleration(true) + SetAvx2Enabled(true) Nil2Float64 = math.NaN() // 这个转换是对的, NaN对float32也有效 Nil2Float32 = float32(Nil2Float64) DTypeNaN = DType(Nil2Float64) } +// SetAvx2Enabled 设定AVX2加速开关 +func SetAvx2Enabled(enabled bool) { + vek.SetAcceleration(enabled) + Avx2Enabled = enabled +} + +// GetAvx2Enabled 获取avx2加速状态 +func GetAvx2Enabled() bool { + return Avx2Enabled +} + // 从指针/地址提取值 // Extract value from pointer func extraceValueFromPointer(v any) (any, bool) { diff --git a/stat/diff.go b/stat/diff.go index 9ffaa19cf22a471177b18785ec079aa86bcb7586..d8102dde2439d3c23cc31e7d415084ff147205ff 100644 --- a/stat/diff.go +++ b/stat/diff.go @@ -1,11 +1,32 @@ package stat -// Diff returns the n-th differences of the given array. -// TODO:这个代码有问题, 需要从generic_diff迁移过来 -func Diff[T Number](x []T) []T { - var result []T - for i := 1; i < len(x); i++ { - result = append(result, x[i]-x[i-1]) +// Diff 元素的第一个离散差 +// +// First discrete difference of element. +// Calculates the difference of a {klass} element compared with another +// element in the {klass} (default is element in previous row). +func Diff[T Number](s []T, param any) []T { + blocks := Rolling[T](s, param) + var d []T + var front = typeDefault[T]() + for _, block := range blocks { + vs := block + vl := len(block) + if vl == 0 { + d = append(d, typeDefault[T]()) + continue + } + vf := vs[0] + vc := vs[vl-1] + if DTypeIsNaN(Any2DType(vc)) || DTypeIsNaN(Any2DType(front)) { + front = vf + d = append(d, typeDefault[T]()) + continue + } + diff := vc - front + d = append(d, diff) + front = vf } - return result + + return d } diff --git a/stat/diff_test.go b/stat/diff_test.go index 61869491b50d7f7ac8d58faf2759f93e94ad922e..fc6e21956745f721d508f5e9fa3d112e301eba7c 100644 --- a/stat/diff_test.go +++ b/stat/diff_test.go @@ -1,9 +1,21 @@ package stat import ( + "fmt" "testing" ) func TestDiff(t *testing.T) { - + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + fmt.Println(d1) + fmt.Println("------------------------------------------------------------") + N := 1 + fmt.Println("固定的参数, N =", N) + r1 := Diff(d1, N) + fmt.Println("序列化结果:", r1) + fmt.Println("------------------------------------------------------------") + s1 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} + fmt.Printf("序列化参数: %+v\n", s1) + r2 := Diff(d1, s1) + fmt.Println("序列化结果:", r2) } diff --git a/stat/fillna.go b/stat/fillna.go index a0201198d6bff90853095a046147a7171db9579d..f62d150af71aafca44c6f211ff939b2e1b983376 100644 --- a/stat/fillna.go +++ b/stat/fillna.go @@ -39,7 +39,7 @@ import "golang.org/x/exp/slices" // Returns // ------- // []T or None -func Fill[T StatType | ~string](v []T, d T, args ...any) (rows []T) { +func Fill[T Number | ~string](v []T, d T, args ...any) (rows []T) { // 默认不替换 var __optInplace = false if len(args) > 0 { @@ -84,7 +84,7 @@ func Fill[T StatType | ~string](v []T, d T, args ...any) (rows []T) { } // FillNa NaN填充默认值 -func FillNa[T StatType | ~string](v []T, args ...any) []T { +func FillNa[T Number | ~string](v []T, args ...any) []T { // 默认不copy var __optInplace = false if len(args) > 0 { diff --git a/stat/params.go b/stat/params.go index 4408426b52addb8237d46245fdf01c25c5955edd..63d5d8e501533492a4a9e2ba9bbb0e555ff4231d 100644 --- a/stat/params.go +++ b/stat/params.go @@ -1,20 +1,9 @@ package stat -//func detectParam[T StatType](v any) (T, []T, error) { -// var base T -// var slice []T -// switch val := v.(type) { -// case []T: -// slice = val -// case T: -// base = val -// } -// return base, slice, nil -//} - // AnyToSlice any转切片 -// 如果a是基础类型, 就是repeat -// 如果a是切片, 就做对齐处理 +// +// 如果a是基础类型, 就是repeat +// 如果a是切片, 就做对齐处理 func AnyToSlice[T BaseType](A any, n int) []T { var d any switch v := A.(type) { diff --git a/stat/rolling.go b/stat/rolling.go index 50a46f8a4ad5ba2025e78cca5c428624e154a6ad..d83b3071fea4bb89746a05ba82c9668b60ae9c68 100644 --- a/stat/rolling.go +++ b/stat/rolling.go @@ -2,7 +2,6 @@ package stat import ( "gitee.com/quant1x/pandas/exception" - "golang.org/x/exp/slices" ) // Rolling returns an array with elements that roll beyond the last position are re-introduced at the first. @@ -36,7 +35,7 @@ func Rolling[T Number | bool](S []T, N any) [][]T { } start := i + 1 - shift end := i + 1 - subSet := slices.Clone(S[start:end]) + subSet := S[start:end] blocks[i] = subSet } return blocks diff --git a/stat/shift.go b/stat/shift.go index 8150e6940ca2fae1a1a8bf241cc8b1e21a23503a..72078b766e1ac4a8e32de4503f7febd2daa23690 100644 --- a/stat/shift.go +++ b/stat/shift.go @@ -34,7 +34,6 @@ func Shift[T GenericType](S []T, periods int) []T { naVals = values } for i := range naVals { - //naVals[i] = cbNan() naVals[i] = typeDefault[T]() } _ = naVals @@ -52,7 +51,6 @@ func Shift2[T GenericType](S []T, N []DType) []T { for i, _ := range S { x := N[i] if DTypeIsNaN(x) || int(x) > i { - //values[i] = cbNan() values[i] = typeDefault[T]() continue } diff --git a/strategy/strategy.go b/strategy/strategy.go index 536e5ae49a9c7d5c2386bd38cc854b1fbd302f95..51fbc372a8c018d55b5e5a0c00d1f3a6046b07c9 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -7,11 +7,16 @@ import ( "gitee.com/quant1x/pandas/data/cache" "gitee.com/quant1x/pandas/data/category" "gitee.com/quant1x/pandas/data/security" + "gitee.com/quant1x/pandas/stat" + "github.com/mymmsc/gox/logger" "github.com/mymmsc/gox/util/treemap" termTable "github.com/olekukonko/tablewriter" - "github.com/schollz/progressbar/v3" + + "github.com/qianlnk/pgbar" "os" + "runtime" "sync" + "time" ) // Strategy 策略/公式指标(features)接口 @@ -26,11 +31,15 @@ type Strategy interface { func main() { var ( - path string - strategy int + path string // 数据路径 + strategy int // 策略编号 + avx2 bool // AVX2加速状态 + cpuNum int // cpu数量 ) flag.StringVar(&path, "path", category.DATA_ROOT_PATH, "stock history data path") flag.IntVar(&strategy, "strategy", 1, "strategy serial number") + flag.BoolVar(&avx2, "avx2", false, "Avx2 acceleration") + flag.IntVar(&cpuNum, "cpu", runtime.NumCPU()/2, "sets the maximum number of CPUs") flag.Parse() cache.Init(path) var api Strategy @@ -40,35 +49,19 @@ func main() { default: api = new(FormulaNo1) } + stat.SetAvx2Enabled(avx2) //numCPU := runtime.NumCPU() / 2 - //runtime.GOMAXPROCS(numCPU) + runtime.GOMAXPROCS(cpuNum) // 获取全部证券代码 ss := data.GetCodeList() count := len(ss) - //var wg = sync.WaitGroup{} - //doneCh := make(chan struct{}) - bar := progressbar.NewOptions(count, - progressbar.OptionEnableColorCodes(true), - progressbar.OptionShowBytes(true), - progressbar.OptionSetWidth(80), - progressbar.OptionSetDescription("[cyan][1/3][reset]执行["+api.Name()+"]..."), - progressbar.OptionSetTheme(progressbar.Theme{ - Saucer: "[red]=[reset]", - SaucerHead: "[red]>[reset]", - SaucerPadding: " ", - BarStart: "[", - BarEnd: "]", - //SaucerPadding: "[white]•", - //BarStart: "[blue]|[reset]", - //BarEnd: "[blue]|[reset]", - }), - //progressbar.OptionOnCompletion(func() { - // doneCh <- struct{}{} - //}), - ) - //fmt.Printf("计划买入, 信号日期, 委托价格, 目标价位\n") + var wg = sync.WaitGroup{} + fmt.Println("Quant1X 预警系统") + fmt.Printf("CPU: %d, AVX2: %t\n", cpuNum, stat.GetAvx2Enabled()) + bar := pgbar.NewBar(0, "执行["+api.Name()+"]", count) var mapStock *treemap.Map mapStock = treemap.NewWithStringComparator() + mainStart := time.Now() for i, v := range ss { fullCode := v basicInfo, err := security.GetBasicInfo(fullCode) @@ -82,15 +75,16 @@ func main() { bar.Add(1) continue } - //go evaluate(bar, api, &wg, fullCode, basicInfo, mapStock) - evaluate(bar, api, nil, fullCode, basicInfo, mapStock) + bar.Add(1) + go evaluate(api, &wg, fullCode, basicInfo, mapStock) _ = i } - // got notified that progress bar is complete. - //<-doneCh - //wg.Wait() + wg.Wait() fmt.Println("\n ======= [" + api.Name() + "] progress bar completed ==========\n") + elapsedTime := time.Since(mainStart) / time.Millisecond + fmt.Printf("CPU: %d, AVX2: %t, 总耗时: %.3fs, 总记录: %d, 平均: %.3f/s\n", cpuNum, stat.GetAvx2Enabled(), float64(elapsedTime)/1000, count, float64(count)/(float64(elapsedTime)/1000)) + logger.Infof("CPU: %d, AVX2: %t, 总耗时: %.3fs, 总记录: %d, 平均: %.3f/s", cpuNum, stat.GetAvx2Enabled(), float64(elapsedTime)/1000, count, float64(count)/(float64(elapsedTime)/1000)) table := termTable.NewWriter(os.Stdout) var row ResultInfo table.SetHeader(row.Headers()) @@ -102,10 +96,8 @@ func main() { table.Render() // Send output } -func evaluate(bar *progressbar.ProgressBar, api Strategy, wg *sync.WaitGroup, code string, info *security.StaticBasic, result *treemap.Map) { - //defer wg.Done() - defer bar.Add(1) - - //wg.Add(1) +func evaluate(api Strategy, wg *sync.WaitGroup, code string, info *security.StaticBasic, result *treemap.Map) { + defer wg.Done() + wg.Add(1) api.Evaluate(code, info, result) } diff --git a/utils/utils.go b/utils/utils.go index 2fc83df1845b74a2ecd1a0c21ea1773cc5d74994..69104ff770856e3bf7731284f8c5b88cfa8b76f9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -1,6 +1,10 @@ package utils -import "math" +import ( + "math" + "reflect" + "unsafe" +) func WantFloat(got, want float64) bool { return got != want && !(math.IsNaN(want) && math.IsNaN(got)) @@ -16,3 +20,28 @@ func SliceWantFloat(got, want []float64) bool { } return b == len(got) } + +// ChanIsClosed 判断channel是否关闭 +func ChanIsClosed(ch any) bool { + if reflect.TypeOf(ch).Kind() != reflect.Chan { + + panic("only channels!") + + } + cptr := *(*uintptr)(unsafe.Pointer( + unsafe.Pointer(uintptr(unsafe.Pointer(&ch)) + unsafe.Sizeof(uint(0))), + )) + // this function will return true if chan.closed > 0 + // see hchan on https://github.com/golang/go/blob/master/src/runtime/chan.go + // type hchan struct { + // qcount uint // total data in the queue + // dataqsiz uint // size of the circular queue + // buf unsafe.Pointer // points to an array of dataqsiz elements + // elemsize uint16 + // closed uint32 + // ** + cptr += unsafe.Sizeof(uint(0)) * 2 + cptr += unsafe.Sizeof(unsafe.Pointer(uintptr(0))) + cptr += unsafe.Sizeof(uint16(0)) + return *(*uint32)(unsafe.Pointer(cptr)) > 0 +}