diff --git a/dataframe.go b/dataframe.go index 1f0505cb778a4172acec8b5bd76eef5f97926923..e2df4d5238fd786b469662ce380d2650ce8c47d0 100644 --- a/dataframe.go +++ b/dataframe.go @@ -25,13 +25,13 @@ func NewDataFrame(se ...Series) DataFrame { for i, s := range se { var d Series if s.Type() == SERIES_TYPE_INT { - d = NewSeriesInt64(s.Name(), s.Values()) + d = NewSeries(SERIES_TYPE_INT, s.Name(), s.Values()) } else if s.Type() == SERIES_TYPE_BOOL { - d = NewSeriesBool(s.Name(), s.Values()) + d = NewSeries(SERIES_TYPE_BOOL, s.Name(), s.Values()) } else if s.Type() == SERIES_TYPE_STRING { - d = NewSeriesString(s.Name(), s.Values()) + d = NewSeries(SERIES_TYPE_STRING, s.Name(), s.Values()) } else { - d = NewSeriesFloat64(s.Name(), s.Values()) + d = NewSeries(SERIES_TYPE_FLOAT, s.Name(), s.Values()) } columns[i] = d } diff --git a/exception/errors.go b/exception/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..bfcccaef9a87e7cc2fb2d50a0f6422fffb5102af --- /dev/null +++ b/exception/errors.go @@ -0,0 +1,29 @@ +package exception + +import "fmt" + +type Throwable interface { + error + Code() int +} + +type Exception struct { + Throwable + code int + message string +} + +func New(code int, message string) *Exception { + return &Exception{ + code: code, + message: message, + } +} + +func (this Exception) Error() string { + return fmt.Sprintf("#%d, message=%s", this.code, this.message) +} + +func (this Exception) Code() int { + return this.code +} diff --git a/frame.go b/frame.go index a422e7993caa3f798511f33b82b6015e640de291..ae9b41bec021abdc5f2d0464717887b5743efadc 100644 --- a/frame.go +++ b/frame.go @@ -16,7 +16,7 @@ type Frame[T GenericType] interface { // Len 获得行数 Len() int // Values 获得全部数据集 - Values() []T + Values() []T // 如果确定类型, 后面可能无法自动调整 } type GenericFrame[T GenericType] struct { diff --git a/generic.go b/generic.go index cc04076732bb956f17697096c7923a4c3dcca1ef..427e6f333cfffa4cdeb3a89e4a89ea0112cbf886 100644 --- a/generic.go +++ b/generic.go @@ -231,13 +231,6 @@ func (self *NDFrame) Shift(periods int) Series { } } -func (self *NDFrame) Rolling(window int) RollingWindow { - return RollingWindow{ - window: window, - series: self, - } -} - func (self *NDFrame) Mean() float64 { if self.Len() < 1 { return NaN() diff --git a/generic_range.go b/generic_range.go index 649244fbaef5cdf16fd3962e5be4fbd152f124d7..3f97ce870d014cbe26adbf4135c1ce8c52220e8f 100644 --- a/generic_range.go +++ b/generic_range.go @@ -26,12 +26,13 @@ func (self *NDFrame) Subset(start, end int, opt ...any) Series { vk := vv.Kind() switch vk { case reflect.Slice, reflect.Array: // 切片和数组同样的处理逻辑 - vs = vv.Slice(start, end).Interface() + vvs := vv.Slice(start, end) + vs = vvs.Interface() rows = vv.Len() if __optCopy && rows > 0 { vs = gc.Clone(vs) } - rows = vv.Len() + rows = vvs.Len() frame := NDFrame{ formatter: self.formatter, name: self.name, diff --git a/generic_rolling.go b/generic_rolling.go new file mode 100644 index 0000000000000000000000000000000000000000..18591089a5f72783f8f04cf6d6e509f71977106c --- /dev/null +++ b/generic_rolling.go @@ -0,0 +1,66 @@ +package pandas + +import ( + "gitee.com/quant1x/pandas/exception" + "gitee.com/quant1x/pandas/stat" +) + +// Rolling 滑动窗口 +func (self *NDFrame) Rolling(window int) RollingWindow { + return RollingWindow{ + window: window, + series: self, + } +} + +// RollingAndExpandingMixin 滚动和扩展静态横切 +type RollingAndExpandingMixin struct { + window []float32 + series Series +} + +// Rolling2 RollingAndExpandingMixin +func (self *NDFrame) Rolling2(param any) RollingAndExpandingMixin { + var N []float32 + switch v := param.(type) { + case int: + N = stat.Repeat[float32](float32(v), self.Len()) + case Series: + vs := v.Values() + N = sliceToFloat32(vs) + N = stat.Align(N, Nil2Float32, self.Len()) + default: + panic(exception.New(1, "error window")) + } + w := RollingAndExpandingMixin{ + window: N, + series: self, + } + return w +} + +func (r RollingAndExpandingMixin) getBlocks() (blocks []Series) { + for i := 0; i < r.series.Len(); i++ { + N := r.window[i] + if Float32IsNaN(N) || int(N) > i+1 { + blocks = append(blocks, r.series.Empty()) + continue + } + window := int(N) + start := i + 1 - window + end := i + 1 + blocks = append(blocks, r.series.Subset(start, end)) + } + + return +} + +// Mean returns the rolling mean. +func (r RollingAndExpandingMixin) Mean() (s Series) { + var d []float64 + for _, block := range r.getBlocks() { + d = append(d, block.Mean()) + } + s = NewSeries(SERIES_TYPE_FLOAT, r.series.Name(), d) + return +} diff --git a/generic_test.go b/generic_test.go index fda2edb420b1f1468722d9273dd1786a1467eb8b..2a2574973e9dc9100bda9a89a20095dd21ef524a 100644 --- a/generic_test.go +++ b/generic_test.go @@ -65,3 +65,22 @@ func TestNDFrameNew(t *testing.T) { fmt.Println(nd2.Records()) fmt.Println(nd2.Empty()) } + +func TestRolling2(t *testing.T) { + d1 := []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + s1 := NewNDFrame[float64]("x", d1...) + df := NewDataFrame(s1) + fmt.Println(df) + fmt.Println("------------------------------------------------------------") + + N := 5 + fmt.Println("固定的参数, N =", N) + r1 := df.Col("x").Rolling2(5).Mean().Values() + fmt.Println("序列化结果:", r1) + fmt.Println("------------------------------------------------------------") + d2 := []float64{1, 2, 3, 4, 3, 3, 2, 1, Nil2Float64, Nil2Float64, Nil2Float64, Nil2Float64} + s2 := NewSeries(SERIES_TYPE_FLOAT, "x", d2) + fmt.Printf("序列化参数: %+v\n", s2.Values()) + r2 := df.Col("x").Rolling2(s2).Mean().Values() + fmt.Println("序列化结果:", r2) +} diff --git a/series.go b/series.go index 2cb8507bd3c95b505ddeb927c2768e9c18ec0e9f..397d938b30947b3020d15874793c56b536d76c06 100644 --- a/series.go +++ b/series.go @@ -59,6 +59,8 @@ type Series interface { Shift(periods int) Series // Rolling creates new RollingWindow Rolling(window int) RollingWindow + // Rolling2 序列化版本 + Rolling2(param any) RollingAndExpandingMixin // Mean calculates the average value of a series Mean() float64 // StdDev calculates the standard deviation of a series diff --git a/series_number.go b/series_number.go index b87ef2d94a83a848cef5cb1178042c6aed2c1fb5..a425471638a38062484c1845dc696cab27b38092 100644 --- a/series_number.go +++ b/series_number.go @@ -159,3 +159,13 @@ func point_to_number[T Number](v any, nil2t T, bool2t func(b bool) T, string2t f } return T(0) } + +//func anyToNumber(v any) int { +// switch val := v.(type) { +// case nil, int8, uint8, int16, uint16, int32, uint32, int64, uint64, int, uint, float32, float64, bool, string: +// // 基础类型 +// series_append(&frame, idx, size, val) +// default: +// } +// return 0 +//} diff --git a/stat/align.go b/stat/align.go index b84e12690fcf988fbbfee1e3585b78ede2706591..83cc7680ec6ae98bc79bdd393e490642f41cc3f1 100644 --- a/stat/align.go +++ b/stat/align.go @@ -1,11 +1,12 @@ package stat -// Data alignment -func align[T StatType](x []T, a T, dLen int) []T { +// Align Data alignment +func Align[T StatType](x []T, a T, dLen int) []T { d := []T{} xLen := len(x) if xLen >= dLen { // 截断 + d = make([]T, dLen) copy(d, x[0:dLen]) } else { // 扩展内存 diff --git a/stat/repeat.go b/stat/repeat.go new file mode 100644 index 0000000000000000000000000000000000000000..53564e86ed8c5ec74e5cd92e8e838ea874abc81e --- /dev/null +++ b/stat/repeat.go @@ -0,0 +1,22 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" + "unsafe" +) + +// Repeat repeat +func Repeat[T Float](f T, n int) []T { + var d any + bitsize := unsafe.Sizeof(f) + if bitsize == 4 { + d = vek32.Repeat(float32(f), n) + } else if bitsize == 8 { + d = vek.Repeat(float64(f), n) + } else { + // 应该不会走到这里 + d = []T{} + } + return d.([]T) +} diff --git a/stat/repeat_test.go b/stat/repeat_test.go new file mode 100644 index 0000000000000000000000000000000000000000..364a04cf52d4746567571a6330ed36343efce340 --- /dev/null +++ b/stat/repeat_test.go @@ -0,0 +1,17 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestRepeat(t *testing.T) { + f32 := float32(1) + f64 := float64(1) + + n := 10 + fs32 := Repeat(f32, n) + fmt.Println(fs32) + fs64 := Repeat(f64, n) + fmt.Println(fs64) +} diff --git a/stat/type.go b/stat/type.go index 12293572243b0c24ebf08d7b1227accd394dbb73..3f2a1679a355d869b0bc62a8d61b5538a8a141f1 100644 --- a/stat/type.go +++ b/stat/type.go @@ -5,6 +5,10 @@ import ( "reflect" ) +type Float interface { + ~float32 | ~float64 +} + type StatType interface { ~int32 | ~int64 | ~float32 | ~float64 } diff --git a/stat/where.go b/stat/where.go index a7d52539391aebd540982b0bb813c3bc45fb2e11..38a1c9b499d794ad776d13c44314ef0e0c63e11c 100644 --- a/stat/where.go +++ b/stat/where.go @@ -30,13 +30,13 @@ func Where[T StatType](condition []T, x, y []T) []T { defaultValue := typeDefault(T(0)) // 对齐所有长度 if clen < maxLength { - condition = align(condition, defaultValue, maxLength) + condition = Align(condition, defaultValue, maxLength) } if xlen < maxLength { - x = align(x, defaultValue, maxLength) + x = Align(x, defaultValue, maxLength) } if ylen < maxLength { - y = align(y, defaultValue, maxLength) + y = Align(y, defaultValue, maxLength) } // 初始化返回值 d := make([]T, maxLength)