MXNet 源码解读系列之一 C++端如何解析NDArray参数文件

Stella981
• 阅读 678

本文相关代码: parsingNDArray

      要想弄清楚MXNet 是如何解析参数文件,并从中提取预训练好的权值,首先第一步要看

MXNet Python端是如何是调用C接口来完成读取NDArray参数文件的。

      这部分代码见源码 python/mxnet/ndarray/utils.py 第149行:

def load(fname):
    """Loads an array from file.

    See more details in ``save``.

    Parameters
    ----------
    fname : str
        The filename.

    Returns
    -------
    list of NDArray, RowSparseNDArray or CSRNDArray, or \
    dict of str to NDArray, RowSparseNDArray or CSRNDArray
        Loaded data.
    """
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(_LIB.MXNDArrayLoad(c_str(fname),                                         
                                  ctypes.byref(out_size),
                                  ctypes.byref(handles),
                                  ctypes.byref(out_name_size),
                                  ctypes.byref(names)))
    if out_name_size.value == 0:
        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
    else:
        assert out_name_size.value == out_size.value
        return dict(
            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
            for i in range(out_size.value))

       这个 load 函数接收参数路径作为输入,然后根据参数文件中有没有包含参数的名字选择返回

NDArray参数数组或者字典。然后可以看到是调用了 MXNDArrayLoad 这个C接口函数,这个函数的

代码见 src/c_api/c_api.cc 第308行:

int MXNDArrayLoad(const char* fname,
                  mx_uint *out_size,
                  NDArrayHandle** out_arr,
                  mx_uint *out_name_size,
                  const char*** out_names) {
  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
  ret->ret_vec_str.clear();
  API_BEGIN();
  std::vector<NDArray> data;
  std::vector<std::string> &names = ret->ret_vec_str;
  {
    std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
    mxnet::NDArray::Load(fi.get(), &data, &names);
  }
  ret->ret_handles.resize(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    NDArray *ptr = new NDArray();
    *ptr = data[i];
    ret->ret_handles[i] = ptr;
  }
  ret->ret_vec_charp.resize(names.size());
  for (size_t i = 0; i < names.size(); ++i) {
    ret->ret_vec_charp[i] = names[i].c_str();
  }
  *out_size = static_cast<mx_uint>(data.size());
  *out_arr = dmlc::BeginPtr(ret->ret_handles);
  *out_name_size = static_cast<mx_uint>(names.size());
  *out_names = dmlc::BeginPtr(ret->ret_vec_charp);
  API_END();
}

     然后可以看到最核心的代码就是第319行调用了NDArray类的静态Load函数获得参数的名字和

内容,Load函数具体实现见:src/ndarray/ndarray.cc 第 1812行:

void NDArray::Load(dmlc::Stream* fi,
                   std::vector<NDArray>* data,
                   std::vector<std::string>* keys) {
  uint64_t header, reserved;
  CHECK(fi->Read(&header))
      << "Invalid NDArray file format";
  CHECK(fi->Read(&reserved))
      << "Invalid NDArray file format";
  CHECK(header == kMXAPINDArrayListMagic)
      << "Invalid NDArray file format";
  CHECK(fi->Read(data))
      << "Invalid NDArray file format";
  CHECK(fi->Read(keys))
      << "Invalid NDArray file format";
  CHECK(keys->size() == 0 || keys->size() == data->size())
      << "Invalid NDArray file format";
}

        从这里读取内容的过程可以大概看出NDArray参数文件存储的内容的顺序是什么了,首先是会

存两个uint64_t类型的数字,然后就是NDArray数组,接着是每个NDArray对应的名字的数组。

        好了接下来就是解读源码中是如何从Stream中解析出内容的,首先我们来看下Stream类的

Read函数,具体见 io.h 第435行:

template<typename T>
inline bool Stream::Read(T *out_data) {
  return serializer::Handler<T>::Read(this, out_data);
}

        这里可以看到,Read 函数内部又调用了 Handler这个类的Read静态函数,这个静态函数对应的

代码见 serializer.h 第262行:

inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

        这里代码我第一次看的时候有点蒙,后来仔细研究了下也看懂了。首先我们要看 IfThenElse

是什么东西,这里还是看到 io.h 的第 38 到 66行:

//! \cond Doxygen_Suppress
/*!
 * \brief Serializer that redirect calls by condition
 * \tparam cond the condition
 * \tparam Then the serializer used for then condition
 * \tparam Else the serializer used for else condition
 * \tparam Return the type of data the serializer handles
 */
template<bool cond, typename Then, typename Else, typename Return>
struct IfThenElse;

template<typename Then, typename Else, typename T>
struct IfThenElse<true, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Then::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Then::Read(strm, data);
  }
};
template<typename Then, typename Else, typename T>
struct IfThenElse<false, Then, Else, T> {
  inline static void Write(Stream *strm, const T &data) {
    Else::Write(strm, data);
  }
  inline static bool Read(Stream *strm, T *data) {
    return Else::Read(strm, data);
  }
};

    这里可以看到 IfThenElse 就是一个结构体,有四个模板参数,意思很明显了,如果第一个参数

