Since the shape of weights is known in advance, the MatMul weights can be created with format tag dnnl::memory::format_tag::any to enable the library to choose the most appropriate layout for best performance.
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
namespace {
void init_vector(std::vector<float> &v) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(0, 1);
for (auto &e : v)
e = u(gen);
}
void init_vector(std::vector<uint8_t> &v) {
std::mt19937 gen;
std::uniform_int_distribution<unsigned int> u(0, 255);
for (auto &e : v)
e = static_cast<uint8_t>(u(gen));
}
}
int number_of_runs = 1;
matmul::primitive_desc matmul_pd_create(
int64_t K, int64_t N, const engine &eng) {
primitive_attr attr;
post_ops po;
attr.set_post_ops(po);
matmul::desc
matmul_d(a_md, b_md, c_md);
return matmul::primitive_desc(matmul_d, attr, eng);
}
void prepare_input(memory &A_u8_mem, memory &scale_f32_mem, memory &zp_A_mem,
memory &zp_C_mem) {
int64_t M = A_u8_mem.get_desc().dims()[0];
int64_t N = scale_f32_mem.get_desc().dims()[0];
int64_t K = A_u8_mem.get_desc().dims()[1];
std::vector<uint8_t> A_u8(M * K);
init_vector(A_u8);
std::vector<float> scales_f32(N);
init_vector(scales_f32);
int32_t zp_A = 128, zp_C = 40;
write_to_dnnl_memory(A_u8.data(), A_u8_mem);
write_to_dnnl_memory(&zp_A, zp_A_mem);
write_to_dnnl_memory(&zp_C, zp_C_mem);
write_to_dnnl_memory(scales_f32.data(), scale_f32_mem);
}
void sanity_check(memory &C_u8_mem, memory &zp_C_mem) {
int64_t M = C_u8_mem.get_desc().dims()[0];
int64_t N = C_u8_mem.get_desc().dims()[1];
int32_t zp_C = 0;
std::vector<uint8_t> C_u8(M * N);
read_from_dnnl_memory(C_u8.data(), C_u8_mem);
read_from_dnnl_memory(&zp_C, zp_C_mem);
for (int64_t i = 0; i < M * N; ++i)
if (C_u8[i] < zp_C)
throw std::logic_error(
"Smoke check failed."
"\n\tQuantized value is smaller than the zero point,"
"\n\twhich should not happen since ReLU was applied.");
}
void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K,
const memory &B_s8_mem, const engine &eng) {
prepare_input(A_u8_mem, scale_f32_mem, zp_A_mem, zp_C_mem);
stream s(eng);
for (int run = 0; run < number_of_runs; ++run)
matmul_p.execute(s,
{{DNNL_ARG_SRC, A_u8_mem}, {DNNL_ARG_WEIGHTS, B_s8_mem},
{DNNL_ARG_DST, C_u8_mem},
{DNNL_ARG_ATTR_OUTPUT_SCALES, scale_f32_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_C_mem}});
s.wait();
sanity_check(C_u8_mem, zp_C_mem);
}
void inference_int8_matmul(engine::kind engine_kind) {
const int64_t K = 96;
const int64_t N = 1000;
auto matmul_pd = matmul_pd_create(K, N, eng);
std::vector<float> B_f32(K * N);
init_vector(B_f32);
memory B_s8_mem(matmul_pd.weights_desc(), eng);
{
stream s(eng);
memory B_f32_mem(
{{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng);
write_to_dnnl_memory(B_f32.data(), B_f32_mem);
reorder(B_f32_mem, B_s8_mem).execute(s, B_f32_mem, B_s8_mem);
s.wait();
}
matmul matmul_p(matmul_pd);
for (int64_t M : {1, 100})
infer(matmul_p, M, N, K, B_s8_mem, eng);
}
int main(int argc, char **argv) {
engine::kind engine_kind = parse_engine_kind(argc, argv);
return handle_example_errors(inference_int8_matmul, engine_kind);
}
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
#define DNNL_RUNTIME_S32_VAL
A wildcard value for int32_t values that are unknown at a primitive creation time.
Definition: dnnl_types.h:1330
#define DNNL_RUNTIME_DIM_VAL
A wildcard value for dimensions that are unknown at a primitive creation time.
Definition: dnnl_types.h:1305
#define DNNL_RUNTIME_F32_VAL
A wildcard value for floating point values that are unknown at a primitive creation time.
Definition: dnnl_types.h:1322
#define DNNL_ARG_DST
A special mnemonic for destination argument for primitives that have a single destination.
Definition: dnnl_types.h:2307
#define DNNL_ARG_SRC
A special mnemonic for source argument for primitives that have a single source.
Definition: dnnl_types.h:2283
@ matmul_d
matmul descriptor
oneDNN namespace
Definition: dnnl.hpp:74
@ any
Placeholder memory format tag.
@ u8
8-bit unsigned integer.
@ s8
8-bit signed integer.
@ f32
32-bit/single-precision floating point.
@ s32
32-bit signed integer.