-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfully_connect_filter.cpp
More file actions
51 lines (41 loc) · 1.37 KB
/
fully_connect_filter.cpp
File metadata and controls
51 lines (41 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "fully_connect_filter.hpp"
#include "block.hpp"
namespace fool{
template<typename Dtype>
void FullyConnectFilter<Dtype>::FilterInitialize(){
std::shared_ptr<FillerFilter<Dtype>> weight_filler(
GetFiller<Dtype>("gaussian", 0.1, 0.001));
// default use bias_term
weight_filler->Fill(this->m_lr_params[0].get());
std::shared_ptr<FillerFilter<Dtype>> bias_filler(
GetFiller<Dtype>("constant", 1.0));
bias_filler->Fill(this->m_lr_params[1].get());
}
template<typename Dtype>
void FullyConnectFilter<Dtype>::Reshape(
const std::vector<Block<Dtype>*>& inputs,
const std::vector<Block<Dtype>*>& outputs){
m_M = inputs[0]->count(0, 1);
vector<int> output_shape = inputs[0]->shape();
output_shape[1] = m_N;
outputs[0]->SyncedBlock(output_shape);
// default use bias_term
vector<int> output_bias_shape(1, m_M) ;
m_output_bias.SyncedBlock(output_bias_shape);
fool_set(m_M, Dtype(1), m_output_bias.mutable_cpu_data());
}
template<typename Dtype>
void FullyConnectFilter<Dtype>::Forward_cpu(
const std::vector<Block<Dtype>*>& inputs,
const std::vector<Block<Dtype>*>& outputs){
//Gemm<float>(false,
// m_K, m_N, (Dtype)1.,
// inputs[0]->cpu_data(), );
}
template<typename Dtype>
void FullyConnectFilter<Dtype>::Backward_cpu(
const std::vector<Block<Dtype>*>& outputs,
const std::vector<Block<Dtype>*>& inputs){
}
INSTANTIATE_CLASS(FullyConnectFilter);
}