微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

tensorflow 常用函数整理

tensorflow常用函数整理


记录自己使用过的tensorflow常用函数,以便以后查询

1 tf.reshape

tf.reshape(
    tensor,
    shape,
    name=None
)

将给定tensor的维度转换为shape。

arr = tf.Variable([[1,2], [3,4], [5,6]])
arr = tf.reshape(arr, [2, 3])
 [[1 2 3]
 [4 5 6]]

当缺失的维度为-1时,会根据给定的维度自动计算缺失的维度; 但是缺失的维度只能有一个

arr = tf.reshape(arr, [3, -1])
[[1 2]
 [3 4]
 [5 6]]

2 tf.concat()

tf.concat(
    values,
    axis,
    name='concat'
)

将输入的张量数据沿着axis维度连接,如果输入数据的维度分别为(2,3), (2,3), axis=0时将第0维的2和2加起来,第1维的两个3不变,连接起来的tensor的shape为(4, 3), 同理axis=1时连接起来的shape为(2,6)。(python中维度的索引从0开始计算,正axis取值范围为[0, rank(values) ,这里也就是[0, 2))

t1 = tf.Variable([[1, 2, 3], [4, 5, 6]])
t2 = tf.Variable([[7, 8, 9], [10, 11, 12]])
t3 = tf.concat([t1, t2], 0)
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
t4 = tf.concat([t1, t2], 1)
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]

在Python中,axis可以为负值,解释为从rank的末尾开始计数,即
axis + rank(values)

t5 = tf.concat([t1, t2], -1)
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]

3.tf.expand_dims

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

这个操作可以用来给单个元素添加batch维度,例如,如果你有一张维度为[height, width, channels]的图片,可以用expand_dims(image, 0)使它成为batch为1的图片,shape将变为
[1, height, width, channels]。

t1 = tf.Variable([[1, 2, 3], [4, 5, 6]])
t_expand = tf.expand_dims(t1, 0)
[[[1 2 3]
  [4 5 6]]]
shape:(1, 2, 3)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