将 npz jax 权重转换为 keras h5 权重

问题描述

有没有办法将 jax npz 预先训练好的权重转换为 kers/tf.keras h5 格式的权重?

在网上找不到任何东西。

谢谢

解决方法

npz 格式转换为 h5 格式的最直接方法是将数据加载到内存中,然后重写。

这是一个简单的例子:

import jax.numpy as jnp
from jax import random
import h5py

# Create some random weights
key = random.PRNGKey(1701)
weights = random.normal(key,shape=(100,))

# Save to an npz file
jnp.savez('weights.npz',weights=weights)

# Load the npz and convert to h5
data = jnp.load('weights.npz')
with h5py.File('weights.h5','w') as hf:
    hf.create_dataset('weights',data=data['weights'])

请注意,此操作的详细信息将取决于 npz 文件的内容以及生成的 h5 文件所需的结构。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...