Mean Shift(均值漂移)是基于密度的非参数聚类算法,其算法思想是假设不同簇类的数据集符合不同的概率密度分布,找到任一样本点密度增大的最快方向(最快方向的含义就是Mean Shift),样本密度高的区域对应于该分布的最大值,这些样本点最终会在局部密度最大值收敛,且收敛到相同局部最大值的点被认为是同一簇类的成员。(https://cloud.tencent.com/developer/article/1459530)
1.核密度估计
Mean Shift算法用核函数估计样本的密度,最常用的核函数是高斯核。它的工作原理是在数据集上的每一个样本点都设置一个核函数,然后对所有的核函数相加,得到数据集的核密度估计(kernel density estimation)。
function[clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag) %perform MeanShift Clustering of data using a flat kernel % % ---INPUT--- % dataPts - input data, (numDim x numPts) % bandWidth - is bandwidth parameter (scalar) % plotFlag - display output if 2 or 3 D (logical) % ---OUTPUT--- % clustCent - is locations of cluster centers (numDim x numClust) % data2cluster - for every data point which cluster it belongs to (numPts) % cluster2dataCell - for every cluster which points are in it (numClust) % % Bryan Feldman 02/24/06 % MeanShift first appears in % K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a % Density Function, with Applications in Pattern Recognition" %*** Check input **** if nargin < 2 error('no bandwidth specified') end if nargin < 3 plotFlag = true; plotFlag = false; end %**** Initialize stuff *** [numDim,numPts] = size(dataPts); numClust = 0; bandSq = bandWidth^2; initPtInds = 1:numPts; maxPos = max(dataPts,[],2); %biggest size in each dimension minPos = min(dataPts,[],2); %smallest size in each dimension boundBox = maxPos-minPos; %bounding box size sizeSpace = norm(boundBox); %indicator of size of data space stopThresh = 1e-3*bandWidth; %when mean has converged clustCent = []; %center of clust beenVisitedFlag = zeros(1,numPts); %track if a points been seen already numInitPts = numPts; %number of points to posibaly use as initilization points clusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership while numInitPts tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point stInd = initPtInds(tempInd); %use this point as start of mean myMean = dataPts(:,stInd); % intilize mean to this points location myMembers = []; % points that will get added to this cluster thisClusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership while1%loop untill convergence sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active inInds = find(sqDistToAll < bandSq); %points within bandWidth thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster myOldMean = myMean; %save the old mean myMean = mean(dataPts(:,inInds),2); %compute the new mean myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster beenVisitedFlag(myMembers) = 1; %mark that these points have been visited %*** plot stuff **** if plotFlag figure(1),clf,hold on if numDim == 2 plot(dataPts(1,:),dataPts(2,:),'.') plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys') plot(myMean(1),myMean(2),'go') plot(myOldMean(1),myOldMean(2),'rd') pause end end %**** if mean doesn't move much stop this cluster *** if norm(myMean-myOldMean) < stopThresh %check for merge posibilities mergeWith = 0; for cN = 1:numClust distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max if distToOther < bandWidth/2%if its within bandwidth/2 merge new and old mergeWith = cN; break; end end if mergeWith > 0% something to merge clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones) %clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster else%its a new cluster numClust = numClust+1; %increment clusters clustCent(:,numClust) = myMean; %record the mean %clustMembsCell{numClust} = myMembers; %store my members clusterVotes(numClust,:) = thisClusterVotes; end break; end end initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited numInitPts = length(initPtInds); %number of active points in set end [val,data2cluster] = max(clusterVotes,[],1); %a point belongs to the cluster with the most votes %*** If they want the cluster2data cell find it for them if nargout > 2 cluster2dataCell = cell(numClust,1); for cN = 1:numClust myMembers = find(data2cluster == cN); cluster2dataCell{cN} = myMembers; end end
% plot(myClustCen(1),myClustCen(2),'o','MarkerEdgeColor','k','MarkerFaceColor',cVec(k), 'MarkerSize',10) end forj=1:size(data,1) text(data(j,2)+0.2,data(j,1)-0.2,num2str(j)); end
function[clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag) %perform MeanShift Clustering of data using a flat kernel % % ---INPUT--- % dataPts - input data, (numDim x numPts) % bandWidth - is bandwidth parameter (scalar) % plotFlag - display output if 2 or 3 D (logical) % ---OUTPUT--- % clustCent - is locations of cluster centers (numDim x numClust) % data2cluster - for every data point which cluster it belongs to (numPts) % cluster2dataCell - for every cluster which points are in it (numClust) % % Bryan Feldman 02/24/06 % MeanShift first appears in % K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a % Density Function, with Applications in Pattern Recognition" %*** Check input **** if nargin < 2 error('no bandwidth specified') end if nargin < 3 plotFlag = true; plotFlag = false; end
%**** Initialize stuff *** [numDim,numPts] = size(dataPts); numClust = 0; bandSq = bandWidth^2; initPtInds = 1:numPts; maxPos = max(dataPts,[],2); %biggest size in each dimension minPos = min(dataPts,[],2); %smallest size in each dimension boundBox = maxPos-minPos; %bounding box size sizeSpace = norm(boundBox); %indicator of size of data space stopThresh = 1e-3*bandWidth; %when mean has converged clustCent = []; %center of clust beenVisitedFlag = zeros(1,numPts); %track if a points been seen already numInitPts = numPts; %number of points to posibaly use as initilization points clusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership nowNum = 1; outCom = zeros(1,numPts); while numInitPts
tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point % stInd = initPtInds(tempInd); %use this point as start of mean stInd = min(initPtInds); %use this point as start of mean myMean = dataPts(:,stInd); % intilize mean to this points location myMembers = []; % points that will get added to this cluster thisClusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership % beyondMe=[]; while1%loop untill convergence sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active inInds = find(sqDistToAll < bandSq); %points within bandWidth
if (length(inInds)>1 && min(outCom(inInds))==0)% 2,6 beyondMe = inInds; beyondMe(beyondMe == initPtInds(1)) = []; if(find(outCom(beyondMe) >0)) thisCluster = min(outCom(find(outCom(beyondMe) >0))); outCom(stInd)=outCom(beyondMe(find(outCom(beyondMe) >0))); else% 4 outCom(stInd)=nowNum; outCom(beyondMe)=nowNum; nowNum = nowNum + 1; end elseiflength(inInds)==1% 1、3、5 outCom(stInd)=nowNum; nowNum = nowNum + 1; end thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster myOldMean = myMean; %save the old mean myMean = mean(dataPts(:,inInds),2); %compute the new mean myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster beenVisitedFlag(myMembers) = 1; %mark that these points have been visited %**** if mean doesn't move much stop this cluster *** if norm(myMean-myOldMean) < stopThresh %check for merge posibilities mergeWith = 0; for cN = 1:numClust distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max if distToOther < bandWidth/2%if its within bandwidth/2 merge new and old mergeWith = cN; break; end end if mergeWith > 0% something to merge clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones) %clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster else%its a new cluster numClust = outCom(stInd); %increment clusters clustCent(:,numClust) = myMean; %record the mean %clustMembsCell{numClust} = myMembers; %store my members clusterVotes(numClust,:) = thisClusterVotes; end
break; end end initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited numInitPts = length(initPtInds); %number of active points in set [val,data2cluster] = max(clusterVotes,[],1); end
data2cluster = outCom;
%*** If they want the cluster2data cell find it for them if nargout > 2 cluster2dataCell = cell(numClust,1); for cN = 1:numClust myMembers = find(data2cluster == cN); cluster2dataCell{cN} = myMembers; end end
function[clustCent,data2cluster,cluster2dataCell] = MeanShiftCluster(dataPts,bandWidth,plotFlag) %perform MeanShift Clustering of data using a flat kernel % % ---INPUT--- % dataPts - input data, (numDim x numPts) % bandWidth - is bandwidth parameter (scalar) % plotFlag - display output if 2 or 3 D (logical) % ---OUTPUT--- % clustCent - is locations of cluster centers (numDim x numClust) % data2cluster - for every data point which cluster it belongs to (numPts) % cluster2dataCell - for every cluster which points are in it (numClust) % % Bryan Feldman 02/24/06 % MeanShift first appears in % K. Funkunaga and L.D. Hosteler, "The Estimation of the Gradient of a % Density Function, with Applications in Pattern Recognition" %*** Check input **** if nargin < 2 error('no bandwidth specified') end if nargin < 3 plotFlag = true; plotFlag = false; end
%**** Initialize stuff *** [numDim,numPts] = size(dataPts); numClust = 0; bandSq = bandWidth^2; initPtInds = 1:numPts; maxPos = max(dataPts,[],2); %biggest size in each dimension minPos = min(dataPts,[],2); %smallest size in each dimension boundBox = maxPos-minPos; %bounding box size sizeSpace = norm(boundBox); %indicator of size of data space stopThresh = 1e-3*bandWidth; %when mean has converged clustCent = []; %center of clust beenVisitedFlag = zeros(1,numPts); %track if a points been seen already numInitPts = numPts; %number of points to posibaly use as initilization points clusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership nowNum = 1; outCom = zeros(1,numPts); while numInitPts
tempInd = ceil( (numInitPts-1e-6)*rand); %pick a random seed point stInd = initPtInds(tempInd); %use this point as start of mean myMean = dataPts(:,stInd); % intilize mean to this points location myMembers = []; % points that will get added to this cluster thisClusterVotes = zeros(1,numPts); %used to resolve conflicts on cluster membership % beyondMe=[]; while1%loop untill convergence sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); %dist squared from mean to all points still active inInds = find(sqDistToAll < bandSq); %points within bandWidth
if (length(inInds)>1 && min(outCom(inInds))==0)% 2,6 beyondMe = inInds; beyondMe(beyondMe == stInd) = []; if(find(outCom(beyondMe) >0)) thisCluster = min(outCom(find(outCom(beyondMe) >0))); outCom(stInd)=min(outCom(beyondMe(find(outCom(beyondMe) >0)))); outCom(beyondMe)=min(outCom(beyondMe(find(outCom(beyondMe) >0)))); else% 4 outCom(stInd)=nowNum; outCom(beyondMe)=nowNum; nowNum = nowNum + 1; end elseiflength(inInds)==1% 1、3、5 outCom(stInd)=nowNum; nowNum = nowNum + 1; end thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; %add a vote for all the in points belonging to this cluster myOldMean = myMean; %save the old mean myMean = mean(dataPts(:,inInds),2); %compute the new mean myMembers = [myMembers inInds]; %add any point within bandWidth to the cluster beenVisitedFlag(myMembers) = 1; %mark that these points have been visited %**** if mean doesn't move much stop this cluster *** if norm(myMean-myOldMean) < stopThresh %check for merge posibilities mergeWith = 0; for cN = 1:numClust distToOther = norm(myMean-clustCent(:,cN)); %distance from posible new clust max to old clust max if distToOther < bandWidth/2%if its within bandwidth/2 merge new and old mergeWith = cN; break; end end if mergeWith > 0% something to merge clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %record the max as the mean of the two merged (I know biased twoards new ones) %clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); %record which points inside clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; %add these votes to the merged cluster else%its a new cluster numClust = outCom(stInd); %increment clusters clustCent(:,numClust) = myMean; %record the mean %clustMembsCell{numClust} = myMembers; %store my members clusterVotes(numClust,:) = thisClusterVotes; end
break; end end initPtInds = find(beenVisitedFlag == 0); %we can initialize with any of the points not yet visited numInitPts = length(initPtInds); %number of active points in set [val,data2cluster] = max(clusterVotes,[],1); end
data2cluster = outCom;
%*** If they want the cluster2data cell find it for them if nargout > 2 cluster2dataCell = cell(numClust,1); for cN = 1:numClust myMembers = find(data2cluster == cN); cluster2dataCell{cN} = myMembers; end end