为true,则会调用Then这个类的Read静态函数,如果第一个参数为false,则会调用Else这个类的

Read静态函数。看完 IfThenElse 的定义之后,我们看回 262 行的Read函数就很清楚了,

inline static bool Read(Stream *strm, T *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODHandler<T>,
                      IfThenElse<dmlc::has_saveload<T>::value,
                                 SaveLoadClassHandler<T>,
                                 UndefinedSerializerFor<T>, T>,
                      T>
    ::Read(strm, data);
  }
};

    意思就是,如果 dmlc::is_pod::value 这个值为 true,那么就会调用 PODHandler 的Read

函数,否则就会走到下一个条件判断,下一个条件判断是当 dmlc::has_saveload::value 这个值

为true的话就调用 SaveLoadClassHandler 的 Read 静态函数,否则就走到 UndefinedSerializerFor。

好了,那么现在就是要看具体走了哪个分支,首先我们要知道T在运行时时什么类型,看回上面的

NDArray 的 Load 函数,知道了首先读取得两个数字的类型是 uint64_t,接着跳转到

源码 type_traits.h,看第126和第152行:

/*! \brief macro to quickly declare traits information */
#define DMLC_DECLARE_TRAITS(Trait, Type, Value)       \
  template<>                                          \
  struct Trait<Type> {                                \
    static const bool value = Value;                  \
  }

DMLC_DECLARE_TRAITS(is_pod, uint64_t, true);

        很明显可以看到,dmlc::is_pod<uint64_t>::value 的值为 true,因此会调用 PODHandler 的

Read 函数,代码:

/*! \brief Serializer for POD(plain-old-data) data */
template<typename T>
struct PODHandler {
  inline static void Write(Stream *strm, const T &data) {
    strm->Write(&data, sizeof(T));
  }
  inline static bool Read(Stream *strm, T *dptr) {
    return strm->Read((void*)dptr, sizeof(T)) == sizeof(T);  // NOLINT(*)
  }
};

        PODHandler 的Read函数就是调用 Stream 的Read,这里如果读者想再详细了解 Stream 类

Read 函数的工作原理可以自己再去细看,不过对于本文来说,到这里知道了会根据T的字节数读取

内容到dptr里面就够了。

        Ok,现在已经读取完两个数字 header, reserved,然后就是读 NDArray Vector 了,然后这里

还是跳转到,调用 Handler::Read 函数,不过这里和读数字不一样的地方在于,这里传入的模板

参数是vector,所以调用的是下面这个Handler定义的Read函数:

//! \cond Doxygen_Suppress
template<typename T>
struct Handler<std::vector<T> > {
  inline static void Write(Stream *strm, const std::vector<T> &data) {
    IfThenElse<dmlc::is_pod<T>::value,
               PODVectorHandler<T>,
               ComposeVectorHandler<T>, std::vector<T> >
    ::Write(strm, data);
  }
  inline static bool Read(Stream *strm, std::vector<T> *data) {
    return IfThenElse<dmlc::is_pod<T>::value,
                      PODVectorHandler<T>,
                      ComposeVectorHandler<T>,
                      std::vector<T> >
    ::Read(strm, data);
  }
};

    然后这里的判断分支是会调用 ComposeVectorHandler 的 Read 函数:

/*!
 * \brief Serializer handler for std::vector<T> where T can be composed type
 * \tparam T element type
 */
template<typename T>
struct ComposeVectorHandler {
  inline static void Write(Stream *strm, const std::vector<T> &vec) {
    uint64_t sz = static_cast<uint64_t>(vec.size());
    strm->Write(&sz, sizeof(sz));
    for (size_t i = 0; i < vec.size(); ++i) {
      Handler<T>::Write(strm, vec[i]);
    }
  }
  inline static bool Read(Stream *strm, std::vector<T> *out_vec) {
    uint64_t sz;
    if (strm->Read(&sz, sizeof(sz)) != sizeof(sz)) return false;
    size_t size = static_cast<size_t>(sz);
    out_vec->resize(size);
    for (size_t i = 0; i < size; ++i) {
      if (!Handler<T>::Read(strm, &(*out_vec)[i])) return false;
    }
    return true;
  }
};

        首先先读出 vector 数组的大小,然后分别读取每个 NDArray,这里在读每个 NDArray 的时候

又会调用 Handler::Read 函数,这次 IfThenElse 分支判断那里会走 SaveLoadClassHandler

这个分支:

