最近在做bert文本分类,有一个生成器,记录一下使用,跟我网上查到的不太一样,主要在.iter()这个地方,很多代码都是没有这个,不知道是不是版本原因
另外需要注意,自定义的生成器需要注意什么时候结束,不然会一直产生数据
datalist, labellist = get_data_from_excel(r'data/test.xlsx')
data = data_generator(datalist).__iter__() # 注意这个.__iter__()
# 获取一批数据
print(next(data))
# 或者
for x in data:
print(x)
点击查看代码
class data_generator:
"""
data_generator只是一种为了节约内存的数据方式
"""
def __init__(self, data, batch_size=Batch_size, shuffle=True):
"""
:param data: 训练的文本列表
:param batch_size: 每次训练的个数
:param shuffle: 文本是否打乱
"""
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.steps = len(self.data) // self.batch_size
if len(self.data) % self.batch_size != 0:
self.steps += 1
def __len__(self):
return self.steps
def __iter__(self):
while True:
idxs = list(range(len(self.data))) # 生成一个序列
if self.shuffle:
np.random.shuffle(idxs) # 打乱序列
X1, X2, Y = [], [], []
for i in idxs:
d = self.data[i]
text = d[0][:maxlen]
x1, x2 = tokenizer.encode(first=text) # 添加[CLS]和[SEP]
y = d[1]
X1.append(x1)
X2.append(x2)
Y.append([y])
if len(X1) == self.batch_size or i == idxs[-1]:
# 对一批数据(最后一批不满batch_size)进行padding
X1 = seq_padding(X1) # 内部转为了np.array
X2 = seq_padding(X2)
Y = seq_padding(Y)
yield [X1, X2], Y[:, 0, :]
[X1, X2, Y] = [], [], []
原文链接: https://www.cnblogs.com/lxzbky/p/16345205.html
欢迎关注
微信关注下方公众号,第一时间获取干货硬货;公众号内回复【pdf】免费获取数百本计算机经典书籍;
也有高质量的技术群,里面有嵌入式、搜广推等BAT大佬
原创文章受到原创版权保护。转载请注明出处:https://www.ccppcoding.com/archives/399904
非原创文章文中已经注明原地址,如有侵权,联系删除
关注公众号【高性能架构探索】,第一时间获取最新文章
转载文章受原作者版权保护。转载请注明原作者出处!