diff --git a/dataframe_join.go b/dataframe_join.go new file mode 100644 index 0000000000000000000000000000000000000000..063d4be6682aaa4f71aec91f92bcd873eb2f5a13 --- /dev/null +++ b/dataframe_join.go @@ -0,0 +1,57 @@ +package pandas + +import "gitee.com/quant1x/pandas/stat" + +func (self DataFrame) align(ss ...Series) []Series { + defaultValue := []Series{} + sLen := len(ss) + if sLen == 0 { + return defaultValue + } + ls := make([]float32, sLen) + for i, v := range ss { + ls[i] = float32(v.Len()) + } + + maxLength := stat.Max(ls) + if maxLength <= 0 { + return defaultValue + } + cols := make([]Series, sLen) + for i, v := range ss { + vt := v.Type() + vn := v.Name() + vs := v.Values() + // 声明any的ns变量用于接收逻辑分支的输出 + // 切片数据不能直接对齐, 需要根据类型指定Nil和NaN默认值 + var ns any + if vt == SERIES_TYPE_BOOL { + ns = stat.Align(vs.([]bool), Nil2Bool, int(maxLength)) + } else if vt == SERIES_TYPE_INT { + ns = stat.Align(vs.([]int64), Nil2Int64, int(maxLength)) + } else if vt == SERIES_TYPE_STRING { + ns = stat.Align(vs.([]string), Nil2String, int(maxLength)) + } else if vt == SERIES_TYPE_FLOAT { + ns = stat.Align(vs.([]float64), Nil2Float64, int(maxLength)) + } + cols[i] = NewSeries(vt, vn, ns) + } + return cols +} + +// Join 默认右连接, 加入一个series +func (self DataFrame) Join(series Series) DataFrame { + if series.Len() < 0 { + return self + } + nCol := self.Ncol() + cols := make([]Series, nCol+1) + cols[len(cols)-1] = series + for i, s := range self.columns { + cols[i] = s + } + cols = self.align(cols...) + df := NewDataFrame(cols...) + self = df + return self +} diff --git a/dataframe_test.go b/dataframe_test.go index f1ef0733b183b07f286b83a4fcf9365bd3877cad..d1f42304915f1b4ec40559099e0f369bb1046085 100644 --- a/dataframe_test.go +++ b/dataframe_test.go @@ -52,3 +52,23 @@ func TestLoadStructs(t *testing.T) { df2 := LoadStructs(dataTags) fmt.Println(df2) } + +func TestDataFrame_Join(t *testing.T) { + type testStruct struct { + A string + B int + C bool + D float64 + } + data := []testStruct{ + {"a", 1, true, 0.0}, + {"b", 2, false, 0.5}, + } + df1 := LoadStructs(data) + fmt.Println(df1) + + // 增加1列 + s_e := GenericSeries[string]("", "a0", "a1", "a2", "a3") + df2 := df1.Join(s_e) + fmt.Println(df2) +} diff --git a/stat/align.go b/stat/align.go index 83cc7680ec6ae98bc79bdd393e490642f41cc3f1..a7e845ed9461380811f4624b3a44b03210673fd0 100644 --- a/stat/align.go +++ b/stat/align.go @@ -1,7 +1,7 @@ package stat // Align Data alignment -func Align[T StatType](x []T, a T, dLen int) []T { +func Align[T MoveType](x []T, a T, dLen int) []T { d := []T{} xLen := len(x) if xLen >= dLen { diff --git a/stat/max.go b/stat/max.go new file mode 100644 index 0000000000000000000000000000000000000000..d0d5b566a5db6f2d6efd7addbc96b8cfbdcd5f28 --- /dev/null +++ b/stat/max.go @@ -0,0 +1,27 @@ +package stat + +import ( + "github.com/viterin/vek" + "github.com/viterin/vek/vek32" + "unsafe" +) + +// Max 计算最大值 +func Max[T Float](f []T) T { + if len(f) == 0 { + return T(0) + } + var d any + var s any + s = f + bitSize := unsafe.Sizeof(f[0]) + if bitSize == 4 { + d = vek32.Max(s.([]float32)) + } else if bitSize == 8 { + d = vek.Max(s.([]float64)) + } else { + // 应该不会走到这里 + d = T(0) + } + return d.(T) +} diff --git a/stat/max_test.go b/stat/max_test.go new file mode 100644 index 0000000000000000000000000000000000000000..405bafdd951b09cb87b89e2dde8957659a7b10a1 --- /dev/null +++ b/stat/max_test.go @@ -0,0 +1,15 @@ +package stat + +import ( + "fmt" + "testing" +) + +func TestMax(t *testing.T) { + + f1 := []float32{1.1, 1.2, 1.3} + f2 := []float32{1.1, 1.2, 1.3} + fmt.Println(Max(f1)) + fmt.Println(Max(f2)) + +} diff --git a/stat/repeat.go b/stat/repeat.go index 53564e86ed8c5ec74e5cd92e8e838ea874abc81e..9f472960aed6401910c7b64a6ff796d90063633e 100644 --- a/stat/repeat.go +++ b/stat/repeat.go @@ -9,10 +9,10 @@ import ( // Repeat repeat func Repeat[T Float](f T, n int) []T { var d any - bitsize := unsafe.Sizeof(f) - if bitsize == 4 { + bitSize := unsafe.Sizeof(f) + if bitSize == 4 { d = vek32.Repeat(float32(f), n) - } else if bitsize == 8 { + } else if bitSize == 8 { d = vek.Repeat(float64(f), n) } else { // 应该不会走到这里 diff --git a/stat/type.go b/stat/type.go index 3f2a1679a355d869b0bc62a8d61b5538a8a141f1..1ee3dc0ed268c3c33a2060f8ec6e41ea5d3c0ce1 100644 --- a/stat/type.go +++ b/stat/type.go @@ -13,6 +13,10 @@ type StatType interface { ~int32 | ~int64 | ~float32 | ~float64 } +type MoveType interface { + StatType | ~bool | ~string +} + // 随便输入一个什么值 func typeDefault[T StatType](x T) T { xv := reflect.ValueOf(x)