直线拟合、二次曲线拟合、折线拟合和KNN近邻(附代码)

冷不防 2022-08-27 06:49 323阅读 0赞

一个工程中的应用,需要对一组数据做上面四种形式的拟合回归,并且根据模型对输入做evaluation,就是做一个函数曲线拟合。

下面的RevPre定义了方法和结构,Util是使用案例,其中的Opt的type指示模型要拟合的是哪一种。

1)直线拟合: y = kx+b

2) 二次曲线拟合:y = AX^2 + BX + C

以上两种很典型,不多解释;

3) KNN

KNN的Opt阶段什么也没做,只是保存了所有数据对,然后检测阶段选取最近的K(MIN_KNN_K < K < MAX_KNN_K)个点来计算距离加权后的结果。

4) 折线拟合:

前面先介绍KNN,是因为折线拟合的结果与之很类似,同样保存了一系列的数据对,检测的时候判断处于哪一段。不同之处是做了简化,而不是全部存储起来。

首先对数据对按照x快排,然后以直线拟合判断总体是升序还是降序,接下来就是不断淘汰不符合顺序的点,发生“突起”或者“凹陷”时清除与拟合直线距离远的点,直到所有点都按序排列,形成一条单调折线。

代码:

RevPre.h

  1. #include <iostream>
  2. #include <math.h>
  3. /************************Model type*********************************/
  4. #define MAX_POL_DEPTH 3
  5. #define MAX_KNN_K 10
  6. #define MIN_KNN_K 1
  7. #define MAX(x,y) (x) > (y) ? (x) : (y)
  8. #define MIN(x,y) (x) < (y) ? (x) : (y)
  9. enum modelType{
  10. StraightLine = 0, // default
  11. CurveAt2,
  12. BrokenLine,
  13. KNNModel
  14. };
  15. typedef struct Model{
  16. enum modelType type;
  17. // Line parameters
  18. double lineParam[MAX_POL_DEPTH];
  19. // Point model
  20. double *px, *py;
  21. int len;
  22. };
  23. Model* CreateModel();
  24. void ReleaseModel( Model** _ptr );
  25. bool SetOptData( Model* ptr, double *x, double *y, int len );
  26. bool Opt( Model *ptr, modelType type );
  27. double Predict( Model *ptr, double x );
  28. /**************************Polynomial*******************************/
  29. /* Internal */
  30. void CalculatePower(double *powers, int ptNum, int maxDepth, double *x ); //将初始x[i]的值的各幂次方存储在一个二维数组里面
  31. void CalculateParams(double *powers, int ptNum, int maxDepth,
  32. double *params, double *y); //计算正规方程组的系数矩阵
  33. void DirectLU( double *params, int ptNum, int maxDepth, double *x ); //列主元LU分解
  34. inline void swap(double &,double &); //交换两个变量的值
  35. /* External */
  36. bool PolynomialOpt( Model *ptr );
  37. /************************StraightLine********************************/
  38. bool StraightLineOpt( Model *ptr );
  39. /************************BrokenLine********************************/
  40. /*Internal*/
  41. int SingleSort( double *index, double *context, int start, int end );
  42. void QuickSort( double *index, double *context, int start, int end );
  43. int CheckSequence( double *context, int start, int end, bool upTrend );
  44. /*External*/
  45. bool BrokenLineOpt( Model *ptr );
  46. /********************KNN(Lazy-learning)****************************/
  47. bool KNNOpt( Model *ptr );

