1. 文章
  2. 文章详情

TensorFlow.js 训练神经网络

TensorFlow.js概述

TensorFlow.js 是一个开源库,不仅可以在浏览器中运行机器学习模型,还可以训练模型。
具有 GPU 加速功能,并自动支持 WebGL
可以导入已经训练好的模型,也可以在浏览器中重新训练现有的所有机器学习模型
运行 Tensorflow.js 只需要你的浏览器,而且在本地开发的代码与发送给用户的代码是相同的。

TensorFlow.js 对未来 web 开发有着重要的影响,JS 开发者可以更容易地实现机器学习,工程师和数据科学家们可以有一种新的方法来训练算法,例如官网上 Emoji Scavenger Hunt 这样的游戏界面,让用户一边玩游戏一边将模型训练地更好。

浏览器中训练神经网络

用户端的机器学习,用来训练模型的数据还有模型的使用都在用户的设备上完成。几乎每个电脑手机平板上都有浏览器,运行JS,无需下载或安装任何应用程序。

JS运行平台

直接从浏览器里写代码,例如 chrome 的 View > Developer > Javascript Console,

有三个流行的在线 JS 平台:CodePen, JSFiddle, JSBin.
https://codepen.io/thekevinscott/pen/aGapZL
https://jsfiddle.net/
https://jsbin.com/?html,output

TensorFlow.js线性回归预测

<html>
 <head>
    <!-- Load TensorFlow.js -->
    <!-- Get latest version at https://github.com/tensorflow/tfjs -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2">   
    </script>
 </head>
 
 <body>
   <div id="output_field"></div>
 </body>
 
 <script>
    async function learnLinear(){
    
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
        
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  
        await model.fit(xs, ys, {epochs: 500});
  
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
    }
    
    learnLinear();
 </script>
</html>
        const model = tf.sequential();
        model.add(tf.layers.dense({
            units: 1, 
            inputShape: [1]
        }));
  • 接着定义 loss 为 MSE 和 optimizer 为 SGD:
        model.compile({
            loss: 'meanSquaredError',
            optimizer: 'sgd'
        });
  • 同时需要定义 input 的 tensor,X 和 y,以及它们的维度都是 [6, 1]:
        const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
        const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  • 然后用 fit 来训练模型,因为要等模型训练完才能预测,所以要用 await:
        await model.fit(xs, ys, {epochs: 500});
  • 训练结束后,用 predict 进行预测,输入的是 [1, 1] 维的 值为 10 的tensor ,
        document.getElementById('output_field').innerText =
            model.predict( tf.tensor2d([10], [1, 1]) );
  • 最后得到的输出为
Tensor 
[[18.9862976],]

学习资料:
https://medium.com/tensorflow/getting-started-with-tensorflow-js-50f6783489b2
https://thekevinscott.com/reasons-for-machine-learning-in-the-browser/
https://www.analyticsvidhya.com/blog/2018/04/tensorflow-js-build-machine-learning-models-javascript/
https://hackernoon.com/introducing-tensorflow-js-3f31d70f5904
https://thekevinscott.com/tensorflowjs-hello-world/


推荐阅读 历史技术博文链接汇总
http://www.jianshu.com/p/28f02bb59fe5

参考:https://www.jianshu.com/p/a42e47c12f3b

发表评论

登录后才能评论

评论列表(0条)