Building a compiler for your own language: from the parse tree to the Abstract Syntax Tree
In this post we are going to see how process and transform the information obtained from the parser. The ANTLR parser recognizes the elements present in the source code and build a parse tree. From the parse tree we will obtain the Abstract Syntax Tree which we will use to perform validation and produce compiled code.
Note that the terminology can vary: many would call the tree obtained from ANTLR an Abstract Syntax Tree. I prefer to mark the difference from this two steps. To me the parse tree is the information as meaningful to the parser, the abstract syntax tree is the information reorganized to better support the next steps.
Series on building your own language
Previous posts:
- Building a lexer
- Building a parser
- Creating an editor with syntax highlighting
- Build an editor with autocompletion
Code is available on GitHub under the tag 05_ast
Spice up our language
In this series of post we have been working on a very simple language for expressions. It is time to make our language slightly more complex introducing:
- casts for example: 10 as Decimal or (1 * 2.45) as Int
- print statement for example: print(3 + 11)
To do so we need to revise our lexer and parser grammar. The syntax highlighting and autocompletion which we have built in previous posts will just keep working.
The new lexer grammar:
lexer grammar SandyLexer; // Whitespace NEWLINE : '\r\n' | 'r' | '\n' ; WS : [\t ]+ -> skip ; // Keywords VAR : 'var' ; PRINT : 'print'; AS : 'as'; INT : 'Int'; DECIMAL : 'Decimal'; // Literals INTLIT : '0'|[1-9][0-9]* ; DECLIT : '0'|[1-9][0-9]* '.' [0-9]+ ; // Operators PLUS : '+' ; MINUS : '-' ; ASTERISK : '*' ; DIVISION : '/' ; ASSIGN : '=' ; LPAREN : '(' ; RPAREN : ')' ; // Identifiers ID : [_]*[a-z][A-Za-z0-9_]* ;
And the new parser grammar:
parser grammar SandyParser; options { tokenVocab=SandyLexer; } sandyFile : lines=line+ ; line : statement (NEWLINE | EOF) ; statement : varDeclaration # varDeclarationStatement | assignment # assignmentStatement | print # printStatement ; print : PRINT LPAREN expression RPAREN ; varDeclaration : VAR assignment ; assignment : ID ASSIGN expression ; expression : left=expression operator=(DIVISION|ASTERISK) right=expression # binaryOperation | left=expression operator=(PLUS|MINUS) right=expression # binaryOperation | value=expression AS targetType=type # typeConversion | LPAREN expression RPAREN # parenExpression | ID # varReference | MINUS expression # minusExpression | INTLIT # intLiteral | DECLIT # decimalLiteral ; type : INT # integer | DECIMAL # decimal ;
The Abstract Syntax Tree metamodel
The Abstract Syntax Tree metamodel is simply the structure of the data we want to use for our Abstract Syntax Tree (AST). In this case we are defining it by defining the classes which we will use for our AST.
The AST metamodel will look reasonably similar to the parse tree metamodel, i.e., the set of classes generated by ANTLR to contain the nodes.
There will be a few key differences:
- it will have a simpler and nicer API than the classes generated by ANTLR (so the classes composing the parse tree). In next sections we will see how this API could permit to perform transformations on the AST
- we will remove elements which are meaningful only while parsing but that logically are useless: for example the parenthesis expression or the line node
- some nodes for which we have separate instances in the parse tree can correspond to a single instance in the AST. This is the case of the type references Int and Decimal which in the AST are defined using singleton objects
- we can define common interfaces for related node types like BinaryExpression
- to define how to parse a variable declaration we reuse the assignement rule. In the AST the two concepts are completely separated
- certain operations have the same node type in the parse tree but are separated in the AST. This is the case of the different types of binary expressions
Let’s see how we can define our AST metamodel using Kotlin.
interface Node // // Sandy specific part // data class SandyFile(val statements : List) : Node interface Statement : Node { } interface Expression : Node { } interface Type : Node { } // // Types // object IntType : Type object DecimalType : Type // // Expressions // interface BinaryExpression : Expression { val left: Expression val right: Expression } data class SumExpression(override val left: Expression, override val right: Expression) : BinaryExpression data class SubtractionExpression(override val left: Expression, override val right: Expression) : BinaryExpression data class MultiplicationExpression(override val left: Expression, override val right: Expression) : BinaryExpression data class DivisionExpression(override val left: Expression, override val right: Expression) : BinaryExpression data class UnaryMinusExpression(val value: Expression) : Expression data class TypeConversion(val value: Expression, val targetType: Type) : Expression data class VarReference(val varName: String) : Expression data class IntLit(val value: String) : Expression data class DecLit(val value: String) : Expression // // Statements // data class VarDeclaration(val varName: String, val value: Expression) : Statement data class Assignment(val varName: String, val value: Expression) : Statement data class Print(val value: Expression) : Statement
We start by defining Node. A Node represents every possible node of an AST and it is general. It could be reused for other languages also. All the rest is instead specific of the language (Sandy on our case). In our specific language we need three important interfaces:
- Statement
- Expression
- Type
Each of these interfaces extends Node.
We then declare the two types we use in our language. They are defined as singleton objects. It means that we have just one instance of these classes.
We then have the BinaryExpression interface, which extends Expression. For classes implements it, one for each of the basic arithmetic expressions.
Most of the expressions have as children other nodes. A few have instead simple values. They are VarReference (which has a property varName of type String), and Intlit and DecLit (both have a property value of type String).
Finally we have the three classes implementing Statement.
Note that we are using data classes so we can for free the hashCode, equals and toString methods. Kotlin generates for us also constructors and getters. Try to imagine how much code that would be in Java.
Mapping the parse tree to the abstract syntax tree
Let’s see how we can get the parse tree, produced by ANTLR, and map it into our AST classes.
fun SandyFileContext.toAst() : SandyFile = SandyFile(this.line().map { it.statement().toAst() }) fun StatementContext.toAst() : Statement = when (this) { is VarDeclarationStatementContext -> VarDeclaration(varDeclaration().assignment().ID().text, varDeclaration().assignment().expression().toAst()) is AssignmentStatementContext -> Assignment(assignment().ID().text, assignment().expression().toAst()) is PrintStatementContext -> Print(print().expression().toAst()) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun ExpressionContext.toAst() : Expression = when (this) { is BinaryOperationContext -> toAst() is IntLiteralContext -> IntLit(text) is DecimalLiteralContext -> DecLit(text) is ParenExpressionContext -> expression().toAst() is VarReferenceContext -> VarReference(text) is TypeConversionContext -> TypeConversion(expression().toAst(), targetType.toAst()) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun TypeContext.toAst() : Type = when (this) { is IntegerContext -> IntType is DecimalContext -> DecimalType else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun BinaryOperationContext.toAst() : Expression = when (operator.text) { "+" -> SumExpression(left.toAst(), right.toAst()) "-" -> SubtractionExpression(left.toAst(), right.toAst()) "*" -> MultiplicationExpression(left.toAst(), right.toAst()) "/" -> DivisionExpression(left.toAst(), right.toAst()) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) }
To implement this we have taken advantage of three very useful features of Kotlin:
- extension methods: we added the method toAst to several existing classes
- the when construct, which is a more powerful version of switch
- smart casts: after we check that an object has a certain class the compiler implicitly cast it to that type so that we can use the specific methods of that class
We could come up with a mechanism to derive automatically this mapping for most of the rules and just customize it where the parse tree and the AST differs. To avoid using too much reflection black magic we are not going to do that for now. If I were using Java I would just go for the reflection road to avoid having to write manually a lot of redundant and boring code. However using Kotlin this code is compact and clear.
Testing the mapping
Of course we need to test this stuff. Let’s see if the AST we get for a certain piece of code is the one we expect.
class MappingTest { @test fun mapSimpleFile() { val code = """var a = 1 + 2 |a = 7 * (2 / 3)""".trimMargin("|") val ast = SandyParserFacade.parse(code).root!!.toAst() val expectedAst = SandyFile(listOf( VarDeclaration("a", SumExpression(IntLit("1"), IntLit("2"))), Assignment("a", MultiplicationExpression( IntLit("7"), DivisionExpression( IntLit("2"), IntLit("3")))))) assertEquals(expectedAst, ast) } @test fun mapCastInt() { val code = "a = 7 as Int" val ast = SandyParserFacade.parse(code).root!!.toAst() val expectedAst = SandyFile(listOf(Assignment("a", TypeConversion(IntLit("7"), IntType)))) assertEquals(expectedAst, ast) } @test fun mapCastDecimal() { val code = "a = 7 as Decimal" val ast = SandyParserFacade.parse(code).root!!.toAst() val expectedAst = SandyFile(listOf(Assignment("a", TypeConversion(IntLit("7"), DecimalType)))) assertEquals(expectedAst, ast) } @test fun mapPrint() { val code = "print(a)" val ast = SandyParserFacade.parse(code).root!!.toAst() val expectedAst = SandyFile(listOf(Print(VarReference("a")))) assertEquals(expectedAst, ast) } }
Considering positions
This would be all nice: we have a clean model of the information present in the code. The metamodel and the mapping code looks very simple and clear. However we would need to add a little detail: the position of the nodes in the source code. This would be needed while showing errors to the user. We want to have the possibility to specify the positions of our AST nodes but we do not want to be forced to do so. In this way depending on the operations we need to do we can ignore or not the positions. Consider the tests we have written so far: wouldn’t be cumbersome and annoying having to specify fake positions for all the nodes? I think so.
This is the new Node definition and a few supporting class:
interface Node { val position: Position? } data class Point(val line: Int, val column: Int) data class Position(val start: Point, val end: Point) fun pos(startLine:Int, startCol:Int, endLine:Int, endCol:Int) = Position(Point(startLine,startCol),Point(endLine,endCol))
We need also to add position as an optional parameter to all the classes. It would have the default value null. For example this is how SandyFile looks now:
data class SandyFile(val statements : List<Statement>, override val position: Position? = null) : Node
The mapping just went a bit more complicated:
fun SandyFileContext.toAst(considerPosition: Boolean = false) : SandyFile = SandyFile(this.line().map { it.statement().toAst(considerPosition) }, toPosition(considerPosition)) fun Token.startPoint() = Point(line, charPositionInLine) fun Token.endPoint() = Point(line, charPositionInLine + text.length) fun ParserRuleContext.toPosition(considerPosition: Boolean) : Position? { return if (considerPosition) Position(start.startPoint(), stop.endPoint()) else null } fun StatementContext.toAst(considerPosition: Boolean = false) : Statement = when (this) { is VarDeclarationStatementContext -> VarDeclaration(varDeclaration().assignment().ID().text, varDeclaration().assignment().expression().toAst(considerPosition), toPosition(considerPosition)) is AssignmentStatementContext -> Assignment(assignment().ID().text, assignment().expression().toAst(considerPosition), toPosition(considerPosition)) is PrintStatementContext -> Print(print().expression().toAst(considerPosition), toPosition(considerPosition)) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun ExpressionContext.toAst(considerPosition: Boolean = false) : Expression = when (this) { is BinaryOperationContext -> toAst(considerPosition) is IntLiteralContext -> IntLit(text, toPosition(considerPosition)) is DecimalLiteralContext -> DecLit(text, toPosition(considerPosition)) is ParenExpressionContext -> expression().toAst(considerPosition) is VarReferenceContext -> VarReference(text, toPosition(considerPosition)) is TypeConversionContext -> TypeConversion(expression().toAst(considerPosition), targetType.toAst(considerPosition), toPosition(considerPosition)) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun TypeContext.toAst(considerPosition: Boolean = false) : Type = when (this) { is IntegerContext -> IntType(toPosition(considerPosition)) is DecimalContext -> DecimalType(toPosition(considerPosition)) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) } fun BinaryOperationContext.toAst(considerPosition: Boolean = false) : Expression = when (operator.text) { "+" -> SumExpression(left.toAst(considerPosition), right.toAst(considerPosition), toPosition(considerPosition)) "-" -> SubtractionExpression(left.toAst(considerPosition), right.toAst(considerPosition), toPosition(considerPosition)) "*" -> MultiplicationExpression(left.toAst(considerPosition), right.toAst(considerPosition), toPosition(considerPosition)) "/" -> DivisionExpression(left.toAst(considerPosition), right.toAst(considerPosition), toPosition(considerPosition)) else -> throw UnsupportedOperationException(this.javaClass.canonicalName) }
At this point all previous tests keep passing but we want to add a test to verify that the position are defined correctly:
@test fun mapSimpleFileWithPositions() { val code = """var a = 1 + 2 |a = 7 * (2 / 3)""".trimMargin("|") val ast = SandyParserFacade.parse(code).root!!.toAst(considerPosition = true) val expectedAst = SandyFile(listOf( VarDeclaration("a", SumExpression( IntLit("1", pos(1,8,1,9)), IntLit("2", pos(1,12,1,13)), pos(1,8,1,13)), pos(1,0,1,13)), Assignment("a", MultiplicationExpression( IntLit("7", pos(2,4,2,5)), DivisionExpression( IntLit("2", pos(2,9,2,10)), IntLit("3", pos(2,13,2,14)), pos(2,9,2,14)), pos(2,4,2,15)), pos(2,0,2,15))), pos(1,0,2,15)) assertEquals(expectedAst, ast) }
Conclusions
The parse tree contains the information organized in the most convenient way for the parser. It is typically not the most convenient way for the steps which follow. Think about the variable declaration rule being implemented by reusing the assignment rule: sure, this make the grammar shorter and it makes sense for the parse tree. However from the logical point of view the two elements are separated, and in the AST they are indeed.
Most of the rest of our tools will operate on the AST so it is better to spend some time working on an AST that makes sense.
Reference: | Building a compiler for your own language: from the parse tree to the Abstract Syntax Tree from our JCG partner Federico Tomassetti at the Federico Tomassetti blog. |