使用C# 调用tensorflow和keras 训练样本
有些样本比较小,算力要求不高的项目我们可以使用个人电脑的CPU来进行学习和培训。工业自动化或者一些特殊场合,有时我们习惯于用C#等做人机交互的前端。对于这样的项目我们如何来调取tensorflow 或者keras来培训模型呢?结合之前发布的 C#来部署tensorflow的培训模型 我们就可以利用C# 完成从图片加载、分类、训练到部署的所有操作。完成一个完整的AI应用项目。
应用准备
本示例应用 VS 2015 Python 3.6, django 2.1, tensorflow 2.0。
实现方法
使用python构建一个django 后台,通过后台对样本进行学习训练,生成模板。C#通过system.net来操作http协议和后台传递数据和通讯。
django 后台
配置路由
urlpatterns = [url(r\'^train/$\', views.train_start, name=\'train_start\'),]
配置试图
务必新建一个线程来开始训练样本。因为对于http协议,最好尽快响应,避免占用路由资源。
def train_start(request):if request.method == \'POST\':rev = json.loads(request.body)elif request.method == \'GET\':rev = json.loads(request.GET.get(\'data\'))else:return HttpResponse(\'access deny\')if rev[\'Event\'] == \'TRAIN\':threading.Thread(target=Manager, args=[rev, rev[\'Event\']]).start()res = {\'return\': 0}else:res = Manager(rev, rev[\'Event\'])return HttpResponse(json.dumps(res))
执行训练
任务分配
from AICore.MainTest.Train import Trainfrom AICore.MainTest.Msg import GetMsgdef Manager (rec, event):functions = {\'TRAIN\': Train,\"CHECK\": GetMsg}func = functions[event]return func(rec)
调取训练
def Train(rec):global train_dirglobal labelsglobal IMG_Wglobal IMG_Hglobal BATCH_SIZEglobal ModelNameglobal Echosglobal ModelTypeif len(rec) != 0:Dir = server_dir + str(rec[\'project\'][\'iID\'])IMG_W = rec[\'project\'][\'fWidth\']IMG_H = rec[\'project\'][\'fHeighth\']ModelName = rec[\'project\'][\'strPN\']BATCH_SIZE = rec[\'Mode\'][\'iBatchSize\']Echos = rec[\'Mode\'][\'iEchos\']ModelType = rec[\'Mode\'][\'strName\']else:Dir = train_dir + \'4\'train, train_label = get_file(Dir, labels)# 训练数据及标签X_train, X_val, Y_train, Y_val = get_batch(train, train_label, IMG_W, IMG_H)if ModelType == \'SEQModel\':seq_model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)elif ModelType == \'VGG19\':VGG19Model(X_train, X_val, Y_train, Y_val , IMG_W, IMG_H, BATCH_SIZE, Dir, ModelName, Echos)
关键代码,以贯序模型为例。
import tensorflow.compat.v1 as tf1import numpy as npimport matplotlib.pyplot as pltModelPath = \"D:\\\\AI_Vision\\\\AIServer\\\\AICore\\BaseModel\\\\\"Msg = \'\'def SEQModel_Msg(MsgEx):global Msgif MsgEx == Msg:Msg = \'\'return Msgdef freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):graph = session.graphwith graph.as_default():output_names = output_names or []print(\"output_names\", output_names)input_graph_def = graph.as_graph_def()print(\"len node1\", len(input_graph_def.node))if clear_devices:for node in input_graph_def.node:node.device = \"\"frozen_graph = tf1.graph_util.convert_variables_to_constants(session, input_graph_def,output_names)outgraph = tf1.graph_util.remove_training_nodes(frozen_graph) # 云掉与推理无关的内容print(\"##################################################################\")for node in outgraph.node:print(\'node:\', node.name)print(\"len node1\", len(outgraph.node))return outgraphdef showCurve(dir,history):fig = plt.figure() # 新建一张图if \'accuracy\' in dict(history.history).keys():plt.plot(history.history[\'accuracy\'], label=\'training acc\')plt.plot(history.history[\'val_accuracy\'], label=\'val acc\')else:plt.plot(history.history[\'acc\'], label=\'training acc\')plt.plot(history.history[\'val_acc\'], label=\'val acc\')plt.title(\'model accuracy\')plt.ylabel(\'accuracy\')plt.xlabel(\'epoch\')plt.legend(loc=\'lower right\')fig.savefig(dir + \"\\\\TrainModel\" + \"\\\\accuracy.png\")fig = plt.figure()plt.plot(history.history[\'loss\'], label=\'training loss\')plt.plot(history.history[\'val_loss\'], label=\'val loss\')plt.title(\'model loss\')plt.ylabel(\'loss\')plt.xlabel(\'epoch\')plt.legend(loc=\'upper right\')fig.savefig(dir + \"\\\\TrainModel\" + \"\\\\loss.png\")def seq_model(X_train, X_val, Y_train, Y_val, img_width, img_height, batch_size, dir, modelName,echos):global Msgnb_train_samples = len(X_train)nb_validation_samples = len(X_val)model = tf1.keras.models.Sequential()model.add(tf1.keras.layers.Conv2D(32, (3, 3), input_shape=(img_width, img_height, 3)))model.add(tf1.keras.layers.Activation(\'relu\'))model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))model.add(tf1.keras.layers.Conv2D(32, (3, 3)))model.add(tf1.keras.layers.Activation(\'relu\'))model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))model.add(tf1.keras.layers.Conv2D(64, (3, 3)))model.add(tf1.keras.layers.Activation(\'relu\'))model.add(tf1.keras.layers.MaxPooling2D(pool_size=(2, 2)))model.add(tf1.keras.layers.Flatten())model.add(tf1.keras.layers.Dense(64))model.add(tf1.keras.layers.Activation(\'relu\'))model.add(tf1.keras.layers.Dropout(0.5))model.add(tf1.keras.layers.Dense(1))model.add(tf1.keras.layers.Activation(\'sigmoid\'))model.compile(loss=\'binary_crossentropy\',optimizer=\'rmsprop\',metrics=[\'accuracy\'])#model.summary()stringlist = []model.summary(print_fn=lambda x: stringlist.append(x))Msg += \"\\n\".join(stringlist)+\"\\r\\n\"train_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)val_datagen = tf1.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True)train_generator = train_datagen.flow(np.array(X_train), Y_train, batch_size=batch_size)validation_generator = val_datagen.flow(np.array(X_val), Y_val, batch_size=batch_size)history = model.fit_generator(train_generator,steps_per_epoch=nb_train_samples // batch_size,epochs=echos,validation_data=validation_generator,validation_steps=nb_validation_samples // batch_size)Msg += \"loss\" + str(history.history[\'loss\']) + \"\\r\\n\"Msg += \"val_loss\" + str(history.history[\'val_loss\']) + \"\\r\\n\"if \'accuracy\' in dict(history.history).keys():Msg += \"accuracy\" + str(history.history[\'accuracy\']) + \"\\r\\n\"Msg += \"val_accuracy\" + str(history.history[\'val_accuracy\']) + \"\\r\\n\"else:Msg += \"accuracy\" + str(history.history[\'acc\']) + \"\\r\\n\"Msg += \"val_accuracy\" + str(history.history[\'val_acc\']) + \"\\r\\n\"model.save_weights(ModelPath+\'model_wieghts.h5\')model.save(ModelPath+\'model_keras.h5\')Msg += \'input is :\'+ model.input.name + \"\\r\\n\" + \'output is:\'+ model.output.name + \"\\r\\n\"#saved_model = tf.keras.models.load_model(\"D:\\AI Vision\\AICore\\model_wieghts.h5\")#saved_model.save(\"D:\\AI Vision\\AICore\\mode_test\")tf1.reset_default_graph()tf1.keras.backend.set_learning_phase(0) # 调用模型前一定要执行该命令tf1.disable_v2_behavior() # 禁止tensorflow2.0的行为network = tf1.keras.models.load_model(ModelPath+\'model_keras.h5\')frozen_graph = freeze_session(tf1.keras.backend.get_session(),output_names=[out.op.name for out in network.outputs])tf1.train.write_graph(frozen_graph, dir+\"\\\\TrainModel\", modelName+\".pb\", as_text=False)showCurve(dir,history)Msg = \'Finish\'
前端代码
创建一个Post类发送和获取后台信息。
using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Threading.Tasks;using System.Net;using System.IO;using Newtonsoft.Json;namespace LT.Device{class PostMan{public string PostWebRequest(string postUrl, string paramData, Encoding dataEncode){string ret = string.Empty;try{byte[] byteArray = dataEncode.GetBytes(paramData); //转化HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));webReq.Method = \"POST\";webReq.ContentType = \"application/x-www-form-urlencoded\";webReq.Timeout = 5000;webReq.ContentLength = byteArray.Length;Stream newStream = webReq.GetRequestStream();newStream.Write(byteArray, 0, byteArray.Length);//写入参数newStream.Close();HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);ret = sr.ReadToEnd();sr.Close();response.Close();newStream.Close();}catch (Exception ex){return ex.Message;}return ret;}public string GetWebRequest(string postUrl, string paramData, Encoding dataEncode){string ret = string.Empty;try{byte[] byteArray = dataEncode.GetBytes(paramData); //转化HttpWebRequest webReq = (HttpWebRequest)WebRequest.Create(new Uri(postUrl));webReq.Method = \"GET\";webReq.ContentType = \"application/x-www-form-urlencoded\";webReq.Timeout = 5000;webReq.ContentLength = byteArray.Length;Stream newStream = webReq.GetRequestStream();newStream.Write(byteArray, 0, byteArray.Length);//写入参数newStream.Close();HttpWebResponse response = (HttpWebResponse)webReq.GetResponse();StreamReader sr = new StreamReader(response.GetResponseStream(), Encoding.Default);ret = sr.ReadToEnd();sr.Close();response.Close();newStream.Close();}catch (Exception ex){return ex.Message;}return ret;}}}
通过一个线程调取上面编写的类来发送信息
#region Method StepsCodeprotected override void StepsCode(ref DataBaseStruct dataDirector, ref ChainsElement chains){PostMan man = new PostMan();PostParamStruct param = new PostParamStruct();switch(numberActiveStep){case 0://initialRootDir = BasePath + dataDirector.CurrProject.iID.ToString();viewModeLocal.UpdateMessage(\"\",true);GoToNextStep(0.2);break;case 1: //Check Model Existif(1==CopyFolder(RootDir, Properties.Settings.Default.FTPPath)){GoToNextStep(0.1);}elseGoToStep(9,0.1);break;case 2://SendPostparam.Event = \"TRAIN\";param.project = dataDirector.CurrProject;param.Labels = dataDirector.LabelCatetorys;param.Mode = dataDirector.CurrMode;strPram = JsonConvert.SerializeObject(param);try{string ret = man.PostWebRequest(\"http://127.0.0.1:8000/ai/train/\", strPram, Encoding.UTF8);JObject jo = JObject.Parse(ret);viewModeLocal.UpdateMessage(strPram);GoToNextStep(0.5);}catch(Exception e){viewModeLocal.UpdateMessage(e.Message.ToString());GoToStep(9, 0.5);}break;case 3://Check Msgparam.Event = \"CHECK\";strPram = JsonConvert.SerializeObject(param);try{string ret = man.PostWebRequest(\"http://127.0.0.1:8000/ai/train/\", strPram, Encoding.UTF8);JObject jo = JObject.Parse(ret);ret = jo[\"return\"].ToString();if (ret == \"Finish\")GoToStep(10, 0.5);else{if (ret != \"\")viewModeLocal.UpdateMessage(jo[\"return\"].ToString());GoToNextStep(0.5);}}catch (Exception e){viewModeLocal.UpdateMessage(e.Message.ToString());GoToStep(9, 0.5);}break;case 4://...GoToNextStep(0.1);break;case 5://...GoToNextStep(0.1);break;case 6://...GoToNextStep(0.5);break;case 7://Check MsgGoToNextStep(0.5);break;case 8://...GoToStep(3, 0.5);break;case 9://...this.Enable = false;break;case 10:this.Enable = false;break;}}#endregion
最终效果
c# 显示训练的过程信息
显示损失-精度图谱