图像处理之高斯混合模型

太过爱你忘了你带给我的痛 2022-06-15 06:52 307阅读 0赞

图像处理之高斯混合模型

一:概述

高斯混合模型(GMM)在图像分割、对象识别、视频分析等方面均有应用,对于任意给定的数据样本集合,根据其分布概率, 可以计算每个样本数据向量的概率分布,从而根据概率分布对其进行分类,但是这些概率分布是混合在一起的,要从中分离出单个样本的概率分布就实现了样本数据聚类,而概率分布描述我们可以使用高斯函数实现,这个就是高斯混合模型-GMM。

20170526170138039

20170526170352714

这种方法也称为D-EM即基于距离的期望最大化。

三:算法步骤

1.初始化变量定义-指定的聚类数目K与数据维度D

2.初始化均值、协方差、先验概率分布

3.迭代E-M步骤

- E步计算期望

- M步更新均值、协方差、先验概率分布

-检测是否达到停止条件(最大迭代次数与最小误差满足),达到则退出迭代,否则继续E-M步骤

4.打印最终分类结果

四:代码实现

  1. package com.gloomyfish.image.gmm;
  2. import java.util.ArrayList;
  3. import java.util.Arrays;
  4. import java.util.List;
  5. /**
  6. *
  7. * @author gloomy fish
  8. *
  9. */
  10. public class GMMProcessor {
  11. public final static double MIN_VAR = 1E-10;
  12. public static double[] samples = new double[]{10, 9, 4, 23, 13, 16, 5, 90, 100, 80, 55, 67, 8, 93, 47, 86, 3};
  13. private int dimNum;
  14. private int mixNum;
  15. private double[] weights;
  16. private double[][] m_means;
  17. private double[][] m_vars;
  18. private double[] m_minVars;
  19. /***
  20. *
  21. * @param m_dimNum - 每个样本数据的维度, 对于图像每个像素点来说是RGB三个向量
  22. * @param m_mixNum - 需要分割为几个部分,即高斯混合模型中高斯模型的个数
  23. */
  24. public GMMProcessor(int m_dimNum, int m_mixNum) {
  25. dimNum = m_dimNum;
  26. mixNum = m_mixNum;
  27. weights = new double[mixNum];
  28. m_means = new double[mixNum][dimNum];
  29. m_vars = new double[mixNum][dimNum];
  30. m_minVars = new double[dimNum];
  31. }
  32. /***
  33. * data - 需要处理的数据
  34. * @param data
  35. */
  36. public void process(double[] data) {
  37. int m_maxIterNum = 100;
  38. double err = 0.001;
  39. boolean loop = true;
  40. double iterNum = 0;
  41. double lastL = 0;
  42. double currL = 0;
  43. int unchanged = 0;
  44. initParameters(data);
  45. int size = data.length;
  46. double[] x = new double[dimNum];
  47. double[][] next_means = new double[mixNum][dimNum];
  48. double[] next_weights = new double[mixNum];
  49. double[][] next_vars = new double[mixNum][dimNum];
  50. List<DataNode> cList = new ArrayList<DataNode>();
  51. while(loop) {
  52. Arrays.fill(next_weights, 0);
  53. cList.clear();
  54. for(int i=0; i<mixNum; i++) {
  55. Arrays.fill(next_means[i], 0);
  56. Arrays.fill(next_vars[i], 0);
  57. }
  58. lastL = currL;
  59. currL = 0;
  60. for (int k = 0; k < size; k++)
  61. {
  62. for(int j=0;j<dimNum;j++)
  63. x[j]=data[k*dimNum+j];
  64. double p = getProbability(x); // 总的概率密度分布
  65. DataNode dn = new DataNode(x);
  66. dn.index = k;
  67. cList.add(dn);
  68. double maxp = 0;
  69. for (int j = 0; j < mixNum; j++)
  70. {
  71. double pj = getProbability(x, j) * weights[j] / p; // 每个分类的概率密度分布百分比
  72. if(maxp < pj) {
  73. maxp = pj;
  74. dn.cindex = j;
  75. }
  76. next_weights[j] += pj; // 得到后验概率
  77. for (int d = 0; d < dimNum; d++)
  78. {
  79. next_means[j][d] += pj * x[d];
  80. next_vars[j][d] += pj* x[d] * x[d];
  81. }
  82. }
  83. currL += (p > 1E-20) ? Math.log10(p) : -20;
  84. }
  85. currL /= size;
  86. // Re-estimation: generate new weight, means and variances.
  87. for (int j = 0; j < mixNum; j++)
  88. {
  89. weights[j] = next_weights[j] / size;
  90. if (weights[j] > 0)
  91. {
  92. for (int d = 0; d < dimNum; d++)
  93. {
  94. m_means[j][d] = next_means[j][d] / next_weights[j];
  95. m_vars[j][d] = next_vars[j][d] / next_weights[j] - m_means[j][d] * m_means[j][d];
  96. if (m_vars[j][d] < m_minVars[d])
  97. {
  98. m_vars[j][d] = m_minVars[d];
  99. }
  100. }
  101. }
  102. }
  103. // Terminal conditions
  104. iterNum++;
  105. if (Math.abs(currL - lastL) < err * Math.abs(lastL))
  106. {
  107. unchanged++;
  108. }
  109. if (iterNum >= m_maxIterNum || unchanged >= 3)
  110. {
  111. loop = false;
  112. }
  113. }
  114. // print result
  115. System.out.println("=================最终结果=================");
  116. for(int i=0; i<mixNum; i++) {
  117. for(int k=0; k<dimNum; k++) {
  118. System.out.println("[" + i + "]: ");
  119. System.out.println("means : " + m_means[i][k]);
  120. System.out.println("var : " + m_vars[i][k]);
  121. System.out.println();
  122. }
  123. }
  124. // 获取分类
  125. for(int i=0; i<size; i++) {
  126. System.out.println("data[" + i + "]=" + data[i] + " cindex : " + cList.get(i).cindex);
  127. }
  128. }
  129. /**
  130. *
  131. * @param data
  132. */
  133. private void initParameters(double[] data) {
  134. // 随机方法初始化均值
  135. int size = data.length;
  136. for (int i = 0; i < mixNum; i++)
  137. {
  138. for (int d = 0; d < dimNum; d++)
  139. {
  140. m_means[i][d] = data[(int)(Math.random()*size)];
  141. }
  142. }
  143. // 根据均值获取分类
  144. int[] types = new int[size];
  145. for (int k = 0; k < size; k++)
  146. {
  147. double max = 0;
  148. for (int i = 0; i < mixNum; i++)
  149. {
  150. double v = 0;
  151. for(int j=0;j<dimNum;j++) {
  152. v += Math.abs(data[k*dimNum+j] - m_means[i][j]);
  153. }
  154. if(v > max) {
  155. max = v;
  156. types[k] = i;
  157. }
  158. }
  159. }
  160. double[] counts = new double[mixNum];
  161. for(int i=0; i<types.length; i++) {
  162. counts[types[i]]++;
  163. }
  164. // 计算先验概率权重
  165. for (int i = 0; i < mixNum; i++)
  166. {
  167. weights[i] = counts[i] / size;
  168. }
  169. // 计算每个分类的方差
  170. int label = -1;
  171. int[] Label = new int[size];
  172. double[] overMeans = new double[dimNum];
  173. double[] x = new double[dimNum];
  174. for (int i = 0; i < size; i++)
  175. {
  176. for(int j=0;j<dimNum;j++)
  177. x[j]=data[i*dimNum+j];
  178. label=Label[i];
  179. // Count each Gaussian
  180. counts[label]++;
  181. for (int d = 0; d < dimNum; d++)
  182. {
  183. m_vars[label][d] += (x[d] - m_means[types[i]][d]) * (x[d] - m_means[types[i]][d]);
  184. }
  185. // Count the overall mean and variance.
  186. for (int d = 0; d < dimNum; d++)
  187. {
  188. overMeans[d] += x[d];
  189. m_minVars[d] += x[d] * x[d];
  190. }
  191. }
  192. // Compute the overall variance (* 0.01) as the minimum variance.
  193. for (int d = 0; d < dimNum; d++)
  194. {
  195. overMeans[d] /= size;
  196. m_minVars[d] = Math.max(MIN_VAR, 0.01 * (m_minVars[d] / size - overMeans[d] * overMeans[d]));
  197. }
  198. // Initialize each Gaussian.
  199. for (int i = 0; i < mixNum; i++)
  200. {
  201. if (weights[i] > 0)
  202. {
  203. for (int d = 0; d < dimNum; d++)
  204. {
  205. m_vars[i][d] = m_vars[i][d] / counts[i];
  206. // A minimum variance for each dimension is required.
  207. if (m_vars[i][d] < m_minVars[d])
  208. {
  209. m_vars[i][d] = m_minVars[d];
  210. }
  211. }
  212. }
  213. }
  214. System.out.println("=================初始化=================");
  215. for(int i=0; i<mixNum; i++) {
  216. for(int k=0; k<dimNum; k++) {
  217. System.out.println("[" + i + "]: ");
  218. System.out.println("means : " + m_means[i][k]);
  219. System.out.println("var : " + m_vars[i][k]);
  220. System.out.println();
  221. }
  222. }
  223. }
  224. /***
  225. *
  226. * @param sample - 采样数据点
  227. * @return 该点总概率密度分布可能性
  228. */
  229. public double getProbability(double[] sample)
  230. {
  231. double p = 0;
  232. for (int i = 0; i < mixNum; i++)
  233. {
  234. p += weights[i] * getProbability(sample, i);
  235. }
  236. return p;
  237. }
  238. /**
  239. * Gaussian Model -> PDF
  240. * @param x - 表示采样数据点向量
  241. * @param j - 表示对对应的第J个分类的概率密度分布
  242. * @return - 返回概率密度分布可能性值
  243. */
  244. public double getProbability(double[] x, int j)
  245. {
  246. double p = 1;
  247. for (int d = 0; d < dimNum; d++)
  248. {
  249. p *= 1 / Math.sqrt(2 * 3.14159 * m_vars[j][d]);
  250. p *= Math.exp(-0.5 * (x[d] - m_means[j][d]) * (x[d] - m_means[j][d]) / m_vars[j][d]);
  251. }
  252. return p;
  253. }
  254. public static void main(String[] args) {
  255. GMMProcessor filter = new GMMProcessor(1, 2);
  256. filter.process(samples);
  257. }
  258. }

