TensorFlow (二): 单变量线性回归

- 日理万妓 2022-03-15 11:54 330阅读 0赞

TensorFlow 单变量线性回归

  • 代码
  • 效果图

代码

  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. # generate data
  5. data_size = 100
  6. x_data = np.random.rand(data_size) * 0.6 + 0.2
  7. noise = np.random.rand(data_size)
  8. noise = (noise - noise.mean()) * 0.2
  9. w = 1.5
  10. b = 0.5
  11. y_data = x_data * w + b + noise
  12. train_test_boundary = int(0.7 * data_size)
  13. x_train = x_data[:train_test_boundary]
  14. y_train = y_data[:train_test_boundary]
  15. x_test = x_data[train_test_boundary:]
  16. y_test = y_data[train_test_boundary:]
  17. # my model
  18. train_loss_history = []
  19. test_loss_history = []
  20. nb_epoch = 100
  21. my_w = 0
  22. my_b = 0
  23. learning_rate = 0.1
  24. with tf.Graph().as_default():
  25. x = tf.placeholder(dtype = tf.float32, shape = [None])
  26. y_ = tf.placeholder(dtype = tf.float32, shape = [None])
  27. w = tf.Variable(0, dtype = tf.float32)
  28. b = tf.Variable(0, dtype = tf.float32)
  29. y = x * w + b
  30. loss = tf.reduce_mean((y - y_) ** 2)
  31. train = tf.train.AdamOptimizer(learning_rate).minimize(loss)
  32. with tf.Session() as sess:
  33. sess.run(tf.global_variables_initializer())
  34. for i in range(nb_epoch):
  35. sess.run(train, feed_dict = { x: x_train, y_: y_train})
  36. l = sess.run(loss, feed_dict = { x: x_train, y_:y_train})
  37. train_loss_history.append(l)
  38. l = sess.run(loss, feed_dict = { x: x_test, y_: y_test})
  39. test_loss_history.append(l)
  40. my_w = w.eval()
  41. my_b = b.eval()
  42. # display
  43. plt.figure(figsize=(8, 4))
  44. plt.subplot(1, 2, 1)
  45. plt.xlim([0, 1])
  46. plt.ylim([0, 3])
  47. plt.plot(np.linspace(0, 1, 100), my_w * np.linspace(0, 1, 100) + my_b, 'r--')
  48. plt.scatter(x_train, y_train, 1.0, 'b', marker='o')
  49. plt.title('dataset & hypothesis function')
  50. plt.subplot(1, 2, 2)
  51. plt.title('training & testing loss')
  52. plt.plot(train_loss_history, 'b--')
  53. plt.plot(test_loss_history, 'r--')
  54. plt.legend(['trian', 'test'])
  55. plt.show()

效果图

0069McTXly1g0ohg9tiotj318g0m8q4o.jpg

发表评论

表情:
评论列表 (有 0 条评论,330人围观)

还没有评论,来说两句吧...

相关阅读