RevPre.cpp

  1. #include "Revise.h"
  2. Model* CreateModel(){
  3. Model *ptr = new Model;
  4. ptr->type = StraightLine;
  5. ptr->px = ptr->py = NULL;
  6. ptr->len = 0;
  7. memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
  8. return ptr;
  9. }
  10. void ReleaseModel( Model** _ptr ){
  11. Model *ptr = *_ptr;
  12. if ( ptr->px ) delete[] ptr->px;
  13. if ( ptr->py ) delete[] ptr->py;
  14. delete ptr;
  15. *_ptr = NULL;
  16. return ;
  17. }
  18. bool SetOptData( Model *ptr, double *x, double *y, int len ){
  19. if ( !ptr || !x || !y ) return false;
  20. if ( !ptr->px ) ptr->px = new double[len];
  21. if ( !ptr->py ) ptr->py = new double[len];
  22. ptr->len = len;
  23. memcpy( ptr->px, x, sizeof(double)*len );
  24. memcpy( ptr->py, y, sizeof(double)*len );
  25. return true;
  26. }
  27. bool Opt( Model *ptr, modelType type ){
  28. if ( !ptr ) return false;
  29. switch( type )
  30. {
  31. case StraightLine:
  32. return StraightLineOpt( ptr );
  33. case CurveAt2:
  34. return PolynomialOpt( ptr );
  35. case BrokenLine:
  36. return BrokenLineOpt( ptr );
  37. case KNNModel:
  38. return KNNOpt( ptr );
  39. default:
  40. return false;
  41. }
  42. }
  43. double Predict( Model *ptr, double x ){
  44. if ( !ptr ) exit (-1);
  45. switch( ptr->type )
  46. {
  47. case StraightLine:
  48. return ptr->lineParam[0] + ptr->lineParam[1]*x;
  49. case CurveAt2:
  50. return ptr->lineParam[0] + ptr->lineParam[1]*x + ptr->lineParam[2]*x*x;
  51. case BrokenLine:
  52. {
  53. if ( ptr->len < 3 ) exit(-2);
  54. int first = 0;
  55. if ( x <= ptr->px[0] )
  56. {
  57. double x0 = ptr->px[0], x1 = ptr->px[1];
  58. double y0 = ptr->py[0], y1 = ptr->py[1];
  59. return y0 - (x0-x)*(y1-y0)/(x1-x0);
  60. }
  61. else if ( x >= ptr->px[ptr->len-1] )
  62. {
  63. double x0 = ptr->px[ptr->len-2], y0 = ptr->py[ptr->len - 2];
  64. double x1 = ptr->px[ptr->len-1], y1 = ptr->py[ptr->len - 1];
  65. return y1 -(x-x1)*(y0-y1)/(x1-x0);
  66. }
  67. else
  68. {
  69. while ( ptr->px[first] < x ) { first ++ ;}
  70. first --;
  71. double deltay = ptr->py[first+1] - ptr->py[first];
  72. double deltax = ptr->px[first+1] - ptr->px[first];
  73. return ptr->py[first] + deltay*(x-ptr->px[first])/deltax;
  74. }
  75. }
  76. case KNNModel:
  77. {
  78. int K = MAX( MIN_KNN_K, MIN( int(ptr->len*0.1), MAX_KNN_K ) );
  79. // Prepare the initial K neighbours
  80. double *dist_team = new double[K];
  81. int *idx_team = new int[K];
  82. int farestIdt = -1;
  83. double farestDist = 0;
  84. int id = 0;
  85. for ( ; id < K; id ++ )
  86. {
  87. idx_team[id] = id;
  88. dist_team[id] = abs( ptr->px[id] - x );
  89. if ( farestDist <= dist_team[id] )
  90. {
  91. farestIdt = id;
  92. farestDist = dist_team[id];
  93. }
  94. }
  95. // Looking for the K nearest neighbours
  96. while ( id < ptr->len )
  97. {
  98. if ( abs( ptr->px[id] -x ) < farestDist )
  99. {
  100. // Update the team
  101. idx_team[farestIdt] = id;
  102. dist_team[farestIdt] = abs( ptr->px[id] - x );
  103. // Update the farest record
  104. farestIdt = 0;
  105. farestDist = dist_team[0];
  106. for ( int searchIdt = 1; searchIdt < K; searchIdt ++ )
  107. {
  108. if ( dist_team[searchIdt] > farestDist )
  109. {
  110. farestDist = dist_team[searchIdt];
  111. farestIdt = searchIdt;
  112. }
  113. }
  114. }
  115. id ++;
  116. }
  117. // Calculate their contribution
  118. double res = 0.0;
  119. double weightSum = 0.0;
  120. for ( int seachIdt = 0; seachIdt < K; seachIdt ++ )
  121. {
  122. weightSum += 1.0/dist_team[seachIdt];
  123. res += 1.0/dist_team[seachIdt]*ptr->py[idx_team[seachIdt]];
  124. }
  125. delete[] dist_team;
  126. delete[] idx_team;
  127. return res/weightSum;
  128. }
  129. default:
  130. exit(-2);
  131. }
  132. }
  133. /**************************Polynomial*******************************/
  134. bool StraightLineOpt( Model *ptr )
  135. {
  136. if ( !ptr ) return false;
  137. if ( !ptr->px || !ptr->py ) return false;
  138. int outLen = 2;
  139. int ptNum = ptr->len, maxDepth = outLen;
  140. double *powers = new double[maxDepth*ptNum];
  141. double *params = new double[maxDepth*(maxDepth+1)];
  142. CalculatePower( powers, ptNum, maxDepth, ptr->px );
  143. CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵
  144. DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
  145. ptr->type = StraightLine;
  146. std::cout<<"-------------------------"<<std::endl;
  147. std::cout<<"拟合函数的系数分别为:\n";
  148. for( int i=0;i<maxDepth;i++)
  149. std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
  150. std::cout<<"-------------------------"<<std::endl;
  151. delete[] powers;
  152. delete[] params;
  153. return true;
  154. }
  155. bool PolynomialOpt( Model *ptr )
  156. {
  157. if ( !ptr ) return false;
  158. if ( !ptr->px || !ptr->py ) return false;
  159. int outLen = MAX_POL_DEPTH;
  160. int ptNum = ptr->len, maxDepth = outLen;
  161. double *powers = new double[maxDepth*ptNum];
  162. double *params = new double[maxDepth*(maxDepth+1)];
  163. CalculatePower( powers, ptNum, maxDepth, ptr->px );
  164. CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵
  165. DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
  166. ptr->type = CurveAt2;
  167. /*std::cout<<"-------------------------"<<std::endl;
  168. std::cout<<"拟合函数的系数分别为:\n";
  169. for( int i=0;i<maxDepth;i++)
  170. std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
  171. std::cout<<"-------------------------"<<std::endl;*/
  172. delete[] powers;
  173. delete[] params;
  174. return true;
  175. }
  176. void CalculatePower(double *powers, int ptNum, int maxDepth, double *x )
  177. {
  178. if ( !powers || !x ) return ;
  179. int i, j, k;
  180. double temp;
  181. for( i = 0; i < maxDepth; i ++ )
  182. for( j = 0; j < ptNum; j ++ )
  183. {
  184. temp = 1;
  185. for( k = 0; k < i; k ++ )
  186. temp *= x[j];
  187. powers[i*ptNum+j] = temp;
  188. }
  189. return ;
  190. }
  191. void CalculateParams(double *powers, int ptNum, int maxDepth,
  192. double *params, double *y)
  193. {
  194. if ( !powers || !params || !y ) return ;
  195. int i, j, k;
  196. double temp;
  197. int step = maxDepth + 1;
  198. for( i = 0; i < maxDepth; i ++ )
  199. {
  200. for(j = 0; j < maxDepth; j ++ )
  201. {
  202. temp = 0;
  203. for( k = 0; k < ptNum; k ++ )
  204. temp += powers[i*ptNum+k]*powers[j*ptNum+k];
  205. params[i*step+j] = temp;
  206. }
  207. temp = 0;
  208. for( k = 0; k < ptNum; k ++ )
  209. {
  210. temp += y[k]*powers[i*ptNum+k];
  211. params[i*step+maxDepth] = temp;
  212. }
  213. }
  214. return ;
  215. }
  216. inline void swap(double &a,double &b)
  217. {
  218. a=a+b;
  219. b=a-b;
  220. a=a-b;
  221. }
  222. void DirectLU( double *params, int ptNum, int maxDepth, double *x )
  223. {
  224. int i, r, k, j;
  225. double max;
  226. int step = maxDepth + 1;
  227. double *s = new double[maxDepth];
  228. double *t = new double[maxDepth];
  229. // choose the main element
  230. for( r = 0; r < maxDepth; r ++ )
  231. {
  232. max = 0;
  233. j = r;
  234. for( i = r; i < maxDepth; i ++ )
  235. {
  236. s[i] = params[i*step+r];
  237. for( k = 0; k < r; k ++ )
  238. s[i] -= params[i*step+k] * params[k*step+r];
  239. s[i] = abs(s[i]);
  240. if( s[i] > max ){
  241. j = i;
  242. max = s[i];
  243. }
  244. }
  245. // if the "main"element is not @ row r, swap the corresponding element
  246. if( j != r )
  247. {
  248. for( i = 0; i < maxDepth + 1; i ++ )
  249. swap( params[r*step+i], params[j*step+i] );
  250. }
  251. for( i = r; i < step; i ++ )
  252. for( k = 0; k < r; k ++ ){
  253. params[r*step+i] -= params[r*step+k] * params[k*step+i];
  254. }
  255. for(i = r+1; i < maxDepth; i ++ )
  256. {
  257. for ( k = 0; k < r; k ++ )
  258. params[i*step+r] -= params[i*step+k] * params[k*step+r];
  259. params[i*step+r] /= params[r*step+r];
  260. }
  261. }
  262. for( i = 0; i < maxDepth; i ++ )
  263. t[i] = params[ i*step + maxDepth ];
  264. for ( i = maxDepth - 1; i >= 0; i -- ) //利用回代法求最终解
  265. {
  266. for ( r = maxDepth - 1; r > i; r -- )
  267. t[i] -= params[ i*step + r ] * x[r];
  268. x[i] = t[i]/params[i*step+i];
  269. }
  270. delete[] s;
  271. delete[] t;
  272. return ;
  273. }
  274. /**********************Broken Line***************************/
  275. // Quick Sort
  276. int SingleSort( double *index, double *context, int start, int end )
  277. {
  278. if ( end - start < 1 ) return start;
  279. int i = start, j = end;
  280. double key = index[i];
  281. double key_ = context[i];
  282. while ( i < j )
  283. {
  284. while ( index[j] > key && j > i ) j --;
  285. if ( index[j] < key )
  286. {
  287. index[i] = index[j];
  288. context[i] = context[j];
  289. }
  290. while ( index[i] < key && j > i ) i ++;
  291. if ( index[i] > key )
  292. {
  293. index[j] = index[i];
  294. context[j] = context[i];
  295. }
  296. }
  297. index[i] = key;
  298. context[j] = key_;
  299. return i;
  300. }
  301. void QuickSort( double *index, double *context, int start, int end )
  302. {
  303. if ( end - start < 1 ) return ; // important
  304. int mid = SingleSort( index, context, start, end );
  305. QuickSort( index, context, start, mid - 1 );
  306. QuickSort( index, context, mid+ 1, end );
  307. }
  308. int CheckSequence( double *context, int start, int end, bool upTrend )
  309. {
  310. int i = start;
  311. for ( ; i < end; i ++ )
  312. {
  313. if ( upTrend && context[i+1] < context[i] )
  314. {
  315. return i;
  316. }
  317. if ( !upTrend && context[i] < context[i+1] )
  318. {
  319. return i;
  320. }
  321. }
  322. return -1;
  323. }
  324. // Form the broken line
  325. bool BrokenLineOpt( Model *ptr )
  326. {
  327. if ( !ptr ) return false;
  328. if ( !ptr->len || !ptr->px || !ptr->py ) return false;
  329. // analyse the trend of points and get its approximate line
  330. StraightLineOpt( ptr );
  331. double k = ptr->lineParam[1], b = ptr->lineParam[0];
  332. bool upTrend = ( k > 0 );
  333. // sort the sequence by py
  334. QuickSort( ptr->px, ptr->py, 0, ptr->len - 1 );
  335. int oddPoint = 0;
  336. while ( (oddPoint = CheckSequence( ptr->py, oddPoint, ptr->len -1, upTrend ) ) != -1 )
  337. {
  338. double formerErr = abs( k*ptr->px[oddPoint] + b - ptr->py[oddPoint] );
  339. double laterErr = abs( k*ptr->px[oddPoint+1] + b - ptr->py[oddPoint+1] );
  340. oddPoint = formerErr > laterErr ? oddPoint : oddPoint + 1;
  341. // remove the odd point
  342. memcpy( ptr->py + oddPoint, ptr->py + oddPoint + 1, sizeof(double) );
  343. memcpy( ptr->px + oddPoint, ptr->px + oddPoint + 1, sizeof(double) );
  344. ptr->len --;
  345. oddPoint --;
  346. }
  347. memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
  348. ptr->type = BrokenLine;
  349. return true;
  350. }
  351. /**********************Lazy Learning***************************/
  352. bool KNNOpt( Model *ptr )
  353. {
  354. // We do nothing as we say it's a lazy-learning method
  355. // Only when predict() is called, the learning process is invoked
  356. memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
  357. ptr->type = KNNModel;
  358. return true;
  359. }

