From 7af6cc8a47e58894a0f3237114c2e2d225cfc62e Mon Sep 17 00:00:00 2001 From: taoying <2474671424@qq.com> Date: Fri, 14 Feb 2025 14:29:20 +0800 Subject: [PATCH] =?UTF-8?q?opengauss-sqlalchemy=E9=80=82=E9=85=8Dbit?= =?UTF-8?q?=E3=80=81vector=E3=80=81sparsevec=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- opengauss_sqlalchemy/usertype/__init__.py | 13 ++ opengauss_sqlalchemy/usertype/bit.py | 26 +++ opengauss_sqlalchemy/usertype/sparsevec.py | 51 ++++++ opengauss_sqlalchemy/usertype/vector.py | 51 ++++++ opengauss_sqlalchemy/utils/__init__.py | 9 + opengauss_sqlalchemy/utils/bit.py | 65 +++++++ opengauss_sqlalchemy/utils/sparsevec.py | 161 +++++++++++++++++ opengauss_sqlalchemy/utils/vector.py | 83 +++++++++ test/test_usertypes.py | 129 ++++++++++++++ test/test_utils.py | 191 +++++++++++++++++++++ 10 files changed, 779 insertions(+) create mode 100644 opengauss_sqlalchemy/usertype/__init__.py create mode 100644 opengauss_sqlalchemy/usertype/bit.py create mode 100644 opengauss_sqlalchemy/usertype/sparsevec.py create mode 100644 opengauss_sqlalchemy/usertype/vector.py create mode 100644 opengauss_sqlalchemy/utils/__init__.py create mode 100644 opengauss_sqlalchemy/utils/bit.py create mode 100644 opengauss_sqlalchemy/utils/sparsevec.py create mode 100644 opengauss_sqlalchemy/utils/vector.py create mode 100644 test/test_usertypes.py create mode 100644 test/test_utils.py diff --git a/opengauss_sqlalchemy/usertype/__init__.py b/opengauss_sqlalchemy/usertype/__init__.py new file mode 100644 index 0000000..feaa60b --- /dev/null +++ b/opengauss_sqlalchemy/usertype/__init__.py @@ -0,0 +1,13 @@ +from .bit import BIT +from .sparsevec import SPARSEVEC +from .vector import VECTOR +from .vector import VECTOR as Vector +from ..utils import SparseVector + +__all__ = [ + 'Vector', + 'VECTOR', + 'BIT', + 'SPARSEVEC', + 'SparseVector' +] \ No newline at end of file diff --git a/opengauss_sqlalchemy/usertype/bit.py b/opengauss_sqlalchemy/usertype/bit.py new file mode 100644 index 0000000..0f83f3c --- /dev/null +++ b/opengauss_sqlalchemy/usertype/bit.py @@ -0,0 +1,26 @@ +from sqlalchemy.dialects.postgresql.base import ischema_names +from sqlalchemy.types import UserDefinedType, Float + + +class BIT(UserDefinedType): + cache_ok = True + + def __init__(self, length=None): + super(UserDefinedType, self).__init__() + self.length = length + + def get_col_spec(self, **kw): + if self.length is None: + return 'BIT' + return 'BIT(%d)' % self.length + + class comparator_factory(UserDefinedType.Comparator): + def hamming_distance(self, other): + return self.op('<~>', return_type=Float)(other) + + def jaccard_distance(self, other): + return self.op('<%>', return_type=Float)(other) + + +# for reflection +ischema_names['bit'] = BIT diff --git a/opengauss_sqlalchemy/usertype/sparsevec.py b/opengauss_sqlalchemy/usertype/sparsevec.py new file mode 100644 index 0000000..370f5d1 --- /dev/null +++ b/opengauss_sqlalchemy/usertype/sparsevec.py @@ -0,0 +1,51 @@ +from sqlalchemy.dialects.postgresql.base import ischema_names +from sqlalchemy.types import UserDefinedType, Float, String +from ..utils import SparseVector + + +class SPARSEVEC(UserDefinedType): + cache_ok = True + _string = String() + + def __init__(self, dim=None): + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw): + if self.dim is None: + return 'SPARSEVEC' + return 'SPARSEVEC(%d)' % self.dim + + def bind_processor(self, dialect): + def process(value): + return SparseVector._to_db(value, self.dim) + return process + + def literal_processor(self, dialect): + string_literal_processor = self._string._cached_literal_processor(dialect) + + def process(value): + return string_literal_processor(SparseVector._to_db(value, self.dim)) + return process + + def result_processor(self, dialect, coltype): + def process(value): + return SparseVector._from_db(value) + return process + + class comparator_factory(UserDefinedType.Comparator): + def l2_distance(self, other): + return self.op('<->', return_type=Float)(other) + + def max_inner_product(self, other): + return self.op('<#>', return_type=Float)(other) + + def cosine_distance(self, other): + return self.op('<=>', return_type=Float)(other) + + def l1_distance(self, other): + return self.op('<+>', return_type=Float)(other) + + +# for reflection +ischema_names['sparsevec'] = SPARSEVEC diff --git a/opengauss_sqlalchemy/usertype/vector.py b/opengauss_sqlalchemy/usertype/vector.py new file mode 100644 index 0000000..f57a045 --- /dev/null +++ b/opengauss_sqlalchemy/usertype/vector.py @@ -0,0 +1,51 @@ +from sqlalchemy.dialects.postgresql.base import ischema_names +from sqlalchemy.types import UserDefinedType, Float, String +from ..utils import Vector + + +class VECTOR(UserDefinedType): + cache_ok = True + _string = String() + + def __init__(self, dim=None): + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw): + if self.dim is None: + return 'VECTOR' + return 'VECTOR(%d)' % self.dim + + def bind_processor(self, dialect): + def process(value): + return Vector._to_db(value, self.dim) + return process + + def literal_processor(self, dialect): + string_literal_processor = self._string._cached_literal_processor(dialect) + + def process(value): + return string_literal_processor(Vector._to_db(value, self.dim)) + return process + + def result_processor(self, dialect, coltype): + def process(value): + return Vector._from_db(value) + return process + + class comparator_factory(UserDefinedType.Comparator): + def l2_distance(self, other): + return self.op('<->', return_type=Float)(other) + + def max_inner_product(self, other): + return self.op('<#>', return_type=Float)(other) + + def cosine_distance(self, other): + return self.op('<=>', return_type=Float)(other) + + def l1_distance(self, other): + return self.op('<+>', return_type=Float)(other) + + +# for reflection +ischema_names['vector'] = VECTOR diff --git a/opengauss_sqlalchemy/utils/__init__.py b/opengauss_sqlalchemy/utils/__init__.py new file mode 100644 index 0000000..4ba1ea1 --- /dev/null +++ b/opengauss_sqlalchemy/utils/__init__.py @@ -0,0 +1,9 @@ +from .bit import Bit +from .sparsevec import SparseVector +from .vector import Vector + +__all__ = [ + 'Vector', + 'Bit', + 'SparseVector' +] \ No newline at end of file diff --git a/opengauss_sqlalchemy/utils/bit.py b/opengauss_sqlalchemy/utils/bit.py new file mode 100644 index 0000000..652c92f --- /dev/null +++ b/opengauss_sqlalchemy/utils/bit.py @@ -0,0 +1,65 @@ +import numpy as np +from struct import pack, unpack_from + + +class Bit: + def __init__(self, value): + if isinstance(value, str): + self._value = self.from_text(value)._value + else: + if isinstance(value, np.ndarray): + if value.dtype == np.uint8: + value = np.unpackbits(value).astype(bool) + elif value.dtype != np.bool_: + raise ValueError('expected dtype to be bool or uint8') + else: + value = np.asarray(value, dtype=bool) + + if value.ndim != 1: + raise ValueError('expected ndim to be 1') + + self._value = value + + def __repr__(self): + return f'Bit({self.to_text()})' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return np.array_equal(self.to_numpy(), other.to_numpy()) + return False + + def to_list(self): + return self._value.tolist() + + def to_numpy(self): + return self._value + + def to_text(self): + return ''.join(self._value.astype(np.uint8).astype(str)) + + def to_binary(self): + return pack('>i', len(self._value)) + np.packbits(self._value).tobytes() + + @classmethod + def from_text(cls, value): + return cls(np.asarray([v != '0' for v in value], dtype=bool)) + + @classmethod + def from_binary(cls, value): + count = unpack_from('>i', value)[0] + buf = np.frombuffer(value, dtype=np.uint8, offset=4) + return cls(np.unpackbits(buf, count=count).astype(bool)) + + @classmethod + def _to_db(cls, value): + if not isinstance(value, cls): + raise ValueError('expected bit') + + return value.to_text() + + @classmethod + def _to_db_binary(cls, value): + if not isinstance(value, cls): + raise ValueError('expected bit') + + return value.to_binary() diff --git a/opengauss_sqlalchemy/utils/sparsevec.py b/opengauss_sqlalchemy/utils/sparsevec.py new file mode 100644 index 0000000..8df2dfd --- /dev/null +++ b/opengauss_sqlalchemy/utils/sparsevec.py @@ -0,0 +1,161 @@ +import numpy as np +from struct import pack, unpack_from + +NO_DEFAULT = object() + + +class SparseVector: + def __init__(self, value, dimensions=NO_DEFAULT, /): + if value.__class__.__module__.startswith('scipy.sparse.'): + if dimensions is not NO_DEFAULT: + raise ValueError('extra argument') + + self._from_sparse(value) + elif isinstance(value, dict): + if dimensions is NO_DEFAULT: + raise ValueError('missing dimensions') + + self._from_dict(value, dimensions) + else: + if dimensions is not NO_DEFAULT: + raise ValueError('extra argument') + + self._from_dense(value) + + def __repr__(self): + elements = dict(zip(self._indices, self._values)) + return f'SparseVector({elements}, {self._dim})' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.dimensions() == other.dimensions() and self.indices() == other.indices() and self.values() == other.values() + return False + + def dimensions(self): + return self._dim + + def indices(self): + return self._indices + + def values(self): + return self._values + + def to_coo(self): + from scipy.sparse import coo_array + + coords = ([0] * len(self._indices), self._indices) + return coo_array((self._values, coords), shape=(1, self._dim)) + + def to_list(self): + vec = [0.0] * self._dim + for i, v in zip(self._indices, self._values): + vec[i] = v + return vec + + def to_numpy(self): + vec = np.repeat(0.0, self._dim).astype(np.float32) + for i, v in zip(self._indices, self._values): + vec[i] = v + return vec + + def to_text(self): + return '{' + ','.join([f'{int(i) + 1}:{float(v)}' for i, v in zip(self._indices, self._values)]) + '}/' + str(int(self._dim)) + + def to_binary(self): + nnz = len(self._indices) + return pack(f'>iii{nnz}i{nnz}f', self._dim, nnz, 0, *self._indices, *self._values) + + def _from_dict(self, d, dim): + elements = [(i, v) for i, v in d.items() if v != 0] + elements.sort() + + self._dim = int(dim) + self._indices = [int(v[0]) for v in elements] + self._values = [float(v[1]) for v in elements] + + def _from_sparse(self, value): + value = value.tocoo() + + if value.ndim == 1: + self._dim = value.shape[0] + elif value.ndim == 2 and value.shape[0] == 1: + self._dim = value.shape[1] + else: + raise ValueError('expected ndim to be 1') + + if hasattr(value, 'coords'): + # scipy 1.13+ + self._indices = value.coords[0].tolist() + else: + self._indices = value.col.tolist() + self._values = value.data.tolist() + + def _from_dense(self, value): + self._dim = len(value) + self._indices = [i for i, v in enumerate(value) if v != 0] + self._values = [float(value[i]) for i in self._indices] + + @classmethod + def from_text(cls, value): + elements, dim = value.split('/', 2) + indices = [] + values = [] + # split on empty string returns single element list + if len(elements) > 2: + for e in elements[1:-1].split(','): + i, v = e.split(':', 2) + indices.append(int(i) - 1) + values.append(float(v)) + return cls._from_parts(int(dim), indices, values) + + @classmethod + def from_binary(cls, value): + dim, nnz, unused = unpack_from('>iii', value) + indices = unpack_from(f'>{nnz}i', value, 12) + values = unpack_from(f'>{nnz}f', value, 12 + nnz * 4) + return cls._from_parts(int(dim), list(indices), list(values)) + + @classmethod + def _from_parts(cls, dim, indices, values): + vec = cls.__new__(cls) + vec._dim = dim + vec._indices = indices + vec._values = values + return vec + + @classmethod + def _to_db(cls, value, dim=None): + if value is None: + return value + + if not isinstance(value, cls): + value = cls(value) + + if dim is not None and value.dimensions() != dim: + raise ValueError('expected %d dimensions, not %d' % (dim, value.dimensions())) + + return value.to_text() + + @classmethod + def _to_db_binary(cls, value): + if value is None: + return value + + if not isinstance(value, cls): + value = cls(value) + + return value.to_binary() + + @classmethod + def _from_db(cls, value): + if value is None or isinstance(value, cls): + return value + + return cls.from_text(value) + + @classmethod + def _from_db_binary(cls, value): + if value is None or isinstance(value, cls): + return value + + return cls.from_binary(value) diff --git a/opengauss_sqlalchemy/utils/vector.py b/opengauss_sqlalchemy/utils/vector.py new file mode 100644 index 0000000..ebbcafd --- /dev/null +++ b/opengauss_sqlalchemy/utils/vector.py @@ -0,0 +1,83 @@ +import numpy as np +from struct import pack, unpack_from + + +class Vector: + def __init__(self, value): + # asarray still copies if same dtype + if not isinstance(value, np.ndarray) or value.dtype != '>f4': + value = np.asarray(value, dtype='>f4') + + if value.ndim != 1: + raise ValueError('expected ndim to be 1') + + self._value = value + + def __repr__(self): + return f'Vector({self.to_list()})' + + def __eq__(self, other): + if isinstance(other, self.__class__): + return np.array_equal(self.to_numpy(), other.to_numpy()) + return False + + def dimensions(self): + return len(self._value) + + def to_list(self): + return self._value.tolist() + + def to_numpy(self): + return self._value + + def to_text(self): + return '[' + ','.join([str(float(v)) for v in self._value]) + ']' + + def to_binary(self): + return pack('>HH', self.dimensions(), 0) + self._value.tobytes() + + @classmethod + def from_text(cls, value): + return cls([float(v) for v in value[1:-1].split(',')]) + + @classmethod + def from_binary(cls, value): + dim, unused = unpack_from('>HH', value) + return cls(np.frombuffer(value, dtype='>f4', count=dim, offset=4)) + + @classmethod + def _to_db(cls, value, dim=None): + if value is None: + return value + + if not isinstance(value, cls): + value = cls(value) + + if dim is not None and value.dimensions() != dim: + raise ValueError('expected %d dimensions, not %d' % (dim, value.dimensions())) + + return value.to_text() + + @classmethod + def _to_db_binary(cls, value): + if value is None: + return value + + if not isinstance(value, cls): + value = cls(value) + + return value.to_binary() + + @classmethod + def _from_db(cls, value): + if value is None or isinstance(value, np.ndarray): + return value + + return cls.from_text(value).to_numpy().astype(np.float32) + + @classmethod + def _from_db_binary(cls, value): + if value is None or isinstance(value, np.ndarray): + return value + + return cls.from_binary(value).to_numpy().astype(np.float32) diff --git a/test/test_usertypes.py b/test/test_usertypes.py new file mode 100644 index 0000000..131a89b --- /dev/null +++ b/test/test_usertypes.py @@ -0,0 +1,129 @@ +from sqlalchemy import Column, Integer, MetaData, Table +from sqlalchemy.sql import select +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertions import AssertsCompiledSQL + +from opengauss_sqlalchemy import dc_psycopg2, psycopg2 +from opengauss_sqlalchemy.utils import Vector, Bit, SparseVector +from opengauss_sqlalchemy.usertype import BIT, SPARSEVEC, VECTOR + +m = MetaData() +tbl = Table( + "test", + m, + Column("id", Integer), + Column("bit_embedding", BIT(3)), + Column("sparsevec_embedding", SPARSEVEC(3)), + Column("vector_embedding", VECTOR(3)) +) + +class TestBit(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = psycopg2.dialect() + + def test_bit_get_col_spec(self): + bit = BIT() + assert bit.get_col_spec() == 'BIT' + bit_with_length = BIT(5) + assert bit_with_length.get_col_spec() == 'BIT(5)' + + def test_bit_distance(self): + hamming_stmt = select(tbl.c.id).order_by(tbl.c.bit_embedding.hamming_distance('110')) + self.assert_compile( + hamming_stmt, + "SELECT test.id FROM test ORDER BY test.bit_embedding <~> %(bit_embedding_1)s", + checkparams = {"bit_embedding_1" : '110'} + ) + + jaccard_stmt = select(tbl.c.id).order_by(tbl.c.bit_embedding.jaccard_distance('110')) + self.assert_compile( + jaccard_stmt, + "SELECT test.id FROM test ORDER BY test.bit_embedding <%%> %(bit_embedding_1)s", + checkparams = {"bit_embedding_1" : '110'} + ) + +class TestSparseVec(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = psycopg2.dialect() + + def test_sparsevec_get_col_spec(self): + sparsevec = SPARSEVEC() + assert sparsevec.get_col_spec() == 'SPARSEVEC' + sparsevec_with_dim = SPARSEVEC(5) + assert sparsevec_with_dim.get_col_spec() == 'SPARSEVEC(5)' + + def test_sparsevec_distance(self): + l2_stmt = select(tbl.c.id).order_by(tbl.c.sparsevec_embedding.l2_distance(SparseVector([1, 2, 3]))) + self.assert_compile( + l2_stmt, + "SELECT test.id FROM test ORDER BY test.sparsevec_embedding <-> %(sparsevec_embedding_1)s", + checkparams = {"sparsevec_embedding_1" : SparseVector([1, 2, 3])} + ) + + max_inner_product_stmt = select(tbl.c.id).order_by(tbl.c.sparsevec_embedding.max_inner_product(SparseVector([1, 2, 3]))) + self.assert_compile( + max_inner_product_stmt, + "SELECT test.id FROM test ORDER BY test.sparsevec_embedding <#> %(sparsevec_embedding_1)s", + checkparams = {"sparsevec_embedding_1" : SparseVector([1, 2, 3])} + ) + + cosine_stmt = select(tbl.c.id).order_by(tbl.c.sparsevec_embedding.cosine_distance(SparseVector([1, 2, 3]))) + self.assert_compile( + cosine_stmt, + "SELECT test.id FROM test ORDER BY test.sparsevec_embedding <=> %(sparsevec_embedding_1)s", + checkparams = {"sparsevec_embedding_1" : SparseVector([1, 2, 3])} + ) + + l1_stmt = select(tbl.c.id).order_by(tbl.c.sparsevec_embedding.l1_distance(SparseVector([1, 2, 3]))) + self.assert_compile( + l1_stmt, + "SELECT test.id FROM test ORDER BY test.sparsevec_embedding <+> %(sparsevec_embedding_1)s", + checkparams = {"sparsevec_embedding_1" : SparseVector([1, 2, 3])} + ) + + def test_sparsevec_literal_binds(self): + sql = select(tbl.c.id).order_by(tbl.c.sparsevec_embedding.l2_distance(SparseVector([1, 2, 3])))\ + .compile(compile_kwargs = {'literal_binds' : True}) + assert "embedding <-> '{1:1.0,2:2.0,3:3.0}/3'" in str(sql) + + +class TestVector(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = psycopg2.dialect() + + def test_vector_get_col_spec(self): + vec = VECTOR() + assert vec.get_col_spec() == 'VECTOR' + vec_with_dim = VECTOR(5) + assert vec_with_dim.get_col_spec() == 'VECTOR(5)' + + def test_vector_distance(self): + l2_stmt = select(tbl.c.id).order_by(tbl.c.vector_embedding.l2_distance([1,2,3])) + self.assert_compile( + l2_stmt, + "SELECT test.id FROM test ORDER BY test.vector_embedding <-> %(vector_embedding_1)s", + checkparams = {"vector_embedding_1" : [1,2,3]} + ) + + max_inner_product_stmt = select(tbl.c.id).order_by(tbl.c.vector_embedding.max_inner_product([1,2,3])) + self.assert_compile( + max_inner_product_stmt, + "SELECT test.id FROM test ORDER BY test.vector_embedding <#> %(vector_embedding_1)s", + checkparams = {"vector_embedding_1" : [1,2,3]} + ) + + cosine_stmt = select(tbl.c.id).order_by(tbl.c.vector_embedding.cosine_distance([1,2,3])) + self.assert_compile( + cosine_stmt, + "SELECT test.id FROM test ORDER BY test.vector_embedding <=> %(vector_embedding_1)s", + checkparams = {"vector_embedding_1" : [1,2,3]} + ) + + l1_stmt = select(tbl.c.id).order_by(tbl.c.vector_embedding.l1_distance([1,2,3])) + self.assert_compile( + l1_stmt, + "SELECT test.id FROM test ORDER BY test.vector_embedding <+> %(vector_embedding_1)s", + checkparams = {"vector_embedding_1" : [1,2,3]} + ) + + def test_vector_literal_binds(self): + sql = select(tbl.c.id).order_by(tbl.c.vector_embedding.l2_distance([1, 2, 3]))\ + .compile(compile_kwargs = {'literal_binds' : True}) + assert "embedding <-> '[1.0,2.0,3.0]'" in str(sql) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..5a031f7 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,191 @@ +import numpy as np +from struct import pack +from scipy.sparse import coo_array +import pytest + +from opengauss_sqlalchemy.utils import Bit, Vector, SparseVector + +class TestBit: + def test_list(self): + assert Bit([True, False, True]).to_list() == [True, False, True] + + def test_tuple(self): + assert Bit((True, False, True)).to_list() == [True, False, True] + + def test_str(self): + assert Bit('101').to_list() == [True, False, True] + + def test_ndarray_uint8(self): + arr = np.array([254, 7, 0], dtype=np.uint8) + assert Bit(arr).to_text() == '111111100000011100000000' + + def test_ndarray_uint16(self): + arr = np.array([254, 7, 0], dtype=np.uint16) + with pytest.raises(ValueError) as error: + Bit(arr) + assert str(error.value) == 'expected dtype to be bool or uint8' + + def test_ndarray_same_object(self): + arr = np.array([True, False, True]) + assert Bit(arr).to_list() == [True, False, True] + assert Bit(arr).to_numpy() is arr + + def test_ndim_two(self): + with pytest.raises(ValueError) as error: + Bit([[True, False], [True, False]]) + assert str(error.value) == 'expected ndim to be 1' + + def test_ndim_zero(self): + with pytest.raises(ValueError) as error: + Bit(True) + assert str(error.value) == 'expected ndim to be 1' + + def test_repr(self): + assert repr(Bit([True, False, True])) == 'Bit(101)' + assert str(Bit([True, False, True])) == 'Bit(101)' + + def test_equality(self): + assert Bit([True, False, True]) == Bit([True, False, True]) + assert Bit([True, False, True]) != Bit([True, False, False]) + +class TestSparseVector: + def test_list(self): + vec = SparseVector([1, 0, 2, 0, 3, 0]) + assert vec.to_list() == [1, 0, 2, 0, 3, 0] + assert np.array_equal(vec.to_numpy(), [1, 0, 2, 0, 3, 0]) + assert vec.indices() == [0, 2, 4] + + def test_list_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector([1, 0, 2, 0, 3, 0], 6) + assert str(error.value) == 'extra argument' + + def test_ndarray(self): + vec = SparseVector(np.array([1, 0, 2, 0, 3, 0])) + assert vec.to_list() == [1, 0, 2, 0, 3, 0] + assert vec.indices() == [0, 2, 4] + + def test_dict(self): + vec = SparseVector({2: 2, 4: 3, 0: 1, 3: 0}, 6) + assert vec.to_list() == [1, 0, 2, 0, 3, 0] + assert vec.indices() == [0, 2, 4] + + def test_dict_no_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector({0: 1, 2: 2, 4: 3}) + assert str(error.value) == 'missing dimensions' + + def test_coo_array(self): + arr = coo_array(np.array([1, 0, 2, 0, 3, 0])) + vec = SparseVector(arr) + assert vec.to_list() == [1, 0, 2, 0, 3, 0] + assert vec.indices() == [0, 2, 4] + + def test_coo_array_dimensions(self): + with pytest.raises(ValueError) as error: + SparseVector(coo_array(np.array([1, 0, 2, 0, 3, 0])), 6) + assert str(error.value) == 'extra argument' + + def test_dok_array(self): + arr = coo_array(np.array([1, 0, 2, 0, 3, 0])).todok() + vec = SparseVector(arr) + assert vec.to_list() == [1, 0, 2, 0, 3, 0] + assert vec.indices() == [0, 2, 4] + + def test_repr(self): + assert repr(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)' + assert str(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)' + + def test_equality(self): + assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector([1, 0, 2, 0, 3, 0]) + assert SparseVector([1, 0, 2, 0, 3, 0]) != SparseVector([1, 0, 2, 0, 3, 1]) + assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector({2: 2, 4: 3, 0: 1, 3: 0}, 6) + assert SparseVector({}, 1) != SparseVector({}, 2) + + def test_dimensions(self): + assert SparseVector([1, 0, 2, 0, 3, 0]).dimensions() == 6 + + def test_indices(self): + assert SparseVector([1, 0, 2, 0, 3, 0]).indices() == [0, 2, 4] + + def test_values(self): + assert SparseVector([1, 0, 2, 0, 3, 0]).values() == [1, 2, 3] + + def test_to_coo(self): + assert np.array_equal(SparseVector([1, 0, 2, 0, 3, 0]).to_coo().toarray(), [[1, 0, 2, 0, 3, 0]]) + + def test_zero_vector_text(self): + vec = SparseVector({}, 3) + assert vec.to_list() == SparseVector.from_text(vec.to_text()).to_list() + + def test_from_text(self): + vec = SparseVector.from_text('{1:1.5,3:2,5:3}/6') + assert vec.dimensions() == 6 + assert vec.indices() == [0, 2, 4] + assert vec.values() == [1.5, 2, 3] + assert vec.to_list() == [1.5, 0, 2, 0, 3, 0] + assert np.array_equal(vec.to_numpy(), [1.5, 0, 2, 0, 3, 0]) + + def test_from_binary(self): + data = pack('>iii3i3f', 6, 3, 0, 0, 2, 4, 1.5, 2, 3) + vec = SparseVector.from_binary(data) + assert vec.dimensions() == 6 + assert vec.indices() == [0, 2, 4] + assert vec.values() == [1.5, 2, 3] + assert vec.to_list() == [1.5, 0, 2, 0, 3, 0] + assert np.array_equal(vec.to_numpy(), [1.5, 0, 2, 0, 3, 0]) + assert vec.to_binary() == data + +class TestVector: + def test_list(self): + assert Vector([1, 2, 3]).to_list() == [1, 2, 3] + + def test_list_str(self): + with pytest.raises(ValueError, match='could not convert string to float'): + Vector([1, 'two', 3]) + + def test_tuple(self): + assert Vector((1, 2, 3)).to_list() == [1, 2, 3] + + def test_ndarray(self): + arr = np.array([1, 2, 3]) + assert Vector(arr).to_list() == [1, 2, 3] + assert Vector(arr).to_numpy() is not arr + + def test_ndarray_same_object(self): + arr = np.array([1, 2, 3], dtype='>f4') + assert Vector(arr).to_list() == [1, 2, 3] + assert Vector(arr).to_numpy() is arr + + def test_ndim_two(self): + with pytest.raises(ValueError) as error: + Vector([[1, 2], [3, 4]]) + assert str(error.value) == 'expected ndim to be 1' + + def test_ndim_zero(self): + with pytest.raises(ValueError) as error: + Vector(1) + assert str(error.value) == 'expected ndim to be 1' + + def test_repr(self): + assert repr(Vector([1, 2, 3])) == 'Vector([1.0, 2.0, 3.0])' + assert str(Vector([1, 2, 3])) == 'Vector([1.0, 2.0, 3.0])' + + def test_equality(self): + assert Vector([1, 2, 3]) == Vector([1, 2, 3]) + assert Vector([1, 2, 3]) != Vector([1, 2, 4]) + + def test_dimensions(self): + assert Vector([1, 2, 3]).dimensions() == 3 + + def test_from_text(self): + vec = Vector.from_text('[1.5,2,3]') + assert vec.to_list() == [1.5, 2, 3] + assert np.array_equal(vec.to_numpy(), [1.5, 2, 3]) + + def test_from_binary(self): + data = pack('>HH3f', 3, 0, 1.5, 2, 3) + vec = Vector.from_binary(data) + assert vec.to_list() == [1.5, 2, 3] + assert np.array_equal(vec.to_numpy(), [1.5, 2, 3]) + assert vec.to_binary() == data \ No newline at end of file -- Gitee