代码拉取完成,页面将自动刷新
import numpy as np
import os
from random import shuffle
DATA_PATH = "data/prepared_data/"
class Data():
def __init__(self):
self.X_counter = 0
self.file_counter = 0
self.files = os.listdir(DATA_PATH)
self.files = [file for file in self.files if '.npy' in file]
shuffle(self.files)
self._load_data()
def _load_data(self):
datas = np.load(os.path.join(DATA_PATH, self.files[self.file_counter]))
self.X = []
for data in datas:
self.X.append(data)
shuffle(self.X)
self.X = np.asarray(self.X)
self.file_counter += 1
def get_data(self, batch_size):
if self.X_counter >= len(self.X):
if self.file_counter > len(self.files) - 1:
print("Data exhausted, Re Initialize")
self.__init__()
return None
else:
self._load_data()
self.X_counter = 0
if self.X_counter + batch_size <= len(self.X):
remaining = len(self.X) - (self.X_counter)
X = self.X[self.X_counter: self.X_counter + batch_size]
else:
X = self.X[self.X_counter: ]
self.X_counter += batch_size
return X
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。