联系方式

  • QQ:99515681
  • 邮箱:99515681@qq.com
  • 工作时间:8:00-23:00
  • 微信:codinghelp

您当前位置:首页 >> javajava

日期:2018-09-24 01:41

#pragma once

#include "E:/c++/C++/BPCPLUSPLUS/Matrix.h"

#include "CNN.h"

SigmoidActivator sigmoidActivator;

void LoadData(char*, const int, int, int, Matrix[], int, double);

void saveParameters(FNN);

double SelfFitErrorRatio(double);

int main(int argc, char* argv[])

{

const int iTrainNum = 60000, iTestNum = 10000;

//const int iTrainNum = 2, iTestNum = 1;

double last_error_ratio = 1.0;

Matrix MatTrainData[iTrainNum], MatTrainLabel[iTrainNum], MatTestData[iTestNum], MatTestLabel[iTestNum];

LoadData("./../../../MNISTDat/train-images-idx3-ubyte", iTrainNum, 28, 28, MatTrainData, 1, 1.0 / 255.0);

LoadData("./../../../MNISTDat/train-labels-idx1-ubyte", iTrainNum, 1, 10, MatTrainLabel, 0, 1.0 / 255.0);

LoadData("./../../../MNISTDat/t10k-images-idx3-ubyte", iTestNum, 28, 28, MatTestData, 1, 1.0 / 255.0);

LoadData("./../../../MNISTDat/t10k-labels-idx1-ubyte", iTestNum, 1, 10, MatTestLabel, 0, 1.0 / 255.0);

CNN cnn(1.0);

int epoch = 0;

double error_ratio = 10.0, rate;

for (int iterTime = 0; iterTime < 30; iterTime++)

{

epoch += 1;

rate = SelfFitErrorRatio(error_ratio);

rate = 0.1;

printf("%0.2lf\n", rate);

cnn.train(MatTrainData, MatTrainLabel, rate, 1, iTrainNum);

cnn.test(MatTestData, MatTestLabel, iTestNum);

}

system("pause");

return 0;

}


double SelfFitErrorRatio(double error_ratio)

{

double dRate;

if (error_ratio > 1.0)

dRate = 0.02;

else if (error_ratio > 0.5 && error_ratio <= 1.0)

dRate = 0.01;

else if (error_ratio > 0.3 && error_ratio <= 0.5)

dRate = 0.005;

else if (error_ratio > 0.1 && error_ratio <= 0.3)

dRate = 0.001;

else if (error_ratio <= 0.1)

dRate = 0.0002;

return dRate;

}


void saveParameters(FNN fnn)

{

FILE* pfout = fopen("parameter", "wb");

for (int i = 0; i < fnn.m_ilayerNum; i++)

{

if (pfout) fprintf(pfout, "\n第%d层权重参数为", i + 1);

for (int iRow = 0; iRow < fnn.m_players[i].m_W.m_iRow; iRow++)

{

if (pfout) fprintf(pfout, "\n");

for (int iCol = 0; iCol < fnn.m_players[i].m_W.m_iCol; iCol++)

{

if (pfout) fprintf(pfout, "%10.5lf", fnn.m_players[i].m_W(iRow, iCol));

}

}

if (pfout) fprintf(pfout, "\n第%d层偏执参数为", i + 1);

for (int iRow = 0; iRow < fnn.m_players[i].m_b.m_iRow; iRow++)

{

for (int iCol = 0; iCol < fnn.m_players[i].m_b.m_iCol; iCol++)

{

if (pfout) fprintf(pfout, "%10.5lf", fnn.m_players[i].m_b(iRow, iCol));

}

}

}

SAFE_FCLOSE(pfout);


}


void LoadData(char* FilePath, const int iTrainNum, int iHeight, int iWidth, Matrix MatTrainData[], int iFlag, double scale)//iFlag:1,表示读取图像; iFlag:0,表示读取向量;

{

int num1[4], i;               //用于存放前四个int32的数;

FILE *fp;

if ((fp = fopen(FilePath, "rb")) == NULL)

{

printf("Open file error");

exit(-1);

}

if (iFlag == 1)

{

fread(num1, 4, 4, fp);        //读取前4个int32类型的数据;

 //for (i = 0; i<4; i++)

 //{                             // 由大端模式转换为小端模式,其实对于占4个字节的数据来说,由小端转大端,也是一样的代码;

 //temp = (num1[i] >> 24 & 0x000000FF) | (num1[i] >> 8 & 0x0000FF00) | (num1[i] << 8 & 0x00FF0000) | (num1[i] << 24 & 0xFF000000);

 //num1[i] = temp;

 //}

}

if (iFlag == 0)

{

fread(num1, 4, 2, fp);        //读取前2个int32类型的数据;

}

unsigned char num2[1000];    //用于存放读出的1000个unsigned char类型的数;

if (iFlag == 1)

{

for (i = 0; i < iTrainNum; i++)

{

fread(num2, 1, iHeight * iWidth * 1, fp);     //读取1幅图像的数据;

MatTrainData[i] = Matrix(iHeight,iWidth, 0.0);

for (int j = 0; j < iHeight; j++)

for (int k = 0; k < iWidth; k++)

{

MatTrainData[i](j,k) = ((double)num2[j*iWidth + k])*scale;

//MatTrainData[i](j, k) = ((double)num2[j * iWidth + k])*scale;

}

}

}

else

{

for (i = 0; i < iTrainNum; i++)

{

fread(num2, 1, 1, fp);

MatTrainData[i] = Matrix(1, 10, 0.1);

for (int j = 0; j < iWidth; j++)

if ((int)(num2[0]) == j)

{

MatTrainData[i](0, j) = 0.9;

}

}


}

fclose(fp);

}


版权所有:留学生程序网 2020 All Rights Reserved 联系方式:QQ:99515681 电子信箱:99515681@qq.com
免责声明:本站部分内容从网络整理而来,只供参考!如有版权问题可联系本站删除。