Skip to content

Commit

Permalink
[NFC][MLIR][TableGen] Eliminate llvm:: for common types in LSP Serv…
Browse files Browse the repository at this point in the history
…er (llvm#110867)
  • Loading branch information
jurahul authored and EricWF committed Oct 2, 2024
1 parent 40225c7 commit 4b8d751
Showing 1 changed file with 51 additions and 55 deletions.
106 changes: 51 additions & 55 deletions mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include <optional>

using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::RecordVal;
using llvm::SourceMgr;

/// Returns the range of a lexical token given a SMLoc corresponding to the
/// start of an token location. The range is computed heuristically, and
Expand All @@ -32,7 +36,7 @@ static SMRange convertTokenLocToRange(SMLoc loc) {

/// Returns a language server uri for the given source location. `mainFileURI`
/// corresponds to the uri for the main file of the source manager.
static lsp::URIForFile getURIFromLoc(const llvm::SourceMgr &mgr, SMLoc loc,
static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &mainFileURI) {
int bufferId = mgr.FindBufferContainingLoc(loc);
if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
Expand All @@ -47,12 +51,12 @@ static lsp::URIForFile getURIFromLoc(const llvm::SourceMgr &mgr, SMLoc loc,
}

/// Returns a language server location from the given source range.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange loc,
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc,
const lsp::URIForFile &uri) {
return lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
lsp::Range(mgr, loc));
}
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMLoc loc,
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &uri) {
return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri);
}
Expand All @@ -61,7 +65,7 @@ static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMLoc loc,
static std::optional<lsp::Diagnostic>
getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
const lsp::URIForFile &uri) {
auto *sourceMgr = const_cast<llvm::SourceMgr *>(diag.getSourceMgr());
auto *sourceMgr = const_cast<SourceMgr *>(diag.getSourceMgr());
if (!sourceMgr || !diag.getLoc().isValid())
return std::nullopt;

Expand All @@ -79,17 +83,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,

// Convert the severity for the diagnostic.
switch (diag.getKind()) {
case llvm::SourceMgr::DK_Warning:
case SourceMgr::DK_Warning:
lspDiag.severity = lsp::DiagnosticSeverity::Warning;
break;
case llvm::SourceMgr::DK_Error:
case SourceMgr::DK_Error:
lspDiag.severity = lsp::DiagnosticSeverity::Error;
break;
case llvm::SourceMgr::DK_Note:
case SourceMgr::DK_Note:
// Notes are emitted separately from the main diagnostic, so we just treat
// them as remarks given that we can't determine the diagnostic to relate
// them to.
case llvm::SourceMgr::DK_Remark:
case SourceMgr::DK_Remark:
lspDiag.severity = lsp::DiagnosticSeverity::Information;
break;
}
Expand All @@ -100,16 +104,15 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,

/// Get the base definition of the given record value, or nullptr if one
/// couldn't be found.
static std::pair<const llvm::Record *, const llvm::RecordVal *>
getBaseValue(const llvm::Record *record, const llvm::RecordVal *value) {
static std::pair<const Record *, const RecordVal *>
getBaseValue(const Record *record, const RecordVal *value) {
if (value->isTemplateArg())
return {nullptr, nullptr};

// Find a base value for the field in the super classes of the given record.
// On success, `record` is updated to the new parent record.
StringRef valueName = value->getName();
auto findValueInSupers =
[&](const llvm::Record *&record) -> llvm::RecordVal * {
auto findValueInSupers = [&](const Record *&record) -> RecordVal * {
for (auto [parentRecord, loc] : record->getSuperClasses()) {
if (auto *newBase = parentRecord->getValue(valueName)) {
record = parentRecord;
Expand All @@ -120,8 +123,8 @@ getBaseValue(const llvm::Record *record, const llvm::RecordVal *value) {
};

// Try to find the lowest definition of the record value.
std::pair<const llvm::Record *, const llvm::RecordVal *> baseValue = {};
while (const llvm::RecordVal *newBase = findValueInSupers(record))
std::pair<const Record *, const RecordVal *> baseValue = {};
while (const RecordVal *newBase = findValueInSupers(record))
baseValue = {record, newBase};

// Check that the base isn't the same as the current value (e.g. if the value
Expand All @@ -140,15 +143,15 @@ namespace {
/// contains the definition of the symbol, the location of the symbol, and any
/// recorded references.
struct TableGenIndexSymbol {
TableGenIndexSymbol(const llvm::Record *record)
TableGenIndexSymbol(const Record *record)
: definition(record),
defLoc(convertTokenLocToRange(record->getLoc().front())) {}
TableGenIndexSymbol(const llvm::RecordVal *value)
TableGenIndexSymbol(const RecordVal *value)
: definition(value), defLoc(convertTokenLocToRange(value->getLoc())) {}
virtual ~TableGenIndexSymbol() = default;

// The main definition of the symbol.
PointerUnion<const llvm::Record *, const llvm::RecordVal *> definition;
PointerUnion<const Record *, const RecordVal *> definition;

/// The source location of the definition.
SMRange defLoc;
Expand All @@ -158,37 +161,33 @@ struct TableGenIndexSymbol {
};
/// This class represents a single record symbol.
struct TableGenRecordSymbol : public TableGenIndexSymbol {
TableGenRecordSymbol(const llvm::Record *record)
: TableGenIndexSymbol(record) {}
TableGenRecordSymbol(const Record *record) : TableGenIndexSymbol(record) {}
~TableGenRecordSymbol() override = default;

static bool classof(const TableGenIndexSymbol *symbol) {
return symbol->definition.is<const llvm::Record *>();
return symbol->definition.is<const Record *>();
}

/// Return the value of this symbol.
const llvm::Record *getValue() const {
return definition.get<const llvm::Record *>();
}
const Record *getValue() const { return definition.get<const Record *>(); }
};
/// This class represents a single record value symbol.
struct TableGenRecordValSymbol : public TableGenIndexSymbol {
TableGenRecordValSymbol(const llvm::Record *record,
const llvm::RecordVal *value)
TableGenRecordValSymbol(const Record *record, const RecordVal *value)
: TableGenIndexSymbol(value), record(record) {}
~TableGenRecordValSymbol() override = default;

static bool classof(const TableGenIndexSymbol *symbol) {
return symbol->definition.is<const llvm::RecordVal *>();
return symbol->definition.is<const RecordVal *>();
}

/// Return the value of this symbol.
const llvm::RecordVal *getValue() const {
return definition.get<const llvm::RecordVal *>();
const RecordVal *getValue() const {
return definition.get<const RecordVal *>();
}

/// The parent record of this symbol.
const llvm::Record *record;
const Record *record;
};

/// This class provides an index for definitions/uses within a TableGen
Expand All @@ -199,7 +198,7 @@ class TableGenIndex {
TableGenIndex() : intervalMap(allocator) {}

/// Initialize the index with the given RecordKeeper.
void initialize(const llvm::RecordKeeper &records);
void initialize(const RecordKeeper &records);

/// Lookup a symbol for the given location. Returns nullptr if no symbol could
/// be found. If provided, `overlappedRange` is set to the range that the
Expand All @@ -217,15 +216,15 @@ class TableGenIndex {
llvm::IntervalMapHalfOpenInfo<const char *>>;

/// Get or insert a symbol for the given record.
TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record) {
TableGenIndexSymbol *getOrInsertDef(const Record *record) {
auto it = defToSymbol.try_emplace(record, nullptr);
if (it.second)
it.first->second = std::make_unique<TableGenRecordSymbol>(record);
return &*it.first->second;
}
/// Get or insert a symbol for the given record value.
TableGenIndexSymbol *getOrInsertDef(const llvm::Record *record,
const llvm::RecordVal *value) {
TableGenIndexSymbol *getOrInsertDef(const Record *record,
const RecordVal *value) {
auto it = defToSymbol.try_emplace(value, nullptr);
if (it.second) {
it.first->second =
Expand All @@ -246,7 +245,7 @@ class TableGenIndex {
};
} // namespace

void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
void TableGenIndex::initialize(const RecordKeeper &records) {
intervalMap.clear();
defToSymbol.clear();

Expand Down Expand Up @@ -282,7 +281,7 @@ void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
llvm::make_pointee_range(llvm::make_second_range(records.getClasses()));
auto defs =
llvm::make_pointee_range(llvm::make_second_range(records.getDefs()));
for (const llvm::Record &def : llvm::concat<llvm::Record>(classes, defs)) {
for (const Record &def : llvm::concat<Record>(classes, defs)) {
auto *sym = getOrInsertDef(&def);
insertRef(sym, sym->defLoc, /*isDef=*/true);

Expand All @@ -293,7 +292,7 @@ void TableGenIndex::initialize(const llvm::RecordKeeper &records) {
insertRef(sym, loc);

// Add definitions for any values.
for (const llvm::RecordVal &value : def.getValues()) {
for (const RecordVal &value : def.getValues()) {
auto *sym = getOrInsertDef(&def, &value);
insertRef(sym, sym->defLoc, /*isDef=*/true);
for (SMRange refLoc : value.getReferenceLocs())
Expand Down Expand Up @@ -359,13 +358,12 @@ class TableGenTextFile {

std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
const lsp::Position &hoverPos);
lsp::Hover buildHoverForRecord(const llvm::Record *record,
lsp::Hover buildHoverForRecord(const Record *record,
const SMRange &hoverRange);
lsp::Hover buildHoverForTemplateArg(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover buildHoverForTemplateArg(const Record *record,
const RecordVal *value,
const SMRange &hoverRange);
lsp::Hover buildHoverForField(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover buildHoverForField(const Record *record, const RecordVal *value,
const SMRange &hoverRange);

private:
Expand All @@ -383,10 +381,10 @@ class TableGenTextFile {
std::vector<std::string> includeDirs;

/// The source manager containing the contents of the input file.
llvm::SourceMgr sourceMgr;
SourceMgr sourceMgr;

/// The record keeper containing the parsed tablegen constructs.
std::unique_ptr<llvm::RecordKeeper> recordKeeper;
std::unique_ptr<RecordKeeper> recordKeeper;

/// The index of the parsed file.
TableGenIndex index;
Expand Down Expand Up @@ -430,8 +428,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
int64_t newVersion,
std::vector<lsp::Diagnostic> &diagnostics) {
version = newVersion;
sourceMgr = llvm::SourceMgr();
recordKeeper = std::make_unique<llvm::RecordKeeper>();
sourceMgr = SourceMgr();
recordKeeper = std::make_unique<RecordKeeper>();

// Build a buffer for this file.
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file());
Expand All @@ -442,7 +440,7 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());

// This class provides a context argument for the llvm::SourceMgr diagnostic
// This class provides a context argument for the SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
std::vector<lsp::Diagnostic> &diagnostics;
Expand Down Expand Up @@ -543,13 +541,13 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri,
// Build hover for a RecordVal, which is either a template argument or a
// field.
auto *recordVal = cast<TableGenRecordValSymbol>(symbol);
const llvm::RecordVal *value = recordVal->getValue();
const RecordVal *value = recordVal->getValue();
if (value->isTemplateArg())
return buildHoverForTemplateArg(recordVal->record, value, hoverRange);
return buildHoverForField(recordVal->record, value, hoverRange);
}

lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
Expand All @@ -570,7 +568,7 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
auto printAndFormatField = [&](StringRef fieldName) {
// Check that the record actually has the given field, and that it's a
// string.
const llvm::RecordVal *value = record->getValue(fieldName);
const RecordVal *value = record->getValue(fieldName);
if (!value || !value->getValue())
return;
auto *stringValue = dyn_cast<llvm::StringInit>(value->getValue());
Expand All @@ -593,10 +591,8 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const llvm::Record *record,
return hover;
}

lsp::Hover
TableGenTextFile::buildHoverForTemplateArg(const llvm::Record *record,
const llvm::RecordVal *value,
const SMRange &hoverRange) {
lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
const Record *record, const RecordVal *value, const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
Expand All @@ -609,8 +605,8 @@ TableGenTextFile::buildHoverForTemplateArg(const llvm::Record *record,
return hover;
}

lsp::Hover TableGenTextFile::buildHoverForField(const llvm::Record *record,
const llvm::RecordVal *value,
lsp::Hover TableGenTextFile::buildHoverForField(const Record *record,
const RecordVal *value,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
{
Expand Down

0 comments on commit 4b8d751

Please sign in to comment.