博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Python学习之路:MINST实战第一版
阅读量:6193 次
发布时间:2019-06-21

本文共 4021 字,大约阅读时间需要 13 分钟。

1、项目介绍:

搭建浅层神经网络完成MNIST数字图像的识别。

2、详细步骤:

(1)将二维图像转成一维,MNIST图像大小为28*28,转成一维就是784。

(2)定义好神经网络的相关参数:

# MNIST数据集相关的常数INPUT_NODE = 784;OUTPUT_NODE = 10;LAYER1_NODE = 500;BATCH_SIZE = 100;LEARNING_RATE_BASE = 0.8;LEARNING_RATE_DECAY = 0.99;REGULARIZATION_RATE = 0.0001;TRAINING_STEPS = 5000;MOVING_ACERTAGE_DECAY = 0.99;

(3)定义一个接口来算神网输出结果,之所以设置这个接口是因为为了适应滑动平均的方法:

def interface(input_tensor,avg_class,weights1,biases1,weights2,biases2):    if avg_class == None:        layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1);        return tf.matmul(layer1,weights2)+biases2;    else:        layer1 = tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights1))+avg_class.\                            average(biases1));        return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2);

(4)定义训练主函数:

训练主函数按照:输入输出placeholder,各层网络节点权值与偏移量定义,设置滑动平滑,输出两种结果y和acroos_y,定义y的交叉熵和正则化,定义指数衰减学习,训练。

def train(mnist):    x = tf.placeholder(dtype=tf.float32,shape=[None,INPUT_NODE],name="x_input");    y_ = tf.placeholder(dtype=tf.float32,shape=[None,OUTPUT_NODE],name="y_output");        weights1 = tf.Variable(tf.truncated_normal(shape=[INPUT_NODE,LAYER1_NODE],stddev=0.1));    biases1 = tf.Variable(tf.constant(0.1,dtype=tf.float32,shape=[LAYER1_NODE]));        weights2 = tf.Variable(tf.truncated_normal(shape=[LAYER1_NODE,OUTPUT_NODE],stddev=0.1));    biases2 = tf.Variable(tf.constant(0.1,dtype=tf.float32,shape=[OUTPUT_NODE]));        y = interface(x,None,weights1,biases1,weights2,biases2);        global_step = tf.Variable(0,trainable=False);    variable_averages = tf.train.ExponentialMovingAverage(MOVING_ACERTAGE_DECAY,global_step);    variable_averages_op = variable_averages.apply(tf.trainable_variables());    average_y = interface(x,variable_averages,weights1,biases1,weights2,biases2);        # why????????????????????    # 这里的交叉熵是以 y 为标准的    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1));    cross_entropy_mean = tf.reduce_mean(cross_entropy);        regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE);    regularization = regularizer(weights1) + regularizer(weights2);        loss = cross_entropy_mean + regularization;        learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,                                               global_step,                                               mnist.train.num_examples / BATCH_SIZE,                                              LEARNING_RATE_DECAY);        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step);            with tf.control_dependencies([train_step,variable_averages_op]):        train_op = tf.no_op(name="train");        correct_prediction = tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1));    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));        with tf.Session() as sess:        tf.global_variables_initializer().run();                validate_feed = {x:mnist.validation.images, y_:mnist.validation.labels};        test_feed = {x:mnist.test.images, y_:mnist.test.labels};                for i in range(TRAINING_STEPS):            if i % 1000 == 0:                validate_acc = sess.run(accuracy,feed_dict = validate_feed);                print("After %d training step(s), validation accuracy using average model is %g " \                      % (i, validate_acc));            xs,ys = mnist.train.next_batch(BATCH_SIZE)            sess.run(train_op,feed_dict={x:xs,y_:ys});                                test_acc = sess.run(accuracy,feed_dict = test_feed);        print(("After %d training step(s), test accuracy using average model is %g"                %(TRAINING_STEPS, test_acc)));

(5)主函数代码:

def main(argv = None):    mnist = input_data.read_data_sets("C://Users/hasee/TensorFlow/实战TensorFlow代码/datasets/MNIST_data/",                                  one_hot=True);    train(mnist);

(6)运行程序:

if __name__ == "__main__":    main();

 

转载于:https://www.cnblogs.com/doubest/p/10695369.html

你可能感兴趣的文章
华山模拟器安装
查看>>
Mysql实现企业级主从复制和互为主从模式架构
查看>>
电脑维修常见软件工具
查看>>
使用SSL证书保障网络游戏信息安全
查看>>
oracle db_link
查看>>
CentOS7.2编译安装LNMP
查看>>
Nginx负载平衡 + Tomcat + 会话存储Redis配置要点
查看>>
Scala学习 - 基础类型
查看>>
前端代码中经常遇到的问题
查看>>
我的友情链接
查看>>
MariaDB10.3 系统版本表 有效防止数据丢失
查看>>
常用命令分析局域网连通故障
查看>>
Ubuntu下配置舒服的Python开发环境
查看>>
(1)Android开发优化---------UI优化
查看>>
ssh服务故障拉起脚本
查看>>
javascript权威指南——笔记(第三章)
查看>>
zijiAPIMVC
查看>>
Jquery文字一行一行向上滚动
查看>>
Max server memory是否需要配置?
查看>>
Apache+tomcat集群和负载均衡使用mod_proxy总结
查看>>