Аналитическое решение для линейной регрессии с использованием Python vs. Julia
Использование примера из класса Andrew Ng (поиск параметров для линейной регрессии с использованием нормального уравнения):
С Python:
X = np.array([[1, 2104, 5, 1, 45], [1, 1416, 3, 2, 40], [1, 1534, 3, 2, 30], [1, 852, 2, 1, 36]])
y = np.array([[460], [232], [315], [178]])
θ = ((np.linalg.inv(X.T.dot(X))).dot(X.T)).dot(y)
print(θ)
Результат:
[[ 7.49398438e+02]
[ 1.65405273e-01]
[ -4.68750000e+00]
[ -4.79453125e+01]
[ -5.34570312e+00]]
С Джулией:
X = [1 2104 5 1 45; 1 1416 3 2 40; 1 1534 3 2 30; 1 852 2 1 36]
y = [460; 232; 315; 178]
θ = ((X' * X)^-1) * X' * y
Результат:
5-element Array{Float64,1}:
207.867
0.0693359
134.906
-77.0156
-7.81836
Кроме того, когда я несколько X от Julia, но не Python - θ, я получаю числа, близкие к y.
Я не могу понять, что я делаю неправильно. Спасибо!
Ответы
Ответ 1
Более гибкий подход на Python, без необходимости самостоятельно выполнять матричную алгебру, заключается в использовании numpy.linalg.lstsq
для выполнения регрессии:
In [29]: np.linalg.lstsq(X, y)
Out[29]:
(array([[ 188.40031942],
[ 0.3866255 ],
[ -56.13824955],
[ -92.9672536 ],
[ -3.73781915]]),
array([], dtype=float64),
4,
array([ 3.08487554e+03, 1.88409728e+01, 1.37100414e+00,
1.97618336e-01]))
(Сравните вектор решения с ответом @waTeim в Julia).
Вы можете увидеть источник плохого кондиционирования, распечатав обратный матричный код:
In [30]: np.linalg.inv(X.T.dot(X))
Out[30]:
array([[ -4.12181049e+13, 1.93633440e+11, -8.76643127e+13,
-3.06844458e+13, 2.28487459e+12],
[ 1.93633440e+11, -9.09646601e+08, 4.11827338e+11,
1.44148665e+11, -1.07338299e+10],
[ -8.76643127e+13, 4.11827338e+11, -1.86447963e+14,
-6.52609055e+13, 4.85956259e+12],
[ -3.06844458e+13, 1.44148665e+11, -6.52609055e+13,
-2.28427584e+13, 1.70095424e+12],
[ 2.28487459e+12, -1.07338299e+10, 4.85956259e+12,
1.70095424e+12, -1.26659193e+11]])
Eeep!
Взятие точечного произведения этого с X.T
приводит к катастрофической потере точности.
Ответ 2
Используя X ^ -1 vs псевдообратный
pinv (X), который соответствует псевдо-обратному, более широко применим, чем inv (X), которому X ^ -1 соответствует. Ни Julia, ни Python не умеют использовать inv, но в этом случае, по-видимому, Джулия делает лучше.
но если вы измените выражение на
julia> z=pinv(X'*X)*X'*y
5-element Array{Float64,1}:
188.4
0.386625
-56.1382
-92.9673
-3.73782
вы можете проверить, что X * z = y
julia> X*z
4-element Array{Float64,1}:
460.0
232.0
315.0
178.0
Ответ 3
Обратите внимание, что X
является матрицей 4x5 или в статистических терминах, что у вас меньше наблюдений, чем параметры для оценки. Поэтому задача наименьших квадратов имеет бесконечно много решений с суммой квадратов ошибок, точно равных нулю. В этом случае нормальные уравнения вам не помогут, потому что матрица X'X
является особой. Вместо этого вы должны просто найти решение для X*b=y
.
Большинство числовых систем линейной алгебры основаны на пакете FORTRAN LAPACK, который использует поворотную факторизацию QR для решения задачи X*b=y
. Поскольку существует бесконечно много решений, LAPACK выбирает решение с наименьшей нормой. В Julia вы можете получить это решение, просто написав
float(X)\y
(К сожалению, часть float
необходима прямо сейчас, но это изменится.)
В точной арифметике вы должны получить то же самое решение, что и выше, с одним из предложенных вами методов, но представление о проблеме с плавающей точкой представляет собой небольшие ошибки округления, и эти ошибки влияют на вычисленное решение. Эффект ошибок округления на решении намного больше при использовании нормальных уравнений по сравнению с использованием факторизации QR непосредственно на X
.
Это справедливо и в обычном случае, когда X
имеет больше строк, чем столбцы, поэтому часто рекомендуется избегать нормальных уравнений при решении проблем с наименьшими квадратами. Однако, когда X
имеет гораздо больше строк, чем столбцы, матрица X'X
относительно невелика. В этом случае гораздо быстрее решить проблему с нормальными уравнениями вместо использования QR-факторизации. Во многих статистических задачах дополнительная численная ошибка чрезвычайно мала по сравнению со статистической погрешностью, поэтому потеря точности из-за нормальных уравнений может быть просто проигнорирована.