end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

再 - antlr4-python3-runtime for python3 による java source の parse / 構文解析

antlr4-python3-runtime for python3 による java source の parse / 構文解析 - end0tknr's kipple - web写経開発

以前、記載した上記entry の修正版。

*.java には、複数classの定義も可能ですので、以下を修正。

上記以外は、前回のentryの通りですので、修正後

$ /usr/local/python3/bin/python3 ast_analyze_executor.py \
    path/to/src/hoge/BaseAction.java

のように、実行して下さい。

ast_analyze_executor.py

# -*- coding: utf-8 -*-

import os
import sys
sys.path.append( os.path.dirname(__file__) )

import logging.config
from ast.ast_processor import AstProcessor
from ast.basic_info_listener import BasicInfoListener
import sys
import yaml
import pprint

log_conf = './log_conf.yaml'  # log設定は自分でも怪しいと思う

def main():
    logging.config.dictConfig(yaml.load(open(log_conf).read(),
                                        Loader=yaml.SafeLoader))
    logger = logging.getLogger('mainLogger')

    target_file_path = sys.argv[1]
    ast_info = \
        AstProcessor(logging, BasicInfoListener()).execute(target_file_path)

    print(pprint.pformat(ast_info, width=80)) # 幅:80文字に整形
#    pprint.pprint(ast_info)


if __name__ == "__main__":
    main()
    

ast.ast_processor.py

# -*- coding: utf-8 -*-
from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker
from ast.JavaLexer import JavaLexer
from ast.JavaParser import JavaParser
import pprint

source_encode = "utf-8"

class AstProcessor:

    def __init__(self, logging, listener):
        self.logging = logging
        self.logger = logging.getLogger(self.__class__.__name__)
        self.listener = listener

    def execute(self, input_source):
        file_stream = FileStream(input_source,encoding=source_encode)
        common_token_stream = CommonTokenStream(JavaLexer(file_stream))
        
        
        parser = JavaParser(common_token_stream)
        walker = ParseTreeWalker()
        walker.walk(self.listener, parser.compilationUnit())

        ast_info = self.listener.ast_info

        # BasicInfoListener 内で getText()した場合、
        # 空白や改行がないtextが取得される為、ここで
        # CommonTokenStream から getText() します
        
        for tmp_class in ast_info['classes']:
            for method in tmp_class['methods']:
                start_index = method['body_pos']['start_index']
                stop_index  = method['body_pos']['stop_index']
            
                method['body_src'] = \
                    common_token_stream.getText(start_index,stop_index)
        
        return ast_info

ast.basic_info_listener.py

# -*- coding: utf-8 -*-
from ast.JavaParserListener import JavaParserListener
from ast.JavaParser import JavaParser
import copy
import re
import sys
import pprint

