user icon

ESP32でディープラーニング

先日作った熱画像をリアルタイムで見れるシステムを改造し、ディープラーニングしてみました。
ジャンケンの3つのジェスチャーを、赤外線アレイモジュールからの入力データをもとに推論します。
ジャンケンの画像データを使いディープラーニングする、というアイデアは雑誌Interface1月号の記事のまねをさせていただきました。他にもいろいろと参考にさせていただきました。
なお、雑誌Interface1月号の記事ではソニーのSPRESENSEボード、および通常のカメラボードを使って行っています。

ディープラーニングのツールとしてNeural Network Console(NNC)のクラウドを使い学習させました。
学習済みモデルをNNPファイルとしてダウンロードし、NNPファイルをC言語ファイルに変換後、ESP32に組み込みました。
【2019/02/04 追記】NNP以外にNNB(Cランタイムフォーマット)でもESP32に組み込むことができることを確認しましたので、それも追記しました。

以下のステップを行いました。

  1. WebSocketクライアントでデータ収集
  2. ブラウザで収集したデータの確認
  3. データを水増しする(Data Augmentation)
  4. データ変換しCSVファイル作成
  5. NNC(クラウド)で学習
  6. 学習済みモデルをESP32に実装

最終的なファイル構成は以下の構成です。

platformio.ini  // IDE設定ファイル
src  ─ main.cpp // メインプログラム
lib  ┬ MLX90640_API ┬ MLX90640_API.cpp // 以下のファイルはSparkFunにあるサンプルから取得
     │              └ MLX90640_API.h
     ├ MLX90640_I2C_Driver ┬ MLX90640_I2C_Driver.cpp
     │                     └ MLX90640_I2C_Driver.h   
     ├ MainRuntime_inference ┬ MainRuntime_inference.c // 推論するプログラム、NNPから変換(NNPを使う場合)
     │                       └ MainRuntime_inference.h  
     └ MainRuntime_parameters┬ MainRuntime_parameters.c // 重みの配列、NNPから変換(NNPを使う場合)
                             └ MainRuntime_parameters.h  
data ┬ index.html  // SPIFFSに置く静的コンテンツ
     ├ result.nnb  // (NNBを使う場合)
     ├ app.js
     ├ app.css
     ├ ace.js.gz  // 以降のファイルはSPIFFSEditor使用時に使う
     ├ ext-searchbox.js.gz  
     ├ mode-css.js.gz
     ├ mode-html.js.gz
     ├ mode-javascript.js.gz
     └ worker-html.js.gz
work ┬ image_data_generator.py  // WebSocketクライアントでデータ収集するプログラム
     ├ check.htm  // 作成したデータを確認するHTML
     ├ image_data_generator.py  // データを水増し(Data Augmentation)するプログラム
     └ make_csv.py  // データ変換しCSVファイル作成するプログラム


WebSocketクライアントでデータ収集

データ収集のやり方は、雑誌Interface1月号と同様に連写でジャンケンの同じジェスチャーをポーズを変えて撮影し、行いました。
データファイルの書き出しは、Pythonプログラムで作成したWebSocketクライアントで行いました。ブラウザでジェスチャーを確認しつつ、もう一つのWebSocketクライアントで接続してデータの収集・書き出しをします。

実行時の引数で、結果のラベルを指定します。
「グー」を0、「チョキ」を1、「パー」を2とします。
以下を実行し、その間カメラの前で「グー」のジェスチャーを、ポーズを変えてデータ収集します。
各ジェスチャー毎に500枚、計1500枚撮影しました。
0.5秒に1枚撮れるので、500×0.5=250秒 の3セットで10分ちょっとかかりました。
データは、あとでブラウザで確認するためJSONデータで出力します。

$ python collect_data.py 0

・collect_data.py

import websocket
import struct
from datetime import datetime
import json
import sys

args = sys.argv
if len(args) >= 2:
    y = int(args[1])
else:
    y = 9

path = "./temperature_%s.json" % datetime.now().strftime("%Y%m%d%H%M%S")
counter = 0

def on_message(ws, message):
    global counter
    tmps = [chunk[0] for chunk in struct.iter_unpack('<f', message)]
    data= [datetime.now().timestamp(), y, tmps]

    if counter%50 == 0:
        print('counter: %s' % counter)

    counter += 1
    with open(path, mode='a') as f:
        f.write(json.dumps(data) + ",\n")

def on_error(ws, error):
    print(error)

def on_close(ws):
    print("### closed ###")

def on_open(ws):
    print("### open ###")

if __name__ == "__main__":
    # websocket.enableTrace(True)
    ws = websocket.WebSocketApp("ws://esp32.local/ws",
                              on_message = on_message,
                              on_error = on_error,
                              on_close = on_close)
    ws.on_open = on_open
    ws.run_forever()
ブラウザで収集したデータの確認

上記で作成したデータファイル3つを結合します。その際、ファイルの頭に”[“を挿入、ファイルの最後の”,”を”]”に置換します。以下のようなかんじのコマンドを実行します。

$ cat temperature_201901* | sed -e '$s/.$/]/' -e '1s/^/[/' > temperature_201901.json

収集したデータをブラウザで確認します。
以下のようなHTMLファイルを作成し、ダブルクリックしてブラウザに表示します。前回作成したapp.js、app.cssを利用しています。
「ファイルを選択」で、作成したデータファイルを選択します。
画面上部のスライダーまたは「←」「→」キーで戻る・進むができます。
・check.htm

