TensorFlow broadcasting 发表于 2019-05-14 | 字数统计 674 字 | 阅读时长 4 分钟 Reference 深度学习与 TensorFlow 2 入门实战 TensorFlow-2.x-Tutorials broadcasting123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293arr = tf.range(60)arr1 = tf.reshape(arr, [3, 4, 5])t = tf.cast(arr1, dtype=tf.float32)t1 = tf.ones(5)t2 = tf.ones([4, 1])t3 = tf.ones([3, 4, 1])t11 = t + t1t21 = t + t2t31 = t + t3print(t) # tf.Tensor( # [[[ 0. 1. 2. 3. 4.] # [ 5. 6. 7. 8. 9.] # [10. 11. 12. 13. 14.] # [15. 16. 17. 18. 19.]] # # [[20. 21. 22. 23. 24.] # [25. 26. 27. 28. 29.] # [30. 31. 32. 33. 34.] # [35. 36. 37. 38. 39.]] # # [[40. 41. 42. 43. 44.] # [45. 46. 47. 48. 49.] # [50. 51. 52. 53. 54.] # [55. 56. 57. 58. 59.]]], shape=(3, 4, 5), dtype=float32)print(t1) # tf.Tensor([1. 1. 1. 1. 1.], shape=(5,), dtype=float32)print(t2) # tf.Tensor( # [[1.] # [1.] # [1.] # [1.]], shape=(4, 1), dtype=float32)print(t3) # tf.Tensor( # [[[1.] # [1.] # [1.] # [1.]] # # [[1.] # [1.] # [1.] # [1.]] # # [[1.] # [1.] # [1.] # [1.]]], shape=(3, 4, 1), dtype=float32)print(t11) # tf.Tensor( # [[[ 1. 2. 3. 4. 5.] # [ 6. 7. 8. 9. 10.] # [11. 12. 13. 14. 15.] # [16. 17. 18. 19. 20.]] # # [[21. 22. 23. 24. 25.] # [26. 27. 28. 29. 30.] # [31. 32. 33. 34. 35.] # [36. 37. 38. 39. 40.]] # # [[41. 42. 43. 44. 45.] # [46. 47. 48. 49. 50.] # [51. 52. 53. 54. 55.] # [56. 57. 58. 59. 60.]]], shape=(3, 4, 5), dtype=float32)print(t21) # tf.Tensor( # [[[ 1. 2. 3. 4. 5.] # [ 6. 7. 8. 9. 10.] # [11. 12. 13. 14. 15.] # [16. 17. 18. 19. 20.]] # # [[21. 22. 23. 24. 25.] # [26. 27. 28. 29. 30.] # [31. 32. 33. 34. 35.] # [36. 37. 38. 39. 40.]] # # [[41. 42. 43. 44. 45.] # [46. 47. 48. 49. 50.] # [51. 52. 53. 54. 55.] # [56. 57. 58. 59. 60.]]], shape=(3, 4, 5), dtype=float32)print(t31) # tf.Tensor( # [[[ 1. 2. 3. 4. 5.] # [ 6. 7. 8. 9. 10.] # [11. 12. 13. 14. 15.] # [16. 17. 18. 19. 20.]] # # [[21. 22. 23. 24. 25.] # [26. 27. 28. 29. 30.] # [31. 32. 33. 34. 35.] # [36. 37. 38. 39. 40.]] # # [[41. 42. 43. 44. 45.] # [46. 47. 48. 49. 50.] # [51. 52. 53. 54. 55.] # [56. 57. 58. 59. 60.]]], shape=(3, 4, 5), dtype=float32)t4 = tf.ones([3, 4, 2])t41 = t + t4 # tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3,4,5] vs. [3,4,2] [Op:Add] name: add/ broadcast_to12345678910111213141516171819202122232425262728293031323334arr = tf.range(12)arr1 = tf.reshape(arr, [3, 4, 1])t = tf.cast(arr1, dtype=tf.float32)t1 = tf.broadcast_to(t, [3, 4, 5])print(t) # tf.Tensor( # [[[ 0.] # [ 1.] # [ 2.] # [ 3.]] # # [[ 4.] # [ 5.] # [ 6.] # [ 7.]] # # [[ 8.] # [ 9.] # [10.] # [11.]]], shape=(3, 4, 1), dtype=float32)print(t1) # tf.Tensor( # [[[ 0. 0. 0. 0. 0.] # [ 1. 1. 1. 1. 1.] # [ 2. 2. 2. 2. 2.] # [ 3. 3. 3. 3. 3.]] # # [[ 4. 4. 4. 4. 4.] # [ 5. 5. 5. 5. 5.] # [ 6. 6. 6. 6. 6.] # [ 7. 7. 7. 7. 7.]] # # [[ 8. 8. 8. 8. 8.] # [ 9. 9. 9. 9. 9.] # [10. 10. 10. 10. 10.] # [11. 11. 11. 11. 11.]]], shape=(3, 4, 5), dtype=float32) tile1234567891011121314151617181920212223242526272829arr = tf.range(12)a = tf.reshape(arr, [3, 4])a1 = tf.broadcast_to(a, [2, 3, 4])a2 = tf.expand_dims(a, axis=0)a21 = tf.tile(a2, [2, 1, 1])print(a) # tf.Tensor( # [[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]], shape=(3, 4), dtype=int32)print(a1) # tf.Tensor( # [[[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]] # # [[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]]], shape=(2, 3, 4), dtype=int32)print(a2) # tf.Tensor( # [[[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]]], shape=(1, 3, 4), dtype=int32)print(a21) # tf.Tensor( # [[[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]] # # [[ 0 1 2 3] # [ 4 5 6 7] # [ 8 9 10 11]]], shape=(2, 3, 4), dtype=int32) 本文标题:TensorFlow broadcasting 文章作者:魏超 发布时间:2019年05月14日 - 09:05 最后更新:2019年06月04日 - 15:06 原始链接:http://www.weichao.io/2019/05/14/TensorFlow-broadcasting/ 许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。 ---------------------本文结束---------------------