From 31b96b60bc9c325ba1aa1d9f07eb19ec5b8a5c40 Mon Sep 17 00:00:00 2001 From: wangfeng Date: Thu, 9 Feb 2023 14:55:50 +0800 Subject: [PATCH] =?UTF-8?q?#I6CYP0=20=E5=AE=9E=E7=8E=B0COUNT=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- formula/README.md | 2 +- formula/count.go | 115 ++++++++++++++++++++++++++++++++++++++++++ formula/count_test.go | 46 +++++++++++++++++ generic.go | 6 ++- rolling_count.go | 23 +++++++++ 5 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 formula/count.go create mode 100644 formula/count_test.go create mode 100644 rolling_count.go diff --git a/formula/README.md b/formula/README.md index fee1677..dc1bf4e 100644 --- a/formula/README.md +++ b/formula/README.md @@ -31,7 +31,7 @@ formula | 0 | SLOPE | S序列N周期回线性回归斜率 | SLOPE(CLOSE,5) | [X] | [X] | | 0 | FORCAST | S序列N周期回线性回归后的预测值 | FORCAST(CLOSE,5) | [X] | [X] | | 0 | LAST | 从前A日到前B日一直满足S条件,要求A>B & A>0 & B>=0 | LAST(CLOSE>REF(CLOSE,1),LOW,HIGH) | [X] | [X] | -| 1 | COUNT | COUNT(CLOSE>O,N),最近N天满足S的天数True的天数 | COUNT(CLOSE>LOW,5) | [X] | [X] | +| 1 | COUNT | COUNT(CLOSE>O,N),最近N天满足S的天数True的天数 | COUNT(CLOSE>LOW,5) | [√] | [√] | | 1 | EVERY | EVERY(CLOSE>O,5),最近N天是否都是True | EVERY(CLOSE>LOW,5) | [X] | [X] | | 1 | EXIST | EXIST(CLOSE>O,5),最近N天是否都是True | EXIST(CLOSE>LOW,5) | [X] | [X] | | 1 | FILTER | FILTER函数,S满足条件后,将其后N周期内的数据置为0 | FILTER(CLOSE>LOW,5) | [X] | [X] | diff --git a/formula/count.go b/formula/count.go new file mode 100644 index 0000000..5abb133 --- /dev/null +++ b/formula/count.go @@ -0,0 +1,115 @@ +package formula + +import ( + "gitee.com/quant1x/pandas" + "gitee.com/quant1x/pandas/exception" + "gitee.com/quant1x/pandas/stat" + "github.com/viterin/vek" +) + +// COUNT 统计S为真的天数 +func COUNT(S pandas.Series, N any) pandas.Series { + return S.Rolling(N).Count() +} + +// COUNT_v1 一般性比较 +func COUNT_v1(S pandas.Series, N any) []stat.Int { + //values := S.DTypes() + return S.Rolling(N).Apply(func(X pandas.Series, W stat.DType) stat.DType { + x := X.DTypes() + n := 0 + for _, v := range x { + if v != 0 { + n++ + } + } + return stat.DType(n) + }).AsInt() +} + +func GT(v []stat.DType, x any) []int { + vlen := len(v) + + // 处理默认值 + defaultValue := stat.DType(0) + var X []stat.DType + switch vx := x.(type) { + case int: + X = stat.Repeat[stat.DType](stat.DType(vx), vlen) + case []stat.DType: + xlen := len(vx) + if vlen < xlen { + vlen = xlen + } + X = stat.Align[stat.DType](vx, defaultValue, vlen) + case pandas.Series: + vs := vx.DTypes() + xlen := len(vs) + if vlen < xlen { + vlen = xlen + } + X = stat.Align(vs, defaultValue, vlen) + default: + panic(exception.New(1, "error window")) + } + //bs := vek.Gt(v, X) + //vek.Count() + ns := make([]int, vlen) + for i := 0; i < vlen; i++ { + if v[i] > X[i] { + ns[i] = 1 + } else { + ns[i] = 0 + } + } + return ns +} + +// CompareGt 比较 v > x +func CompareGt(v []stat.DType, x any) []bool { + return __compare(v, x, vek.Gt) +} + +// CompareGte 比较 v >= x +func CompareGte(v []stat.DType, x any) []bool { + return __compare(v, x, vek.Gte) +} + +// CompareLt 比较 v < x +func CompareLt(v []stat.DType, x any) []bool { + return __compare(v, x, vek.Lt) +} + +// CompareLte 比较 v <= x +func CompareLte(v []stat.DType, x any) []bool { + return __compare(v, x, vek.Lte) +} + +// __compare 比较 v 和 x +func __compare(v []stat.DType, x any, comparator func(x, y []float64) []bool) []bool { + vlen := len(v) + + // 处理默认值 + defaultValue := stat.DType(0) + var X []stat.DType + switch vx := x.(type) { + case int: + X = stat.Repeat[stat.DType](stat.DType(vx), vlen) + case []stat.DType: + xlen := len(vx) + if vlen < xlen { + vlen = xlen + } + X = stat.Align[stat.DType](vx, defaultValue, vlen) + case pandas.Series: + vs := vx.DTypes() + xlen := len(vs) + if vlen < xlen { + vlen = xlen + } + X = stat.Align(vs, defaultValue, vlen) + default: + panic(exception.New(1, "error window")) + } + return comparator(v, X) +} diff --git a/formula/count_test.go b/formula/count_test.go new file mode 100644 index 0000000..1ba8001 --- /dev/null +++ b/formula/count_test.go @@ -0,0 +1,46 @@ +package formula + +import ( + "fmt" + "gitee.com/quant1x/pandas" + "testing" +) + +func TestCOUNT(t *testing.T) { + f0 := []float64{1, 2, 3, 4, 5, 6, 0, 8, 9, 10, 11, 12} + i0 := CompareGte(f0, 1) + s0 := pandas.NewSeriesWithoutType("f0", i0) + fmt.Println(COUNT(s0, 5)) + //s2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, stat.DTypeNaN, stat.DTypeNaN, stat.DTypeNaN, stat.DTypeNaN} + //fmt.Println(s2) + ////stat.Fill(s2, 1.0, true) + ////fmt.Println(s2) + //fmt.Println(DMA(s0, s2)) + //csv := "~/.quant1x/data/cn/002528.csv" + //df := pandas.ReadCSV(csv) + //df.SetNames("data", "open", "close", "high", "low", "volume", "amount", "zf", "zdf", "zde", "hsl") + //fmt.Println(df) + //CLOSE := df.Col("close") + // + //cs := CLOSE.Values().([]float32) + //REF10 := REF(CLOSE, 10).([]float32) + //v1 := vek32.Div(cs, REF10) + //df01 := pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "x", v1) + //x0 := make([]stat.DType, CLOSE.Len()) + //df01.Apply(func(idx int, v any) { + // f := v.(float32) + // t := stat.DType(0) + // if f >= 1.05 { + // t = stat.DType(1) + // } + // x0[idx] = t + //}) + //n := BARSLAST(pandas.NewSeries(pandas.SERIES_TYPE_FLOAT32, "", x0)) + //fmt.Println(n[len(n)-10:]) + //x := DMA(CLOSE, pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "", n)) + // + ////x := EMA(CLOSE, 7) + //sx := pandas.NewSeries(pandas.SERIES_TYPE_DTYPE, "x", x) + //df = pandas.NewDataFrame(CLOSE, sx) + //fmt.Println(df) +} diff --git a/generic.go b/generic.go index 823a7ec..3128bfb 100644 --- a/generic.go +++ b/generic.go @@ -156,7 +156,7 @@ func (self *NDFrame) Float() []float32 { return ToFloat32(self) } -// DType 计算以这个函数为主 +// DTypes 计算以这个函数为主 func (self *NDFrame) DTypes() []stat.DType { return stat.Slice2DType(self.Values()) } @@ -164,7 +164,9 @@ func (self *NDFrame) DTypes() []stat.DType { // AsInt 强制转换成整型 func (self *NDFrame) AsInt() []stat.Int { values := self.DTypes() - return stat.DType2Int(values) + fs := stat.Fill[stat.DType](values, stat.DType(0)) + ns := stat.DType2Int(fs) + return ns } func (self *NDFrame) Empty() Series { diff --git a/rolling_count.go b/rolling_count.go new file mode 100644 index 0000000..668438a --- /dev/null +++ b/rolling_count.go @@ -0,0 +1,23 @@ +package pandas + +import ( + "gitee.com/quant1x/pandas/stat" + "github.com/viterin/vek" +) + +func (r RollingAndExpandingMixin) Count() (s Series) { + if r.series.Type() != SERIES_TYPE_BOOL { + panic("不支持非bool序列") + } + values := make([]stat.DType, r.series.Len()) + for i, block := range r.getBlocks() { + if block.Len() == 0 { + values[i] = 0 + continue + } + bs := block.Values().([]bool) + values[i] = stat.DType(vek.Count(bs)) + } + s = NewSeries(SERIES_TYPE_DTYPE, r.series.Name(), values) + return +} -- Gitee