class BasicInfoListener(JavaParserListener):
    def __init__(self):
        self.ast_info   = {'package'    : {},
                           'imports'    : [],
                           'classes'    : [] }
        self.class_base = {'name'       : '',
                           'annotation' : [],
                           'modifier'   : {},
                           'implements' : [],
                           'extends'    : '',
                           'fields'     : [],
                           'methods'    : [] }
        self.tmp_class = {}
        
        self.tmp_annotation = []
        self.tmp_modifier   = []

    def enterPackageDeclaration(self, ctx):
        self.ast_info['package'] = {
            'name' : ctx.qualifiedName().getText(),
            'pos'  : {'start_line'  : ctx.start.line,
                      'start_column': ctx.start.column,
                      'start_index' : ctx.start.tokenIndex,
                      'stop_line'   : ctx.stop.line,
                      'stop_column' : ctx.stop.column,
                      'stop_index'  : ctx.stop.tokenIndex}
                    }


    def enterImportDeclaration(self, ctx):
        self.ast_info['imports'].append(
            {'name' : ctx.qualifiedName().getText(),
             'pos'  : {'start_line'  : ctx.start.line,
                       'start_column': ctx.start.column,
                       'start_index' : ctx.start.tokenIndex,
                       'stop_line'   : ctx.stop.line,
                       'stop_column' : ctx.stop.column,
                       'stop_index'  : ctx.stop.tokenIndex}} )

    def enterClassOrInterfaceModifier(self, ctx):
        tmp_name = ctx.getText()
        # なぜか ctx.start.column == ctx.stop.column の為
        stop_column = ctx.stop.column + len(tmp_name)
        
        tmp_info = {
            'name' : tmp_name,
            'pos'  : {'start_line'  : ctx.start.line,
                      'start_column': ctx.start.column,
                      'start_index' : ctx.start.tokenIndex,
                      'stop_line'   : ctx.stop.line,
                      'stop_column' : stop_column,
                      'stop_index'  : ctx.stop.tokenIndex}}

        if re.match('^@', tmp_info['name']):
            self.tmp_annotation.append(tmp_info)
        else :
            self.tmp_modifier.append(tmp_info)
            
        
    def enterClassDeclaration(self, ctx):
        self.tmp_class = copy.copy( self.class_base )
        
        self.tmp_class['annotation'] = self.tmp_annotation
        self.tmp_class['modifier']   = self.tmp_modifier
        self.tmp_annotation = []
        self.tmp_modifier = []

        self.tmp_class['pos'] ={
            'start_line'  : ctx.start.line,
            'start_column': ctx.start.column,
            'start_index' : ctx.start.tokenIndex,
            'stop_line'   : ctx.stop.line,
            'stop_column' : ctx.stop.column,
            'stop_index'  : ctx.stop.tokenIndex }
        
        child_count = int(ctx.getChildCount())

        if child_count == 7:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2)                        # extends
            c4 = ctx.getChild(3).getChild(0).getText()  # extends class name
            c5 = ctx.getChild(4)                        # implements
            c7 = ctx.getChild(6)                        # class body
            
            self.tmp_class['name']       = c2
            self.tmp_class['extends']    = c4
            self.tmp_class['implements'] = \
                self.parse_implements_block(ctx.getChild(5))
            return
        
        if child_count == 5:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2).getText()              # extends or implements
            c5 = ctx.getChild(4)                        # class body

            self.tmp_class['name'] = c2
            if c3 == 'implements':
                self.tmp_class['implements'] = \
                    self.parse_implements_block(ctx.getChild(3))
            elif c3 == 'extends':
                c4 = ctx.getChild(3).getChild(0).getText()
                self.tmp_class['extends'] = c4
            return
        
        if child_count == 3:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2)                        # class body
            self.tmp_class['name'] = c2
            
            return
        
        print("ERROR unknown child_count"+ str(child_count))
        sys.exit()


    def exitClassDeclaration(self, ctx):
        self.ast_info['classes'].append(copy.copy(self.tmp_class) )
        
    
    def enterFieldDeclaration(self, ctx):
        field = {'type'      : ctx.getChild(0).getText(),
                 'body_src'  : ctx.getChild(1).getText(),
                 'annotation': [],
                 'modifier'  : []  }
        
        field['annotation'] = copy.copy(self.tmp_annotation)
        field['modifier']   = copy.copy(self.tmp_modifier)
        self.tmp_annotation = []
        self.tmp_modifier = []
        
        self.tmp_class['fields'].append(field)


    def enterMethodDeclaration(self, ctx):
        
        c1 = ctx.getChild(0).getText()  # return type
        c2 = ctx.getChild(1).getText()  # method name
        # params
        params = self.parse_method_params_block(ctx.getChild(2))

        # method bodyを CommonTokenStream と tokenIndex により得る為
        ctx_method_body = ctx.getChild(-1)


        method_info = {'returnType': c1,
                       'name'      : c2,
                       'annotation': [],
                       'modifier'  : [],
                       'params': params,
                       'pos'  : {'start_line'  : ctx.start.line,
                                 'start_column': ctx.start.column,
                                 'start_index' : ctx.start.tokenIndex,
                                 'stop_line'   : ctx.stop.line,
                                 'stop_column' : ctx.stop.column,
                                 'stop_index'  : ctx.stop.tokenIndex},
                       'body_pos' : {
                           'start_line'  : ctx_method_body.start.line,
                           'start_column': ctx_method_body.start.column,
                           'start_index' : ctx_method_body.start.tokenIndex,
                           'stop_line'   : ctx_method_body.stop.line,
                           'stop_column' : ctx_method_body.stop.column,
                           'stop_index'  : ctx_method_body.stop.tokenIndex}}
        method_info['annotation'] = self.tmp_annotation
        method_info['modifier']   = self.tmp_modifier
        self.tmp_annotation = []
        self.tmp_modifier = []

        self.tmp_class['methods'].append(method_info)

           
    def parse_implements_block(self, ctx):
        implements_child_count = int(ctx.getChildCount())
        result = []
        if implements_child_count == 1:
            impl_class = ctx.getChild(0).getText()
            result.append(impl_class)
        elif implements_child_count > 1:
            for i in range(implements_child_count):
                if i % 2 == 0:
                    impl_class = ctx.getChild(i).getText()
                    result.append(impl_class)
        return result

    def parse_method_params_block(self, ctx):
        params_exist_check = int(ctx.getChildCount())
        result = []
        if params_exist_check == 3:
            params_child_count = int(ctx.getChild(1).getChildCount())
            if params_child_count == 1:
                param_type = ctx.getChild(1).getChild(0).getChild(0).getText()
                param_name = ctx.getChild(1).getChild(0).getChild(1).getText()
                param_info = {'paramType': param_type,
                              'paramName': param_name }
                result.append(param_info)
            elif params_child_count > 1:
                for i in range(params_child_count):
                    if i % 2 == 0:
                        param_type = \
                            ctx.getChild(1).getChild(i).getChild(0).getText()
                        param_name = \
                            ctx.getChild(1).getChild(i).getChild(1).getText()
                        param_info = {'paramType': param_type,
                                      'paramName': param_name }
                        result.append(param_info)
        return result