直线拟合、二次曲线拟合、折线拟合和KNN近邻(附代码)
一个工程中的应用,需要对一组数据做上面四种形式的拟合回归,并且根据模型对输入做evaluation,就是做一个函数曲线拟合。
下面的RevPre定义了方法和结构,Util是使用案例,其中的Opt的type指示模型要拟合的是哪一种。
1)直线拟合: y = kx+b
2) 二次曲线拟合:y = AX^2 + BX + C
以上两种很典型,不多解释;
3) KNN
KNN的Opt阶段什么也没做,只是保存了所有数据对,然后检测阶段选取最近的K(MIN_KNN_K < K < MAX_KNN_K)个点来计算距离加权后的结果。
4) 折线拟合:
前面先介绍KNN,是因为折线拟合的结果与之很类似,同样保存了一系列的数据对,检测的时候判断处于哪一段。不同之处是做了简化,而不是全部存储起来。
首先对数据对按照x快排,然后以直线拟合判断总体是升序还是降序,接下来就是不断淘汰不符合顺序的点,发生“突起”或者“凹陷”时清除与拟合直线距离远的点,直到所有点都按序排列,形成一条单调折线。
代码:
RevPre.h
#include <iostream>
#include <math.h>
/************************Model type*********************************/
#define MAX_POL_DEPTH 3
#define MAX_KNN_K 10
#define MIN_KNN_K 1
#define MAX(x,y) (x) > (y) ? (x) : (y)
#define MIN(x,y) (x) < (y) ? (x) : (y)
enum modelType{
StraightLine = 0, // default
CurveAt2,
BrokenLine,
KNNModel
};
typedef struct Model{
enum modelType type;
// Line parameters
double lineParam[MAX_POL_DEPTH];
// Point model
double *px, *py;
int len;
};
Model* CreateModel();
void ReleaseModel( Model** _ptr );
bool SetOptData( Model* ptr, double *x, double *y, int len );
bool Opt( Model *ptr, modelType type );
double Predict( Model *ptr, double x );
/**************************Polynomial*******************************/
/* Internal */
void CalculatePower(double *powers, int ptNum, int maxDepth, double *x ); //将初始x[i]的值的各幂次方存储在一个二维数组里面
void CalculateParams(double *powers, int ptNum, int maxDepth,
double *params, double *y); //计算正规方程组的系数矩阵
void DirectLU( double *params, int ptNum, int maxDepth, double *x ); //列主元LU分解
inline void swap(double &,double &); //交换两个变量的值
/* External */
bool PolynomialOpt( Model *ptr );
/************************StraightLine********************************/
bool StraightLineOpt( Model *ptr );
/************************BrokenLine********************************/
/*Internal*/
int SingleSort( double *index, double *context, int start, int end );
void QuickSort( double *index, double *context, int start, int end );
int CheckSequence( double *context, int start, int end, bool upTrend );
/*External*/
bool BrokenLineOpt( Model *ptr );
/********************KNN(Lazy-learning)****************************/
bool KNNOpt( Model *ptr );
RevPre.cpp
#include "Revise.h"
Model* CreateModel(){
Model *ptr = new Model;
ptr->type = StraightLine;
ptr->px = ptr->py = NULL;
ptr->len = 0;
memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
return ptr;
}
void ReleaseModel( Model** _ptr ){
Model *ptr = *_ptr;
if ( ptr->px ) delete[] ptr->px;
if ( ptr->py ) delete[] ptr->py;
delete ptr;
*_ptr = NULL;
return ;
}
bool SetOptData( Model *ptr, double *x, double *y, int len ){
if ( !ptr || !x || !y ) return false;
if ( !ptr->px ) ptr->px = new double[len];
if ( !ptr->py ) ptr->py = new double[len];
ptr->len = len;
memcpy( ptr->px, x, sizeof(double)*len );
memcpy( ptr->py, y, sizeof(double)*len );
return true;
}
bool Opt( Model *ptr, modelType type ){
if ( !ptr ) return false;
switch( type )
{
case StraightLine:
return StraightLineOpt( ptr );
case CurveAt2:
return PolynomialOpt( ptr );
case BrokenLine:
return BrokenLineOpt( ptr );
case KNNModel:
return KNNOpt( ptr );
default:
return false;
}
}
double Predict( Model *ptr, double x ){
if ( !ptr ) exit (-1);
switch( ptr->type )
{
case StraightLine:
return ptr->lineParam[0] + ptr->lineParam[1]*x;
case CurveAt2:
return ptr->lineParam[0] + ptr->lineParam[1]*x + ptr->lineParam[2]*x*x;
case BrokenLine:
{
if ( ptr->len < 3 ) exit(-2);
int first = 0;
if ( x <= ptr->px[0] )
{
double x0 = ptr->px[0], x1 = ptr->px[1];
double y0 = ptr->py[0], y1 = ptr->py[1];
return y0 - (x0-x)*(y1-y0)/(x1-x0);
}
else if ( x >= ptr->px[ptr->len-1] )
{
double x0 = ptr->px[ptr->len-2], y0 = ptr->py[ptr->len - 2];
double x1 = ptr->px[ptr->len-1], y1 = ptr->py[ptr->len - 1];
return y1 -(x-x1)*(y0-y1)/(x1-x0);
}
else
{
while ( ptr->px[first] < x ) { first ++ ;}
first --;
double deltay = ptr->py[first+1] - ptr->py[first];
double deltax = ptr->px[first+1] - ptr->px[first];
return ptr->py[first] + deltay*(x-ptr->px[first])/deltax;
}
}
case KNNModel:
{
int K = MAX( MIN_KNN_K, MIN( int(ptr->len*0.1), MAX_KNN_K ) );
// Prepare the initial K neighbours
double *dist_team = new double[K];
int *idx_team = new int[K];
int farestIdt = -1;
double farestDist = 0;
int id = 0;
for ( ; id < K; id ++ )
{
idx_team[id] = id;
dist_team[id] = abs( ptr->px[id] - x );
if ( farestDist <= dist_team[id] )
{
farestIdt = id;
farestDist = dist_team[id];
}
}
// Looking for the K nearest neighbours
while ( id < ptr->len )
{
if ( abs( ptr->px[id] -x ) < farestDist )
{
// Update the team
idx_team[farestIdt] = id;
dist_team[farestIdt] = abs( ptr->px[id] - x );
// Update the farest record
farestIdt = 0;
farestDist = dist_team[0];
for ( int searchIdt = 1; searchIdt < K; searchIdt ++ )
{
if ( dist_team[searchIdt] > farestDist )
{
farestDist = dist_team[searchIdt];
farestIdt = searchIdt;
}
}
}
id ++;
}
// Calculate their contribution
double res = 0.0;
double weightSum = 0.0;
for ( int seachIdt = 0; seachIdt < K; seachIdt ++ )
{
weightSum += 1.0/dist_team[seachIdt];
res += 1.0/dist_team[seachIdt]*ptr->py[idx_team[seachIdt]];
}
delete[] dist_team;
delete[] idx_team;
return res/weightSum;
}
default:
exit(-2);
}
}
/**************************Polynomial*******************************/
bool StraightLineOpt( Model *ptr )
{
if ( !ptr ) return false;
if ( !ptr->px || !ptr->py ) return false;
int outLen = 2;
int ptNum = ptr->len, maxDepth = outLen;
double *powers = new double[maxDepth*ptNum];
double *params = new double[maxDepth*(maxDepth+1)];
CalculatePower( powers, ptNum, maxDepth, ptr->px );
CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵
DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
ptr->type = StraightLine;
std::cout<<"-------------------------"<<std::endl;
std::cout<<"拟合函数的系数分别为:\n";
for( int i=0;i<maxDepth;i++)
std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
std::cout<<"-------------------------"<<std::endl;
delete[] powers;
delete[] params;
return true;
}
bool PolynomialOpt( Model *ptr )
{
if ( !ptr ) return false;
if ( !ptr->px || !ptr->py ) return false;
int outLen = MAX_POL_DEPTH;
int ptNum = ptr->len, maxDepth = outLen;
double *powers = new double[maxDepth*ptNum];
double *params = new double[maxDepth*(maxDepth+1)];
CalculatePower( powers, ptNum, maxDepth, ptr->px );
CalculateParams( powers, ptNum, maxDepth, params, ptr->py ); //计算正规方程组的系数矩阵
DirectLU( params, ptNum, maxDepth, ptr->lineParam ); //列主元LU分解
ptr->type = CurveAt2;
/*std::cout<<"-------------------------"<<std::endl;
std::cout<<"拟合函数的系数分别为:\n";
for( int i=0;i<maxDepth;i++)
std::cout<<"a["<<i<<"]="<<ptr->lineParam[i]<<std::endl;
std::cout<<"-------------------------"<<std::endl;*/
delete[] powers;
delete[] params;
return true;
}
void CalculatePower(double *powers, int ptNum, int maxDepth, double *x )
{
if ( !powers || !x ) return ;
int i, j, k;
double temp;
for( i = 0; i < maxDepth; i ++ )
for( j = 0; j < ptNum; j ++ )
{
temp = 1;
for( k = 0; k < i; k ++ )
temp *= x[j];
powers[i*ptNum+j] = temp;
}
return ;
}
void CalculateParams(double *powers, int ptNum, int maxDepth,
double *params, double *y)
{
if ( !powers || !params || !y ) return ;
int i, j, k;
double temp;
int step = maxDepth + 1;
for( i = 0; i < maxDepth; i ++ )
{
for(j = 0; j < maxDepth; j ++ )
{
temp = 0;
for( k = 0; k < ptNum; k ++ )
temp += powers[i*ptNum+k]*powers[j*ptNum+k];
params[i*step+j] = temp;
}
temp = 0;
for( k = 0; k < ptNum; k ++ )
{
temp += y[k]*powers[i*ptNum+k];
params[i*step+maxDepth] = temp;
}
}
return ;
}
inline void swap(double &a,double &b)
{
a=a+b;
b=a-b;
a=a-b;
}
void DirectLU( double *params, int ptNum, int maxDepth, double *x )
{
int i, r, k, j;
double max;
int step = maxDepth + 1;
double *s = new double[maxDepth];
double *t = new double[maxDepth];
// choose the main element
for( r = 0; r < maxDepth; r ++ )
{
max = 0;
j = r;
for( i = r; i < maxDepth; i ++ )
{
s[i] = params[i*step+r];
for( k = 0; k < r; k ++ )
s[i] -= params[i*step+k] * params[k*step+r];
s[i] = abs(s[i]);
if( s[i] > max ){
j = i;
max = s[i];
}
}
// if the "main"element is not @ row r, swap the corresponding element
if( j != r )
{
for( i = 0; i < maxDepth + 1; i ++ )
swap( params[r*step+i], params[j*step+i] );
}
for( i = r; i < step; i ++ )
for( k = 0; k < r; k ++ ){
params[r*step+i] -= params[r*step+k] * params[k*step+i];
}
for(i = r+1; i < maxDepth; i ++ )
{
for ( k = 0; k < r; k ++ )
params[i*step+r] -= params[i*step+k] * params[k*step+r];
params[i*step+r] /= params[r*step+r];
}
}
for( i = 0; i < maxDepth; i ++ )
t[i] = params[ i*step + maxDepth ];
for ( i = maxDepth - 1; i >= 0; i -- ) //利用回代法求最终解
{
for ( r = maxDepth - 1; r > i; r -- )
t[i] -= params[ i*step + r ] * x[r];
x[i] = t[i]/params[i*step+i];
}
delete[] s;
delete[] t;
return ;
}
/**********************Broken Line***************************/
// Quick Sort
int SingleSort( double *index, double *context, int start, int end )
{
if ( end - start < 1 ) return start;
int i = start, j = end;
double key = index[i];
double key_ = context[i];
while ( i < j )
{
while ( index[j] > key && j > i ) j --;
if ( index[j] < key )
{
index[i] = index[j];
context[i] = context[j];
}
while ( index[i] < key && j > i ) i ++;
if ( index[i] > key )
{
index[j] = index[i];
context[j] = context[i];
}
}
index[i] = key;
context[j] = key_;
return i;
}
void QuickSort( double *index, double *context, int start, int end )
{
if ( end - start < 1 ) return ; // important
int mid = SingleSort( index, context, start, end );
QuickSort( index, context, start, mid - 1 );
QuickSort( index, context, mid+ 1, end );
}
int CheckSequence( double *context, int start, int end, bool upTrend )
{
int i = start;
for ( ; i < end; i ++ )
{
if ( upTrend && context[i+1] < context[i] )
{
return i;
}
if ( !upTrend && context[i] < context[i+1] )
{
return i;
}
}
return -1;
}
// Form the broken line
bool BrokenLineOpt( Model *ptr )
{
if ( !ptr ) return false;
if ( !ptr->len || !ptr->px || !ptr->py ) return false;
// analyse the trend of points and get its approximate line
StraightLineOpt( ptr );
double k = ptr->lineParam[1], b = ptr->lineParam[0];
bool upTrend = ( k > 0 );
// sort the sequence by py
QuickSort( ptr->px, ptr->py, 0, ptr->len - 1 );
int oddPoint = 0;
while ( (oddPoint = CheckSequence( ptr->py, oddPoint, ptr->len -1, upTrend ) ) != -1 )
{
double formerErr = abs( k*ptr->px[oddPoint] + b - ptr->py[oddPoint] );
double laterErr = abs( k*ptr->px[oddPoint+1] + b - ptr->py[oddPoint+1] );
oddPoint = formerErr > laterErr ? oddPoint : oddPoint + 1;
// remove the odd point
memcpy( ptr->py + oddPoint, ptr->py + oddPoint + 1, sizeof(double) );
memcpy( ptr->px + oddPoint, ptr->px + oddPoint + 1, sizeof(double) );
ptr->len --;
oddPoint --;
}
memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
ptr->type = BrokenLine;
return true;
}
/**********************Lazy Learning***************************/
bool KNNOpt( Model *ptr )
{
// We do nothing as we say it's a lazy-learning method
// Only when predict() is called, the learning process is invoked
memset( ptr->lineParam, 0, sizeof(double)*MAX_POL_DEPTH );
ptr->type = KNNModel;
return true;
}
Util.cpp
#include "Revise.h"
int _tmain(int argc, _TCHAR* argv[])
{
double x[10], y[10];
for ( int i = 0; i < 10; i ++ )
{
x[i] = 10 - i;
y[i] = (i-2)*(i-2);
}
Model *model = CreateModel();
SetOptData( model, x, y, 10 );
Opt( model, BrokenLine );
double result = Predict( model, 1.5 );
ReleaseModel( &model );
return 0;
}
还没有评论,来说两句吧...