Util.cpp

  1. #include "Revise.h"
  2. int _tmain(int argc, _TCHAR* argv[])
  3. {
  4. double x[10], y[10];
  5. for ( int i = 0; i < 10; i ++ )
  6. {
  7. x[i] = 10 - i;
  8. y[i] = (i-2)*(i-2);
  9. }
  10. Model *model = CreateModel();
  11. SetOptData( model, x, y, 10 );
  12. Opt( model, BrokenLine );
  13. double result = Predict( model, 1.5 );
  14. ReleaseModel( &model );
  15. return 0;
  16. }

发表评论

表情:
评论列表 (有 0 条评论,323人围观)

还没有评论,来说两句吧...

相关阅读

    相关 抵抗过

    回归拟合有三种情况: ![在这里插入图片描述][20210221113528609.png] (1)欠拟合就是模型不能正确预测出数据的分布情况。 (2)正确拟合就是

    相关

    开始我是很难弄懂什么是过拟合,什么是欠拟合以及造成两者的各自原因以及相应的解决办法,学习了一段时间机器学习和深度学习后,分享下自己的观点,方便初学者能很好很形象地理解上面的问题

    相关 MATLAB 曲线

    最近做一个实验 ,代码中需要找出数据之间的函数关系,所以需要把数据进行拟合,找到关系式。听别人说MATLAB做 拟合的时候特别方便,所以就测试了一下,果真方便,现在将大