diff --git a/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py b/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py index 9e47c6e8cdce25ceda2eb4e7c16bb50579b7fbdd..efa6867c5a7bacf30d3959e51ad63c6667f8cf31 100644 --- a/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py +++ b/tests/st/test_ut/test_dataset/test_dataloader/test_adgen_dataloader.py @@ -16,9 +16,9 @@ import os import unittest import tempfile -from mindformers.dataset.dataloader.adgen_dataloader import ADGenDataset, ADGenDataLoader import pytest from tests.st.test_ut.test_dataset.get_test_data import get_adgen_data +from mindformers.dataset.dataloader.adgen_dataloader import ADGenDataset, ADGenDataLoader class TestAdgenDataloader(unittest.TestCase): @@ -26,13 +26,17 @@ class TestAdgenDataloader(unittest.TestCase): @classmethod def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with cls.path = cls.temp_dir.name cls.phase = "train" cls.data_path = os.path.join(cls.path, f"{cls.phase}.json") cls.columns = ["content", "summary"] get_adgen_data(cls.path) + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -45,8 +49,8 @@ class TestAdgenDataloader(unittest.TestCase): dataloader = dataloader.batch(1) for item in dataloader: assert len(item) == 2 - assert item[0].asnumpy()[0] == "类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤" - assert item[1].asnumpy()[0] == "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。" + assert item[0][0] == "类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤" + assert item[1][0] == "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。" break @pytest.mark.level1 @@ -84,9 +88,13 @@ class TestAdgenDataSet(unittest.TestCase): @classmethod def setUpClass(cls): - cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with cls.path = cls.temp_dir.name + @classmethod + def tearDownClass(cls): + cls.temp_dir.cleanup() + @pytest.mark.level1 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard