  • 本文档暂未完成
  • 多项式+最小二乘法做曲线拟合的数学基础
  • 程序
  • 解释
    • 训练和测试数据
    • 使用原始训练数据做曲线拟合
    • 过拟合现象的loss表现
    • 数据增强
    • 使用增强后的数据做曲线拟合
    • 数据增强后的loss表现



y = a 0 + a 1 x 1 + a 2 x 2 + . . . + a n x n y = a_0 + a_1x^1 + a_2x^2 + … + a_nx^n y=a0​+a1​x1+a2​x2+…+an​xn
拟合的目的就是利用已知点求取所有的系数 a i a_i ai​。

假设我们有m个点用来做拟合,m个点的坐标为: ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x m , y m ) (x_1, y_1), (x_2, y_2), … , (x_m, y_m) (x1​,y1​),(x2​,y2​),…,(xm​,ym​)

a 0 + a 1 x 1 + a 2 x 1 2 + . . . + a n x 1 n = y 1 a 0 + a 1 x 2 + a 2 x 2 2 + . . . + a n x 2 n = y 2 . . . a 0 + a 1 x m + a 2 x m 2 + . . . + a n x m n = y m a_0 + a_1{x_1} + a_2{x_1}^2 + … + a_n{x_1}^n = y_1 \\[2ex] a_0 + a_1{x_2} + a_2{x_2}^2 + … + a_n{x_2}^n = y_2 \\[2ex] … \\[2ex] a_0 + a_1{x_m} + a_2{x_m}^2 + … + a_n{x_m}^n = y_m a0​+a1​x1​+a2​x1​2+…+an​x1​n=y1​a0​+a1​x2​+a2​x2​2+…+an​x2​n=y2​…a0​+a1​xm​+a2​xm​2+…+an​xm​n=ym​
[ 1 x 1 1 x 1 2 . . . x 1 n 1 x 2 1 x 2 2 . . . x 2 n . . . 1 x m 1 x m 2 . . . x m n ] [ a 0 a 1 . . . a n ] = [ y 1 y 2 . . . y m ] \begin{bmatrix} 1 & {x_1}^1 & {x_1}^2 & … & {x_1}^n \\ 1 & {x_2}^1 & {x_2}^2 & … & {x_2}^n \\ && … \\ 1 & {x_m}^1 & {x_m}^2 & … & {x_m}^n \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ … \\ a_n \end{bmatrix} = \begin{bmatrix} y_1 \\ y_2 \\ … \\ y_m \end{bmatrix} ⎣⎢⎢⎡​111​x1​1x2​1xm​1​x1​2x2​2…xm​2​………​x1​nx2​nxm​n​⎦⎥⎥⎤​⎣⎢⎢⎡​a0​a1​…an​​⎦⎥⎥⎤​=⎣⎢⎢⎡​y1​y2​…ym​​⎦⎥⎥⎤​
X = [ 1 x 1 1 x 1 2 . . . x 1 n 1 x 2 1 x 2 2 . . . x 2 n . . . 1 x m 1 x m 2 . . . x m n ] X = \begin{bmatrix} 1 & {x_1}^1 & {x_1}^2 & … & {x_1}^n \\ 1 & {x_2}^1 & {x_2}^2 & … & {x_2}^n \\ && … \\ 1 & {x_m}^1 & {x_m}^2 & … & {x_m}^n \end{bmatrix} X=⎣⎢⎢⎡​111​x1​1x2​1xm​1​x1​2x2​2…xm​2​………​x1​nx2​nxm​n​⎦⎥⎥⎤​
A = [ a 0 a 1 . . . a n ] A = \begin{bmatrix} a_0 \\ a_1 \\ … \\ a_n \end{bmatrix} A=⎣⎢⎢⎡​a0​a1​…an​​⎦⎥⎥⎤​
Y = [ y 1 y 2 . . . y m ] Y = \begin{bmatrix} y_1 \\ y_2 \\ … \\ y_m \end{bmatrix} Y=⎣⎢⎢⎡​y1​y2​…ym​​⎦⎥⎥⎤​
X A = Y X T X A = X T Y ( X T X ) − 1 X T X A = ( X T X ) − 1 X T Y A = ( X T X ) − 1 X T Y XA=Y \\[2ex] X^TXA=X^TY \\[2ex] (X^TX)^{-1}X^TXA=(X^TX)^{-1}X^TY \\[2ex] A=(X^TX)^{-1}X^TY XA=YXTXA=XTY(XTX)−1XTXA=(XTX)−1XTYA=(XTX)−1XTY
注意,只有当 m > n m>n m>n 时,得到的是最小二乘意义下的解;当 m = n m=n m=n 时得到的是精确解,此时拟合曲线可以精准穿过所有的已知点;当 m < n m<n m<n 时,就不能称之为最小二乘解了(该叫什么我也忘了,好像是最小方差解?),此时尽管多项式本身有更加富足的能力穿过所有的已知点,但是通过上述线性代数求解方法则得不到这样的解。