// serializer for class that have save/load function
template<typename T>
struct SaveLoadClassHandler {
  inline static void Write(Stream *strm, const T &data) {
    data.Save(strm);
  }
  inline static bool Read(Stream *strm, T *data) {
    return data->Load(strm);
  }
};

    最后看到其实就是调用了 NDArray 类本身的 Load 函数,见源码 src/ndarray/ndarray.cc

bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;
  const int32_t nad = num_aux_data(static_cast<NDArrayStorageType>(stype));

  // load storage shape
  TShape sshape;
  if (nad > 0) {
    if (!sshape.Load(strm)) return false;
  }

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;

  // load aux_types and aux_shapes
  std::vector<int32_t> aux_types;
  std::vector<TShape> aux_shapes;
  if (nad > 0) {
    aux_types.resize(nad);
    aux_shapes.resize(nad);
    for (int i = 0; i < nad; ++i) {
      // load aux_type(i)
      if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false;
      // load aux_shapes(i)
      if (!aux_shapes[i].Load(strm)) return false;
    }
  }

  // load data into CPU
  NDArray temp;
  if (0 == nad) {
    temp = NDArray(shape, Context::CPU(), false, type_flag);
  } else {
    temp = NDArray(static_cast<NDArrayStorageType>(stype), shape,
                   Context::CPU(), false, type_flag,
                   aux_types, aux_shapes, sshape);
  }
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;

  // load aux_data
  if (nad > 0) {
    for (int i = 0; i < nad; ++i) {
      load_data = temp.aux_data(i);
      type_size = mshadow::mshadow_sizeof(load_data.type_flag_);
      nread = type_size * load_data.Size();
      if (strm->Read(load_data.dptr_, nread) != nread) return false;
    }
  }

  if (ctx.dev_mask() == cpu::kDevMask) {
    *this = std::move(temp); return true;
  } else {
#if MXNET_USE_CUDA
    *this = temp.Copy(ctx); return true;
#else
    *this = std::move(temp); return true;
#endif
  }
}

    这里首先,读出一个 magic number ,如果用 V1.0 之后的MXNet版本,magic number

都是会等于  NDARRAY_V2_MAGIC,具体定义见下面:

/* magic number for ndarray version 1, with int64_t TShape */
static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8;

/* magic number for ndarray version 2, with storage type */
static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9;

        所以不会进入 LegacyLoad 函数,接着就是读 storage type,NDArray的类型,除了常用的

普通类型,现在也已经支持了稀疏类型:

enum NDArrayStorageType {
  kUndefinedStorage = -1,  // undefined storage
  kDefaultStorage,         // dense
  kRowSparseStorage,       // row sparse
  kCSRStorage,             // csr
};

    一般来说,storage type 都是 kDefaultStorage 类型,我现在写的解析小工具里面也只考虑

了解析普通类型的NDArray,之后再改进吧。然后看到 num_aux_data函数,这个函数如果传入

普通类型则返回0,所以 nad 的值为 0。

size_t num_aux_data(NDArrayStorageType stype) {
  size_t num = 0;
  switch (stype) {
    case kDefaultStorage: num = 0; break;
    case kCSRStorage: num = 2; break;
    case kRowSparseStorage: num = 1; break;
     default: LOG(FATAL) << "Unknown storage type" << stype; break;
  }
  return num;
}

    nad 值为0 的话整个代码就简洁很多了,简化之后如下:

bool NDArray::Load(dmlc::Stream *strm) {
  uint32_t magic;
  if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false;
  if (magic != NDARRAY_V2_MAGIC) {
    return LegacyLoad(strm, magic);
  }

  // load storage type
  int32_t stype;
  if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false;

  // load shape
  TShape shape;
  if (!shape.Load(strm)) return false;
  if (shape.ndim() == 0) {
    *this = NDArray(); return true;
  }

  // load context
  Context ctx;
  if (!ctx.Load(strm)) return false;

  // load type flag
  int32_t type_flag;
  if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false;
  // load data into CPU
  NDArray temp;
  temp = NDArray(shape, Context::CPU(), false, type_flag);
  
  // load data
  TBlob load_data = temp.data();
  size_t type_size = mshadow::mshadow_sizeof(type_flag);
  size_t nread = type_size * load_data.Size();
  if (strm->Read(load_data.dptr_, nread) != nread) return false;
}

    到这里为,大概怎么读取NDArray,相信应该挺清晰的了。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
6个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
Java日期时间API系列36
  十二时辰,古代劳动人民把一昼夜划分成十二个时段,每一个时段叫一个时辰。二十四小时和十二时辰对照表:时辰时间24时制子时深夜11:00凌晨01:0023:0001:00丑时上午01:00上午03:0001:0003:00寅时上午03:00上午0
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这
Oracle 分组与拼接字符串同时使用
SELECTT.,ROWNUMIDFROM(SELECTT.EMPLID,T.NAME,T.BU,T.REALDEPART,T.FORMATDATE,SUM(T.S0)S0,MAX(UPDATETIME)CREATETIME,LISTAGG(TOCHAR(