antlr4-python3-runtime for python3 による java source の parse / 構文解析 - end0tknr's kipple - web写経開発
以前、記載した上記entry の修正版。
*.java には、複数classの定義も可能ですので、以下を修正。
上記以外は、前回のentryの通りですので、修正後
$ /usr/local/python3/bin/python3 ast_analyze_executor.py \ path/to/src/hoge/BaseAction.java
のように、実行して下さい。
ast_analyze_executor.py
# -*- coding: utf-8 -*- import os import sys sys.path.append( os.path.dirname(__file__) ) import logging.config from ast.ast_processor import AstProcessor from ast.basic_info_listener import BasicInfoListener import sys import yaml import pprint log_conf = './log_conf.yaml' # log設定は自分でも怪しいと思う def main(): logging.config.dictConfig(yaml.load(open(log_conf).read(), Loader=yaml.SafeLoader)) logger = logging.getLogger('mainLogger') target_file_path = sys.argv[1] ast_info = \ AstProcessor(logging, BasicInfoListener()).execute(target_file_path) print(pprint.pformat(ast_info, width=80)) # 幅:80文字に整形 # pprint.pprint(ast_info) if __name__ == "__main__": main()
ast.ast_processor.py
# -*- coding: utf-8 -*- from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker from ast.JavaLexer import JavaLexer from ast.JavaParser import JavaParser import pprint source_encode = "utf-8" class AstProcessor: def __init__(self, logging, listener): self.logging = logging self.logger = logging.getLogger(self.__class__.__name__) self.listener = listener def execute(self, input_source): file_stream = FileStream(input_source,encoding=source_encode) common_token_stream = CommonTokenStream(JavaLexer(file_stream)) parser = JavaParser(common_token_stream) walker = ParseTreeWalker() walker.walk(self.listener, parser.compilationUnit()) ast_info = self.listener.ast_info # BasicInfoListener 内で getText()した場合、 # 空白や改行がないtextが取得される為、ここで # CommonTokenStream から getText() します for tmp_class in ast_info['classes']: for method in tmp_class['methods']: start_index = method['body_pos']['start_index'] stop_index = method['body_pos']['stop_index'] method['body_src'] = \ common_token_stream.getText(start_index,stop_index) return ast_info
ast.basic_info_listener.py
# -*- coding: utf-8 -*- from ast.JavaParserListener import JavaParserListener from ast.JavaParser import JavaParser import copy import re import sys import pprint class BasicInfoListener(JavaParserListener): def __init__(self): self.ast_info = {'package' : {}, 'imports' : [], 'classes' : [] } self.class_base = {'name' : '', 'annotation' : [], 'modifier' : {}, 'implements' : [], 'extends' : '', 'fields' : [], 'methods' : [] } self.tmp_class = {} self.tmp_annotation = [] self.tmp_modifier = [] def enterPackageDeclaration(self, ctx): self.ast_info['package'] = { 'name' : ctx.qualifiedName().getText(), 'pos' : {'start_line' : ctx.start.line, 'start_column': ctx.start.column, 'start_index' : ctx.start.tokenIndex, 'stop_line' : ctx.stop.line, 'stop_column' : ctx.stop.column, 'stop_index' : ctx.stop.tokenIndex} } def enterImportDeclaration(self, ctx): self.ast_info['imports'].append( {'name' : ctx.qualifiedName().getText(), 'pos' : {'start_line' : ctx.start.line, 'start_column': ctx.start.column, 'start_index' : ctx.start.tokenIndex, 'stop_line' : ctx.stop.line, 'stop_column' : ctx.stop.column, 'stop_index' : ctx.stop.tokenIndex}} ) def enterClassOrInterfaceModifier(self, ctx): tmp_name = ctx.getText() # なぜか ctx.start.column == ctx.stop.column の為 stop_column = ctx.stop.column + len(tmp_name) tmp_info = { 'name' : tmp_name, 'pos' : {'start_line' : ctx.start.line, 'start_column': ctx.start.column, 'start_index' : ctx.start.tokenIndex, 'stop_line' : ctx.stop.line, 'stop_column' : stop_column, 'stop_index' : ctx.stop.tokenIndex}} if re.match('^@', tmp_info['name']): self.tmp_annotation.append(tmp_info) else : self.tmp_modifier.append(tmp_info) def enterClassDeclaration(self, ctx): self.tmp_class = copy.copy( self.class_base ) self.tmp_class['annotation'] = self.tmp_annotation self.tmp_class['modifier'] = self.tmp_modifier self.tmp_annotation = [] self.tmp_modifier = [] self.tmp_class['pos'] ={ 'start_line' : ctx.start.line, 'start_column': ctx.start.column, 'start_index' : ctx.start.tokenIndex, 'stop_line' : ctx.stop.line, 'stop_column' : ctx.stop.column, 'stop_index' : ctx.stop.tokenIndex } child_count = int(ctx.getChildCount()) if child_count == 7: c1 = ctx.getChild(0) # class c2 = ctx.getChild(1).getText() # class name c3 = ctx.getChild(2) # extends c4 = ctx.getChild(3).getChild(0).getText() # extends class name c5 = ctx.getChild(4) # implements c7 = ctx.getChild(6) # class body self.tmp_class['name'] = c2 self.tmp_class['extends'] = c4 self.tmp_class['implements'] = \ self.parse_implements_block(ctx.getChild(5)) return if child_count == 5: c1 = ctx.getChild(0) # class c2 = ctx.getChild(1).getText() # class name c3 = ctx.getChild(2).getText() # extends or implements c5 = ctx.getChild(4) # class body self.tmp_class['name'] = c2 if c3 == 'implements': self.tmp_class['implements'] = \ self.parse_implements_block(ctx.getChild(3)) elif c3 == 'extends': c4 = ctx.getChild(3).getChild(0).getText() self.tmp_class['extends'] = c4 return if child_count == 3: c1 = ctx.getChild(0) # class c2 = ctx.getChild(1).getText() # class name c3 = ctx.getChild(2) # class body self.tmp_class['name'] = c2 return print("ERROR unknown child_count"+ str(child_count)) sys.exit() def exitClassDeclaration(self, ctx): self.ast_info['classes'].append(copy.copy(self.tmp_class) ) def enterFieldDeclaration(self, ctx): field = {'type' : ctx.getChild(0).getText(), 'body_src' : ctx.getChild(1).getText(), 'annotation': [], 'modifier' : [] } field['annotation'] = copy.copy(self.tmp_annotation) field['modifier'] = copy.copy(self.tmp_modifier) self.tmp_annotation = [] self.tmp_modifier = [] self.tmp_class['fields'].append(field) def enterMethodDeclaration(self, ctx): c1 = ctx.getChild(0).getText() # return type c2 = ctx.getChild(1).getText() # method name # params params = self.parse_method_params_block(ctx.getChild(2)) # method bodyを CommonTokenStream と tokenIndex により得る為 ctx_method_body = ctx.getChild(-1) method_info = {'returnType': c1, 'name' : c2, 'annotation': [], 'modifier' : [], 'params': params, 'pos' : {'start_line' : ctx.start.line, 'start_column': ctx.start.column, 'start_index' : ctx.start.tokenIndex, 'stop_line' : ctx.stop.line, 'stop_column' : ctx.stop.column, 'stop_index' : ctx.stop.tokenIndex}, 'body_pos' : { 'start_line' : ctx_method_body.start.line, 'start_column': ctx_method_body.start.column, 'start_index' : ctx_method_body.start.tokenIndex, 'stop_line' : ctx_method_body.stop.line, 'stop_column' : ctx_method_body.stop.column, 'stop_index' : ctx_method_body.stop.tokenIndex}} method_info['annotation'] = self.tmp_annotation method_info['modifier'] = self.tmp_modifier self.tmp_annotation = [] self.tmp_modifier = [] self.tmp_class['methods'].append(method_info) def parse_implements_block(self, ctx): implements_child_count = int(ctx.getChildCount()) result = [] if implements_child_count == 1: impl_class = ctx.getChild(0).getText() result.append(impl_class) elif implements_child_count > 1: for i in range(implements_child_count): if i % 2 == 0: impl_class = ctx.getChild(i).getText() result.append(impl_class) return result def parse_method_params_block(self, ctx): params_exist_check = int(ctx.getChildCount()) result = [] if params_exist_check == 3: params_child_count = int(ctx.getChild(1).getChildCount()) if params_child_count == 1: param_type = ctx.getChild(1).getChild(0).getChild(0).getText() param_name = ctx.getChild(1).getChild(0).getChild(1).getText() param_info = {'paramType': param_type, 'paramName': param_name } result.append(param_info) elif params_child_count > 1: for i in range(params_child_count): if i % 2 == 0: param_type = \ ctx.getChild(1).getChild(i).getChild(0).getText() param_name = \ ctx.getChild(1).getChild(i).getChild(1).getText() param_info = {'paramType': param_type, 'paramName': param_name } result.append(param_info) return result