《BP算法 java實現(xiàn)》由會員分享,可在線閱讀,更多相關《BP算法 java實現(xiàn)(5頁珍藏版)》請在裝配圖網(wǎng)上搜索。
1、package backp;
import java.*;
import java.awt.*;
import java.io.*;
import java.util.Scanner;
//by realmagician
import org.omg.CORBA.portable.InputStream;
public class backpro {
public static void main(String args[])
{
String filename=new String("delta.in");
try {
FileInpu
2、tStream fileInputStream=new FileInputStream(filename);
Scanner sinScanner=new Scanner(fileInputStream);
int attN,hidN,outN,samN;
attN=sinScanner.nextInt();
outN=sinScanner.nextInt();
hidN=sinScanner.nextInt();
samN=sinScanner.nextInt();
//System.out.println(attN+" "+outN+
3、" "+hidN+" "+samN);
double samin[][]=new double[samN][attN];
double samout[][]=new double[samN][outN];
for(int i=0;i
4、tDouble();
}
}
int times=10000;
double rate=0.5;
BP2 bp2=new BP2(attN,outN,hidN,samN,times,rate);
bp2.train(samin, samout);
for(int i=0;i
5、i=0;i
6、tN;++i)
{
testin[i]=testinScanner.nextDouble();
}
testout=bp2.getResault(testin);
for(int i=0;i
7、);
}
}
class BP2//包含一個隱含層的神經網(wǎng)絡
{
double dw1[][],dw2[][];
int hidN;//隱含層單元個數(shù)
int samN;//學習樣例個數(shù)
int attN;//輸入單元個數(shù)
int outN;//輸出單元個數(shù)
int times;//迭代次數(shù)
double rate;//學習速率
boolean trained=false;//保證在得結果前,先訓練
BP2(int attN,int outN,int hidN,int samN,int times,double rate)
{
this.
8、attN=attN;
this.outN=outN;
this.hidN=hidN;
this.samN=samN;
dw1=new double[hidN][attN+1];//每行最后一個是閾值w0
for(int i=0;i
9、;i
10、le tempout[]=new double[outN];
double wcout[]=new double[outN];
double wchid[]=new double[hidN];
while((count--)>0)//迭代訓練
{
dis=0;
for(int i=0;i
11、 temphid[j]+=dw1[j][k]*samin[i][k];
temphid[j]+=dw1[j][attN];//計算閾值產生的隱含層結果
temphid[j]=1.0/(1+Math.exp(-temphid[j] ));
}
for(int j=0;j
12、=dw2[j][hidN];//計算閾值產生的輸出結果
tempout[j]=1.0/(1+Math.exp( -tempout[j] ));
}
//計算每個輸出單元的誤差項
for(int j=0;j
13、 for(int j=0;j
14、][k]+=rate*wcout[j]*temphid[k];
}
dw2[j][hidN]=rate*wcout[j];
}
//改變隱含層的權值
for(int j=0;j
15、k;
}
trained=true;
}
public double[] getResault(double samin[])
{
double temphid[]=new double[hidN];
double tempout[]=new double[outN];
if(trained==false)
return null;
for(int j=0;j
16、id[j]+=dw1[j][k]*samin[k];
temphid[j]+=dw1[j][attN];//計算閾值產生的隱含層結果
temphid[j]=1.0/(1+Math.exp(-temphid[j] ));
}
for(int j=0;j