TensorFlow 维度变换

本文最后更新于:1 年前

Reference


reshape

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
arr = tf.range(60)
t = tf.reshape(arr, [3, 4, 5])
t1 = tf.reshape(t, [3, 4 * 5])
t2 = tf.reshape(t, [3, -1])
t3 = tf.reshape(t, [-1])
print(t) # tf.Tensor(
# [[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]], shape=(3, 4, 5), dtype=int32)
print(t1) # tf.Tensor(
# [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
# [40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]], shape=(3, 20), dtype=int32)
print(t2) # tf.Tensor(
# [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
# [40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]], shape=(3, 20), dtype=int32)
print(t3) # tf.Tensor(
# [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
# 48 49 50 51 52 53 54 55 56 57 58 59], shape=(60,), dtype=int32)

transpose

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
arr = tf.range(60)
t1 = tf.reshape(arr, [3, -1])
t2 = tf.reshape(arr, [3, 4, 5])
t11 = tf.transpose(t1)
t21 = tf.transpose(t2)
t22 = tf.transpose(t2, perm=[0, 2, 1])
print(t1) # tf.Tensor(
# [[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19]
# [20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
# [40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59]], shape=(3, 20), dtype=int32)
print(t11) # tf.Tensor(
# [[ 0 20 40]
# [ 1 21 41]
# [ 2 22 42]
# [ 3 23 43]
# [ 4 24 44]
# [ 5 25 45]
# [ 6 26 46]
# [ 7 27 47]
# [ 8 28 48]
# [ 9 29 49]
# [10 30 50]
# [11 31 51]
# [12 32 52]
# [13 33 53]
# [14 34 54]
# [15 35 55]
# [16 36 56]
# [17 37 57]
# [18 38 58]
# [19 39 59]], shape=(20, 3), dtype=int32)
print(t2) # tf.Tensor(
# [[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]], shape=(3, 4, 5), dtype=int32)
print(t21) # tf.Tensor(
# [[[ 0 20 40]
# [ 5 25 45]
# [10 30 50]
# [15 35 55]]
#
# [[ 1 21 41]
# [ 6 26 46]
# [11 31 51]
# [16 36 56]]
#
# [[ 2 22 42]
# [ 7 27 47]
# [12 32 52]
# [17 37 57]]
#
# [[ 3 23 43]
# [ 8 28 48]
# [13 33 53]
# [18 38 58]]
#
# [[ 4 24 44]
# [ 9 29 49]
# [14 34 54]
# [19 39 59]]], shape=(5, 4, 3), dtype=int32)
print(t22) # tf.Tensor(
# [[[ 0 5 10 15]
# [ 1 6 11 16]
# [ 2 7 12 17]
# [ 3 8 13 18]
# [ 4 9 14 19]]
#
# [[20 25 30 35]
# [21 26 31 36]
# [22 27 32 37]
# [23 28 33 38]
# [24 29 34 39]]
#
# [[40 45 50 55]
# [41 46 51 56]
# [42 47 52 57]
# [43 48 53 58]
# [44 49 54 59]]], shape=(3, 5, 4), dtype=int32)

expand_dims

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
arr = tf.range(60)
t = tf.reshape(arr, [3, 4, 5])
t1 = tf.expand_dims(t, axis=0)
print(t) # tf.Tensor(
# [[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]], shape=(3, 4, 5), dtype=int32)
print(t1) # tf.Tensor(
# [[[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]]], shape=(1, 3, 4, 5), dtype=int32)

squeeze

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
arr = tf.range(60)
t = tf.reshape(arr, [1, 3, 4, 5])
t1 = tf.squeeze(t, axis=0)
print(t) # tf.Tensor(
# [[[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]]], shape=(1, 3, 4, 5), dtype=int32)
print(t1) # tf.Tensor(
# [[[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]]
#
# [[20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]]
#
# [[40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]]], shape=(3, 4, 5), dtype=int32)


TensorFlow 维度变换
https://weichao.io/cf5f694f697b/
作者
魏超
发布于
2019年5月13日
更新于
2022年12月4日
许可协议