如何找出模型本身使用的VRAM数量? LSTM

问题描述

如何找出此模型的VRAM使用情况? (这与训练数据无关,而是模型及其权重已加载到VRAM中

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,LSTM,Batchnormalization
import tensorflow as tf


model = Sequential()
model.add(LSTM(700,input_shape=(10000,5000,20),return_sequences=True))
model.add(Dropout(0.2))
model.add(Batchnormalization())


model.add(LSTM(700,return_sequences=True))
model.add(Dropout(0.2))
model.add(Batchnormalization())


model.add(LSTM(700))
model.add(Dropout(0.2))
model.add(Batchnormalization())


model.add(Dense(64,activation='sigmoid'))
model.add(Dropout(0.2))


model.add(Dense(32,activation='sigmoid'))
model.add(Dropout(0.2))

解决方法

您可以根据参数数量和参数数据类型对其进行近似。

model.summary()

会告诉您您有多少个参数,然后乘以大小即可。例如,如果发现有5000个参数并且它们是32b浮点,则模型至少为5000 * 32/8 = 1250字节。

自然地,会有更多的东西进入其中,但这通常是获得相当接近的下限的一种简便方法。您还可以保存模型,并查看保存的文件有多大,其中还包括模型结构。

tf.saved_model.save(model,path)

要小心,因为如果您的模型很大,则训练它甚至会很大,因为它还将存储所有可学习参数的梯度。因此,您可以轻松想象1GB模型可能需要5-10GB的训练量。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...