Хранение малоразмерной матрицы в HDF5 (PyTables)
У меня возникли проблемы с сохранением numpy csr_matrix с PyTables. Я получаю эту ошибку:
TypeError: objects of type ``csr_matrix`` are not supported in this context, sorry; supported objects are: NumPy array, record or scalar; homogeneous list or tuple, integer, float, complex or string
Мой код:
f = tables.openFile(path,'w')
atom = tables.Atom.from_dtype(self.count_vector.dtype)
ds = f.createCArray(f.root, 'count', atom, self.count_vector.shape)
ds[:] = self.count_vector
f.close()
Любые идеи?
Спасибо
Ответы
Ответ 1
CSR-матрица может быть полностью восстановлена из ее атрибутов data
, indices
и indptr
. Это просто регулярные массивы numpy, поэтому не должно быть проблем с их хранением в виде трех отдельных массивов в pytables, а затем их возврата к конструктору csr_matrix
. См. scipy docs.
Изменить: Ответ Pietro указал, что элемент shape
также должен быть сохранен
Ответ 2
Ответ DaveP почти прав... но может вызвать проблемы для очень разреженных матриц: если последний столбец или строка (строки) пустые, они отбрасываются. Поэтому, чтобы быть уверенным, что все работает, атрибут "shape" также должен быть сохранен.
Это код, который я регулярно использую:
import tables as tb
from numpy import array
from scipy import sparse
def store_sparse_mat(m, name, store='store.h5'):
msg = "This code only works for csr matrices"
assert(m.__class__ == sparse.csr.csr_matrix), msg
with tb.openFile(store,'a') as f:
for par in ('data', 'indices', 'indptr', 'shape'):
full_name = '%s_%s' % (name, par)
try:
n = getattr(f.root, full_name)
n._f_remove()
except AttributeError:
pass
arr = array(getattr(m, par))
atom = tb.Atom.from_dtype(arr.dtype)
ds = f.createCArray(f.root, full_name, atom, arr.shape)
ds[:] = arr
def load_sparse_mat(name, store='store.h5'):
with tb.openFile(store) as f:
pars = []
for par in ('data', 'indices', 'indptr', 'shape'):
pars.append(getattr(f.root, '%s_%s' % (name, par)).read())
m = sparse.csr_matrix(tuple(pars[:3]), shape=pars[3])
return m
Тривиально адаптировать его к csc-матрицам.
Ответ 3
Я обновил Pietro Battiston отличный ответ для Python 3.6 и PyTables 3.x, поскольку некоторые имена функций PyTables изменились при обновлении с 2.x.
import numpy as np
from scipy import sparse
import tables
def store_sparse_mat(M, name, filename='store.h5'):
"""
Store a csr matrix in HDF5
Parameters
----------
M : scipy.sparse.csr.csr_matrix
sparse matrix to be stored
name: str
node prefix in HDF5 hierarchy
filename: str
HDF5 filename
"""
assert(M.__class__ == sparse.csr.csr_matrix), 'M must be a csr matrix'
with tables.open_file(filename, 'a') as f:
for attribute in ('data', 'indices', 'indptr', 'shape'):
full_name = f'{name}_{attribute}'
# remove existing nodes
try:
n = getattr(f.root, full_name)
n._f_remove()
except AttributeError:
pass
# add nodes
arr = np.array(getattr(M, attribute))
atom = tables.Atom.from_dtype(arr.dtype)
ds = f.create_carray(f.root, full_name, atom, arr.shape)
ds[:] = arr
def load_sparse_mat(name, filename='store.h5'):
"""
Load a csr matrix from HDF5
Parameters
----------
name: str
node prefix in HDF5 hierarchy
filename: str
HDF5 filename
Returns
----------
M : scipy.sparse.csr.csr_matrix
loaded sparse matrix
"""
with tables.open_file(filename) as f:
# get nodes
attributes = []
for attribute in ('data', 'indices', 'indptr', 'shape'):
attributes.append(getattr(f.root, f'{name}_{attribute}').read())
# construct sparse matrix
M = sparse.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
return M