原文链接: tensorflowjs 简单使用
github
https://github.com/tensorflow/tfjs
引入和安装
简单引入js文件
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
在js中即可使用tf变量
npm方式
cnpm install @tensorflow/tfjs -D
import * as tf from '@tensorflow/tfjs';
常见操作
https://js.tensorflow.org/api/latest/index.html
标量,默认是float类型
const scalar = tf.scalar(5);
console.log('scalar ', scalar);
scalar.print()
向量
const vector = tf.tensor([1, 2, 3, 4])
vector.print()
console.log(vector)
for (let i of [vector.mean(), vector.max(), vector.min(), vector.sum()]) {
i.print()
}
矩阵的加减乘除与np类似
const mat = tf.tensor([[1, 2, 3], [4, 5, 6]])
mat.print()
const s = tf.scalar(2)
mat.add(s).print()
mat.mul(s).print()
const mat2 = tf.tensor([[1, 2, 3], [4, 5, 6]])
mat2.add(mat).print()
mat2.sub(mat).print()
mat2.mul(mat).print()
乘法和转置
// 矩阵乘法
const a1 = tf.tensor([[1, 2], [3, 4]])
const a2 = tf.tensor([[1, 2], [3, 4]])
a1.matMul(a2).print()
a1.transpose().print()
// 单位矩阵
const e = tf.oneHot(tf.tensor1d([0, 1, 2],'int32'),3)
e.print()
const e2 = tf.eye(3)
e2.print()
线性取点
tf.linspace(0, 9, 20).print();