Skip to content

Commit

Permalink
Merge pull request cms-sw#2 from waredjeb/GNN_Linking_ticlGraph
Browse files Browse the repository at this point in the history
TICLGraph for GNN Linking
  • Loading branch information
jejarosl authored Oct 31, 2022
2 parents 648b103 + d7d4ddd commit 208aaaa
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 31 deletions.
40 changes: 40 additions & 0 deletions DataFormats/HGCalReco/interface/TICLGraph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef DataFormats_HGCalReco_TICLGraph_h
#define DataFormats_HGCalReco_TICLGraph_h

#include "DataFormats/HGCalReco/interface/Trackster.h"
#include "DataFormats/TrackReco/interface/Track.h"

class Node {
public:
Node() = default;
Node(unsigned index, bool isTrackster = true) : index_(index), isTrackster_(isTrackster){};

void addInner(unsigned int trackster_id) { innerNodes_.push_back(trackster_id); }
void addOuter(unsigned int trackster_id) { outerNodes_.push_back(trackster_id); }

const unsigned int getId() const { return index_; }
std::vector<unsigned int> getInner() const { return innerNodes_; }
std::vector<unsigned int> getOuter() const { return outerNodes_; }

~Node() = default;

private:
unsigned index_;
bool isTrackster_;
std::vector<unsigned int> innerNodes_;
std::vector<unsigned int> outerNodes_;
};

class TICLGraph {
public:
TICLGraph() = default;
TICLGraph(std::vector<Node> &n) { nodes_ = n; };
const std::vector<Node> &getNodes() const { return nodes_; }
const Node &getNode(int i) const { return nodes_[i]; }
~TICLGraph() = default;

private:
std::vector<Node> nodes_;
};

#endif
1 change: 1 addition & 0 deletions DataFormats/HGCalReco/src/classes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
#include "DataFormats/HGCalReco/interface/TICLSeedingRegion.h"
#include "DataFormats/HGCalReco/interface/TICLCandidate.h"
#include "DataFormats/Common/interface/Wrapper.h"
#include "DataFormats/HGCalReco/interface/TICLGraph.h"
12 changes: 12 additions & 0 deletions DataFormats/HGCalReco/src/classes_def.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,16 @@
<class name="std::vector<TICLCandidate>" />
<class name="edm::Wrapper<TICLCandidate>" />
<class name="edm::Wrapper<std::vector<TICLCandidate> >" />

<class name="Node">
</class>
<class name="std::vector<Node>" />
<class name="edm::Wrapper<Node>" />
<class name="edm::Wrapper<std::vector<Node> >" />