<!DOCTYPE html>
<html>
  <head>
    <meta http-equiv="Content-type" content="text/html; charset=utf-8">
    <meta name="viewport" content="width=350,initial-scale=0.5">
    <title>赤外線アレイカメラ MLX90640</title>
    <link rel="stylesheet" type="text/css" href="../data/app.css" >
    <link rel="stylesheet" type="text/css" href="https://code.jquery.com/ui/1.12.1/themes/base/jquery-ui.css" >
    <style>
        a#download{
            position: relative;
            display: inline-block;
            font-weight: bold;
            padding: 0.25em 0.5em;
            text-decoration: none;
            color: #00BCD4;
            background: #ECECEC;
            transition: .4s;
        }
        a#download:hover{
            background: #00bcd4;
            color: white;
        }       
        #slider {
            width: 640px;
            margin: 10px 0;
        } 
        #pos {
            display: inline-block;
        }
        #labeling {
            height: 60px;
        }
        #dt, #yw {
            margin: 0 10px;
        }
        #yw {
            display: inline-block;
        }
    </style>
  </head>
  <body id="body" onload="onBodyLoad()">
    <div id="container">
      <div id="labeling">
        <form method="post" enctype="multipart/form-data">
            <input type="file" id="file" accept="application/json">
            <!-- <span><a id="download" href="#" download="test.txt" onclick="tmps.handleDownload()">ダウンロード</a></span> -->
            <div id="pos">
                <span id="slider-pos"></span>
                <span id="dt"></span>
                <span id="yw" style="display: none;">
                    <select id="y">
                        <option value="0">グー</option>
                        <option value="1">チョキ</option>
                        <option value="2">パー</option>
                        <!-- <option value="3">3</option>
                        <option value="4">4</option>
                        <option value="5">5</option>
                        <option value="6">6</option>
                        <option value="7">7</option>
                        <option value="8">8</option>
                        <option value="9">9</option> -->
                    </select>
                </span>
            </div>
        </form>
        <div id='slider'></div> 
      </div>
      <canvas id="canvas" width="32" height="24"></canvas>
      <div id="scale"></div>  
      <div id="scale-divisions">
        <div id="min-tmp-division"><span id="min-down" class="divisionBtn">&#x25c0;</span><span id="min-tmp"></span><span id="min-up" class="divisionBtn">&#x25b6;</span></div>
        <div id="max-tmp-division"><span id="max-down" class="divisionBtn">&#x25c0;</span><span id="max-tmp"></span><span id="max-up" class="divisionBtn">&#x25b6;</span></div>
      </div>
      <div id="messages"></div>
    </div>
    <script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
    <script src="https://code.jquery.com/ui/1.12.1/jquery-ui.min.js"></script>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/moment.js/2.23.0/moment.min.js"></script>
    <script src="../data/app.js"></script>
    <script>
        const tmps = {
            obj:null,
            index: 0,
            orgFileName: null,
            handleDownload: function () {
                var content = JSON.stringify(this.obj);
                var blob = new Blob([ content ], { "type" : "text/plain" });
                var d = ge("download");
                d.download = this.orgFileName;
                d.href = window.URL.createObjectURL(blob);
            },
            draw: function() {
                const data = this.obj[this.index];
                const timestamp = data[0];
                const y = data[1];
                const temperature = data[2];
                this.setDt(timestamp);
                cv.draw(temperature);
                $("#y").val(y);
            },
            setDt: function(timestamp) {
                ge("dt").innerHTML = moment.unix(timestamp).format();
            }, 
            setY: function (value) {
                value = parseInt(value);
                if (value != this.obj[this.index][1]) {
                    console.log(`y is ${value}`);
                    this.obj[this.index][1] = value;
                }           
            }   
        };
        const slider = {
            slider: null,
            setUp: (length)=> {
                this.slider = $("#slider").slider({
                    value:0,
                    min:0,
                    max:length - 1,
                    step:1,
                    change: function (e, ui) {
                        tmps.index = ui.value;
                        tmps.draw();
                    }
                });
            },
            setValue: (value) => {
                this.slider.slider("value", value);
            }
        };
        onBodyLoad = function(){
            const reader = new FileReader();
            //HTMLを初期化し、新たなファイルを文字列として読込む     
            file.addEventListener('change', function(e) {
                tmps.orgFileName = e.target.files[0].name;
                reader.readAsText(e.target.files[0]);
            });
            //ファイルをオブジェクト化して表示
            reader.onload = function(e) {
                try {
                    tmps.obj = JSON.parse(e.target.result);
                } catch (err1) {
                    console.log(err1);
                    try {
                        tmps.obj = JSON.parse('[' + e.target.result.slice( 0, -2 ) + ']');
                    } catch (err2) {
                        alert("jsonに誤りがあります。");
                        console.log(e2);
                        return;
                    }
                }
                console.log(tmps.obj);
                slider.setUp(tmps.obj.length);
                $("#yw").show();
                tmps.draw();
            };

            cv.createScale();
            cv.createCanvas();
        }
        $(function() {
            $("#y").change(function() {
                const value = $(this).val();
                tmps.setY(value);
            });
            $(document).on('keydown', function(e) {
                if (!tmps.obj) return;
                console.log(`pressed keyCode:${e.keyCode}`);
                if (48 <= e.keyCode && e.keyCode <= 57) {
                    const value = e.keyCode - 48;
                    $("#y").val(value);
                    tmps.setY(value);
                }
                switch( e.keyCode ) {
                    case 37:
                        if (tmps.index > 0) {
                            console.log("戻る");
                            tmps.index--;
                            tmps.draw();
                            slider.setValue(tmps.index);
                        }
                        break;
                    case 39:
                        // 進む
                        if (tmps.index < tmps.obj.length) {
                            console.log("進む");
                            tmps.index++;
                            tmps.draw();
                            slider.setValue(tmps.index);
                        }                        
                        break;
                }
            });
        });                
    </script>
  </body>
</html>

確認画面

データを水増しする(Data Augmentation)

