问题描述
我正在使用 AllenNLP 来训练分层注意力网络模型。我的训练数据集包含一个 JSON 对象 列表(例如,列表中的每个对象都是一个 JSON 对象,键为 := ["text","label"]。与文本键关联的值是一个列表列表,例如:
[{"text":[["i","feel","sad"],["not","sure","i","guess","the","weather"]],"label":0} ... {"text":[[str]],"label":int}]
我的 DatasetReader 类看起来像:
@DatasetReader.register("my_reader")
class TranscriptDataReader(DatasetReader):
def __init__(self,token_indexers: Optional[Dict[str,TokenIndexer]] = None,lazy: bool = True) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
def _read(self,file_path: str) -> Iterator[Instance]:
with open(file_path,'r') as f:
data = json.loads(f.read())
for _,data_json in enumerate(data):
sent_list = []
for segment in data_json["text"]:
sent_list.append(self.get_text_field(segment))
yield self.create_instance(sent_list,str(data_json["label"]))
def get_text_field(self,segment):
return TextField([Token(token.lower()) for token in segment],self._token_indexers)
def create_instance(self,sent_list,label):
label_field = LabelField(label,skip_indexing=False)
fields = {'tokens': ListField(sent_list),'label': label_field}
return Instance(fields)
{
dataset_reader: {
type: 'my_reader',},train_data_path: 'data/train.json',validation_data_path: 'data/dev.json',data_loader: {
batch_sampler: {
type: 'bucket',batch_size: 10
}
},
我已尝试(或者)将数据集读取器的 lazy
参数设置为 True
和 False
。
- 当设置为
True
时,模型能够进行训练,但是,当我的数据集包含 ~100 时,我观察到实际上只有一列火车和一个开发实例被加载。 - 设置为
False
时,我已将yield
中的_read
行修改为return
;然而,这会导致基本词汇类中的类型错误。我还尝试将yield
设置为False
时保持原样;在这种情况下,根本没有加载任何实例,并且由于实例集是空的,词汇表不会被实例化,并且嵌入类会抛出错误。
希望得到指点和/或调试技巧。
解决方法
如果您使用 allennlp>=v2.0.0
,lazy
构造函数中的 DatasetReader
参数已弃用。因此,您的 super().__init__(lazy)
将被解释为新的构造函数参数 max_instances
,即相当于 max_instances=True
的 max_instances=1
。
您能否打印并告诉我们读取 json 文件后加载了多少个实例(为了清楚起见,在下面添加了打印命令)
def _read(self,file_path: str) -> Iterator[Instance]:
with open(file_path,'r') as f:
data = json.loads(f.read())
print(len(data))
for _,data_json in enumerate(data):
sent_list = []
for segment in data_json["text"]:
sent_list.append(self.get_text_field(segment))
yield self.create_instance(sent_list,str(data_json["label"]))