This C++ API example demonstrates how to build an AlexNet model training.
#include <assert.h>
#include <math.h>
#include "example_utils.hpp"
auto eng =
engine(engine_kind, 0);
stream s(eng);
std::vector<primitive> net_fwd, net_bwd;
std::vector<std::unordered_map<int, memory>> net_fwd_args, net_bwd_args;
const int batch = 32;
std::vector<float> net_src(batch * 3 * 227 * 227);
std::vector<float> net_dst(batch * 96 * 27 * 27);
for (size_t i = 0; i < net_src.size(); ++i)
net_src[i] = sinf((float)i);
std::vector<float> conv_weights(product(conv_weights_tz));
std::vector<float> conv_bias(product(conv_bias_tz));
for (size_t i = 0; i < conv_weights.size(); ++i)
conv_weights[i] = sinf((float)i);
for (size_t i = 0; i < conv_bias.size(); ++i)
conv_bias[i] = sinf((float)i);
auto conv_user_src_memory
= memory({{conv_src_tz}, dt::f32, tag::nchw}, eng);
write_to_dnnl_memory(net_src.data(), conv_user_src_memory);
auto conv_user_weights_memory
= memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng);
write_to_dnnl_memory((void *)conv_weights.data(), conv_user_weights_memory);
auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory);
auto conv_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
auto conv_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
auto conv_weights_md = memory::desc({conv_weights_tz}, dt::f32, tag::any);
auto conv_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);
conv_bias_md, conv_dst_md, conv_strides, conv_padding,
conv_padding);
auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng);
auto conv_src_memory = conv_user_src_memory;
if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) {
conv_src_memory = memory(conv_pd.src_desc(), eng);
net_fwd.push_back(reorder(conv_user_src_memory, conv_src_memory));
}
auto conv_weights_memory = conv_user_weights_memory;
if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) {
conv_weights_memory = memory(conv_pd.weights_desc(), eng);
net_fwd.push_back(
reorder(conv_user_weights_memory, conv_weights_memory));
net_fwd_args.push_back({{
DNNL_ARG_FROM, conv_user_weights_memory},
}
auto conv_dst_memory = memory(conv_pd.dst_desc(), eng);
net_fwd.push_back(convolution_forward(conv_pd));
const float negative_slope = 0.0f;
auto relu_pd = eltwise_forward::primitive_desc(relu_desc, eng);
auto relu_dst_memory = memory(relu_pd.dst_desc(), eng);
net_fwd.push_back(eltwise_forward(relu_pd));
net_fwd_args.push_back(
const uint32_t local_size = 5;
const float alpha = 0.0001f;
const float beta = 0.75f;
const float k = 1.0f;
alpha, beta, k);
auto lrn_pd = lrn_forward::primitive_desc(lrn_desc, eng);
auto lrn_dst_memory = memory(lrn_pd.dst_desc(), eng);
auto lrn_workspace_memory = memory(lrn_pd.workspace_desc(), eng);
net_fwd.push_back(lrn_forward(lrn_pd));
net_fwd_args.push_back(
auto pool_user_dst_memory
= memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
write_to_dnnl_memory(net_dst.data(), pool_user_dst_memory);
auto pool_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
pool_strides, pool_kernel, pool_padding, pool_padding);
auto pool_pd = pooling_forward::primitive_desc(pool_desc, eng);
auto pool_workspace_memory = memory(pool_pd.workspace_desc(), eng);
net_fwd.push_back(pooling_forward(pool_pd));
auto pool_dst_memory = pool_user_dst_memory;
if (pool_pd.dst_desc() != pool_user_dst_memory.get_desc()) {
pool_dst_memory = memory(pool_pd.dst_desc(), eng);
net_fwd_args.back().insert({
DNNL_ARG_DST, pool_dst_memory});
net_fwd.push_back(reorder(pool_dst_memory, pool_user_dst_memory));
} else {
net_fwd_args.back().insert({
DNNL_ARG_DST, pool_dst_memory});
}
std::vector<float> net_diff_dst(batch * 96 * 27 * 27);
for (size_t i = 0; i < net_diff_dst.size(); ++i)
net_diff_dst[i] = sinf((float)i);
auto pool_user_diff_dst_memory
= memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng);
write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory);
auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::f32, tag::any);
auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::f32, tag::any);
pool_diff_src_md, pool_diff_dst_md, pool_strides, pool_kernel,
pool_padding, pool_padding);
auto pool_bwd_pd
= pooling_backward::primitive_desc(pool_bwd_desc, eng, pool_pd);
auto pool_diff_dst_memory = pool_user_diff_dst_memory;
if (pool_dst_memory.get_desc() != pool_user_diff_dst_memory.get_desc()) {
pool_diff_dst_memory = memory(pool_dst_memory.get_desc(), eng);
net_bwd.push_back(
reorder(pool_user_diff_dst_memory, pool_diff_dst_memory));
net_bwd_args.push_back({{
DNNL_ARG_FROM, pool_user_diff_dst_memory},
}
auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_desc(), eng);
net_bwd.push_back(pooling_backward(pool_bwd_pd));
auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::f32, tag::any);
lrn_pd.src_desc(), lrn_diff_dst_md, local_size, alpha, beta, k);
auto lrn_bwd_pd = lrn_backward::primitive_desc(lrn_bwd_desc, eng, lrn_pd);
auto lrn_diff_dst_memory = pool_diff_src_memory;
if (lrn_diff_dst_memory.get_desc() != lrn_bwd_pd.diff_dst_desc()) {
lrn_diff_dst_memory = memory(lrn_bwd_pd.diff_dst_desc(), eng);
net_bwd.push_back(reorder(pool_diff_src_memory, lrn_diff_dst_memory));
}
auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_desc(), eng);
net_bwd.push_back(lrn_backward(lrn_bwd_pd));
auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::f32, tag::any);
auto relu_src_md = conv_pd.dst_desc();
relu_diff_dst_md, relu_src_md, negative_slope);
auto relu_bwd_pd
= eltwise_backward::primitive_desc(relu_bwd_desc, eng, relu_pd);
auto relu_diff_dst_memory = lrn_diff_src_memory;
if (relu_diff_dst_memory.get_desc() != relu_bwd_pd.diff_dst_desc()) {
relu_diff_dst_memory = memory(relu_bwd_pd.diff_dst_desc(), eng);
net_bwd.push_back(reorder(lrn_diff_src_memory, relu_diff_dst_memory));
}
auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_desc(), eng);
net_bwd.push_back(eltwise_backward(relu_bwd_pd));
std::vector<float> conv_user_diff_weights_buffer(product(conv_weights_tz));
std::vector<float> conv_diff_bias_buffer(product(conv_bias_tz));
auto conv_user_diff_weights_memory
= memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng);
write_to_dnnl_memory(conv_user_diff_weights_buffer.data(),
conv_user_diff_weights_memory);
auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng);
write_to_dnnl_memory(conv_diff_bias_buffer.data(), conv_diff_bias_memory);
auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::f32, tag::any);
auto conv_diff_bias_md = memory::desc({conv_bias_tz}, dt::f32, tag::any);
auto conv_diff_weights_md
= memory::desc({conv_weights_tz}, dt::f32, tag::any);
auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::f32, tag::any);
auto conv_bwd_weights_desc
conv_bwd_src_md, conv_diff_weights_md, conv_diff_bias_md,
conv_diff_dst_md, conv_strides, conv_padding, conv_padding);
auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
conv_bwd_weights_desc, eng, conv_pd);
auto conv_bwd_src_memory = conv_src_memory;
if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) {
conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng);
net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory));
}
auto conv_diff_dst_memory = relu_diff_src_memory;
if (conv_bwd_weights_pd.diff_dst_desc()
!= relu_diff_src_memory.get_desc()) {
conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng);
net_bwd.push_back(reorder(relu_diff_src_memory, conv_diff_dst_memory));
}
net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd));
net_bwd_args.push_back({{
DNNL_ARG_SRC, conv_bwd_src_memory},
auto conv_diff_weights_memory = conv_user_diff_weights_memory;
if (conv_bwd_weights_pd.diff_weights_desc()
!= conv_user_diff_weights_memory.get_desc()) {
conv_diff_weights_memory
= memory(conv_bwd_weights_pd.diff_weights_desc(), eng);
net_bwd_args.back().insert(
net_bwd.push_back(reorder(
conv_diff_weights_memory, conv_user_diff_weights_memory));
net_bwd_args.push_back({{
DNNL_ARG_FROM, conv_diff_weights_memory},
} else {
net_bwd_args.back().insert(
}
assert(net_fwd.size() == net_fwd_args.size() && "something is missing");
assert(net_bwd.size() == net_bwd_args.size() && "something is missing");
int n_iter = 1;
while (n_iter) {
for (size_t i = 0; i < net_fwd.size(); ++i)
net_fwd.at(i).execute(s, net_fwd_args.at(i));
for (size_t i = 0; i < net_bwd.size(); ++i)
net_bwd.at(i).execute(s, net_bwd_args.at(i));
--n_iter;
}
s.wait();
}
int main(int argc, char **argv) {
return handle_example_errors(simple_net, parse_engine_kind(argc, argv));
}
@ convolution_direct
Direct convolution.
@ pooling_max
Max pooling.
@ lrn_across_channels
Local response normalization (LRN) across multiple channels.
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
@ forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
#define DNNL_ARG_DIFF_SRC
A special mnemonic for primitives that have a single diff source argument.
Definition: dnnl_types.h:2374
#define DNNL_ARG_DIFF_BIAS
Gradient (diff) of the bias tensor argument.
Definition: dnnl_types.h:2443
#define DNNL_ARG_DIFF_WEIGHTS
A special mnemonic for primitives that have a single diff weights argument.
Definition: dnnl_types.h:2416
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition: dnnl_types.h:2307
#define DNNL_ARG_WORKSPACE
Workspace tensor argument.
Definition: dnnl_types.h:2366
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition: dnnl_types.h:2289
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition: dnnl_types.h:2283
#define DNNL_ARG_DIFF_DST
A special mnemonic for primitives that have a single diff destination argument.
Definition: dnnl_types.h:2395
#define DNNL_ARG_BIAS
Bias tensor argument.
Definition: dnnl_types.h:2357
#define DNNL_ARG_WEIGHTS
A special mnemonic for primitives that have a single weights argument.
Definition: dnnl_types.h:2330
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition: dnnl_types.h:2310
oneDNN namespace
Definition: dnnl.hpp:74
kind
Kinds of engines.
Definition: dnnl.hpp:874
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1205
data_type
Data type specification.
Definition: dnnl.hpp:1130
std::vector< dim > dims
Vector of dimensions.
Definition: dnnl.hpp:1115