KerasのImageDataGeneratorを使って、データの水増しを行います。画像を回転や水平/垂直方向に移動・ズーム等してデータを10倍にします。
このとき、ついでに以下のようにデータを加工しています。「28度以下は0」というのは、環境によってはうまくいかないかもしれないです。

  • 28度以下は0
  • 35度以上は1
  • 0以上1以下の値に変換

これも後で確認するためjsonファイルで出力しときます。
・image_data_generator.py

import json
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime

data_file = '<入力ファイルを設定する>'
with open(data_file) as f:
    df = json.load(f)
X = [data[2] for data in df]
X = np.array(X)

# 28度以下を0、35度以上を1とし、0以上1以下の値に変換
max = 35
min = 28
X = np.where(X >max, 1, (X - min)/(max - min))
X = np.where(X < 0, 0, X)

y = [data[1] for data in df]
y = np.array(y)
y = y.reshape(y.shape[0], 1)

datagen = ImageDataGenerator(
    rotation_range=90,  # 整数.画像をランダムに回転する回転範囲
    width_shift_range=0.1,  # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
    height_shift_range=0.1,  # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲
    fill_mode='constant',   # 入力画像の境界周りを埋めるモード
    cval=0.0,               # constantで0.0で埋める
    horizontal_flip=True,  # 真理値.水平方向に入力をランダムに反転します
    vertical_flip=True,     # 真理値.垂直方向に入力をランダムに反転します
    zoom_range=0.3          # 浮動小数点数または[lower,upper].ランダムにズームする範囲.浮動小数点数
    )  # randomly flip images

# ImageDataGeneratorの入力値は(samples, channels, height, width)なのでそれに合わせる
X = X.reshape(X.shape[0], 24, 32, 1)

# 空のnumpyを作成し、ループでappendする
new_X = np.empty((0, 24, 32, 1), float)
new_Y = np.empty((0, 1), int)
counter = 0
for x_batch, y_batch in datagen.flow(X, y, batch_size=32):
    new_X = np.append(new_X, x_batch, axis=0)
    new_Y = np.append(new_Y, y_batch, axis=0)
    counter += 1
    if counter >= (X.shape[0] * 10 / 32):  # 10倍に水増し
        break

# Xの値を0=<X<=1 から28<=X<=35に戻す
new_X = new_X*(max - min) + min
new_X = new_X.reshape(new_X.shape[0], 32*24)

data = []
for (each_x, each_y) in zip(new_X, new_Y):
    data.append([
        datetime.now().timestamp(),
        each_y[0],
        each_x
    ])

class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)

path = "image_data_generator_%s.json" % datetime.now().strftime("%Y%m%d%H%M%S")
with open(path, mode='w') as f:
    f.write(json.dumps(data, cls = MyEncoder))

ImageDataGeneratorで生成した画像の例

データ変換しCSVファイル作成

上記で作成したJSONファイルからNNCにアップロードするCSVファイルを作成します。
・make_csv.py

import json
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split


data_file = '<入力ファイルを設定する>'
with open(data_file) as f:
    df = json.load(f)
X = [data[2] for data in df]
X = np.array(X)

# 28度以下を0、35度以上を1とし、0以上1以下の値に変換
max = 35
min = 28
X = np.where(X >max, 1, (X - min)/(max - min))
X = np.where(X < 0, 0, X)

y = [data[1] for data in df]
y = np.array(y)
y = y.reshape(y.shape[0], 1)

index = ['x__{0}'.format(i) for i in range(0, 768)]
index.append('y')
header = ','.join(index)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
train_data = np.hstack((X_train, y_train))
test_data = np.hstack((X_test, y_test))

np.savetxt('<trainの出力ファイルを設定する>', train_data, fmt='%.3f', delimiter=',', header=header, comments='')
np.savetxt('<testの出力ファイルを設定する>', test_data, fmt='%.3f', delimiter=',', header=header, comments='')
NNC(クラウド)で学習

NNC専用のアップローダーでtrainとtestデータのCSVをアップロードします。

サンプルにあったimage_recognition.MNIST.LeNetのモデルをほぼそのまま使いました。

インプットは768個(32×24)の1次元、出力は3つです。

はまったのが、ESP32のメモリのサイズです。
最初にサンプルにあったパラメータのままで作成したら、生成されたESP32に実装する重みパラメータが書かれたファイル(MainRuntime_parameters.c)が2MB近くありました。そのためESP32のコンパイルのリンク時に以下のエラーが発生しました。