在代码中,将系数 a i a_i ai​ 称为模型 model,将待拟合的已知点称为训练集train data,将用来测试的点称为测试集test data


代码构造了x_train, y_train, x_test, y_test用于测试,并根据这些数据画了一些图片以帮助理解。

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. DEGREE = [2, 4, 6, 8]
  5. DEGREE_AUG = [2, 4, 6, 8, 10, 12]
  6. class PolynomialModel(object):
  7. def __init__(self, degree):
  8. self.model = None
  9. self.degree = degree
  10. def _generate_x_matrix(self, X):
  11. """ generate X matrix Parameters: ----------- X: list or 1-dim numpy array x coordinates of points Returns: -------- x_matrix: 2-dim numpy array X matrix """
  12. x_matrix = np.ones([len(X), self.degree + 1])
  13. for i in range(1, self.degree + 1):
  14. x_matrix[:, i] = np.power(X, i)
  15. return x_matrix
  16. def fit(self, X, Y):
  17. """ compute model by least square fitting Parameters: ----------- X, Y: list or 1-dim numpy array coordinates of points (x, y) """
  18. Y = np.array(Y)
  19. x_mat = self._generate_x_matrix(X)
  20. Y = np.reshape(Y, [Y.shape[0], 1])
  21. self.model = np.linalg.inv(x_mat.T @ x_mat) @ x_mat.T @ Y
  22. def predict(self, X):
  23. """ predict output of input X, using the computed model Parameters: ----------- X: list or 1-dim numpy array x coordinates of points """
  24. x_mat = self._generate_x_matrix(X)
  25. Y = x_mat @ self.model
  26. return Y
  27. def compute_mse(y1, y2):
  28. """ compute mean square error Parameters: ----------- y1, y2: list or numpy array, with same shape Returns: -------- mse: mean square error """
  29. y1 = np.array(y1)
  30. y2 = np.array(y2)
  31. assert np.product(y1.shape) == np.product(y2.shape)
  32. y2 = np.reshape(y2, y1.shape)
  33. mse = np.mean(np.power(y1 - y2, 2))
  34. return mse
  35. def gauss_augmentation(X, Y, aug_num, std):
  36. """ augment data using gauss distributed random number Parameters: ----------- X, Y: list or 1-dim numpy array coordinates of points (x, y) aug_num: augmentation number for each point std: std of random number Returns: -------- x_aug, y_aug: augmented points """
  37. assert len(X) == len(Y)
  38. length = len(X)
  39. x_aug = []
  40. y_aug = []
  41. for i in range(length):
  42. x = X[i] + np.random.randn(aug_num) * std
  43. y = Y[i] + np.random.randn(aug_num) * std
  44. x_aug.append(x)
  45. y_aug.append(y)
  46. x_aug = np.array(x_aug).reshape([-1])
  47. y_aug = np.array(y_aug).reshape([-1])
  48. return x_aug, y_aug
  49. def get_x_range(x, increment):
  50. """ get evenly distributed coordinates by input x, the low limit and high limit of coordinates are min(x) and max(x), separately. Parameters: ----------- x: list or 1-dim numpy array x coordinates of points increment: float number interval of adjacent coordinates Returns: -------- range: evenly distributed coordinates """
  51. x = np.array(x)
  52. xmin = np.min(x)
  53. xmax = np.max(x) + increment
  54. range = np.arange(xmin, xmax, increment)
  55. return range
  56. def plot_original_dataset(x_train, y_train, x_test, y_test):
  57. plt.figure(1, figsize=(9, 6))
  58. plt.plot(x_train, y_train,
  59. color='red', marker='o', markersize=6,
  60. linewidth=0, label='train_data')
  61. plt.plot(x_test, y_test,
  62. color='green', marker='o', markersize=6,
  63. linewidth=0, label='test_data')
  64. plt.legend()
  65. plt.savefig('original_dataset.png')
  66. def plot_augmented_dataset(x_train, y_train, x_test, y_test, x_aug, y_aug):
  67. plt.figure(2, figsize=(9, 6))
  68. plt.plot(x_train, y_train,
  69. color='red', marker='o', markersize=6,
  70. linewidth=0, label='train_data')
  71. plt.plot(x_test, y_test,
  72. color='green', marker='o', markersize=6,
  73. linewidth=0, label='test_data')
  74. plt.plot(x_aug, y_aug,
  75. color='blue', marker='o', markersize=3,
  76. linewidth=0, label='augmented_data')
  77. plt.legend()
  78. plt.savefig('augmented_dataset.png')
  79. def plot_no_augmentation_model(x_train, y_train, x_test, y_test):
  80. x_range = get_x_range(x_train, 0.1)
  81. plt.figure(3, figsize=(15, 9))
  82. for i, degree in enumerate(DEGREE):
  83. plt.subplot(2, 2, i + 1)
  84. poly_model = PolynomialModel(degree)
  85. poly_model.fit(x_train, y_train)
  86. y_predict = poly_model.predict(x_range)
  87. plt.plot(x_train, y_train,
  88. color='red', marker='o', markersize=6,
  89. linewidth=0, label='train_data')
  90. plt.plot(x_test, y_test,
  91. color='green', marker='o', markersize=6,
  92. linewidth=0, label='test_data')
  93. plt.plot(x_range, y_predict, color='black', label='fitted curve')
  94. plt.title('degree = %d' % degree)
  95. plt.legend()
  96. plt.savefig('no_augmentation_model.png')
  97. def plot_loss_for_overfitting(x_train, y_train, x_test, y_test):
  98. plt.figure(4, figsize=(9, 6))
  99. degrees = get_x_range(DEGREE, 1)
  100. train_loss = []
  101. test_loss = []
  102. for _, degree in enumerate(degrees):
  103. poly_model = PolynomialModel(degree)
  104. poly_model.fit(x_train, y_train)
  105. train_predict = poly_model.predict(x_train)
  106. test_predict = poly_model.predict(x_test)
  107. train_mse = compute_mse(y_train, train_predict)
  108. test_mse = compute_mse(y_test, test_predict)
  109. train_loss.append(train_mse)
  110. test_loss.append(test_mse)
  111. plt.plot(degrees, train_loss, color='red', label='train_loss')
  112. plt.plot(degrees, test_loss, color='green', label='test_loss')
  113. plt.xlabel('degree')
  114. plt.ylabel('mean square error')
  115. plt.legend()
  116. plt.savefig('loss_for_overfitting.png')
  117. def plot_augmentation_model(x_train, y_train, x_aug, y_aug):
  118. x_range = get_x_range(x_train, 0.1)
  119. plt.figure(5, figsize=(15, 15))
  120. for i, degree in enumerate(DEGREE_AUG):
  121. plt.subplot(3, 2, i + 1)
  122. poly_model = PolynomialModel(degree)
  123. poly_model.fit(x_aug, y_aug)
  124. y_predict = poly_model.predict(x_range)
  125. plt.plot(x_train, y_train,
  126. color='red', marker='o', markersize=6,
  127. linewidth=0, label='train_data')
  128. plt.plot(x_test, y_test,
  129. color='green', marker='o', markersize=6,
  130. linewidth=0, label='test_data')
  131. plt.plot(x_range, y_predict, color='black', label='fitted curve')
  132. plt.title('degree = %d' % degree)
  133. plt.legend()
  134. plt.savefig('augmentation_model.png')
  135. def plot_loss_for_augmentation(x_train, y_train, x_test, y_test, x_aug, y_aug):
  136. plt.figure(6, figsize=(9, 6))
  137. degrees = get_x_range(DEGREE_AUG, 1)
  138. train_loss = []
  139. test_loss = []
  140. for _, degree in enumerate(degrees):
  141. poly_model = PolynomialModel(degree)
  142. poly_model.fit(x_aug, y_aug)
  143. train_predict = poly_model.predict(x_train)
  144. test_predict = poly_model.predict(x_test)
  145. train_mse = compute_mse(y_train, train_predict)
  146. test_mse = compute_mse(y_test, test_predict)
  147. train_loss.append(train_mse)
  148. test_loss.append(test_mse)
  149. plt.plot(degrees, train_loss, color='red', label='train_loss')
  150. plt.plot(degrees, test_loss, color='green', label='test_loss')
  151. plt.xlabel('degree')
  152. plt.ylabel('mean square error')
  153. plt.legend()
  154. plt.savefig('loss_for_augmentation.png')
  155. if __name__ == '__main__':
  156. np.random.seed(1337)
  157. x_train = [-3.0, -2.1, -0.9, 0.1, 1.2, 2.0, 3]
  158. y_train = [2.5, 1.2, 1.1, -2.9, -0.7, -3.2, 1.3]
  159. x_test = [-3.0, -2.7, -2.3, -2.0, -1.8, -1.6, -1.3, -1.0, -0.9, -0.6, -0.2,
  160. 0.1, 0.4, 0.7, 1.0, 1.2, 1.5, 1.8, 2.0, 2.3, 2.5, 2.7, 3.0]
  161. y_test = [2.4, 2.1, 1.6, 1.1, 1.3, 1.0, 1.2, 1.0, 0.8, -0.2, -1.3, -2.3,
  162. -2.7, -2.3, -1.5, -1.2, -1.5, -2.9, -2.5, -1.3, -1.1, -0.4, 1.1]
  163. x_aug, y_aug = gauss_augmentation(x_train, y_train, 20, 0.4)
  164. plot_original_dataset(x_train, y_train, x_test, y_test)
  165. plot_augmented_dataset(x_train, y_train, x_test, y_test, x_aug, y_aug)
  166. plot_no_augmentation_model(x_train, y_train, x_test, y_test)
  167. plot_loss_for_overfitting(x_train, y_train, x_test, y_test)
  168. plot_augmentation_model(x_train, y_train, x_aug, y_aug)
  169. plot_loss_for_augmentation(x_train, y_train, x_test, y_test, x_aug, y_aug)
















