Examining the Arm NN code
In this section of the guide, we examine the demonstration code and highlight some of the key functional points for integrating with the Arm NN C++ API.
vision_detector_node.cpp contains the entry point to the ROS node. The key tasks performed by this code include the following:
-
Initialize and register the ROS node:
ros::init(argc, argv, "vision_detector");
-
Create an Arm NN runtime object:
// Create Arm NN Runtime armnn::IRuntime::CreationOptions options; options.m_EnableGpuProfiling = false; armnn::IRuntimePtr runtime = armnn::IRuntime::Create(options);
-
Specify the list of backends that can be used to optimize the network:
// Enumerate Compute Device backends std::vector<armnn::BackendId> computeDevices; computeDevices.push_back(armnn::Compute::GpuAcc); computeDevices.push_back(armnn::Compute::CpuAcc); computeDevices.push_back(armnn::Compute::CpuRef); detector_armnn::Yolo2TinyDetector<DataType> yolo(runtime); yolo.load_network(pretrained_model_file, computeDevices);
armnn_yolo2tiny.hpp contains
the definition of Yolo2TinyDetector
. In Yolo2TinyDetector
we call Arm NN to create a parser object that loads the network
file. Arm NN has parsers for several model file types, including TF, TFLite,
ONNX, and Caffe. Parsers create the underlying Arm NN graph, so you do not need
to construct your model graph by hand.
The following example creates a TensorFlow parser to load our TensorFlow protobuf file from the specified path:
// Setup Arm NN Network // Parse pre-trained model from TensorFlow protobuf format using ParserType = armnnTfParser::ITfParser; auto parser(ParserType::Create()); armnn::INetworkPtr network{nullptr, [](armnn::INetwork *) {}}; network = parser->CreateNetworkFromBinaryFile(model_path.c_str(), inputShapes,requestedOutputs);
This network is then optimized and loaded into the Arm NN runtime as follows:
// Set optimisation options armnn::OptimizerOptions options; options.m_ReduceFp32ToFp16 = false; // Optimize network armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork *){}}; optNet = armnn::Optimize(*network, compute_devices, runtime->GetDeviceSpec(), options); if (!optNet) { throw armnn::Exception("armnn::Optimize failed"); } // Load network into runtime armnn::Status ret = this->runtime->LoadNetwork(this->networkID, std::move(optNet)); if (ret == armnn::Status::Failure) { throw armnn::Exception("IRuntime::LoadNetwork failed"); }
Every time an image is published to the ROS topic
/image_raw
, the callback function DetectorNode<T>::callback_image
is invoked. To run inference, the callback function calls Yolo2TinyDetector<T>::run_inference
, which in turn calls Arm NN to execute the inference.
The Arm NN parser extracts the input and output information for the network. We obtain the input and output tensors, then retrieve the binding information. This binding information contains all the essential information about the layer. The binding information is a tuple containing integer identifiers for bindable input and output layers and the tensor information. Tensor information consists of data type, quantization information, number of dimensions, and total number of elements.
The EnqueueWorkload()
function of the
runtime context executes the inference for the network loaded as follows:
// Allocate output container size_t output_size = this->output_tensor_shape.GetNumElements(); std::vector<T> output(output_size); // Create input and output tensors and their bindings armnn::InputTensors inputTensors{ {0, armnn::ConstTensor(this->runtime->GetInputTensorInfo(this->networkID, 0), input_tensor.data())}}; armnn::OutputTensors outputTensors{ {0, armnn::ConstTensor(this->runtime->GetOutputTensorInfo(this->networkID, 0), output.data())}}; // Run inference this->runtime->EnqueueWorkload(this->networkID,inputTensors,outputTensors);
The output of the inference is decoded in the
Yolo2TinyDetector<T>::process_output
function. A score threshold is applied, and the
non_maximum_suppression
algorithm is applied to remove spurious detections. The final
detection is output in autoware_msgs::DetectedObjectArray
format.