Ответ 1
Достаточно просто написать собственную реализацию алгоритма EM. Это также даст вам хорошую интуицию в этом процессе. Я предполагаю, что ковариация известна и что предыдущие вероятности компонентов равны и соответствуют только средствам.
Класс будет выглядеть так (в Python 3):
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
class FixedCovMixture:
""" The model to estimate gaussian mixture with fixed covariance matrix. """
def __init__(self, n_components, cov, max_iter=100, random_state=None, tol=1e-10):
self.n_components = n_components
self.cov = cov
self.random_state = random_state
self.max_iter = max_iter
self.tol=tol
def fit(self, X):
# initialize the process:
np.random.seed(self.random_state)
n_obs, n_features = X.shape
self.mean_ = X[np.random.choice(n_obs, size=self.n_components)]
# make EM loop until convergence
i = 0
for i in range(self.max_iter):
new_centers = self.updated_centers(X)
if np.sum(np.abs(new_centers-self.mean_)) < self.tol:
break
else:
self.mean_ = new_centers
self.n_iter_ = i
def updated_centers(self, X):
""" A single iteration """
# E-step: estimate probability of each cluster given cluster centers
cluster_posterior = self.predict_proba(X)
# M-step: update cluster centers as weighted average of observations
weights = (cluster_posterior.T / cluster_posterior.sum(axis=1)).T
new_centers = np.dot(weights, X)
return new_centers
def predict_proba(self, X):
likelihood = np.stack([multivariate_normal.pdf(X, mean=center, cov=self.cov)
for center in self.mean_])
cluster_posterior = (likelihood / likelihood.sum(axis=0))
return cluster_posterior
def predict(self, X):
return np.argmax(self.predict_proba(X), axis=0)
В данных, подобных вашей, модель будет сходиться быстро:
np.random.seed(1)
X = np.random.normal(size=(100,2), scale=3)
X[50:] += (10, 5)
model = FixedCovMixture(2, cov=[[3,0],[0,3]], random_state=1)
model.fit(X)
print(model.n_iter_, 'iterations')
print(model.mean_)
plt.scatter(X[:,0], X[:,1], s=10, c=model.predict(X))
plt.scatter(model.mean_[:,0], model.mean_[:,1], s=100, c='k')
plt.axis('equal')
plt.show();
и вывод
11 iterations
[[9.92301067 4.62282807]
[0.09413883 0.03527411]]
Вы можете видеть, что расчетные центры ((9.9, 4.6)
и (0.09, 0.03)
) близки к истинным центрам ((10, 5)
и (0, 0)
).