<class name="TICLGraph">
</class>
<class name="std::vector<TICLGraph>" />
<class name="edm::Wrapper<TICLGraph>" />
<class name="edm::Wrapper<std::vector<TICLGraph> >" />
</lcgdict>
3 changes: 2 additions & 1 deletion RecoHGCal/Configuration/python/RecoHGCal_EventContent_cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
'keep *_ticlTrackstersHFNoseMIP_*_*',
'keep *_ticlTrackstersHFNoseHAD_*_*',
'keep *_ticlTrackstersHFNoseMerge_*_*',] +
['keep *_pfTICL_*_*']
['keep *_pfTICL_*_*'] +
['keep *_ticlGraph_*_*']
)
)
TICL_RECO.outputCommands.extend(TICL_AOD.outputCommands)
Expand Down
68 changes: 38 additions & 30 deletions RecoHGCal/TICL/plugins/LinkingAlgoByGNN.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ void Graph::addEdge(int v, int w) {
void Graph::DFSUtil(int v) {
// Mark the current node as visited and print it
visited[v] = true;
std::cout << v << " ";
connected_components.back().push_back(v);
std::cout << v << std::endl;

// Recur for all the vertices adjacent to this vertex
list<int>::iterator i;
for (auto i = adj[v].begin(); i != adj[v].end(); ++i)
if (!visited[*i]) {
connected_components.emplace_back(*i);
connected_components.back().push_back(*i);
std::cout << "Pushed back " << *i << std::endl;
DFSUtil(*i);
}
}
Expand All @@ -59,7 +59,9 @@ void Graph::DFS() {
// traversal starting from all vertices one by one
for (auto i : adj)
if (visited[i.first] == false) {
//connected_components.emplace_back(i.first);
std::cout << "Emplaced back: " << i.first << std::endl;
connected_components.emplace_back(1, i.first); // {i.first}
std::cout << "Starting DFS from node: " << i.first << std::endl;
DFSUtil(i.first);
}
}
Expand Down Expand Up @@ -174,40 +176,51 @@ void LinkingAlgoByGNN::linkTracksters(const edm::Handle<std::vector<reco::Track>
}

input_shapes.push_back({1, N, shapeFeatures});
input_shapes.push_back({1, 2, 3 * N});

data.emplace_back(features);

// Creating Edges: uncomment when have a Graph as an input
/*
std::vector<float_t> edges_src;
std::vector<float_t> edges_dst;
for (int i = 0; i < N; i++){
for (auto & i_neighbour : graph.node_linked_inners[i]){
// Create an edge between the tracksters
edges_src.push_back(i_neighbour);
edges_dst.push_back(i);
}
}
*/

//std::vector<float_t> edges_src;
//std::vector<float_t> edges_dst;
//for (int i = 0; i < N; i++){
// for (auto & i_neighbour : graph.node_linked_inners[i]){
// // Create an edge between the tracksters
// edges_src.push_back(i_neighbour);
// edges_dst.push_back(i);
// }
//}

// Create fully connected graph for testing
std::vector<float> edges_src;
std::vector<float> edges_dst;

for (int i = 0; i < N; i++) {
for (int j = i; i < N; j++) {
edges_src.push_back(i);
edges_dst.push_back(j);
std::cout << "i: " << i << std::endl;
for (int j = i; j < N; j++) {
std::cout << "j: " << j << std::endl;
edges_src.push_back(static_cast<float>(i));
edges_dst.push_back(static_cast<float>(j));
}
}

long unsigned int numEdges = edges_src.size();
input_shapes.push_back({1, 2, static_cast<int>(numEdges)});
std::cout << "Num edges: " << numEdges << std::endl;

data.emplace_back(edges_src);
for (auto &dst : edges_dst) {
data.back().push_back(dst);
}

std::vector<float> edge_predictions = cache->run(input_names, data, input_shapes)[0];

std::cout << "Network output shape is " << edge_predictions.size() << std::endl;

for (long unsigned int i = 0; i < edge_predictions.size(); i++) {
std::cout << "Network output for edge " << data[1][i] << "-" << data[1][numEdges + i]
<< " is: " << edge_predictions[i] << std::endl;
}

// Create a graph
Graph g;
const auto classification_threshold = 0.7;
Expand All @@ -223,8 +236,10 @@ void LinkingAlgoByGNN::linkTracksters(const edm::Handle<std::vector<reco::Track>
}
}

std::cout << "Following is Depth First Traversal\n";
std::cout << "Connected components are:\n";
std::cout << "HERE 8" << std::endl;

std::cout << "Following is Depth First Traversal" << std::endl;
std::cout << "Connected components are: " << std::endl;
g.DFS();

int i = 0;
Expand All @@ -235,18 +250,11 @@ void LinkingAlgoByGNN::linkTracksters(const edm::Handle<std::vector<reco::Track>
for (auto &trackster_id : component) {
std::cout << "Component " << i << ": trackster id " << trackster_id << std::endl;
tracksterCandidate.addTrackster(edm::Ptr<Trackster>(tsH, trackster_id));
i++;
}
i++;
connectedCandidates.push_back(tracksterCandidate);
}

std::cout << "Network output shape is " << edge_predictions.size() << std::endl;

for (long unsigned int i = 0; i < edge_predictions.size(); i++) {
std::cout << "Network output for edge " << data[1][i] << "-" << data[1][numEdges + i]
<< " is: " << edge_predictions[i] << std::endl;
}

// The final candidates are passed to `resultLinked`
resultLinked.insert(std::end(resultLinked), std::begin(connectedCandidates), std::end(connectedCandidates));

Expand Down
189 changes: 189 additions & 0 deletions RecoHGCal/TICL/plugins/TICLGraphProducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#include <memory>

#include "FWCore/Framework/interface/stream/EDProducer.h"
#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "FWCore/Utilities/interface/ESGetToken.h"
#include "FWCore/Framework/interface/ESHandle.h"
#include "FWCore/Framework/interface/Frameworkfwd.h"
#include "FWCore/Framework/interface/MakerMacros.h"
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"

#include "DataFormats/HGCalReco/interface/TICLGraph.h"
#include "DataFormats/HGCalReco/interface/Trackster.h"
#include "DataFormats/HGCalReco/interface/TICLLayerTile.h"

#include "DataFormats/TrackReco/interface/Track.h"

#include "TrackingTools/TrajectoryState/interface/TrajectoryStateTransform.h"
#include "TrackingTools/GeomPropagators/interface/Propagator.h"
#include "TrackingTools/Records/interface/TrackingComponentsRecord.h"

#include "RecoLocalCalo/HGCalRecAlgos/interface/RecHitTools.h"
#include "CommonTools/Utils/interface/StringCutObjectSelector.h"

#include "MagneticField/Engine/interface/MagneticField.h"
#include "MagneticField/Records/interface/IdealMagneticFieldRecord.h"

#include "Geometry/HGCalCommonData/interface/HGCalDDDConstants.h"
#include "Geometry/Records/interface/IdealGeometryRecord.h"
#include "Geometry/CaloGeometry/interface/CaloGeometry.h"
#include "Geometry/Records/interface/CaloGeometryRecord.h"
#include "Geometry/CommonDetUnit/interface/GeomDet.h"

using namespace ticl;

class TICLGraphProducer : public edm::stream::EDProducer<> {
public:
explicit TICLGraphProducer(const edm::ParameterSet &ps);
~TICLGraphProducer() override{};
void produce(edm::Event &, const edm::EventSetup &) override;
static void fillDescriptions(edm::ConfigurationDescriptions &descriptions);

void beginJob();
void endJob();

void beginRun(edm::Run const &iEvent, edm::EventSetup const &es) override;

private:
typedef math::XYZVector Vector;
const edm::EDGetTokenT<std::vector<Trackster>> tracksters_clue3d_token_;
const edm::EDGetTokenT<std::vector<reco::Track>> tracks_token_;
const StringCutObjectSelector<reco::Track> cutTk_;
const edm::ESGetToken<CaloGeometry, CaloGeometryRecord> geometry_token_;
const std::string detector_;
const std::string propName_;
const edm::ESGetToken<MagneticField, IdealMagneticFieldRecord> bfield_token_;
const edm::ESGetToken<Propagator, TrackingComponentsRecord> propagator_token_;

const HGCalDDDConstants *hgcons_;
hgcal::RecHitTools rhtools_;
edm::ESGetToken<HGCalDDDConstants, IdealGeometryRecord> hdc_token_;
};

TICLGraphProducer::TICLGraphProducer(const edm::ParameterSet &ps)
: tracksters_clue3d_token_(consumes<std::vector<Trackster>>(ps.getParameter<edm::InputTag>("trackstersclue3d"))),
tracks_token_(consumes<std::vector<reco::Track>>(ps.getParameter<edm::InputTag>("tracks"))),
cutTk_(ps.getParameter<std::string>("cutTk")),
geometry_token_(esConsumes<CaloGeometry, CaloGeometryRecord, edm::Transition::BeginRun>()),
detector_(ps.getParameter<std::string>("detector")),
propName_(ps.getParameter<std::string>("propagator")),
bfield_token_(esConsumes<MagneticField, IdealMagneticFieldRecord, edm::Transition::BeginRun>()),
propagator_token_(
esConsumes<Propagator, TrackingComponentsRecord, edm::Transition::BeginRun>(edm::ESInputTag("", propName_))) {
produces<TICLGraph>();
std::string detectorName_ = (detector_ == "HFNose") ? "HGCalHFNoseSensitive" : "HGCalEESensitive";
hdc_token_ =
esConsumes<HGCalDDDConstants, IdealGeometryRecord, edm::Transition::BeginRun>(edm::ESInputTag("", detectorName_));
}

void TICLGraphProducer::beginJob() {}

void TICLGraphProducer::endJob(){};

void TICLGraphProducer::beginRun(edm::Run const &iEvent, edm::EventSetup const &es) {
edm::ESHandle<HGCalDDDConstants> hdc = es.getHandle(hdc_token_);
hgcons_ = hdc.product();

edm::ESHandle<CaloGeometry> geom = es.getHandle(geometry_token_);
rhtools_.setGeometry(*geom);

edm::ESHandle<MagneticField> bfield = es.getHandle(bfield_token_);
edm::ESHandle<Propagator> propagator = es.getHandle(propagator_token_);
};

void TICLGraphProducer::produce(edm::Event &evt, const edm::EventSetup &es) {
edm::Handle<std::vector<Trackster>> trackstersclue3d_h;
evt.getByToken(tracksters_clue3d_token_, trackstersclue3d_h);
auto trackstersclue3d = *trackstersclue3d_h;

//std::vector<Trackster> trackstersclue3d_sorted(trackstersclue3d);
//std::sort(trackstersclue3d_sorted.begin(), trackstersclue3d_sorted.end(), [](Trackster& t1, Trackster& t2){return t1.barycenter().z() < t2.barycenter().z();});

TICLLayerTile tracksterTilePos;
TICLLayerTile tracksterTileNeg;

for (size_t id_t = 0; id_t < trackstersclue3d.size(); ++id_t) {
auto t = trackstersclue3d[id_t];
if (t.barycenter().eta() > 0.) {
tracksterTilePos.fill(t.barycenter().eta(), t.barycenter().phi(), id_t);
} else if (t.barycenter().eta() < 0.) {
tracksterTileNeg.fill(t.barycenter().eta(), t.barycenter().phi(), id_t);
}
}

std::vector<Node> allNodes;

for (size_t id_t = 0; id_t < trackstersclue3d.size(); ++id_t) {
auto t = trackstersclue3d[id_t];

Node tNode(id_t);

auto bary = t.barycenter();
double del = 0.1;

double eta_min = std::max(abs(bary.eta()) - del, (double)TileConstants::minEta);
double eta_max = std::min(abs(bary.eta()) + del, (double)TileConstants::maxEta);

if (bary.eta() > 0.) {
std::array<int, 4> search_box =
tracksterTilePos.searchBoxEtaPhi(eta_min, eta_max, bary.phi() - del, bary.phi() + del);
if (search_box[2] > search_box[3]) {
search_box[3] += TileConstants::nPhiBins;
}

for (int eta_i = search_box[0]; eta_i <= search_box[1]; ++eta_i) {
for (int phi_i = search_box[2]; phi_i <= search_box[3]; ++phi_i) {
auto &neighbours = tracksterTilePos[tracksterTilePos.globalBin(eta_i, (phi_i % TileConstants::nPhiBins))];
for (auto n : neighbours) {
if (trackstersclue3d[n].barycenter().z() < bary.z()) {
tNode.addInner(n);
} else if (trackstersclue3d[n].barycenter().z() > bary.z()) {
tNode.addOuter(n);
}
}
}
}
}

else if (bary.eta() < 0.) {
std::array<int, 4> search_box =
tracksterTileNeg.searchBoxEtaPhi(eta_min, eta_max, bary.phi() - del, bary.phi() + del);
if (search_box[2] > search_box[3]) {
search_box[3] += TileConstants::nPhiBins;
}

for (int eta_i = search_box[0]; eta_i <= search_box[1]; ++eta_i) {
for (int phi_i = search_box[2]; phi_i <= search_box[3]; ++phi_i) {
auto &neighbours = tracksterTileNeg[tracksterTileNeg.globalBin(eta_i, (phi_i % TileConstants::nPhiBins))];
for (auto n : neighbours) {
if (abs(trackstersclue3d[n].barycenter().z()) < abs(bary.z())) {
tNode.addInner(n);
} else if (abs(trackstersclue3d[n].barycenter().z()) > abs(bary.z())) {
tNode.addOuter(n);
}
}
}
}
}
allNodes.push_back(tNode);
}
auto resultGraph = std::make_unique<TICLGraph>(allNodes);

evt.put(std::move(resultGraph));
}

void TICLGraphProducer::fillDescriptions(edm::ConfigurationDescriptions &descriptions) {
edm::ParameterSetDescription desc;

desc.add<edm::InputTag>("trackstersclue3d", edm::InputTag("ticlTrackstersCLUE3DHigh"));
desc.add<edm::InputTag>("tracks", edm::InputTag("generalTracks"));
desc.add<edm::InputTag>("muons", edm::InputTag("muons1stStep"));
desc.add<std::string>("detector", "HGCAL");
desc.add<std::string>("propagator", "PropagatorWithMaterial");
desc.add<std::string>("cutTk",
"1.48 < abs(eta) < 3.0 && pt > 1. && quality(\"highPurity\") && "
"hitPattern().numberOfLostHits(\"MISSING_OUTER_HITS\") < 5");
descriptions.add("ticlGraphProducer", desc);
}

DEFINE_FWK_MODULE(TICLGraphProducer);
Loading

0 comments on commit 208aaaa

Please sign in to comment.