问题描述
我正在尝试解决一个非常简单的任务(我认为是这样),即在 TPU 上的自定义层中复制张量。
我的输入是 2 个形状为 A=(BS,H,n,C) 和 B = (BS,W,C) 的张量,其中 n 在我的情况下可以是 (1,3,5,7),但可能也适用于其他数字。
我的任务是重复两个张量 A 和 B 以形成 (BS,C) 并将它们相加作为输出。如果 H(或 W)总是被 n 整除会很容易,但它们不是。因此,A 的每个切片 (BS,1,C) 的重复次数会有所不同。因此,使用以下伪代码计算输出:
for i in range(W):
A1[BS,i,C] = A[BS,floor(n*i/W),C]
我尝试以多种方式实现它:
class StripPoolingCombine(tf.keras.layers.Layer):
def __init__(self,n=1):
super(StripPoolingCombine,self).__init__()
self.n = n
def call(self,v,h,training=False):
H,W = v.shape[1],h.shape[2]
v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]
v = tf.repeat(v,repeats=v_repeats,axis=2)
h = tf.repeat(h,repeats=h_repeats,axis=1)
return Add()([v,h])
或者将 unique_with_counts
替换为以下逻辑:
tf.math.bincount(tf.cast(tf.math.floor(tf.range(W) * self.n / W),dtype=tf.int32)
- 使用临时公式:
f = tf.cast(tf.math.ceil(W / self.n),dtype=tf.int32)
s = tf.cast(tf.math.floor(W / self.n),dtype=tf.int32)
b = tf.cast(f!=s,dtype=tf.int32)
r = W - f - s * (self.n - 1)
x1 = s * tf.ones(self.n-1,dtype=tf.int32)
x2 = (1 - tf.range(r*2) % 2) * b
x2 = tf.pad(x2,paddings=[[0,self.n-r*2-1]])
x3 = tf.concat([[f],tf.add(x1,x2)],axis=0)
但是正如在 Available TensorFlow Ops 中可以看到的 TPU,它不支持动态 tf.range
、tf.unique_with_counts
或 tf.math.bincount
,并且我的实现在构建时都会导致错误一个模型并调用 model.fit()
或 model.predict()
。然而我仍然希望 tensorflow 提供了一些方法来以适合我的任务的方式处理动态形状,并且我不会为这样一个微不足道的问题重写整个 Ops 模块。请帮忙!
完全可重现的示例(使用 Colab TPU):
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Add
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print(f'Running on TPU: {tpu.master()}')
except ValueError:
print('Could not connect to TPU')
tpu = None
if tpu:
try:
print('Initializing TPU...')
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.TPUStrategy(tpu)
print('TPU initialized!')
except Exception:
print('Failed to initialize TPU')
# class StripPoolingCombine(tf.keras.layers.Layer):
# def __init__(self,n=1):
# super(StripPoolingCombine,self).__init__()
# self.n = n
# def call(self,training=False):
# H,h.shape[2]
# v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
# h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]
# v = tf.repeat(v,axis=2)
# h = tf.repeat(h,axis=1)
# return Add()([v,h])
class StripPoolingCombine(tf.keras.layers.Layer):
def __init__(self,W = tf.shape(v)[1],tf.shape(h)[2]
f = tf.cast(tf.math.ceil(W / self.n),dtype=tf.int32)
s = tf.cast(tf.math.floor(W / self.n),dtype=tf.int32)
b = tf.cast(f!=s,dtype=tf.int32)
r = W - f - s * (self.n - 1)
x1 = s * tf.ones(self.n-1,dtype=tf.int32)
x2 = (1 - tf.range(r*2) % 2) * b
x2 = tf.pad(x2,self.n-r*2-1]])
x3 = tf.concat([[f],axis=0)
v = tf.repeat(v,repeats=x3,axis=1)
output = tf.add(v,h)
return output
def build_model(n=7):
v = Input(shape=(256,3))
h = Input(shape=(n,256,3))
outputs = StripPoolingCombine()(v,h)
model = Model(inputs=[v,h],outputs=outputs)
return model
tf.keras.backend.clear_session()
with strategy.scope():
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4,beta_1=0.9,beta_2=0.999)
model = build_model()
model.compile(optimizer=optimizer,loss='mean_squared_error')
rng_1 = tf.random.uniform([1,7,3])
rng_2 = tf.random.uniform([1,3])
model.predict([rng_1,rng_2])
解决方法
使用tf.gather
:
def call(self,v,h,training=False):
def out(A,H,axis):
r = tf.range(H)
inds = tf.floor(self.n * r / H)
inds = tf.cast(inds,tf.int32)
return tf.gather(A,inds,axis=axis)
H,W = tf.shape(v)[1],tf.shape(h)[2]
v = out(v,W,2)
h = out(h,1)
output = tf.add(v,h)
return output