Ответ 1
Если вы пытаетесь предсказать одно значение из двух других, то вы должны использовать lstsq
с аргументом a
как ваши независимые переменные (плюс столбец 1 для оценки перехвата) и b
как ваш зависимая переменная.
Если, с другой стороны, вы просто хотите получить наилучшую подходящую строку для данных, то есть линию, которая, если вы проецируете данные на нее, минимизирует квадрат расстояния между реальной точкой и ее проекцией, тогда то, что вы хотите, является первым основным компонентом.
Один из способов определить это линия, вектор направления которой является собственным вектором ковариационной матрицы, соответствующей самому большому собственному значению, которое проходит через среднее значение ваших данных. Тем не менее, eig(cov(data))
- действительно плохой способ вычислить его, поскольку он делает много ненужных вычислений и копирования и потенциально менее точен, чем при использовании svd
. См. Ниже:
import numpy as np
# Generate some data that lies along a line
x = np.mgrid[-2:5:120j]
y = np.mgrid[1:9:120j]
z = np.mgrid[-5:3:120j]
data = np.concatenate((x[:, np.newaxis],
y[:, np.newaxis],
z[:, np.newaxis]),
axis=1)
# Perturb with some Gaussian noise
data += np.random.normal(size=data.shape) * 0.4
# Calculate the mean of the points, i.e. the 'center' of the cloud
datamean = data.mean(axis=0)
# Do an SVD on the mean-centered data.
uu, dd, vv = np.linalg.svd(data - datamean)
# Now vv[0] contains the first principal component, i.e. the direction
# vector of the 'best fit' line in the least squares sense.
# Now generate some points along this best fit line, for plotting.
# I use -7, 7 since the spread of the data is roughly 14
# and we want it to have mean 0 (like the points we did
# the svd on). Also, it a straight line, so we only need 2 points.
linepts = vv[0] * np.mgrid[-7:7:2j][:, np.newaxis]
# shift by the mean to get the line in the right place
linepts += datamean
# Verify that everything looks right.
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as m3d
ax = m3d.Axes3D(plt.figure())
ax.scatter3D(*data.T)
ax.plot3D(*linepts.T)
plt.show()
Вот что это выглядит: