[Deep Learning]ONNX Runtime C++(1):DL Model Inference Using ONNX Runtime


  • Download ONNX Runtime released Zip file.
    • I chose onnxruntime-win-x64-gpu-1.13.1.zip
    • Set its include directory and library dependencies in Visual Studio

Example: FnsCandyStyleTransfer

This example is based on microsoft/onnxruntime-inference-examples.

I use Windows x64 platform, so LibPNG is a dependency: Compile and use libpng and zlib in Visual Studio

#pragma once
// fns_candy_style_transfer.h
// https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c
#include "OnnxSample.h"
#include "onnxruntime_c_api.h"
#include <string>

class FnsCandyStyleTransfer : public OnnxSample
{
public:
    FnsCandyStyleTransfer();
    ~FnsCandyStyleTransfer();
    virtual void Run();

private:
    bool Init();
    bool EnableCuda();
    void VerifyInputOutputCount();
    bool RunInference(const char* input_file, const char* output_file);
    const OrtApi* g_ort;

    static void ResizeImage720(const char* rawPath, const char* resizedPath);

private:
    const ORTCHAR_T* model_path;
    const std::string execution_provider;
    OrtSession* session;
    OrtSessionOptions* session_options;
    OrtEnv* env;
};
// fns_candy_style_transfer.cpp
#include "fns_candy_style_transfer.h"
#include <iostream>
#include <assert.h>
#include "image_file.h"
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>

#ifdef _WIN32
#include <objbase.h>
#endif

#define ORT_ABORT_ON_ERROR(expr)                             \
  do {                                                       \
    OrtStatus* onnx_status = (expr);                         \
    if (onnx_status != NULL) {                               \
      const char* msg = g_ort->GetErrorMessage(onnx_status); \
      std::cerr<< msg << std::endl;                          \
      g_ort->ReleaseStatus(onnx_status);                     \
      abort();                                               \
    }                                                        \
  } while (0);

FnsCandyStyleTransfer::FnsCandyStyleTransfer()
    :g_ort(nullptr),
    model_path{ L"../models/candy.onnx" },
    execution_provider{ "cuda" },
    session(nullptr),
    session_options(nullptr),
    env(nullptr)
{
}

FnsCandyStyleTransfer::~FnsCandyStyleTransfer()
{
    g_ort->ReleaseSessionOptions(session_options);
    g_ort->ReleaseSession(session);
    g_ort->ReleaseEnv(env);
#ifdef _WIN32
    CoUninitialize();
#endif
}

bool FnsCandyStyleTransfer::EnableCuda() {
    // OrtCUDAProviderOptions is a C struct. C programming language doesn't have constructors/destructors.
    OrtCUDAProviderOptions o;
    // Here we use memset to initialize every field of the above data struct to zero.
    memset(&o, 0, sizeof(o));
    // But is zero a valid value for every variable? Not quite. It is not guaranteed. In the other words: does every enum
    // type contain zero? The following line can be omitted because EXHAUSTIVE is mapped to zero in onnxruntime_c_api.h.
    o.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive;
    o.gpu_mem_limit = SIZE_MAX;
    OrtStatus* onnx_status = g_ort->SessionOptionsAppendExecutionProvider_CUDA(session_options, &o);
    if (onnx_status != NULL) {
        const char* msg = g_ort->GetErrorMessage(onnx_status);
        fprintf(stderr, "%s\n", msg);
        g_ort->ReleaseStatus(onnx_status);
        return false;
    }
    return true;
}

bool FnsCandyStyleTransfer::RunInference(const char* input_file, const char* output_file) {
    size_t input_height;
    size_t input_width;
    float* model_input;
    size_t model_input_ele_count;

    // read the image data: float*
    bool status = read_image_file(input_file, &input_height, &input_width, &model_input, &model_input_ele_count);
    assert(status);
    assert(input_height == 720 && input_width == 720);

    OrtMemoryInfo* memory_info;
    ORT_ABORT_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info));
    const int64_t input_shape[] = { 1, 3, 720, 720 };  // 1 * 3 * 720 * 720
    const size_t input_shape_len = sizeof(input_shape) / sizeof(input_shape[0]);
    const size_t model_input_len = model_input_ele_count * sizeof(float);

    OrtValue* input_tensor = NULL;  // this is a tensor
    ORT_ABORT_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, model_input, model_input_len, input_shape,
        input_shape_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
        &input_tensor));
    assert(input_tensor != NULL);
    int is_tensor;
    ORT_ABORT_ON_ERROR(g_ort->IsTensor(input_tensor, &is_tensor));
    assert(is_tensor);
    g_ort->ReleaseMemoryInfo(memory_info);
    const char* input_names[] = { "inputImage" };
    const char* output_names[] = { "outputImage" };
    // inference
    OrtValue* output_tensor = NULL;
    ORT_ABORT_ON_ERROR(g_ort->Run(session, NULL, input_names, (const OrtValue* const*)&input_tensor, 1, output_names, 1,
        &output_tensor));
    assert(output_tensor != NULL);
    ORT_ABORT_ON_ERROR(g_ort->IsTensor(output_tensor, &is_tensor));
    assert(is_tensor);
    
    float* output_tensor_data = NULL;
    // extract output
    ORT_ABORT_ON_ERROR(g_ort->GetTensorMutableData(output_tensor, (void**)&output_tensor_data));
    uint8_t* output_image_data = NULL;
    chw_to_hwc(output_tensor_data, 720, 720, &output_image_data);

    bool ret = true;
    if (write_image_file(output_image_data, 720, 720, output_file) != 0) ret = false;
    g_ort->ReleaseValue(output_tensor);
    g_ort->ReleaseValue(input_tensor);
    free(model_input);
    return ret;
}

void FnsCandyStyleTransfer::VerifyInputOutputCount() {
    size_t count;
    ORT_ABORT_ON_ERROR(g_ort->SessionGetInputCount(session, &count));  // Input Count = 1
    assert(count == 1);
    ORT_ABORT_ON_ERROR(g_ort->SessionGetOutputCount(session, &count));  // Output Count = 1
    assert(count == 1);
}

bool FnsCandyStyleTransfer::Init()
{
    g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
    if (!g_ort)
    {
        std::cerr << "Failed to init ONNX Runtime engine.\n";
        return false;
    }

    ORT_ABORT_ON_ERROR(g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "FnsCandyStyleTransfer", &env));
    assert(env != nullptr);

    ORT_ABORT_ON_ERROR(g_ort->CreateSessionOptions(&session_options));

    if (execution_provider == "cuda")
    {
        std::cout << "Try to enable CUDA...\n";
        if (EnableCuda()) std::cerr << "CUDA is not available\n";
        else std::cout << "CUDA is enabled\n";
    }

    ORT_ABORT_ON_ERROR(g_ort->CreateSession(env, model_path, session_options, &session));

    VerifyInputOutputCount();

    return true;
}

void FnsCandyStyleTransfer::Run()
{
    std::cout << "------------------------------------------\n";
    std::cout << "Sample: FnsCandyStyleTransfer\n\n";
    std::cout << "ONNX Runtime Version: " << ORT_API_VERSION << std::endl;
    if (!Init()) return;

    // I don't have 720*720 image, so I use this function to resize image for the code running.
    const char* input_file = "../assets/180_resized.png";
    ResizeImage720("../assets/180.png", input_file);
    const char* output_file = "../output/FnsCandyStyleTransfer.png";

    bool success = RunInference(input_file, output_file);
    if (!success) std::cerr << "fail.\n";
}

/**
 * Use this function to generate 720 * 720 * 3 image.
 */
void FnsCandyStyleTransfer::ResizeImage720(const char* rawPath, const char* resizedPath)
{
    cv::Mat img = cv::imread(rawPath);
    cv::Mat imgResize;
    cv::resize(img, imgResize, cv::Size(720, 720));
    cv::imwrite(resizedPath, imgResize);
}

References


Leave a Reply

Your email address will not be published. Required fields are marked *

css.php