代码拉取完成,页面将自动刷新
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package trino
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
)
func TestConfig(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8080",
SessionProperties: map[string]string{"query_priority": "1"},
}
dsn, err := c.FormatDSN()
if err != nil {
t.Fatal(err)
}
want := "http://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client"
if dsn != want {
t.Fatal("unexpected dsn:", dsn)
}
}
func TestConfigSSLCertPath(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8080",
SessionProperties: map[string]string{"query_priority": "1"},
SSLCertPath: "cert.pem",
}
dsn, err := c.FormatDSN()
if err != nil {
t.Fatal(err)
}
want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3D1&source=trino-go-client"
if dsn != want {
t.Fatal("unexpected dsn:", dsn)
}
}
func TestConfigWithoutSSLCertPath(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8080",
SessionProperties: map[string]string{"query_priority": "1"},
}
dsn, err := c.FormatDSN()
if err != nil {
t.Fatal(err)
}
want := "https://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client"
if dsn != want {
t.Fatal("unexpected dsn:", dsn)
}
}
func TestKerberosConfig(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
KerberosEnabled: "true",
KerberosKeytabPath: "/opt/test.keytab",
KerberosPrincipal: "trino/testhost",
KerberosRealm: "example.com",
KerberosConfigPath: "/etc/krb5.conf",
SSLCertPath: "/tmp/test.cert",
}
dsn, err := c.FormatDSN()
if err != nil {
t.Fatal(err)
}
want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=trino-go-client"
if dsn != want {
t.Fatal("unexpected dsn:", dsn)
}
}
func TestInvalidKerberosConfig(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8090",
KerberosEnabled: "true",
}
_, err := c.FormatDSN()
if err == nil {
t.Fatal("dsn generated from invalid secure url, since kerberos enabled must has SSL enabled")
}
}
func TestConfigWithMalformedURL(t *testing.T) {
_, err := (&Config{ServerURI: ":("}).FormatDSN()
if err == nil {
t.Fatal("dsn generated from malformed url")
}
}
func TestConnErrorDSN(t *testing.T) {
testcases := []struct {
Name string
DSN string
}{
{Name: "malformed", DSN: "://"},
{Name: "unknown_client", DSN: "http://localhost?custom_client=unknown"},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
db, err := sql.Open("trino", tc.DSN)
if err != nil {
t.Fatal(err)
}
if _, err = db.Query("SELECT 1"); err == nil {
db.Close()
t.Fatal("test dsn is supposed to fail:", tc.DSN)
}
})
}
}
func TestRegisterCustomClientReserved(t *testing.T) {
for _, tc := range []string{"true", "false"} {
t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
err := RegisterCustomClient(tc, &http.Client{})
if err == nil {
t.Fatal("client key name supposed to fail:", tc)
}
})
}
}
func TestRoundTripRetryQueryError(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count == 0 {
count++
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&stmtResponse{
Error: stmtError{
ErrorName: "TEST",
},
})
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Query("SELECT 1")
if _, ok := err.(*ErrQueryFailed); !ok {
t.Fatal("unexpected error:", err)
}
}
func TestRoundTripCancellation(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = db.QueryContext(ctx, "SELECT 1")
if err == nil {
t.Fatal("unexpected query with cancelled context succeeded")
}
}
func TestAuthFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
}
func TestQueryForUsername(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8080",
SessionProperties: map[string]string{"query_priority": "1"},
}
dsn, err := c.FormatDSN()
if err != nil {
t.Fatal(err)
}
db, err := sql.Open("trino", dsn)
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query("SELECT current_user", sql.Named("X-Trino-User", string("TestUser")))
if err != nil {
t.Fatal("Failed executing query", err.Error())
}
if rows != nil {
for rows.Next() {
var ts string
err = rows.Scan(&ts)
if err != nil {
t.Fatal("Failed scanning query result", err.Error())
}
want := "TestUser"
if ts != want {
t.Fatal("Expected value does not equal result value : ", ts, " != ", want)
}
}
}
}
func TestQueryCancellation(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&stmtResponse{
Error: stmtError{
ErrorName: "USER_CANCELLED",
},
})
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Query("SELECT 1")
if err != ErrQueryCancelled {
t.Fatal("unexpected error:", err)
}
}
func TestQueryFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Query("SELECT 1")
if _, ok := err.(*ErrQueryFailed); !ok {
t.Fatal("unexpected error:", err)
}
}
func TestUnsupportedHeader(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(trinoSetRoleHeader, "foo")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
db, err := sql.Open("trino", ts.URL)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Query("SELECT 1")
if err != ErrUnsupportedHeader {
t.Fatal("unexpected error:", err)
}
}
func TestSSLCertPath(t *testing.T) {
db, err := sql.Open("trino", "https://localhost:9?SSLCertPath=/tmp/invalid_test.cert")
if err != nil {
t.Fatal(err)
}
defer db.Close()
want := "Error loading SSL Cert File"
if err := db.Ping(); err == nil {
t.Fatal(err)
} else if !strings.Contains(err.Error(), want) {
t.Fatalf("want: %q, got: %v", want, err)
}
}
func TestWithoutSSLCertPath(t *testing.T) {
db, err := sql.Open("trino", "https://localhost:9")
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err := db.Ping(); err != nil {
t.Fatal(err)
}
}
func TestUnsupportedTransaction(t *testing.T) {
db, err := sql.Open("trino", "http://localhost:9")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Begin()
if err == nil {
t.Fatal("unsupported transaction succeeded with no error")
}
expected := "operation not supported"
if !strings.Contains(err.Error(), expected) {
t.Fatalf("expected begin to fail with %s but got %v", expected, err)
}
}
func TestTypeConversion(t *testing.T) {
utc, err := time.LoadLocation("UTC")
if err != nil {
t.Fatal(err)
}
testcases := []struct {
DataType string
ResponseUnmarshalledSample interface{}
ExpectedGoValue interface{}
}{
{
DataType: "boolean",
ResponseUnmarshalledSample: true,
ExpectedGoValue: true,
},
{
DataType: "varchar(1)",
ResponseUnmarshalledSample: "hello",
ExpectedGoValue: "hello",
},
{
DataType: "bigint",
ResponseUnmarshalledSample: json.Number("1234516165077230279"),
ExpectedGoValue: int64(1234516165077230279),
},
{
DataType: "double",
ResponseUnmarshalledSample: json.Number("1.0"),
ExpectedGoValue: float64(1),
},
{
DataType: "date",
ResponseUnmarshalledSample: "2017-07-10",
ExpectedGoValue: time.Date(2017, 7, 10, 0, 0, 0, 0, time.Local),
},
{
DataType: "time",
ResponseUnmarshalledSample: "01:02:03.000",
ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.Local),
},
{
DataType: "time with time zone",
ResponseUnmarshalledSample: "01:02:03.000 UTC",
ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, utc),
},
{
DataType: "timestamp",
ResponseUnmarshalledSample: "2017-07-10 01:02:03.000",
ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.Local),
},
{
DataType: "timestamp with time zone",
ResponseUnmarshalledSample: "2017-07-10 01:02:03.000 UTC",
ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, utc),
},
{
DataType: "map",
ResponseUnmarshalledSample: nil,
ExpectedGoValue: nil,
},
{
// arrays return data as-is for slice scanners
DataType: "array",
ResponseUnmarshalledSample: nil,
ExpectedGoValue: nil,
},
}
for _, tc := range testcases {
converter := newTypeConverter(tc.DataType)
t.Run(tc.DataType+":nil", func(t *testing.T) {
if _, err := converter.ConvertValue(nil); err != nil {
t.Fatal(err)
}
})
t.Run(tc.DataType+":bogus", func(t *testing.T) {
if _, err := converter.ConvertValue(struct{}{}); err == nil {
t.Fatal("bogus data scanned with no error")
}
})
t.Run(tc.DataType+":sample", func(t *testing.T) {
v, err := converter.ConvertValue(tc.ResponseUnmarshalledSample)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(v, tc.ExpectedGoValue) {
t.Fatalf("unexpected data from sample:\nhave %+v\nwant %+v", v, tc.ExpectedGoValue)
}
})
}
}
func TestSliceTypeConversion(t *testing.T) {
testcases := []struct {
GoType string
Scanner sql.Scanner
TrinoResponseUnmarshalledSample interface{}
TestScanner func(t *testing.T, s sql.Scanner)
}{
{
GoType: "[]bool",
Scanner: &NullSliceBool{},
TrinoResponseUnmarshalledSample: []interface{}{true},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceBool)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[]string",
Scanner: &NullSliceString{},
TrinoResponseUnmarshalledSample: []interface{}{"hello"},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceString)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[]int64",
Scanner: &NullSliceInt64{},
TrinoResponseUnmarshalledSample: []interface{}{json.Number("1")},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceInt64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[]float64",
Scanner: &NullSliceFloat64{},
TrinoResponseUnmarshalledSample: []interface{}{json.Number("1.0")},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceFloat64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[]time.Time",
Scanner: &NullSliceTime{},
TrinoResponseUnmarshalledSample: []interface{}{"2017-07-01"},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceTime)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[]map[string]interface{}",
Scanner: &NullSliceMap{},
TrinoResponseUnmarshalledSample: []interface{}{map[string]interface{}{"hello": "world"}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSliceMap)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
}
for _, tc := range testcases {
t.Run(tc.GoType+":nil", func(t *testing.T) {
if err := tc.Scanner.Scan(nil); err != nil {
t.Error(err)
}
})
t.Run(tc.GoType+":bogus", func(t *testing.T) {
if err := tc.Scanner.Scan(struct{}{}); err == nil {
t.Error("bogus data scanned with no error")
}
if err := tc.Scanner.Scan([]interface{}{struct{}{}}); err == nil {
t.Error("bogus data scanned with no error")
}
})
t.Run(tc.GoType+":sample", func(t *testing.T) {
if err := tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample); err != nil {
t.Error(err)
}
tc.TestScanner(t, tc.Scanner)
})
}
}
func TestSlice2TypeConversion(t *testing.T) {
testcases := []struct {
GoType string
Scanner sql.Scanner
TrinoResponseUnmarshalledSample interface{}
TestScanner func(t *testing.T, s sql.Scanner)
}{
{
GoType: "[][]bool",
Scanner: &NullSlice2Bool{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{true}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2Bool)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][]string",
Scanner: &NullSlice2String{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{"hello"}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2String)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][]int64",
Scanner: &NullSlice2Int64{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1")}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2Int64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][]float64",
Scanner: &NullSlice2Float64{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1.0")}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2Float64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][]time.Time",
Scanner: &NullSlice2Time{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{"2017-07-01"}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2Time)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][]map[string]interface{}",
Scanner: &NullSlice2Map{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice2Map)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
}
for _, tc := range testcases {
t.Run(tc.GoType+":nil", func(t *testing.T) {
if err := tc.Scanner.Scan(nil); err != nil {
t.Error(err)
}
if err := tc.Scanner.Scan([]interface{}{nil}); err != nil {
t.Error(err)
}
})
t.Run(tc.GoType+":bogus", func(t *testing.T) {
if err := tc.Scanner.Scan(struct{}{}); err == nil {
t.Error("bogus data scanned with no error")
}
if err := tc.Scanner.Scan([]interface{}{struct{}{}}); err == nil {
t.Error("bogus data scanned with no error")
}
if err := tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}); err == nil {
t.Error("bogus data scanned with no error")
}
})
t.Run(tc.GoType+":sample", func(t *testing.T) {
if err := tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample); err != nil {
t.Error(err)
}
tc.TestScanner(t, tc.Scanner)
})
}
}
func TestSlice3TypeConversion(t *testing.T) {
testcases := []struct {
GoType string
Scanner sql.Scanner
TrinoResponseUnmarshalledSample interface{}
TestScanner func(t *testing.T, s sql.Scanner)
}{
{
GoType: "[][][]bool",
Scanner: &NullSlice3Bool{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{true}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3Bool)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][][]string",
Scanner: &NullSlice3String{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"hello"}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3String)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][][]int64",
Scanner: &NullSlice3Int64{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1")}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3Int64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][][]float64",
Scanner: &NullSlice3Float64{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1.0")}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3Float64)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][][]time.Time",
Scanner: &NullSlice3Time{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"2017-07-01"}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3Time)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
{
GoType: "[][][]map[string]interface{}",
Scanner: &NullSlice3Map{},
TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}}},
TestScanner: func(t *testing.T, s sql.Scanner) {
v, _ := s.(*NullSlice3Map)
if !v.Valid {
t.Fatal("scanner failed")
}
},
},
}
for _, tc := range testcases {
t.Run(tc.GoType+":nil", func(t *testing.T) {
if err := tc.Scanner.Scan(nil); err != nil {
t.Fatal(err)
}
if err := tc.Scanner.Scan([]interface{}{[]interface{}{nil}}); err != nil {
t.Fatal(err)
}
})
t.Run(tc.GoType+":bogus", func(t *testing.T) {
if err := tc.Scanner.Scan(struct{}{}); err == nil {
t.Error("bogus data scanned with no error")
}
if err := tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}); err == nil {
t.Error("bogus data scanned with no error")
}
if err := tc.Scanner.Scan([]interface{}{[]interface{}{[]interface{}{struct{}{}}}}); err == nil {
t.Error("bogus data scanned with no error")
}
})
t.Run(tc.GoType+":sample", func(t *testing.T) {
if err := tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample); err != nil {
t.Error(err)
}
tc.TestScanner(t, tc.Scanner)
})
}
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。