region `dram0_0_seg' overflowed by 225088 bytes

Convolution(畳み込み)のフィルター数やフィルターサイズを調整して、最終的にMainRuntime_parameters.cのサイズを30KBにしたら、上記エラーがでなくなりました。

学習した結果の精度(Acuacy)ですが、水増し(Data Augmentation)しなかった場合は、0.9865ですが、水増しした場合、0.8768となりました。データを加工しすぎかもしれません。

学習済みモデルをESP32に実装

■NNPを使った実装の場合
NNPファイルをダウンロードします。
以下のコマンドで展開します。

$ nnabla_cli convert -O CSRC -b 1 result_evaluate.nnp <出力ディレクトリ>

以下の5つのファイルができます。
・GNUmakefile
・MainRuntime_example.c // サンプルプログラム
・MainRuntime_inference.c  // 推論するプログラム
・MainRuntime_inference.h
・MainRuntime_parameters.c  // 重みパラメータの配列
・MainRuntime_parameters.h

MainRuntime_inference.c、MainRuntime_inference.h、MainRuntime_parameters.c、ainRuntime_parameters.hをlibディレクトリ以下に配置します。

これらはnnabla-c-runtimeライブラリに依存しますので、platformio.iniに以下のように追記します。

lib_deps = https://github.com/me-no-dev/ESPAsyncWebServer.git, https://github.com/sony/nnabla-c-runtime.git

推論を実装したプログラム(NNP版)です。
・main.c

#include <Wire.h>
#include "MLX90640_API.h"
#include "MLX90640_I2C_Driver.h"
#include <WiFi.h>
#include <ESPmDNS.h>
#include <ArduinoOTA.h>
// #include <FS.h>
#include <ESPAsyncWebServer.h>
#include <SPIFFS.h>
#include <SPIFFSEditor.h>

#include "MainRuntime_inference.h"
#include "MainRuntime_parameters.h"

// WIFI設定
const char* ssid = "*******";
const char* password =  "******";

// mDNS
const char *hostName = "esp32";

// SPIFFSEditorの認証
const char *http_username = "admin";
const char *http_password = "admin";

// MLX90640
const byte MLX90640_address = 0x33; //Default 7-bit unshifted address of the MLX90640
#define TA_SHIFT 8                  //Default shift for MLX90640 in open air
float mlx90640To[768];
paramsMLX90640 mlx90640;

// SKETCH BEGIN
AsyncWebServer server(80);
AsyncWebSocket ws("/ws");

void *_context = NULL;

int mode = 0;

// template <class T, size_t N>
// void standard(T (&data)[N])
// {
//   float ave = std::accumulate(std::begin(data), std::end(data), 0.0) / N;
//   float sd = sqrt(std::inner_product(std::begin(data), std::end(data), std::begin(data), 0.0) / N - ave * ave);
//   std::for_each(std::begin(data), std::end(data), [&ave, &sd](float &temperature) {
//     temperature = (temperature - ave) / sd;
//   });
// }

// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
template <class T, size_t N>
int normalize(T (&data)[N], float min, float max)
{
  int cnt = 0;
  std::for_each(std::begin(data), std::end(data), [&min, &max, &cnt](float &temperature) {
    if (min > temperature)
    {
      temperature = 0;
      cnt++;
    } else if (max < temperature) {
      temperature = 1;
    } else {
      temperature = (temperature - min)/(max - min);
    }
  });
  return cnt;
}

int predict(float *data, float r)
{
  memcpy(nnablart_mainruntime_input_buffer(_context, 0), data, 768*4);
  nnablart_mainruntime_inference(_context);
  float *probs = nnablart_mainruntime_output_buffer(_context, 0);

  Serial.printf("predict %.3f, %.3f, %.3f\n", probs[0], probs[1], probs[2]);
  for (int cl = 0; cl < NNABLART_MAINRUNTIME_OUTPUT0_SIZE; cl++)
  {
    if (probs[cl] > r)
    {
      return cl;
    }
  }
  return 9;
}

void onWsEvent(AsyncWebSocket *server, AsyncWebSocketClient *client, AwsEventType type, void *arg, uint8_t *data, size_t len)
{
  if (type == WS_EVT_CONNECT)
  {
    Serial.printf("ws[%s][%u] connect\n", server->url(), client->id());
  }
  else if (type == WS_EVT_DISCONNECT)
  {
    Serial.printf("ws[%s][%u] disconnect\n", server->url(), client->id());
  }
  else if (type == WS_EVT_ERROR)
  {
    Serial.printf("ws[%s][%u] error(%u): %s\n", server->url(), client->id(), *((uint16_t *)arg), (char *)data);
  }
}

//Returns true if the MLX90640 is detected on the I2C bus
boolean isConnected()
{
  Wire.beginTransmission((uint8_t)MLX90640_address);
  if (Wire.endTransmission() != 0)
    return (false); //Sensor did not ACK
  return (true);
}

void setUpMLX90640()
{
  if (isConnected() == false)
  {
    Serial.println("MLX90640 not detected at default I2C address. Please check wiring. Freezing.");
    while (1)
      ;
  }

  //Get device parameters - We only have to do this once
  int status;
  uint16_t eeMLX90640[832];
  status = MLX90640_DumpEE(MLX90640_address, eeMLX90640);
  if (status != 0)
    Serial.println("Failed to load system parameters");

  status = MLX90640_ExtractParameters(eeMLX90640, &mlx90640);
  if (status != 0)
    Serial.println("Parameter extraction failed");
  Serial.println(status);

  //Once params are extracted, we can release eeMLX90640 array

  //MLX90640_SetRefreshRate(MLX90640_address, 0x02); //Set rate to 2Hz
  MLX90640_SetRefreshRate(MLX90640_address, 0x03); //Set rate to 4Hz
  //MLX90640_SetRefreshRate(MLX90640_address, 0x07); //Set rate to 64Hz
}

void setUpOTA()
{
  ArduinoOTA.onStart([]() { Serial.println("Update Start"); });
  ArduinoOTA.onEnd([]() { Serial.println("Update End"); });
  ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
    Serial.printf("Progress: %u%%\r", (progress / (total / 100)));
  });
  ArduinoOTA.onError([](ota_error_t error) {
    Serial.println("OTA ERROR");
  });
  ArduinoOTA.setHostname(hostName);
  ArduinoOTA.begin();
}

void setup()
{
  Wire.begin();
  Serial.begin(115200);
  Serial.setDebugOutput(true);
  WiFi.begin(ssid, password);

  while (WiFi.status() != WL_CONNECTED)
  {
    delay(1000);
    Serial.println("Connecting to WiFi..");
  }
  Serial.println("WiFi connected!");

  //OTA
  setUpOTA();

  // mDNS
  if (!MDNS.begin(hostName))
  {
    Serial.println("Error setting up MDNS responder!");
    while (1)
    {
      delay(1000);
    }
  }

  SPIFFS.begin(true);

  ws.onEvent(onWsEvent);
  server.addHandler(&ws);

  // SPIFFSにあるファイルをブラウザで/editから編集できる
  server.addHandler(new SPIFFSEditor(SPIFFS, http_username, http_password));

  // SPIFFS
  server.serveStatic("/", SPIFFS, "/").setDefaultFile("index.htm");

  // predictモード開始
  server.on("/start", HTTP_GET, [](AsyncWebServerRequest *request) {
    _context = nnablart_mainruntime_allocate_context(MainRuntime_parameters);
    mode = 1;
    request->send(200, "text/plain", String("ok"));
  });

  // predictモード終了
  server.on("/stop", HTTP_GET, [](AsyncWebServerRequest *request) {
    mode = 0;
    nnablart_mainruntime_free_context(_context);
    request->send(200, "text/plain", String("ok"));
  });

  server.onNotFound([](AsyncWebServerRequest *request) {
    Serial.printf("NOT_FOUND: ");
    request->send(404);
  });
  server.begin();

  // MLX90640の初期設定
  setUpMLX90640();
}

void loop()
{
  ArduinoOTA.handle();

  // WebSocket接続してない時は何もしない
  if (ws.count() <= 0)
  {
    return;
  }

  long startTime = millis();
  for (byte x = 0; x < 2; x++)
  {
    uint16_t mlx90640Frame[834];
    MLX90640_GetFrameData(MLX90640_address, mlx90640Frame);
    // float vdd = MLX90640_GetVdd(mlx90640Frame, &mlx90640);
    float Ta = MLX90640_GetTa(mlx90640Frame, &mlx90640);

    float tr = Ta - TA_SHIFT; //Reflected temperature based on the sensor ambient temperature
    float emissivity = 0.95;

    MLX90640_CalculateTo(mlx90640Frame, &mlx90640, emissivity, tr, mlx90640To);
  }
  long calculatedTime = millis();

  AsyncWebSocketMessageBuffer *buffer = ws.makeBuffer((uint8_t *)&mlx90640To, sizeof(mlx90640To));
  ws.binaryAll(buffer); // バイナリー(uint8_tの配列)で全クライアントに送信

  int top_class = 9;
  if (mode == 1)
  {
    // 28度以下を0、35度以上を1とし、0以上1以下の値に変換
    int cnt = normalize(mlx90640To, 28, 35);

    Serial.printf("predict mode: count of below 28C: %d\n", cnt);

    // 28度以下が768ドット中の700ドット以上の場合は、predictしない
    if (cnt < 700)
    {
      // 精度が0.5以上の場合のみ、結果を返す
      top_class = predict(mlx90640To, 0.5);
    }
    ws.textAll("result:" + String(top_class));
  }
  long finishedTime = millis();
  Serial.printf("calculated secs:%.2f, finished secs:%.2f, top_class: %d\n", (float)(calculatedTime - startTime) / 1000, (float)(finishedTime - startTime) / 1000, top_class);
}

■NNBを使った実装の場合
NNBファイルをダウンロードします。
SPIFFSでファイルを読み込めるようにするため、dataディレクトリ以下にNNBファイルを置きます。

nnabla-c-runtimeライブラリに依存しますので、platformio.iniに以下のように追記します。

lib_deps = https://github.com/me-no-dev/ESPAsyncWebServer.git, https://github.com/sony/nnabla-c-runtime.git

推論を実装したプログラム(NNB版)です。
NNBファイルをcharの配列として読み込んで、それをnn_network_t構造体にキャストしています。
NNBファイルのサイズは98KBだと大丈夫でしたが、136KBだと読み込み時にエラーが発生しました。使えるヒープサイズは150KB以上あるのですが。連続したヒープメモリが必要だからかもしれません。

#include <Wire.h>
#include "MLX90640_API.h"
#include "MLX90640_I2C_Driver.h"
#include <WiFi.h>
#include <ESPmDNS.h>
#include <ArduinoOTA.h>
#include <FS.h>
#include <ESPAsyncWebServer.h>
#include <SPIFFS.h>
#include <SPIFFSEditor.h>

// #include "MainRuntime_inference.h"
// #include "MainRuntime_parameters.h"
#include <nnablart/network.h>
#include <nnablart/runtime.h>
// WIFI設定
const char *ssid = "******";
const char *password = "******";

// mDNS
const char *hostName = "esp32";

// SPIFFSEditorの認証
const char *http_username = "admin";
const char *http_password = "admin";

// MLX90640
const byte MLX90640_address = 0x33; //Default 7-bit unshifted address of the MLX90640
#define TA_SHIFT 8                  //Default shift for MLX90640 in open air
float mlx90640To[768];
paramsMLX90640 mlx90640;

// SKETCH BEGIN
AsyncWebServer server(80);
AsyncWebSocket ws("/ws");

void *_context = NULL;

int mode = 0;

char *nnb;


// template <class T, size_t N>
// void standard(T (&data)[N])
// {
//   float ave = std::accumulate(std::begin(data), std::end(data), 0.0) / N;
//   float sd = sqrt(std::inner_product(std::begin(data), std::end(data), std::begin(data), 0.0) / N - ave * ave);
//   std::for_each(std::begin(data), std::end(data), [&ave, &sd](float &temperature) {
//     temperature = (temperature - ave) / sd;
//   });
// }

// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
template <class T, size_t N>
int normalize(T (&data)[N], float min, float max)
{
  int cnt = 0;
  std::for_each(std::begin(data), std::end(data), [&min, &max, &cnt](float &temperature) {
    if (min > temperature)
    {
      temperature = 0;
      cnt++;
    } else if (max < temperature) {
      temperature = 1;
    } else {
      temperature = (temperature - min)/(max - min);
    }
  });
  return cnt;
}

int predict(float *data, float r)
{
  memcpy(rt_input_buffer(_context, 0), data, 768*4);
  rt_forward(_context);

  // Serial.printf("num:%d, ", rt_num_of_output(_context));
  // Serial.printf("size:%d\n", rt_output_size(_context, 0));

  float  *probs = (float *)rt_output_buffer(_context, 0);

  Serial.printf("predict %.3f, %.3f, %.3f\n", probs[0], probs[1], probs[2]);
  for (int cl = 0; cl < 3; cl++)
  {
    if (probs[cl] > r)
    {
      return cl;
    }
  }
  return 9;
}

void onWsEvent(AsyncWebSocket *server, AsyncWebSocketClient *client, AwsEventType type, void *arg, uint8_t *data, size_t len)
{
  if (type == WS_EVT_CONNECT)
  {
    Serial.printf("ws[%s][%u] connect\n", server->url(), client->id());
  }
  else if (type == WS_EVT_DISCONNECT)
  {
    Serial.printf("ws[%s][%u] disconnect\n", server->url(), client->id());
  }
  else if (type == WS_EVT_ERROR)
  {
    Serial.printf("ws[%s][%u] error(%u): %s\n", server->url(), client->id(), *((uint16_t *)arg), (char *)data);
  }
}

//Returns true if the MLX90640 is detected on the I2C bus
boolean isConnected()
{
  Wire.beginTransmission((uint8_t)MLX90640_address);
  if (Wire.endTransmission() != 0)
    return (false); //Sensor did not ACK
  return (true);
}

void setUpMLX90640()
{
  if (isConnected() == false)
  {
    Serial.println("MLX90640 not detected at default I2C address. Please check wiring. Freezing.");
    while (1)
      ;
  }

  //Get device parameters - We only have to do this once
  int status;
  uint16_t eeMLX90640[832];
  status = MLX90640_DumpEE(MLX90640_address, eeMLX90640);
  if (status != 0)
    Serial.println("Failed to load system parameters");

  status = MLX90640_ExtractParameters(eeMLX90640, &mlx90640);
  if (status != 0)
    Serial.println("Parameter extraction failed");
  Serial.println(status);

  //Once params are extracted, we can release eeMLX90640 array

  //MLX90640_SetRefreshRate(MLX90640_address, 0x02); //Set rate to 2Hz
  MLX90640_SetRefreshRate(MLX90640_address, 0x03); //Set rate to 4Hz
  //MLX90640_SetRefreshRate(MLX90640_address, 0x07); //Set rate to 64Hz
}

void setUpOTA()
{
  ArduinoOTA.onStart([]() { Serial.println("Update Start"); });
  ArduinoOTA.onEnd([]() { Serial.println("Update End"); });
  ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
    Serial.printf("Progress: %u%%\r", (progress / (total / 100)));
  });
  ArduinoOTA.onError([](ota_error_t error) {
    Serial.println("OTA ERROR");
  });
  ArduinoOTA.setHostname(hostName);
  ArduinoOTA.begin();
}

void setup()
{
  Wire.begin();
  Serial.begin(115200);
  Serial.setDebugOutput(true);
  WiFi.begin(ssid, password);

  while (WiFi.status() != WL_CONNECTED)
  {
    delay(1000);
    Serial.println("Connecting to WiFi..");
  }
  Serial.println("WiFi connected!");

  //OTA
  setUpOTA();

  // mDNS
  if (!MDNS.begin(hostName))
  {
    Serial.println("Error setting up MDNS responder!");
    while (1)
    {
      delay(1000);
    }
  }

  SPIFFS.begin(true);

  ws.onEvent(onWsEvent);
  server.addHandler(&ws);

  // SPIFFSにあるファイルをブラウザで/editから編集できる
  server.addHandler(new SPIFFSEditor(SPIFFS, http_username, http_password));

  // SPIFFS
  server.serveStatic("/", SPIFFS, "/").setDefaultFile("index.htm");

  // predictモード開始
  server.on("/start", HTTP_GET, [](AsyncWebServerRequest *request) {
    // _context = nnablart_mainruntime_allocate_context(MainRuntime_parameters);

    /* READ FILE */
    File fp = SPIFFS.open("/result.nnb", FILE_READ); // 読み取り
    Serial.printf("file size:%d\n", fp.size());
    rt_return_value_t ret = rt_allocate_context(&_context);
    nnb = (char *)malloc(fp.size());  // malloc使わないとバッファーオーバーフローエラーが発生した
    fp.readBytes(nnb, fp.size());
    fp.close();

    nn_network_t *net = (nn_network_t *)nnb;
    ret = rt_initialize_context(_context, net);
    Serial.println(ret);

    mode = 1;
    request->send(200, "text/plain", String("ok"));
  });

  // predictモード終了
  server.on("/stop", HTTP_GET, [](AsyncWebServerRequest *request) {
    mode = 0;

    rt_free_context(&_context);
    free(nnb);

    request->send(200, "text/plain", String("ok"));
  });

  server.onNotFound([](AsyncWebServerRequest *request) {
    Serial.printf("NOT_FOUND: ");
    request->send(404);
  });
  server.begin();

  // MLX90640の初期設定
  setUpMLX90640();
}

void loop()
{
  ArduinoOTA.handle();

  // WebSocket接続してない時は何もしない
  if (ws.count() <= 0)
  {
    return;
  }

  long startTime = millis();
  for (byte x = 0; x < 2; x++)
  {
    uint16_t mlx90640Frame[834];
    MLX90640_GetFrameData(MLX90640_address, mlx90640Frame);
    // float vdd = MLX90640_GetVdd(mlx90640Frame, &mlx90640);
    float Ta = MLX90640_GetTa(mlx90640Frame, &mlx90640);

    float tr = Ta - TA_SHIFT; //Reflected temperature based on the sensor ambient temperature
    float emissivity = 0.95;

    MLX90640_CalculateTo(mlx90640Frame, &mlx90640, emissivity, tr, mlx90640To);
  }
  long calculatedTime = millis();

  AsyncWebSocketMessageBuffer *buffer = ws.makeBuffer((uint8_t *)&mlx90640To, sizeof(mlx90640To));
  ws.binaryAll(buffer); // バイナリー(uint8_tの配列)で全クライアントに送信

  int top_class = 9;
  if (mode == 1)
  {
    // 28度以下を0、35度以上を1とし、0以上1以下の値に変換
    int cnt = normalize(mlx90640To, 28, 35);

    Serial.printf("predict mode: count of below 28C: %d\n", cnt);

    // 28度以下が768ドット中の700ドット以上の場合は、predictしない
    if (cnt < 700)
    {
      // 精度が0.5以上の場合のみ、結果を返す
      top_class = predict(mlx90640To, 0.5);
    }
    ws.textAll("result:" + String(top_class));
  }
  long finishedTime = millis();
  Serial.printf("calculated secs:%.2f, finished secs:%.2f, top_class: %d\n", (float)(calculatedTime - startTime) / 1000, (float)(finishedTime - startTime) / 1000, top_class);
}

推論の結果をブラウザで表示するためHTML,js、cssを以下のように修正しました。
・index.htm

<!DOCTYPE html>
<html>
  <head>
    <meta http-equiv="Content-type" content="text/html; charset=utf-8">
    <meta name="viewport" content="width=350,initial-scale=0.5">
    <title>赤外線アレイカメラ MLX90640</title>
    <link rel="stylesheet" type="text/css" href="app.css" >
  </head>
  <body id="body" onload="onBodyLoad()">
    <div id="container">
      <div id="header"><span><a id="predict" href="#">判定する</a></span>
      </div>
      <canvas id="canvas" width="32" height="24"></canvas>
      <div id="scale"></div>  
      <div id="scale-divisions">
        <div id="min-tmp-division"><span id="min-down" class="divisionBtn">&#x25c0;</span><span id="min-tmp"></span><span id="min-up" class="divisionBtn">&#x25b6;</span></div>
        <div id="max-tmp-division"><span id="max-down" class="divisionBtn">&#x25c0;</span><span id="max-tmp"></span><span id="max-up" class="divisionBtn">&#x25b6;</span></div>
      </div>
      <div id="messages"></div>
      <div id="result"></div>
    </div>
    <script src="app.js"></script>
  </body>
</html>

・app.js

const ge = (s) => { return document.getElementById(s); }
const ce = (s) => { return document.createElement(s); }
const gc = (s) => { return document.getElementsByClassName(s); }
const addMessage = (m) => {
  // メッセージ表示
  // console.log(m);
  const msg = ce("div");
  msg.innerText = m;
  ge("messages").appendChild(msg);
  ge("messages").append
}
skt = {
  ws: null,
  start: function () {
    // WebSocketを開始
    ws = ws = new WebSocket('ws://' + document.location.host + '/ws', ['arduino']);
    ws.binaryType = "arraybuffer";
    ws.onopen = (e) => {
      addMessage("Connected");
    };
    ws.onclose = (e) => {
      addMessage("Disconnected");
    };
    ws.onerror = (e) => {
      console.log("ws error", e);
      addMessage("Error");
    };
    ws.onmessage = (e) => {
      if (e.data instanceof ArrayBuffer) {
        // バイナリーデータの場合
        this.parseTemparatures(e.data);
      } else {
        console.log(predict.mode);
        const result = ge("result");
        const resultStrings = { "0": "グー", "1": "チョキ", "2": "パー" };
        const m = e.data.match(/^result:([0-2])/);
        if (m) {
          result.innerHTML = resultStrings[m[1]];
        } else {
          result.innerHTML = "";
        }
      }
    };
  },
  parseTemparatures: (data) => {
    // Uint8Arrayにセットされた4バイトのfloatの配列をFloat32Arrayの型付き配列にセットし、描画します。
    const dv = new DataView(data);
    const byteSize = 4;
    const tmps = new Float32Array(data.byteLength / byteSize);
    for (let i = 0; i < tmps.length; i++) {
      tmps[i] = dv.getFloat32(i * byteSize, true);
    }
    cv.draw(tmps);
  }
};
const HSVtoRGB = (h, s, v) => {
  // HSVからRGBに変換 パラメータh,s,vは0以上1以下
  let r, g, b, i, f, p, q, t;
  i = Math.floor(h * 6);
  f = h * 6 - i;
  p = v * (1 - s);
  q = v * (1 - f * s);
  t = v * (1 - (1 - f) * s);
  switch (i % 6) {
    case 0: r = v, g = t, b = p; break;
    case 1: r = q, g = v, b = p; break;
    case 2: r = p, g = v, b = t; break;
    case 3: r = p, g = q, b = v; break;
    case 4: r = t, g = p, b = v; break;
    case 5: r = v, g = p, b = q; break;
  }
  return {
    r: Math.round(r * 255),
    g: Math.round(g * 255),
    b: Math.round(b * 255)
  };
}

const rgb = {
  min: 20,  // 表示する最低温度
  max: 35,  // 表示する最高温度
  get: function (tmp) {
    // 温度から色(RGB)を取得
    let rate = 1 - (tmp - this.min) / (this.max - this.min);
    if (rate < 0) {
      rate = 0;
    } else if (rate > 1) {
      rate = 1;
    }
    // const h = 0.7*rate;
    const h = (Math.tanh(rate * 2 - 1.5) + 1) / 2 - 0.04; // 適当
    return HSVtoRGB(h, 1, 1);
  }
};

const cv = {
  canvas: null,
  content: null,
  imageData: null,
  createCanvas: function () {
    // Canvasを作成
    this.canvas = ge('canvas');
    this.context = this.canvas.getContext('2d');
    this.imageData = this.context.createImageData(32, 24);
  },
  createScale: function () {
    // スケールを作成
    const scale = ge('scale');
    let color, t, span;
    for (let i = 0; i < 100; i++) {
      t = i * (rgb.max - rgb.min) / 100 + rgb.min;
      span = ce('span');
      color = rgb.get(t);
      span.style.backgroundColor = span.style.color = 'rgb(' + color.r + ',' + color.g + ',' + color.b + ')';
      scale.appendChild(span);
    }
    this.createDivisions();
  },
  createDivisions: () => {
    // 目盛り作成
    ge("min-tmp").textContent = rgb.min;
    ge("max-tmp").textContent = rgb.max;

    var scaleDivisions = ge('scale-divisions');
    const divisions = gc('division');
    while (divisions.length > 0) {
      scaleDivisions.removeChild(divisions[0]);
    }
    let div;
    for (let temp = rgb.min + 5; temp < rgb.max; temp += 5) {
      div = ce('div');
      div.innerText = temp;
      div.classList.add("division");
      div.style.left = 640 * (temp - rgb.min) / (rgb.max - rgb.min) - 7 + 'px';
      scaleDivisions.appendChild(div);
    }
  },
  draw: function (tmps) {
    // 描画
    const data = this.imageData.data; // RGBA の順番のデータを含んだ 1次元配列。それぞれの値は 0 ~ 255 の範囲となります。
    if (data.length / 4 != tmps.length) {
      alert(なにかおかしいです);
      return;
    }
    let tmp, color, j, mirror;
    const maxValue = 255;
    for (let i = 0; i < tmps.length; i++) {
      if (true) {
        j = 4 * i;
      } else {
        // 左右反転させる
        mirror = (31 - i % 32) + parseInt(i / 32) * 32;
        j = 4 * mirror;
      }
      tmp = tmps[i];
      color = rgb.get(tmp);
      data[j] = color.r;
      data[j + 1] = color.g;
      data[j + 2] = color.b;
      data[j + 3] = maxValue;
    }
    this.context.putImageData(this.imageData, 0, 0);
  }
};

const divisionBtns = gc('divisionBtn');
for (let i = 0; i < divisionBtns.length; i++) {
  divisionBtns[i].addEventListener('click', function () {
    // 目盛り変更
    const d = 5; // 目盛りの間隔
    switch (this.id) {
      case 'min-down':
        if (rgb.min >= 5) rgb.min -= d;
        break;
      case 'min-up':
        if (rgb.min <= rgb.max - 2 * d) rgb.min += d;
        break;
      case 'max-down':
        if (rgb.min <= rgb.max - 2 * d) rgb.max -= d;
        break;
      case 'max-up':
        if (rgb.max <= 90) rgb.max += d;
        break;
    }
    cv.createDivisions();
  });
}

const predict = {
  mode: false,
  sw: function () {
    const elem = ge("predict");
    const result = ge("result");
    if (this.mode) {
      fetch("stop");
      elem.classList.remove('predictiong');
      elem.innerHTML = "判定する";
      result.style.display = "none";
      this.mode = false;
    } else {
      fetch("start");
      elem.classList.add('predictiong');
      elem.innerHTML = "判定中";
      result.style.display = "block";
      this.mode = true;
    }
  }
}

let onBodyLoad = function () {
  cv.createScale();
  skt.start();
  cv.createCanvas();

  ge("predict").onclick = predict.sw;
}

・app.css

body {
    display: flex;
    justify-content: center;
    align-items: center;
    background-color: black;
    color: #ffffff;
    font-size: 12px;
}
#container {
  position: relative;
}
#canvas {
        background: #666;
        width: 640px;
        height: 480px;
}
#scale, #scale-divisions {
  width: 100%;
  height: 24px;
}
#scale-divisions {
  position: relative;
}
#min-tmp-division {
  position: absolute;
  left: -20px;
}
#max-tmp-division {
  position: absolute;
  right: -20px;
}
.division {
  position: absolute;
}
#scale {
  display: flex;
}
#scale span {
  display: block;
  width: 1%;
  height: 23px;
}
#messages {
  overflow-y: auto;
}
.divisionBtn {
  cursor: pointer;
}
.divisionBtn:hover{
  color: #99b2ce;
}
#result {
  position: absolute;
  right: 1px;
  top: 1px;
  height: 18px;
  width: 50px;
  background: #faf62d;
  padding: 3px 10px;
  color: #000000;
  font-weight: bold;
  display: none;
  top: 30px;
}
#header {
  text-align: right;
  padding-bottom: 5px;
}
a#predict{
  position: relative;
  display: inline-block;
  font-weight: bold;
  padding: 0.25em 0.5em;
  text-decoration: none;
  color: rgb(242, 248, 250);
  background: rgb(20, 85, 7);
  transition: .4s;
  width: 80px;
  text-align: center;
}
a#predict.predictiong {
  background: rgb(230, 33, 7);
}
a#predict:hover{
    background: #5bf15b;
    color: white;
}       

@media only screen and (max-device-width: 480px) {
  body {
    font-size:24px;
  }
  #messages {
    margin-top:20px;
  }
}
結果

上記プログラムを実行した結果、だいたい正確に推論するようでした。
ただし背景に熱を持つもの、モニタとか蛍光灯等があると、ノイズが入ってうまくいかないです。

Facebooktwitterlinkedintumblrmail

タグ: ,

名前
E-mail
URL
コメント

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)