问题描述
我正在尝试将 csv(有 10 列)转换为 vavepal wabbit 输入格式 txt 文件。有些 csv 列具有整数值,有些具有字符串(例如:com.12346.xyz)。例如,如果我的 csv 列如下所示:
loss weight SSD_id weight label imp feat_val
0.693147 0.693147 1 1.0 -1.0000 0.0000 com.12346.xyz
0.419189 0.145231 2 2.0 1.0000 -1.8559 com.12346.xyz
0.235457 0.051725 4 4.0 -1.0000 -2.7588 com.12356.xyz
6.371911 12.508365 8 8.0 -1.0000 -3.7784 com.12346.xyz
3.485084 0.598258 16 16.0 1.0000 -2.2767 com.12346.xyz
1.765249 0.045413 32 32.0 -1.0000 -2.8924 com.1236.xyz
1.017911 0.270573 64 64.0 -1.0000 -3.0438 com.12236.xyz
0.611419 0.204927 128 128.0 1.0000 -3.1539 com.16746.xyz
0.469127 0.326834 256 256.0 -1.0000 -1.6101 com.1946.xyz
0.403473 0.337820 512 512.0 1.0000 -2.8843 com.126.xyz
0.337348 0.271222 1024 1024.0 -1.0000 -2.5209 com.1346.xyz
0.328909 0.320471 2048 2048.0 1.0000 -2.0732 com.1234.xyz
0.309401 0.289892 4096 4096.0 1.0000 -2.7639 com.12396.xyz
和vovpal wabbit 输入格式如下所示:
label weight |i imp SSD_id loss |c feat_val
并且里面的vovepal wabbit txt文件值应该是:
-1 0.051725 |i imp:-2.7588 SSD_id:4 loss:0.235457 |c feat_val=com.12356.xyz
1 0.598258 |i imp:-2.7588 SSD_id:4 loss:3.485034 |c feat_val=com.12346.xyz
... 等等...对于所有行值。我想将 csv 文件中的大量行转换为上述格式并将它们全部保存在单个 txt 文件中。我已经开始使用下面给出的这个小函数:
def to_new_format(document,label=None):
return str(label or '') + ' |i ' + ' '.join(re.findall('\w{3,}',document.lower())) + '\n'
to_new_format(str(text_train[1])
但是在对数据框、csv 格式和尝试功能进行了多次试验后,我现在完全迷失了。有人可以给我一些指导,我可以如何以最少的代码行实现这一目标。
解决方法
这比看起来更简单,因为 Pandas 可以通过一些方便的方式让您像处理 Python 中的单个值一样处理序列。
首先,我们将导入您的 CSV 文件,将所有值视为字符串以简化格式化:
import pandas as pd
df = pd.read_csv('test_data.txt',dtype=pd.StringDtype())
您的 label
列在您的文件中记录为 1.0000
,但您不希望输出中包含小数点或零。我们可以使用 Pandas 的 str.replace
方法解决这个问题。
df.label = df.label.str.replace('.0000','',regex=False)
这就是神奇的部分:我们可以将它们连接起来,就像它们是单独的字符串一样!
formatted = (
df.label + ' ' + df.weight +
' |i imp:' + df.imp +
' SSD_id: ' + df.SSD_id +
' loss:' + df.loss +
' |c feat_val=' + df.feat_val +
'\n'
)
该代码看起来会创建一个字符串,但由于它包含数据帧的列(每个列都是 Pandas 序列),结果也是一个序列:
print(formatted)
0 -1 0.693147 |i imp:0.0000 SSD_id: 1 loss:0.693...
1 1 0.145231 |i imp:-1.8559 SSD_id: 2 loss:0.419...
2 -1 0.051725 |i imp:-2.7588 SSD_id: 4 loss:0.23...
3 -1 12.508365 |i imp:-3.7784 SSD_id: 8 loss:6.3...
4 1 0.598258 |i imp:-2.2767 SSD_id: 16 loss:3.48...
5 -1 0.045413 |i imp:-2.8924 SSD_id: 32 loss:1.7...
6 -1 0.270573 |i imp:-3.0438 SSD_id: 64 loss:1.0...
7 1 0.204927 |i imp:-3.1539 SSD_id: 128 loss:0.6...
8 -1 0.326834 |i imp:-1.6101 SSD_id: 256 loss:0....
9 1 0.337820 |i imp:-2.8843 SSD_id: 512 loss:0.4...
10 -1 0.271222 |i imp:-2.5209 SSD_id: 1024 loss:0...
11 1 0.320471 |i imp:-2.0732 SSD_id: 2048 loss:0....
12 1 0.289892 |i imp:-2.7639 SSD_id: 4096 loss:0....
像这样打印时,每一行都会被截断,但它都在那里。例如:
print(formatted[0])
-1 0.693147 |i imp:0.0000 SSD_id: 1 loss:0.693147 |c feat_val=com.12346.xyz
剩下的就是将它保存到一个文件中:
with open('out.txt','w') as f:
f.writelines(formatted)