结构类DataNode

  1. package com.gloomyfish.image.gmm;
  2. public class DataNode {
  3. public int cindex; // cluster
  4. public int index;
  5. public double[] value;
  6. public DataNode(double[] v) {
  7. this.value = v;
  8. cindex = -1;
  9. index = -1;
  10. }
  11. }

五:结果

SouthEast

这里初始中心均值的方法我是通过随机数来实现,GMM算法运行结果跟初始化有很大关系,常见初始化中心点的方法是通过K-Means来计算出中心点。大家可以尝试修改代码基于K-Means初始化参数,我之所以选择随机参数初始,主要是为了省事!

不炒作概念,只分享干货!

请继续关注本博客!

发表评论

表情:
评论列表 (有 0 条评论,307人围观)

还没有评论,来说两句吧...

相关阅读

    相关 图像处理混合模型

    一、高斯混合模型 现有的图像中目标的分类常用深度学习模型处理,但是深度学习需要大量模型处理。对于明显提取的目标,常常有几个明显特征,利用这几个明显特征使用少量图片便可以完

    相关 混合模型(GMM)

    1. 前言 高斯混合模型是使用高斯分布对原始数据进行估计,其中高斯函数的均值 μ \\mu μ和方差 σ \\sigma σ以及各个高斯函数分量占的比例 α \\alph