Commit fd241b6c authored by oscar's avatar oscar

修改读取nuScenes数据集逻辑

parent 2ae8dad0
...@@ -16,6 +16,9 @@ include_directories(/usr/include/eigen3) ...@@ -16,6 +16,9 @@ include_directories(/usr/include/eigen3)
# include_directories(./src/Endecryption) # include_directories(./src/Endecryption)
add_executable(test main.cpp) add_executable(test main.cpp)
add_executable(main1 main1.cpp)
add_executable(main2 main2.cpp)
add_executable(main3 main3.cpp)
add_executable(generate generate.cpp) add_executable(generate generate.cpp)
add_executable(test1 test.cpp) add_executable(test1 test.cpp)
......
...@@ -84,16 +84,48 @@ float calculate_iou_3d(const BoundingBox& bbox1, const BoundingBox& bbox2) { ...@@ -84,16 +84,48 @@ float calculate_iou_3d(const BoundingBox& bbox1, const BoundingBox& bbox2) {
map<int, map<int, map<int, float>>> calculate_match_score_3d( map<int, map<int, map<int, float>>> calculate_match_score_3d(
const vector<Object>& tracking_data, const vector<Object>& gt_data, float iou_threshold) { const vector<Object>& tracking_data, const vector<Object>& gt_data, float iou_threshold) {
map<int, map<int, map<int, float>>> match_score; map<int, map<int, map<int, float>>> match_score;
for (const auto& obj1 : tracking_data) { // for (const auto& obj1 : tracking_data) {
for (const auto& obj2 : gt_data) { // for (const auto& obj2 : gt_data) {
if (obj1.frame_id == obj2.frame_id) { // if (obj1.frame_id == obj2.frame_id) {
float iou = calculate_iou_3d(obj1.bbox, obj2.bbox); // float iou = calculate_iou_3d(obj1.bbox, obj2.bbox);
if (iou > iou_threshold) { // if (iou > iou_threshold) {
match_score[obj1.frame_id][obj1.object_id][obj2.object_id] = iou; // match_score[obj1.frame_id][obj1.object_id][obj2.object_id] = iou;
// }
// }
// }
// }
for (const auto& obj_gt : gt_data)
{
for (const auto& obj_tra : tracking_data)
{
if (obj_gt.frame_id == obj_tra.frame_id)
{
float iou = calculate_iou_3d(obj_gt.bbox, obj_tra.bbox);
if (iou > iou_threshold)
{
match_score[obj_gt.frame_id][obj_tra.object_id][obj_gt.object_id] = iou;
} }
} }
} }
} }
//获取帧数
// std::map<int,int> frameMap;
// std::vector<int> frameVet;
// for(auto& obj_gt : gt_data)
// {
// if(frameMap.find(obj_gt.frame_id) == frameMap.end())
// {
// frameVet.push_back(obj_gt.frame_id);
// frameMap[obj_gt.frame_id] = frameVet.size() - 1;
// }
// }
// for(auto frameId : frameVet)
// {
// }
return match_score; return match_score;
} }
......
#include <iostream>
#include <vector>
#include <fstream>
#include <sstream>
struct Object {
int id;
int frameNumber;
float x;
float y;
float z;
float length;
float width;
float height;
float orientation;
int type;
std::string typeName;
};
float calculateMOTA(const std::vector<Object>& trackerOutput, const std::vector<Object>& groundTruth) {
int numMissed = 0;
int numFalsePositives = 0;
int numMismatch = 0;
int numTotalObjects = groundTruth.size();
for (const auto& groundTruthObject : groundTruth) {
bool matched = false;
for (const auto& trackerObject : trackerOutput) {
if (groundTruthObject.id == trackerObject.id &&
groundTruthObject.frameNumber == trackerObject.frameNumber) {
matched = true;
break;
}
}
if (!matched) {
numMissed++;
}
}
for (const auto& trackerObject : trackerOutput) {
bool matched = false;
for (const auto& groundTruthObject : groundTruth) {
if (groundTruthObject.id == trackerObject.id &&
groundTruthObject.frameNumber == trackerObject.frameNumber) {
matched = true;
break;
}
}
if (!matched) {
numFalsePositives++;
}
}
for (const auto& trackerObject : trackerOutput) {
bool matched = false;
for (const auto& groundTruthObject : groundTruth) {
if (groundTruthObject.id == trackerObject.id &&
groundTruthObject.frameNumber == trackerObject.frameNumber) {
matched = true;
break;
}
}
if (matched && trackerObject.type != groundTruth.at(trackerObject.id).type) {
numMismatch++;
}
}
float mota = 1.0 - (numMissed + numFalsePositives + numMismatch) / static_cast<float>(numTotalObjects);
return mota;
}
// 读取CSV格式的文件
std::vector<Object> read_csv_file(const std::string& filename) {
std::vector<Object> data;
std::ifstream file(filename);
std::string line;
while (getline(file, line)) {
if (line.empty()) {
continue;
}
std::stringstream ss(line);
std::string field;
std::vector<std::string> fields;
while (getline(ss, field, ',')) {
fields.push_back(field);
}
if(fields[0] == "id")
continue;
Object obj;
obj.id = stoi(fields[0]);
obj.frameNumber = stoi(fields[1]);
obj.x = stof(fields[2]);
obj.y = stof(fields[3]);
obj.z = stof(fields[4]);
obj.length = stof(fields[5]);
obj.width = stof(fields[6]);
obj.height = stof(fields[7]);
obj.orientation = stof(fields[8]);
// obj.confidence = stof(fields[9]);
obj.type = stoi(fields[10]);
obj.typeName = fields[11];
data.push_back(obj);
}
return data;
}
int main(int argc, char** argv) {
// // 假设有以下示例跟踪结果和真实标注数据
// std::vector<Object> trackerOutput = {
// // 跟踪结果示例
// { 1, 1, 10.0f, 5.0f, 0.0f, 1.0f, 0.5f, 1.5f, 0.0f, 1, "Car" },
// { 2, 1, 4.0f, 6.0f, 0.0f, 0.8f, 0.6f, 1.2f, 0.0f, 1, "Car" },
// { 3, 1, 8.0f, 3.0f, 0.0f, 1.2f, 0.4f, 1.0f, 0.0f, 1, "Car" },
// // ...
// };
// std::vector<Object> groundTruth = {
// // 真实标注示例
// { 1, 1, 10.0f, 5.0f, 0.0f, 1.0f, 0.5f, 1.5f, 0.0f, 1, "Car" },
// { 2, 1, 4.0f, 6.0f, 0.0f, 0.8f, 0.6f, 1.2f, 0.0f, 1, "Car" },
// { 3, 1, 8.0f, 3.0f, 0.0f, 1.2f, 0.4f, 1.0f, 0.0f, 1, "Car" },
// // ...
// };
std::string tracking_csv = "tracking_data.csv";
std::string gt_csv = "gt_data.csv";
if( argc == 3 )
{
tracking_csv = argv[1];
gt_csv = argv[2];
}
// 读取跟踪结果和真值
std::vector<Object> tracking_data = read_csv_file(tracking_csv);
std::vector<Object> gt_data = read_csv_file(gt_csv);
float mota = calculateMOTA(tracking_data, gt_data);
std::cout << "MOTA: " << mota << std::endl;
return 0;
}
\ No newline at end of file
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <cmath>
#include <limits>
struct ObjectData {
int id;
int frame;
double x;
double y;
double z;
double length;
double width;
double height;
double orientation;
double confidence;
int type;
std::string typeName;
};
struct MatchedPair {
int gtIndex; // index in groundTruthData vector
int detIndex; // index in detectionData vector
};
double calculateDistance(double x1, double y1, double x2, double y2) {
double dx = x1 - x2;
double dy = y1 - y2;
return std::sqrt(dx * dx + dy * dy);
}
double calculateMOTA(const std::vector<ObjectData>& groundTruthData, const std::vector<ObjectData>& detectionData) {
int GT = groundTruthData.size();
int MT = 0; // number of matches
int PT = 0; // number of false positives
int ML = 0; // number of missed detections
// Create a 2D matrix to store distance between ground truth and detection objects
std::vector<std::vector<double>> costMatrix(GT, std::vector<double>(detectionData.size(), 0.0));
// Assign distances to the cost matrix
for (int i = 0; i < GT; i++) {
for (int j = 0; j < detectionData.size(); j++) {
double distance = calculateDistance(groundTruthData[i].x, groundTruthData[i].y, detectionData[j].x, detectionData[j].y);
costMatrix[i][j] = distance;
}
}
// Greedy matching algorithm
std::vector<MatchedPair> matchedPairs;
bool done[detectionData.size()] = { false };
for (int i = 0; i < GT; i++) {
double minDistance = std::numeric_limits<double>::max();
int minIndex = -1;
for (int j = 0; j < detectionData.size(); j++) {
if (!done[j] && costMatrix[i][j] < minDistance) {
minDistance = costMatrix[i][j];
minIndex = j;
}
}
if (minIndex != -1) {
MatchedPair matchedPair;
matchedPair.gtIndex = i;
matchedPair.detIndex = minIndex;
matchedPairs.push_back(matchedPair);
done[minIndex] = true;
}
}
MT = matchedPairs.size();
PT = detectionData.size() - MT;
ML = GT - MT;
double MOTA = 1.0 - (static_cast<double>(ML + PT) / static_cast<double>(GT));
return MOTA;
}
// 读取CSV格式的文件
std::vector<ObjectData> read_csv_file(const std::string& filename) {
std::vector<ObjectData> data;
std::ifstream file(filename);
std::string line;
while (getline(file, line)) {
if (line.empty()) {
continue;
}
std::stringstream ss(line);
std::string field;
std::vector<std::string> fields;
while (getline(ss, field, ',')) {
fields.push_back(field);
}
if(fields[0] == "id")
continue;
ObjectData obj;
obj.id = stoi(fields[0]);
obj.frame = stoi(fields[1]);
obj.x = stof(fields[2]);
obj.y = stof(fields[3]);
obj.z = stof(fields[4]);
obj.length = stof(fields[5]);
obj.width = stof(fields[6]);
obj.height = stof(fields[7]);
obj.orientation = stof(fields[8]);
obj.confidence = stof(fields[9]);
obj.type = stoi(fields[10]);
obj.typeName = fields[11];
data.push_back(obj);
}
return data;
}
int main(int argc, char** argv) {
std::string tracking_csv = "tracking_data.csv";
std::string gt_csv = "gt_data.csv";
if( argc == 3 )
{
tracking_csv = argv[1];
gt_csv = argv[2];
}
// 读取跟踪结果和真值
std::vector<ObjectData> tracking_data = read_csv_file(tracking_csv);
std::vector<ObjectData> gt_data = read_csv_file(gt_csv);
float mota = calculateMOTA(tracking_data, gt_data);
std::cout << "MOTA: " << mota << std::endl;
return 0;
}
\ No newline at end of file
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#include <cmath>
using namespace std;
struct BoundingBox {
float x;
float y;
float z;
float length;
float width;
float height;
float orientation;
};
struct Object {
int object_id;
int frame_id;
BoundingBox bbox;
float confidence;
int type;
string class_name;
};
// 读取CSV格式的文件
vector<Object> read_csv_file(const string& filename) {
vector<Object> data;
ifstream file(filename);
string line;
while (getline(file, line)) {
if (line.empty()) {
continue;
}
stringstream ss(line);
string field;
vector<string> fields;
while (getline(ss, field, ',')) {
fields.push_back(field);
}
if(fields[0] == "id")
continue;
Object obj;
obj.object_id = stoi(fields[0]);
obj.frame_id = stoi(fields[1]);
obj.bbox.x = stof(fields[2]);
obj.bbox.y = stof(fields[3]);
obj.bbox.z = stof(fields[4]);
obj.bbox.length = stof(fields[5]);
obj.bbox.width = stof(fields[6]);
obj.bbox.height = stof(fields[7]);
obj.bbox.orientation = stof(fields[8]);
obj.confidence = stof(fields[9]);
obj.type = stoi(fields[10]);
obj.class_name = fields[11];
data.push_back(obj);
}
return data;
}
// 计算IoU
float calculate_iou_3d(const BoundingBox& bbox1, const BoundingBox& bbox2) {
float dx = abs(bbox1.x - bbox2.x);
float dy = abs(bbox1.y - bbox2.y);
float dz = abs(bbox1.z - bbox2.z);
float delta_l = abs(bbox1.length - bbox2.length);
float delta_w = abs(bbox1.width - bbox2.width);
float delta_h = abs(bbox1.height - bbox2.height);
float delta_o = abs(bbox1.orientation - bbox2.orientation);
float delta_xyz = sqrt(dx * dx + dy * dy + dz * dz);
float delta_dim = sqrt(delta_l * delta_l + delta_w * delta_w + delta_h * delta_h);
float delta_orien = delta_o;
float iou_xyz = 1.0f - delta_xyz / (bbox1.length + bbox1.width + bbox1.height + bbox2.length + bbox2.width + bbox2.height);
float iou_dim = 1.0f - delta_dim / (bbox1.length + bbox1.width + bbox1.height + bbox2.length + bbox2.width + bbox2.height);
float iou_orien = 1.0f - delta_orien / M_PI;
float iou = iou_xyz * iou_dim * iou_orien;
return iou;
}
// 计算匹配分数
map<int, map<int, map<int, float>>> calculate_match_score_3d(
const vector<Object>& tracking_data, const vector<Object>& gt_data, float iou_threshold) {
map<int, map<int, map<int, float>>> match_score;
// for (const auto& obj1 : tracking_data) {
// for (const auto& obj2 : gt_data) {
// if (obj1.frame_id == obj2.frame_id) {
// float iou = calculate_iou_3d(obj1.bbox, obj2.bbox);
// if (iou > iou_threshold) {
// match_score[obj1.frame_id][obj1.object_id][obj2.object_id] = iou;
// }
// }
// }
// }
for (const auto& obj_gt : gt_data)
{
for (const auto& obj_tra : tracking_data)
{
if (obj_gt.frame_id == obj_tra.frame_id)
{
float iou = calculate_iou_3d(obj_gt.bbox, obj_tra.bbox);
if (iou > iou_threshold)
{
match_score[obj_gt.frame_id][obj_tra.object_id][obj_gt.object_id] = iou;
}
}
}
}
//获取帧数
// std::map<int,int> frameMap;
// std::vector<int> frameVet;
// for(auto& obj_gt : gt_data)
// {
// if(frameMap.find(obj_gt.frame_id) == frameMap.end())
// {
// frameVet.push_back(obj_gt.frame_id);
// frameMap[obj_gt.frame_id] = frameVet.size() - 1;
// }
// }
// for(auto frameId : frameVet)
// {
// }
return match_score;
}
// 确定真实值和跟踪结果之间的匹配关系
map<int, map<int, int>> get_matches(const map<int, map<int, map<int, float>>>& match_score) {
map<int, map<int, int>> matches;
for (const auto& frame_pair : match_score) {
matches[frame_pair.first];
map<int, int> used_objects;
map<int, int> used_gt_objects;
for (const auto& object_pair : frame_pair.second) {
int object_id = object_pair.first;
int max_gt_object_id = -1;
float max_iou = 0;
for (const auto& gt_object_pair : object_pair.second) {
int gt_object_id = gt_object_pair.first;
float iou = gt_object_pair.second;
if (iou > max_iou && used_gt_objects.count(gt_object_id) == 0) {
max_iou = iou;
max_gt_object_id = gt_object_id;
}
}
if (max_gt_object_id != -1) {
matches[frame_pair.first][object_id] = max_gt_object_id;
used_objects[object_id] = 1;
used_gt_objects[max_gt_object_id] = 1;
}
}
for (const auto& object_pair : frame_pair.second) {
int object_id = object_pair.first;
if (used_objects.count(object_id) == 0) {
matches[frame_pair.first][object_id] = -1;
}
}
for (const auto& gt_object_pair : frame_pair.second.begin()->second) {
int gt_object_id = gt_object_pair.first;
if (used_gt_objects.count(gt_object_id) == 0) {
matches[frame_pair.first][-1] = gt_object_id;
}
}
}
return matches;
}
// 计算MOTA指标
float calculate_mota_3d(const vector<Object>& tracking_data, const vector<Object>& gt_data, float iou_threshold) {
auto match_score = calculate_match_score_3d(tracking_data, gt_data, iou_threshold);
auto matches = get_matches(match_score);
int num_misses = 0;
int num_false_positives = 0;
int num_switches = 0;
int num_objects = 0;
for (const auto& frame_pair : matches) {
for (const auto& object_pair : frame_pair.second) {
int object_id = object_pair.first;
int gt_object_id = object_pair.second;
if (object_id == -1) {
num_misses++;
} else if (gt_object_id == -1) {
num_false_positives++;
} else {
num_objects++;
if (match_score[frame_pair.first][object_id][gt_object_id] < iou_threshold) {
num_switches++;
}
}
}
}
//获取帧数
std::map<int,int> frameMap;
std::vector<int> frameVet;
for(auto& obj_gt : gt_data)
{
if(frameMap.count(obj_gt.frame_id) == 0)
{
frameVet.push_back(obj_gt.frame_id);
frameMap[obj_gt.frame_id] = frameVet.size() - 1;
}
}
//检测出来的物体,但是没有匹配到任何真值的,要计算到误检测num_false_positives
for(auto& obj_tracking : tracking_data)
{
if(matches.count(obj_tracking.frame_id) == 0)//检测到的帧没有
{
num_false_positives++;
}
else if(matches[obj_tracking.frame_id].count(obj_tracking.object_id) == 0)
{
num_false_positives++;
}
}
//真值数据,但没有匹配到任何检测物体的,要计算到漏检测里num_misses
for(auto& obj_gt : gt_data)
{
if(matches.count(obj_gt.frame_id) == 0)//检测到的帧没有
{
num_misses++;
}
else
{
int isHas = 0;
for(auto& iter : matches[obj_gt.frame_id])
{
if(iter.second == obj_gt.object_id)
{
isHas = 1;
break;
}
}
if(isHas == 0)
{
num_misses++;
}
}
}
printf("num_misses = %d, num_false_positives = %d, num_switches = %d, num_objects = %d\n",num_misses,num_false_positives,num_switches,num_objects);
float mota = 1.0f - (float)(num_misses + num_false_positives + num_switches) / num_objects;
return mota;
}
int main(int argc, char** argv)
{
std::string tracking_csv = "tracking_data.csv";
std::string gt_csv = "gt_data.csv";
if( argc == 3 )
{
tracking_csv = argv[1];
gt_csv = argv[2];
}
// 读取跟踪结果和真值
vector<Object> tracking_data = read_csv_file(tracking_csv);
vector<Object> gt_data = read_csv_file(gt_csv);
// 计算MOTA指标
float iou_threshold = 0.1;
float mota = calculate_mota_3d(tracking_data, gt_data, iou_threshold);
cout << "MOTA = " << mota << endl;
return 0;
}
\ No newline at end of file
./main3 /media/sf_shared/nuScenes/output/scene-0061-tracking.csv /media/sf_shared/nuScenes/output/scene-0061-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0103-tracking.csv /media/sf_shared/nuScenes/output/scene-0103-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0553-tracking.csv /media/sf_shared/nuScenes/output/scene-0553-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0655-tracking.csv /media/sf_shared/nuScenes/output/scene-0655-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0757-tracking.csv /media/sf_shared/nuScenes/output/scene-0757-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0796-tracking.csv /media/sf_shared/nuScenes/output/scene-0796-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-0916-tracking.csv /media/sf_shared/nuScenes/output/scene-0916-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-1077-tracking.csv /media/sf_shared/nuScenes/output/scene-1077-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-1094-tracking.csv /media/sf_shared/nuScenes/output/scene-1094-gt.csv
./main3 /media/sf_shared/nuScenes/output/scene-1100-tracking.csv /media/sf_shared/nuScenes/output/scene-1100-gt.csv
\ No newline at end of file
This diff is collapsed.
#ifndef _NUSCENES_ROS_MSG_HPP_
#define _NUSCENES_ROS_MSG_HPP_
#include "ros/ros.h"
#include <jsk_recognition_msgs/BoundingBoxArray.h>
#include <visualization_msgs/MarkerArray.h>
#include "jfx_common_msgs/det_tracking_array.h"
#include "jfx_common_msgs/localization.h"
#include <sensor_msgs/PointCloud2.h>
#include "nlohmann_json/json.hpp"
class NuScenesRosMsg
{
public:
NuScenesRosMsg() = default;
~NuScenesRosMsg() = default;
int SaveGT(ros::NodeHandle& nh);
int Play(ros::NodeHandle& nh);
ros::Publisher m_pubFusionRes;//发送融合后的结果
// ros::Publisher m_pubLocalization;//发送定位结果
ros::Publisher m_pubBoundingBoxes;//发送3D框信息
ros::Publisher m_pubMarkerArray;//发送marker框信息
ros::Publisher m_pubMarkerArrow;//发送marker框信息
};
#endif
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment