本文相关代码: parsingNDArray
要想弄清楚MXNet 是如何解析参数文件,并从中提取预训练好的权值,首先第一步要看
MXNet Python端是如何是调用C接口来完成读取NDArray参数文件的。
这部分代码见源码 python/mxnet/ndarray/utils.py 第149行:
def load(fname):
"""Loads an array from file.
See more details in ``save``.
Parameters
----------
fname : str
The filename.
Returns
-------
list of NDArray, RowSparseNDArray or CSRNDArray, or \
dict of str to NDArray, RowSparseNDArray or CSRNDArray
Loaded data.
"""
if not isinstance(fname, string_types):
raise TypeError('fname required to be a string')
out_size = mx_uint()
out_name_size = mx_uint()
handles = ctypes.POINTER(NDArrayHandle)()
names = ctypes.POINTER(ctypes.c_char_p)()
check_call(_LIB.MXNDArrayLoad(c_str(fname),
ctypes.byref(out_size),
ctypes.byref(handles),
ctypes.byref(out_name_size),
ctypes.byref(names)))
if out_name_size.value == 0:
return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
else:
assert out_name_size.value == out_size.value
return dict(
(py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
for i in range(out_size.value))
这个 load 函数接收参数路径作为输入,然后根据参数文件中有没有包含参数的名字选择返回
NDArray参数数组或者字典。然后可以看到是调用了 MXNDArrayLoad 这个C接口函数,这个函数的
代码见 src/c_api/c_api.cc 第308行:
int MXNDArrayLoad(const char* fname,
mx_uint *out_size,
NDArrayHandle** out_arr,
mx_uint *out_name_size,
const char*** out_names) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
ret->ret_vec_str.clear();
API_BEGIN();
std::vector<NDArray> data;
std::vector<std::string> &names = ret->ret_vec_str;
{
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
mxnet::NDArray::Load(fi.get(), &data, &names);
}
ret->ret_handles.resize(data.size());
for (size_t i = 0; i < data.size(); ++i) {
NDArray *ptr = new NDArray();
*ptr = data[i];
ret->ret_handles[i] = ptr;
}
ret->ret_vec_charp.resize(names.size());
for (size_t i = 0; i < names.size(); ++i) {
ret->ret_vec_charp[i] = names[i].c_str();
}
*out_size = static_cast<mx_uint>(data.size());
*out_arr = dmlc::BeginPtr(ret->ret_handles);
*out_name_size = static_cast<mx_uint>(names.size());
*out_names = dmlc::BeginPtr(ret->ret_vec_charp);
API_END();
}
然后可以看到最核心的代码就是第319行调用了NDArray类的静态Load函数获得参数的名字和
内容,Load函数具体实现见:src/ndarray/ndarray.cc 第 1812行:
void NDArray::Load(dmlc::Stream* fi,
std::vector<NDArray>* data,
std::vector<std::string>* keys) {
uint64_t header, reserved;
CHECK(fi->Read(&header))
<< "Invalid NDArray file format";
CHECK(fi->Read(&reserved))
<< "Invalid NDArray file format";
CHECK(header == kMXAPINDArrayListMagic)
<< "Invalid NDArray file format";
CHECK(fi->Read(data))
<< "Invalid NDArray file format";
CHECK(fi->Read(keys))
<< "Invalid NDArray file format";
CHECK(keys->size() == 0 || keys->size() == data->size())
<< "Invalid NDArray file format";
}
从这里读取内容的过程可以大概看出NDArray参数文件存储的内容的顺序是什么了,首先是会
存两个uint64_t类型的数字,然后就是NDArray数组,接着是每个NDArray对应的名字的数组。
好了接下来就是解读源码中是如何从Stream中解析出内容的,首先我们来看下Stream类的
Read函数,具体见 io.h 第435行:
template<typename T>
inline bool Stream::Read(T *out_data) {
return serializer::Handler<T>::Read(this, out_data);
}
这里可以看到,Read 函数内部又调用了 Handler这个类的Read静态函数,这个静态函数对应的
代码见 serializer.h 第262行:
inline static bool Read(Stream *strm, T *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>
::Read(strm, data);
}
};
这里代码我第一次看的时候有点蒙,后来仔细研究了下也看懂了。首先我们要看 IfThenElse
是什么东西,这里还是看到 io.h 的第 38 到 66行:
//! \cond Doxygen_Suppress
/*!
* \brief Serializer that redirect calls by condition
* \tparam cond the condition
* \tparam Then the serializer used for then condition
* \tparam Else the serializer used for else condition
* \tparam Return the type of data the serializer handles
*/
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;
template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
inline static void Write(Stream *strm, const T &data) {
Then::Write(strm, data);
}
inline static bool Read(Stream *strm, T *data) {
return Then::Read(strm, data);
}
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
inline static void Write(Stream *strm, const T &data) {
Else::Write(strm, data);
}
inline static bool Read(Stream *strm, T *data) {
return Else::Read(strm, data);
}
};
这里可以看到 IfThenElse 就是一个结构体,有四个模板参数,意思很明显了,如果第一个参数
为true,则会调用Then这个类的Read静态函数,如果第一个参数为false,则会调用Else这个类的
Read静态函数。看完 IfThenElse 的定义之后,我们看回 262 行的Read函数就很清楚了,
inline static bool Read(Stream *strm, T *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODHandler<T>,
IfThenElse<dmlc::has_saveload<T>::value,
SaveLoadClassHandler<T>,
UndefinedSerializerFor<T>, T>,
T>
::Read(strm, data);
}
};
意思就是,如果 dmlc::is_pod
函数,否则就会走到下一个条件判断,下一个条件判断是当 dmlc::has_saveload
为true的话就调用 SaveLoadClassHandler 的 Read 静态函数,否则就走到 UndefinedSerializerFor。
好了,那么现在就是要看具体走了哪个分支,首先我们要知道T在运行时时什么类型,看回上面的
NDArray 的 Load 函数,知道了首先读取得两个数字的类型是 uint64_t,接着跳转到
源码 type_traits.h,看第126和第152行:
/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \
template<> \
struct Trait<Type> { \
static const bool value = Value; \
}
DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);
很明显可以看到,dmlc::is_pod<uint64_t>::value 的值为 true,因此会调用 PODHandler 的
Read 函数,代码:
/*! \brief Serializer for POD(plain-old-data) data */
template<typename T>
struct PODHandler {
inline static void Write(Stream *strm, const T &data) {
strm->Write(&data, sizeof(T));
}
inline static bool Read(Stream *strm, T *dptr) {
return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*)
}
};
PODHandler 的Read函数就是调用 Stream 的Read,这里如果读者想再详细了解 Stream 类
Read 函数的工作原理可以自己再去细看,不过对于本文来说,到这里知道了会根据T的字节数读取
内容到dptr里面就够了。
Ok,现在已经读取完两个数字 header, reserved,然后就是读 NDArray Vector 了,然后这里
还是跳转到,调用 Handler
参数是vector
//! \cond Doxygen_Suppress
template<typename T>
struct Handler<std::vector<T> > {
inline static void Write(Stream *strm, const std::vector<T> &data) {
IfThenElse<dmlc::is_pod<T>::value,
PODVectorHandler<T>,
ComposeVectorHandler<T>, std::vector<T> >
::Write(strm, data);
}
inline static bool Read(Stream *strm, std::vector<T> *data) {
return IfThenElse<dmlc::is_pod<T>::value,
PODVectorHandler<T>,
ComposeVectorHandler<T>,
std::vector<T> >
::Read(strm, data);
}
};
然后这里的判断分支是会调用 ComposeVectorHandler 的 Read 函数:
/*!
* \brief Serializer handler for std::vector<T> where T can be composed type
* \tparam T element type
*/
template<typename T>
struct ComposeVectorHandler {
inline static void Write(Stream *strm, const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
strm->Write(&sz, sizeof(sz));
for (size_t i = 0; i < vec.size(); ++i) {
Handler<T>::Write(strm, vec[i]);
}
}
inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
uint64_t sz;
if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
size_t size = static_cast<size_t>(sz);
out_vec->resize(size);
for (size_t i = 0; i < size; ++i) {
if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false;
}
return true;
}
};
首先先读出 vector 数组的大小,然后分别读取每个 NDArray,这里在读每个 NDArray 的时候
又会调用 Handler
这个分支:
// serializer for class that have save/load function
template<typename T>
struct SaveLoadClassHandler {
inline static void Write(Stream *strm, const T &data) {
data.Save(strm);
}
inline static bool Read(Stream *strm, T *data) {
return data->Load(strm);
}
};
最后看到其实就是调用了 NDArray 类本身的 Load 函数,见源码 src/ndarray/ndarray.cc:
bool NDArray::Load(dmlc::Stream *strm) {
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
return LegacyLoad(strm, magic);
}
// load storage type
int32_t stype;
if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));
// load storage shape
TShape sshape;
if (nad > 0) {
if (!sshape.Load(strm)) return false;
}
// load shape
TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
*this = NDArray(); return true;
}
// load context
Context ctx;
if (!ctx.Load(strm)) return false;
// load type flag
int32_t type_flag;
if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;
// load aux_types and aux_shapes
std::vector<int32_t> aux_types;
std::vector<TShape> aux_shapes;
if (nad > 0) {
aux_types.resize(nad);
aux_shapes.resize(nad);
for (int i = 0; i < nad; ++i) {
// load aux_type(i)
if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false;
// load aux_shapes(i)
if (!aux_shapes[i].Load(strm)) return false;
}
}
// load data into CPU
NDArray temp;
if (0 == nad) {
temp = NDArray(shape, Context::CPU(), false, type_flag);
} else {
temp = NDArray(static_cast<NDArrayStorageType>(stype), shape,
Context::CPU(), false, type_flag,
aux_types, aux_shapes, sshape);
}
// load data
TBlob load_data = temp.data();
size_t type_size = mshadow::mshadow_sizeof(type_flag);
size_t nread = type_size * load_data.Size();
if (strm->Read(load_data.dptr_, nread) != nread) return false;
// load aux_data
if (nad > 0) {
for (int i = 0; i < nad; ++i) {
load_data = temp.aux_data(i);
type_size = mshadow::mshadow_sizeof(load_data.type_flag_);
nread = type_size * load_data.Size();
if (strm->Read(load_data.dptr_, nread) != nread) return false;
}
}
if (ctx.dev_mask() == cpu::kDevMask) {
*this = std::move(temp); return true;
} else {
#if MXNET_USE_CUDA
*this = temp.Copy(ctx); return true;
#else
*this = std::move(temp); return true;
#endif
}
}
这里首先,读出一个 magic number ,如果用 V1.0 之后的MXNet版本,magic number
都是会等于 NDARRAY_V2_MAGIC,具体定义见下面:
/* magic number for ndarray version 1, with int64_t TShape */
static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;
/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;
所以不会进入 LegacyLoad 函数,接着就是读 storage type,NDArray的类型,除了常用的
普通类型,现在也已经支持了稀疏类型:
enum NDArrayStorageType {
kUndefinedStorage = -1, // undefined storage
kDefaultStorage, // dense
kRowSparseStorage, // row sparse
kCSRStorage, // csr
};
一般来说,storage type 都是 kDefaultStorage 类型,我现在写的解析小工具里面也只考虑
了解析普通类型的NDArray,之后再改进吧。然后看到 num_aux_data函数,这个函数如果传入
普通类型则返回0,所以 nad 的值为 0。
size_t num_aux_data(NDArrayStorageType stype) {
size_t num = 0;
switch (stype) {
case kDefaultStorage: num = 0; break;
case kCSRStorage: num = 2; break;
case kRowSparseStorage: num = 1; break;
default: LOG(FATAL) << "Unknown storage type" << stype; break;
}
return num;
}
nad 值为0 的话整个代码就简洁很多了,简化之后如下:
bool NDArray::Load(dmlc::Stream *strm) {
uint32_t magic;
if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
if (magic != NDARRAY_V2_MAGIC) {
return LegacyLoad(strm, magic);
}
// load storage type
int32_t stype;
if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
// load shape
TShape shape;
if (!shape.Load(strm)) return false;
if (shape.ndim() == 0) {
*this = NDArray(); return true;
}
// load context
Context ctx;
if (!ctx.Load(strm)) return false;
// load type flag
int32_t type_flag;
if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;
// load data into CPU
NDArray temp;
temp = NDArray(shape, Context::CPU(), false, type_flag);
// load data
TBlob load_data = temp.data();
size_t type_size = mshadow::mshadow_sizeof(type_flag);
size_t nread = type_size * load_data.Size();
if (strm->Read(load_data.dptr_, nread) != nread) return false;
}
到这里为,大概怎么读取NDArray,相信应该挺清晰的了。