Бинарная классификация PyTorch - та же структура сети, "более простые" данные, но худшая производительность?
Чтобы разобраться с PyTorch (и углубленным изучением в целом), я начал с работы с некоторыми базовыми классификационными примерами. Одним из таких примеров была классификация нелинейного набора данных, созданного с использованием sklearn (полный код доступен в виде блокнота здесь)
n_pts = 500
X, y = datasets.make_circles(n_samples=n_pts, random_state=123, noise=0.1, factor=0.2)
x_data = torch.FloatTensor(X)
y_data = torch.FloatTensor(y.reshape(500, 1))
Затем это точно классифицируется с использованием довольно простой нейронной сети
class Model(nn.Module):
def __init__(self, input_size, H1, output_size):
super().__init__()
self.linear = nn.Linear(input_size, H1)
self.linear2 = nn.Linear(H1, output_size)
def forward(self, x):
x = torch.sigmoid(self.linear(x))
x = torch.sigmoid(self.linear2(x))
return x
def predict(self, x):
pred = self.forward(x)
if pred >= 0.5:
return 1
else:
return 0
Поскольку меня интересуют данные о состоянии здоровья, я решил попробовать использовать ту же структуру сети, чтобы классифицировать некоторые базовые наборы данных реального мира. Я взял данные о частоте сердечных сокращений для одного пациента из здесь и изменил их, чтобы все значения> 91 были помечены как аномалии (например, a 1
и все & lt; = 91 помечено как 0
). Это совершенно произвольно, но я просто хотел посмотреть, как будет работать классификация. Полный блокнот для этого примера здесь.
Для меня не интуитивно понятно, почему первый пример достигает потери 0,0016 после 1000 эпох, в то время как второй пример достигает потери 0,4296 после 10000 эпох
Возможно, я наивен, думая, что пример сердечного ритма будет гораздо легче классифицировать. Любое понимание, которое поможет мне понять, почему это не то, что я вижу, было бы замечательно!
Ответы
Ответ 1
TL; DR
Ваши входные данные не нормализованы.
- используйте
x_data = (x_data - x_data.mean()) / x_data.std()
- увеличить скорость обучения
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
Вы получите
сходимость всего за 1000 итераций.
Подробнее
Основное различие между этими двумя примерами состоит в том, что данные x
в первом примере сосредоточены вокруг (0, 0) и имеют очень низкую дисперсию.
С другой стороны, данные во втором примере сосредоточены вокруг 92 и имеют относительно большую дисперсию.
Это начальное смещение в данных не учитывается, когда вы случайно инициализируете весовые коэффициенты, что делается на основе предположения о том, что входные данные примерно нормально распределены вокруг нуля.
Для процесса оптимизации практически невозможно компенсировать это грубое отклонение - таким образом, модель застревает в неоптимальном решении.
После нормализации входных данных путем вычитания среднего значения и деления на стандартное значение процесс оптимизации снова становится стабильным и быстро сходится к хорошему решению.
Более подробную информацию о нормализации ввода и инициализации весов вы можете прочитать в разделе 2.2 в статье He et al. Углубление в выпрямители: превосходство уровня человеческого уровня в классификации ImageNet (ICCV 2015).
Что если я не могу нормализовать данные?
Если по какой-то причине вы не можете заранее рассчитать средние и стандартные данные, вы все равно можете использовать nn.BatchNorm1d
для оценки и нормализации данных как части процесса обучения. Например,
class Model(nn.Module):
def __init__(self, input_size, H1, output_size):
super().__init__()
self.bn = nn.BatchNorm1d(input_size) # adding batchnorm
self.linear = nn.Linear(input_size, H1)
self.linear2 = nn.Linear(H1, output_size)
def forward(self, x):
x = torch.sigmoid(self.linear(self.bn(x))) # batchnorm the input x
x = torch.sigmoid(self.linear2(x))
return x
Эта модификация без каких-либо изменений входных данных дает сходную конвергенцию только после 1000 эпох:
Небольшой комментарий
Для обеспечения стабильности чисел лучше использовать nn.BCEWithLogitsLoss
вместо nn.BCELoss
. Для этого вам необходимо удалить torch.sigmoid
из выхода forward()
, sigmoid
будет вычислен внутри потерь.
См., Например, эту ветку относительно связанной сигмоидальной потери + перекрестной энтропии для двоичных предсказаний.
Ответ 2
Давайте начнем с понимания того, как работают нейронные сети, нейронные сети наблюдают закономерности, отсюда и необходимость в больших наборах данных. В случае примера два, какой шаблон вы намереваетесь найти, это когда if HR < 91: label = 0
, это условие if может быть представлено формулой sigmoid ((HR-91) * 1), если вы подключите различные значения в формуле вы можете видеть, что все значения <91, метка 0 и другие метка 1. Я вывел эту формулу, и она может быть чем угодно, пока она дает правильные значения.
В основном, мы применяем формулу wx + b, где x в наших входных данных, и мы изучаем значения для w и b. Теперь изначально все значения являются случайными, поэтому получение значения b от 1030131190 (случайное значение) до 98 может быть быстрым, поскольку потеря велика, скорость обучения позволяет значениям быстро перемещаться. Но как только вы достигнете 98, ваша потеря уменьшается, и когда вы применяете скорость обучения, требуется больше времени, чтобы приблизиться к 91, отсюда и медленное уменьшение потери. По мере приближения значений предпринимаемые шаги становятся еще медленнее.
Это можно подтвердить с помощью значений потерь, они постоянно уменьшаются, вначале замедление выше, но затем становится меньше. Ваша сеть все еще учится, но медленно.
Следовательно, в глубоком обучении вы используете этот метод, называемый ступенчатой скоростью обучения, при которой с увеличением эпох вы снижаете скорость обучения, чтобы ваше обучение было быстрее.