コード内容
コード保存のため投稿。コードは汚い。
SONYが開発した、neural network console(NNC)をCUIから実行。SONYが提供しているNNablaライブラリを使用する。
準備
- NNCのProject File(*.sdcproj)
- Neural networkのArchitecture(*.nntxt)
- 学習したparameter(*.h5)
- Testするinput data(*.csv)
実行
start_DL.sh
#!/bin/bash#delete previous outputrm output_result.csv
#folder_dir=`pwd`sdcproj=`ls*.sdcproj`testdata=`ls*.csv`network=`ls*.nntxt`parameters=`ls*.h5`#Value test data
nnabla_cli forward -c$network-p$parameters-d$testdata-o ./
#nnabla_cli forward -c net.nntxt -p parameters.h5 -d output_result.csv -o ./#calculate accuracy
python calc_accuracy.py
結果の確認
脳SPECT 認知症分類
脳SPECT画像から認知症を分類するAIを作成。
AIの精度の確認はPythonを用いて以下のコードを実行。
4クラス分類なので、4クラスごとにprobabilityが算出される。
この4クラスの内、最大のprobabilityをAIの推定したクラスとする。
その後、実際のラベルと比較することでAccuracyを計測。
calc_accuracy.py
importcsv,osimportnumpyasnp#for confusion matrix
fromsklearn.metricsimportconfusion_matriximportpandasaspdimportseabornassnsimportmatplotlib.pyplotasplttotal_cnt=0match_cnt=0csv_filename="output_result.csv"csv_obj=open(csv_filename)reader_obj=csv.reader(csv_obj)org_values=list()result_values=list()forrowinreader_obj:ifreader_obj.line_num==1:continue# 最初の行をスキップする
org_label=int(row[1])results=np.array(row[2:11],np.float32)result_label=int(results.argmax())total_cnt+=1if(org_label==result_label):match_cnt+=1org_values.append(org_label)result_values.append(result_label)print("match_cnt={0} : total_cnt={1}".format(match_cnt,total_cnt))print("Accuracy:{0:.3f}\n".format(match_cnt/total_cnt))print(org_values)print(result_values)defprint_cmx(org_values,result_values):label_name=['AD','DLB','FTLD','NPH']cmx_data=confusion_matrix(org_values,result_values)df_cmx=pd.DataFrame(cmx_data,index=label_name,columns=label_name)print(df_cmx)title="Accuracy:"+str(match_cnt/total_cnt)sns.heatmap(df_cmx,annot=True,square=True)plt.title(title)plt.ylabel('Estimated Age')plt.xlabel('True Age')plt.figure()plt.show()plt.savefig('confusion_matrix.png')print_cmx(org_values,result_values)csv_obj.close()
脳MRI 小児年齢推定
脳MRI画像から小児の年齢を推定。
#!/bin/bash<<COMMENTOUT
#folder archtecher
0_12_months(or 13_months)[working directory]
├── start_evaluation.sh
├── images
│ ├── select
│ │ ├── T1_0.25m_01 (12)_cropped.bmp
│ │ ├── T1_0.25m_01 (16)_cropped.bmp
│ │ ├── T1_0.25m_01 (19)_cropped.bmp
│ │ ├── ・
│ │ ├── ・
│ │ ├── ・
│ ├── select_T2
├── net.nntxt
├── parameters.h5
└── test.csv
COMMENTOUT
#delete previous alldata folderif[-e output_result.csv ];then
rm-f output_result.csv progress.txt result.txt
fi#copy files in each folders to alldata folder unset testdata network parameters
testdata=`ls*.csv`network=`ls*.nntxt`parameters=`ls*.h5`#value test data using NNabla
nnabla_cli forward -c$network-p$parameters-d$testdata-o.#show outputcounter=0
while read row;do
counter=$((${counter}+1))#remove title in first rowif[$counter-eq 1 ]then
continue
fi
file_name=`echo${row} |cut -d , -f 1`file_name=`basename$file_name |sed -e's/T1_//g'`true_age=`echo${row} |cut -d , -f 11`estimated_age=`echo${row} |cut -d , -f 12`printf"File name : ${file_name}\nTrue Age : ${true_age} months\nEstimated Age : %.2f months\n\n"${estimated_age}>> result.txt
done< output_result.csv
cat result.txt