我想强制转换或定义一个符号列表变量,以便可以使用符号索引来索引Node对象列表(见下文)。我对此感兴趣,以便在图算法中定义一个成本函数,在该算法中对节点本身进行采样。考虑到索引是张量,我不确定如何将节点对象的python列表转换为符号对象以执行此索引。我可以索引到张量字符串列表,但不能索引到自定义对象列表。我收到TypeError:无法将类型
import tensorflow as tf
import numpy as np
class Node:
def __init__(self,id=None,data=None):
self.id = id
self.data = data
self.left = None
self.right = None
v1 = Node(id='A')
v2 = Node(id='B')
v3 = Node(id='C')
node_list = [v1,v2,v3]
node_names = tf.Variable(['node_A','node_B','node_C'],name='VertexNames')
v_idx = tf.Variable([2,1],dtype=tf.int32)
indexed_node_names = tf.gather(node_names,v_idx)
# This works
indexed_node_names
# This however,returns an error:
indexed_nodes = tf.gather(node_list,v_idx)