model.fit_generator中的steps_per_epoch实际在做什么?

问题描述

在阅读steps_per_epoch方法中有关model.fit_generator必需参数的Keras文档之后,我对它的理解是:

如果数据集包含“ N”个样本,并且生成器函数(传递给Keras)对于每个调用均返回“ B = batch_size”个样本数(此处,我认为调用是生成器函数的单个收益),并且由于steps_per_epoch = ceil(N / B),生成器被调用steps_per_epoch次,因此完整数据集将在一个时期后通过模型,并且对于每个时期都重复相同的过程,直到训练完成。>

为了测试我的理解是否正确,我实施了以下

import numpy as np
from keras.models import Sequential
from keras.layers import Dense

index = 0

def get_values(inputs,targets):
    i = 0
    while True:
        yield inputs[i],targets[i]
        i += 1
        if i >= len(inputs):
            i = 0


def get_batch(inputs,targets,batch_size=2):
    global index
    batch_X = []
    batch_Y = []
    for inp,targ in get_values(inputs,targets):
        batch_X.append(inp)
        batch_Y.append(targ)

        if len(batch_X) >= batch_size:
            yield np.array(batch_X),np.array(batch_Y)
            index += 1
            batch_X = []
            batch_Y = []


data = list(range(10))
labels = [2*val for val in range(10)]

model = Sequential([
    Dense(16,activation='relu',input_shape=(1,)),Dense(1)
])

model.compile(optimizer='rmsprop',loss='mean_squared_error')
model.fit_generator(get_batch(data,labels,batch_size=2),steps_per_epoch=5,epochs=1,verbose=False)

print(index) # Should Print 5 but it prints 15

程序不难理解...

但是根据我的解释,它应该打印5但打印15。我对steps_per_epoch的解释是否错误? 如果是这样,请给我正确解释steps_per_epoch

PS。我是Tensorflow和Keras的新手,谢谢。

解决方法

没有检查您的代码,但是您的原始密码是正确的。实际上,根据here中的文档,您可以每个时期省略步骤,并且model.fit将数据集的长度(N)除以批处理大小来确定步骤。我确实复制并运行了您的代码。猜猜它将索引打印为5的是什么。我唯一能想到的可能就是导入。

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...