Сравните (утвердите равенство) две сложные структуры данных, содержащие массивы numpy в unittest
Я использую модуль Python unittest
и хочу проверить, равны ли две сложные структуры данных. Объектами могут быть списки dicts со всеми значениями: номерами, строками, контейнерами Python (списки/кортежи/dicts) и numpy
массивами. Последние являются причиной задавать вопрос, потому что я не могу просто сделать
self.assertEqual(big_struct1, big_struct2)
поскольку он создает
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
Я предполагаю, что для этого мне нужно написать свой собственный тест на равенство. Он должен работать для произвольных структур. Моя текущая идея - это рекурсивная функция, которая:
- пытается прямое сравнение текущего "node"
arg1
с соответствующим node of arg2
;
- Если исключение не создано, он перемещается (здесь также обрабатываются "конечные" узлы/листья);
- если
ValueError
пойман, идет глубже, пока не найдет numpy.array
;
- сравнивает массивы (например, как это).
Кажется, что проблематично отслеживать "соответствующие" узлы двух структур, но, возможно, zip
- это все, что мне нужно.
Вопрос: есть ли хорошие (более простые) альтернативы этому подходу? Может быть, numpy
содержит некоторые инструменты для этого? Если альтернативы не предложены, я реализую эту идею (если у меня не будет лучшей) и опубликую ответ.
P.S. У меня есть смутное чувство, что я мог бы рассмотреть вопрос, касающийся этой проблемы, но я не могу найти его сейчас.
P.P.S. Альтернативный подход - это функция, которая пересекает структуру и преобразует все numpy.array
в списки, но проще ли это реализовать? Кажется таким же для меня.
Изменить: Подклассификация numpy.ndarray
звучит очень многообещающе, но, очевидно, у меня нет двух сторон сравнения, жестко закодированных в тесте. Один из них, правда, действительно жестко закодирован, поэтому я могу:
- заполнить его пользовательскими подклассами
numpy.array
;
- измените
isinstance(other, SaneEqualityArray)
на isinstance(other, np.ndarray)
в jterrace ответ;
- всегда используйте его как LHS в сравнении.
Мои вопросы в этом отношении:
- Будет ли это работать (я имею в виду, это звучит для меня правильно, но, может быть, некоторые сложные случаи кросс не будут обрабатываться правильно)? Будет ли мой пользовательский объект всегда заканчиваться как LHS в рекурсивных проверках равенства, как я ожидаю?
- Опять же, есть ли лучшие способы (учитывая, что я получаю хотя бы одну из структур с реальными массивами
numpy
).
Изменить 2. Я пробовал это, рабочая версия (по-видимому) показана в этом ответе.
Ответы
Ответ 1
Итак, идея, проиллюстрированная jterrace, кажется, работает для меня с небольшой модификацией:
class SaneEqualityArray(np.ndarray):
def __eq__(self, other):
return (isinstance(other, np.ndarray) and self.shape == other.shape and
np.allclose(self, other))
Как я уже сказал, контейнер с этими объектами должен находиться в левой части проверки равенства. Я создаю объекты SaneEqualityArray
из существующего numpy.ndarray
следующим образом:
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
в соответствии с ndarray
сигнатурой конструктора:
ndarray(shape, dtype=float, buffer=None, offset=0,
strides=None, order=None)
Этот класс определен в наборе тестов и служит только для тестирования. RHS проверки равенства является фактическим объектом, возвращаемым проверенной функцией и содержит реальные объекты numpy.ndarray
.
P.S. Благодаря авторам обоих ответов, опубликованных до сих пор, они оба были очень полезными. Если кто-либо увидит какие-либо проблемы с этим подходом, я буду благодарен за ваши отзывы.
Ответ 2
Прокомментировал бы, но он слишком длинный...
Забавно, вы не можете использовать ==
для проверки того, являются ли массивы одинаковыми, я бы предложил вместо np.testing.assert_array_equal
.
- который проверяет тип, форму и т.д.,
- что не подходит для аккуратной математической математики
(float('nan') == float('nan')) == False
(нормальная последовательность python ==
имеет еще более интересный способ игнорировать это иногда, потому что она использует PyObject_RichCompareBool
, которая делает (для NaNs неверно) is
быстрая проверка (для тестирования, конечно, это идеально)...
- Существует также
assert_allclose
, поскольку равенство с плавающей запятой может стать очень сложным, если вы выполняете фактические вычисления, и обычно вы хотите получить почти одинаковые значения, поскольку значения могут стать зависимыми от оборудования или, возможно, случайными, в зависимости от того, что вы с ними делаете.
Я бы почти предложил попробовать сериализовать его с рассолом, если вы хотите что-то такое безумно вложенное, но это чересчур строго (а точка 3, конечно, полностью сломана), например, макет памяти вашего массива не имеет значения, но имеет значение для его сериализации.
Ответ 3
Функция assertEqual
будет вызывать метод объектов __eq__
, который должен обрабатывать сложные типы данных. Исключением является numpy, который не имеет разумного метода __eq__
. Используя подкласс numpy из этого вопроса, вы можете восстановить здравомыслие в отношении поведения равенства:
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __name__ == '__main__':
unittest.main()
Этот тест проходит.
Ответ 4
Я бы определил свой собственный метод assertNumpyArraysEqual(), который явно делает сравнение, которое вы хотите использовать. Таким образом, ваш производственный код не изменится, но вы можете сделать разумные утверждения в своих модульных тестах. Обязательно определите его в модуле, который включает __unittest = True
, чтобы он не включался в трассировку стека:
import numpy
__unittest = True
def assertNumpyArraysEqual(self, other):
if self.shape != other.shape:
raise AssertionError("Shapes don't match")
if not numpy.allclose(self, other)
raise AssertionError("Elements don't match!")
Ответ 5
Я столкнулся с той же проблемой и разработал функцию сравнения равенства на основе создания фиксированного хэша для объекта. Это дает дополнительное преимущество в том, что вы можете проверить, что объект, как и ожидалось, сравнивает его хэш с фиксированным, который хранится в коде.
Код (автономный файл python, здесь). Существуют две функции: fixed_hash_eq
, которая решает вашу проблему, и compute_fixed_hash
, что делает хэш из структуры. Тесты здесь
Здесь тест:
obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
Ответ 6
Основываясь на @dbw (с благодарностью), следующий метод, вставленный в подклассу тестового случая, хорошо работал у меня:
def assertNumpyArraysEqual(self,this,that,msg=''):
'''
modified from http://stackoverflow.com/a/15399475/5459638
'''
if this.shape != that.shape:
raise AssertionError("Shapes don't match")
if not np.allclose(this,that):
raise AssertionError("Elements don't match!")
Я использовал его как self.assertNumpyArraysEqual(this,that)
внутри моих тестовых примеров и работал как шарм.
Ответ 7
check numpy.testing.assert_almost_equal
, который вызывает AssertionError, если два элемента не равны до нужной точности, например:
import numpy.testing as npt
npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
np.array([1.0,2.33333334]), decimal=9)