AllenNLP DatasetReader:只加载单个实例,而不是迭代训练数据集中的所有实例

问题描述

我正在使用 AllenNLP 来训练分层注意力网络模型。我的训练数据集包含一个 JSON 对象 列表(例如,列表中的每个对象都是一个 JSON 对象,键为 := ["text","label"]。与文本键关联的值是一个列表列表,例如:

[{"text":[["i","feel","sad"],["not","sure","i","guess","the","weather"]],"label":0} ... {"text":[[str]],"label":int}] 

我的 DatasetReader 类看起来像:

@DatasetReader.register("my_reader")
class TranscriptDataReader(DatasetReader):
    def __init__(self,token_indexers: Optional[Dict[str,TokenIndexer]] = None,lazy: bool = True) -> None:
        super().__init__(lazy)
        self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}

    def _read(self,file_path: str) -> Iterator[Instance]:
        with open(file_path,'r') as f:
            data = json.loads(f.read())
            for _,data_json in enumerate(data):
                sent_list = []
                for segment in data_json["text"]:
                    sent_list.append(self.get_text_field(segment))
                yield self.create_instance(sent_list,str(data_json["label"]))

    def get_text_field(self,segment):
        return TextField([Token(token.lower()) for token in segment],self._token_indexers)


    def create_instance(self,sent_list,label):
        label_field = LabelField(label,skip_indexing=False)
        fields = {'tokens': ListField(sent_list),'label': label_field}
        return Instance(fields)

在我的配置文件中,我有

{
  dataset_reader: {
    type: 'my_reader',},train_data_path: 'data/train.json',validation_data_path: 'data/dev.json',data_loader: {
    batch_sampler: {
      type: 'bucket',batch_size: 10
    }
 },

我已尝试(或者)将数据集读取器的 lazy 参数设置为 TrueFalse

  • 当设置为 True 时,模型能够进行训练,但是,当我的数据集包含 ~100 时,我观察到实际上只有一列火车和一个开发实例被加载。
  • 设置为 False 时,我已将 yield 中的 _read修改return;然而,这会导致基本词汇类中的类型错误。我还尝试将 yield 设置为 False 时保持原样;在这种情况下,根本没有加载任何实例,并且由于实例集是空的,词汇表不会被实例化,并且嵌入类会抛出错误

希望得到指点和/或调试技巧。

解决方法

如果您使用 allennlp>=v2.0.0lazy 构造函数中的 DatasetReader 参数已弃用。因此,您的 super().__init__(lazy) 将被解释为新的构造函数参数 max_instances,即相当于 max_instances=Truemax_instances=1

,

您能否打印并告诉我们读取 json 文件后加载了多少个实例(为了清楚起见,在下面添加了打印命令)

def _read(self,file_path: str) -> Iterator[Instance]:
        with open(file_path,'r') as f:
            data = json.loads(f.read())
            print(len(data))
            for _,data_json in enumerate(data):
               sent_list = []
                for segment in data_json["text"]:
                    sent_list.append(self.get_text_field(segment))
                yield self.create_instance(sent_list,str(data_json["label"]))