-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcnn_model.py
35 lines (28 loc) · 1.13 KB
/
cnn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 1 14:32:58 2017
@author: Xiangyong Cao
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import slim
from utils import patch_size
num_band = 220 #162 # 220 103
num_classes = 16 # 16 9
def conv_net(x):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu):
net = tf.reshape(x, [-1, patch_size, patch_size, num_band])
net = slim.conv2d(net, 100, 5, padding='VALID',
weights_initializer=tf.contrib.layers.xavier_initializer())
net = slim.max_pool2d(net,2,padding='SAME')
net = slim.conv2d(net, 300, 3, padding='VALID',
weights_initializer=tf.contrib.layers.xavier_initializer())
net = slim.max_pool2d(net,2,padding='SAME')
net = slim.flatten(net)
net = slim.fully_connected(net,200)
net = slim.fully_connected(net,100)
logits = slim.fully_connected(net, num_classes, activation_fn=None)
return logits