end0tknr's kipple - web写経開発

太宰府天満宮の狛犬って、妙にカワイイ

再x2 - antlr4-python3-runtime for python3 による java source の parse / 構文解析 / コメント抽出

再 - antlr4-python3-runtime for python3 による java source の parse / 構文解析 - end0tknr's kipple - web写経開発

先程、記載した上記entry を再度、修正。

単純に parse / 構文解析を行うと、コメント分が削除される為、 ast.ast_processor.py 内で、無理やり?、COMMENT & LINE_COMMNET を収集。

詳細は、ast.ast_processor.py をご覧下さい。

▲ 1,000行を超えるsrcの場合、parseできない場合があるようです。

ast_analyze_executor.py

# -*- coding: utf-8 -*-

import os
import sys
sys.path.append( os.path.dirname(__file__) )

import glob
import logging.config
from ast.ast_processor import AstProcessor
from ast.basic_info_listener import BasicInfoListener
import re
import sys
import yaml
import pprint

log_conf = './log_conf.yaml'  # log設定は自分でも怪しいと思う
line_feed_str = "\r\n"

def main():
    logging.config.dictConfig(yaml.load(open(log_conf).read(),
                                        Loader=yaml.SafeLoader))
    logger = logging.getLogger('mainLogger')

    java_base_dir = sys.argv[1]

    if os.path.isdir(java_base_dir):
        java_paths = glob.glob(os.path.join(java_base_dir,'**/*.java'),
                               recursive=True)
    else:
        java_paths = [java_base_dir]

    src_infos = {}
    for java_path in sorted(java_paths):
        print( java_path )

        try:
            src_info_and_comments = \
                AstProcessor(logging, BasicInfoListener()).execute(java_path)
        except:
            print("ERROR fail AstProcessor.execute()", java_path)
            continue

        src_info = src_info_and_comments[0]
        comments = src_info_and_comments[1]

        # 上下に連続するcomment は merge
        comments = merge_comments(comments,src_info)
        # 近接するsrcにcommentを添付?
        attach_comments_to_src(comments,src_info)

        src_infos[java_path] = src_info

    for java_path in src_infos:
        src_info = src_infos[java_path]
        print(pprint.pformat(src_info, width=80))

def attach_comments_to_src(comments,src_info):
    for comment in comments:
        for offset in [1,2]:
            start_line = comment["pos"]["start_line"]
            stop_line  = comment["pos"]["stop_line" ] + offset
        
            found_src = find_src_by_line_no_range([start_line,stop_line],
                                                  src_info)
            if not found_src:
                continue
            #以下の場合 body_srcに既にコメントが含まれている為、pass
            if found_src["pos"]["start_line"] < start_line and \
               stop_line < found_src["pos"]["stop_line"]:
                continue
            found_src["comment"] = comment
            break
        
def merge_comments(comments,src_info):

    ret_comments = []
    org_comments_size = len(comments)
    i = 0
    while i+1 < org_comments_size:
        comment_0 = comments[i]
        comment_1 = comments[i+1]
        merge_result = merge_comments_sub(comment_0,comment_1,src_info)
        if len(merge_result) == 2:
            ret_comments.append(merge_result[0])
        i += 1
    return ret_comments
        
def merge_comments_sub(comment_0,comment_1,src_info):
    if comment_0["pos"]["stop_line"] +1 != comment_1["pos"]["start_line"]:
        return [comment_0,comment_1]

    line_nos = [comment_0["pos"]["start_line"],
                comment_1["pos"]["stop_line"]]

    # comment 範囲に、実際のsrcがある場合、merge対象外
    found_src = find_src_by_line_no_range(line_nos,src_info)
    if found_src:
        return [comment_0,comment_1]

    comment_1["text"] = comment_0["text"] + line_feed_str +comment_1["text"]
    comment_1["pos"]["start_line"]   = comment_0["pos"]["start_line"]
    comment_1["pos"]["start_column"] = comment_0["pos"]["start_column"]
    return [comment_1]

    
def find_src_by_line_no_range(line_nos,src_info):

    if type(src_info) is dict:
        for atri_key in src_info:
            atri_val = src_info[atri_key]
            if atri_key == "pos":
                if(src_info[atri_key]["start_line"] <= line_nos[0] and \
                   line_nos[0] <= src_info[atri_key]["stop_line"] ):
                    return src_info
                elif(src_info[atri_key]["start_line"] <= line_nos[1] and \
                     line_nos[1] <= src_info[atri_key]["stop_line"]):
                    return src_info
            
            if type(atri_val) is list or type(atri_val) is dict:
                found_src = find_src_by_line_no_range(line_nos,atri_val)
                if found_src:
                    return found_src
                
    elif type(src_info) is list:
        for atri_val in src_info:
            
            if type(atri_val) is list or type(atri_val) is dict:
                found_src = find_src_by_line_no_range(line_nos,atri_val)
                if found_src:
                    return found_src
    return None


if __name__ == "__main__":
    main()

ast.ast_processor.py

# -*- coding: utf-8 -*-
from antlr4 import FileStream, CommonTokenStream, ParseTreeWalker
from ast.JavaLexer import JavaLexer
from ast.JavaParser import JavaParser
import copy
import pprint
import unicodedata

source_encode = "utf-8"

class AstProcessor:

    def __init__(self, logging, listener):
        self.logging = logging
        self.logger = logging.getLogger(self.__class__.__name__)
        self.listener = listener

    def execute(self, input_source):
        file_stream = FileStream(input_source,encoding=source_encode)
        java_lexer = JavaLexer(file_stream)

        # CommonTokenStream へ渡す前であれば、commentを取得できます
        comments = self.extract_comments(java_lexer)

                
        # 一度、java_lexer.getAllTokens() を行うと
        # なぜか java_lexer が壊れるようですので、改めて
        # FileStream から。
        file_stream = FileStream(input_source,encoding=source_encode)
        java_lexer = JavaLexer(file_stream)
        
        common_token_stream = CommonTokenStream(java_lexer)
        parser = JavaParser(common_token_stream)
        walker = ParseTreeWalker()
        walker.walk(self.listener, parser.compilationUnit())

        ast_info = self.listener.ast_info

        # BasicInfoListener 内で getText()した場合、
        # 空白や改行がないtextが取得される為、ここで
        # CommonTokenStream から getText() します
        for tmp_class in ast_info['classes']:
            for method in tmp_class['methods']:
                start_index = method['body_pos']['start_index']
                stop_index  = method['body_pos']['stop_index']
            
                method['body_src'] = \
                    common_token_stream.getText(start_index,stop_index)
        
        return [ast_info, comments]

    def extract_comments(self, java_lexer):
        tmp_tokens = java_lexer.getAllTokens()
        comments = []

        # CommonTokenStream へ渡す前であれば、commentを取得できます
        for tmp_token in tmp_tokens:
            # LINE_COMMENT=110-1, IDENTIFIER=111-1 in JavaLexer.py
            if not tmp_token.type in [109,110]:
                continue
            comment_text = tmp_token.text
            comment_lines = comment_text.splitlines()
            stop_line = tmp_token.line + len(comment_lines)-1
            stop_column = len(comment_lines[-1])
            if tmp_token.line == stop_line:
                stop_column += tmp_token.column
                
            comments.append({
                "text":comment_text,
                "pos" :{"start_line"  :tmp_token.line,
                        "start_column":tmp_token.column,
                        "stop_line"   :stop_line,
                        "stop_column" :stop_column }
            })
        return comments

ast.basic_info_listener.py

# -*- coding: utf-8 -*-
from ast.JavaParserListener import JavaParserListener
from ast.JavaParser import JavaParser
import copy
import re
import sys
import pprint

class BasicInfoListener(JavaParserListener):
    def __init__(self):
        self.ast_info   = {'package'    : {},
                           'imports'    : [],
                           'classes'    : [] }
        self.class_base = {'name'       : '',
                           'annotation' : [],
                           'modifier'   : {},
                           'implements' : [],
                           'extends'    : '',
                           'fields'     : [],
                           'methods'    : [] }
        self.tmp_class = {}
        
        self.tmp_annotation = []
        self.tmp_modifier   = []


        
    def enterPackageDeclaration(self, ctx):
        self.ast_info['package'] = {
            'name' : ctx.qualifiedName().getText(),
            'pos'  : {'start_line'  : ctx.start.line,
                      'start_column': ctx.start.column,
                      'start_index' : ctx.start.tokenIndex,
                      'stop_line'   : ctx.stop.line,
                      'stop_column' : ctx.stop.column,
                      'stop_index'  : ctx.stop.tokenIndex}
                    }


    def enterImportDeclaration(self, ctx):
        self.ast_info['imports'].append(
            {'name' : ctx.qualifiedName().getText(),
             'pos'  : {'start_line'  : ctx.start.line,
                       'start_column': ctx.start.column,
                       'start_index' : ctx.start.tokenIndex,
                       'stop_line'   : ctx.stop.line,
                       'stop_column' : ctx.stop.column,
                       'stop_index'  : ctx.stop.tokenIndex}} )

    def enterClassOrInterfaceModifier(self, ctx):
        tmp_name = ctx.getText()
        # なぜか ctx.start.column == ctx.stop.column の為
        stop_column = ctx.stop.column + len(tmp_name)
        
        tmp_info = {
            'name' : tmp_name,
            'pos'  : {'start_line'  : ctx.start.line,
                      'start_column': ctx.start.column,
                      'start_index' : ctx.start.tokenIndex,
                      'stop_line'   : ctx.stop.line,
                      'stop_column' : stop_column,
                      'stop_index'  : ctx.stop.tokenIndex}}

        if re.match('^@', tmp_info['name']):
            self.tmp_annotation.append(tmp_info)
        else :
            self.tmp_modifier.append(tmp_info)
            
        
    def enterClassDeclaration(self, ctx):
        self.tmp_class = copy.copy( self.class_base )
        
        self.tmp_class['annotation'] = self.tmp_annotation
        self.tmp_class['modifier']   = self.tmp_modifier
        self.tmp_annotation = []
        self.tmp_modifier = []

        self.tmp_class['pos'] ={
            'start_line'  : ctx.start.line,
            'start_column': ctx.start.column,
            'start_index' : ctx.start.tokenIndex,
            'stop_line'   : ctx.stop.line,
            'stop_column' : ctx.stop.column,
            'stop_index'  : ctx.stop.tokenIndex }
        
        child_count = int(ctx.getChildCount())

        if child_count == 7:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2)                        # extends
            c4 = ctx.getChild(3).getChild(0).getText()  # extends class name
            c5 = ctx.getChild(4)                        # implements
            c7 = ctx.getChild(6)                        # class body
            
            self.tmp_class['name']       = c2
            self.tmp_class['extends']    = c4
            self.tmp_class['implements'] = \
                self.parse_implements_block(ctx.getChild(5))
            return
        
        if child_count == 5:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2).getText()              # extends or implements
            c5 = ctx.getChild(4)                        # class body

            self.tmp_class['name'] = c2
            if c3 == 'implements':
                self.tmp_class['implements'] = \
                    self.parse_implements_block(ctx.getChild(3))
            elif c3 == 'extends':
                c4 = ctx.getChild(3).getChild(0).getText()
                self.tmp_class['extends'] = c4
            return
        
        if child_count == 3:
            c1 = ctx.getChild(0)                        # class
            c2 = ctx.getChild(1).getText()              # class name
            c3 = ctx.getChild(2)                        # class body
            self.tmp_class['name'] = c2
            
            return
        
        print("ERROR unknown child_count"+ str(child_count))
        sys.exit()


    def exitClassDeclaration(self, ctx):
        self.ast_info['classes'].append(copy.copy(self.tmp_class) )
        
    
    def enterFieldDeclaration(self, ctx):
        field = {'type'      : ctx.getChild(0).getText(),
                 'body_src'  : ctx.getChild(1).getText(),
                 'annotation': [],
                 'modifier'  : []  }
        
        field['annotation'] = copy.copy(self.tmp_annotation)
        field['modifier']   = copy.copy(self.tmp_modifier)
        self.tmp_annotation = []
        self.tmp_modifier = []
        
        self.tmp_class['fields'].append(field)


    def enterMethodDeclaration(self, ctx):
        
        c1 = ctx.getChild(0).getText()  # return type
        c2 = ctx.getChild(1).getText()  # method name
        # params
        params = self.parse_method_params_block(ctx.getChild(2))

        # method bodyを CommonTokenStream と tokenIndex により得る為
        ctx_method_body = ctx.getChild(-1)


        method_info = {'returnType': c1,
                       'name'      : c2,
                       'annotation': [],
                       'modifier'  : [],
                       'params': params,
                       'pos'  : {'start_line'  : ctx.start.line,
                                 'start_column': ctx.start.column,
                                 'start_index' : ctx.start.tokenIndex,
                                 'stop_line'   : ctx.stop.line,
                                 'stop_column' : ctx.stop.column,
                                 'stop_index'  : ctx.stop.tokenIndex},
                       'body_pos' : {
                           'start_line'  : ctx_method_body.start.line,
                           'start_column': ctx_method_body.start.column,
                           'start_index' : ctx_method_body.start.tokenIndex,
                           'stop_line'   : ctx_method_body.stop.line,
                           'stop_column' : ctx_method_body.stop.column,
                           'stop_index'  : ctx_method_body.stop.tokenIndex}}
        method_info['annotation'] = self.tmp_annotation
        method_info['modifier']   = self.tmp_modifier
        self.tmp_annotation = []
        self.tmp_modifier = []

        self.tmp_class['methods'].append(method_info)

           
    def parse_implements_block(self, ctx):
        implements_child_count = int(ctx.getChildCount())
        result = []
        if implements_child_count == 1:
            impl_class = ctx.getChild(0).getText()
            result.append(impl_class)
        elif implements_child_count > 1:
            for i in range(implements_child_count):
                if i % 2 == 0:
                    impl_class = ctx.getChild(i).getText()
                    result.append(impl_class)
        return result

    def parse_method_params_block(self, ctx):
        params_exist_check = int(ctx.getChildCount())
        result = []
        if params_exist_check == 3:
            params_child_count = int(ctx.getChild(1).getChildCount())
            if params_child_count == 1:
                param_type = ctx.getChild(1).getChild(0).getChild(0).getText()
                param_name = ctx.getChild(1).getChild(0).getChild(1).getText()
                param_info = {'paramType': param_type,
                              'paramName': param_name }
                result.append(param_info)
            elif params_child_count > 1:
                for i in range(params_child_count):
                    if i % 2 == 0:
                        param_type = \
                            ctx.getChild(1).getChild(i).getChild(0).getText()
                        param_name = \
                            ctx.getChild(1).getChild(i).getChild(1).getText()
                        param_info = {'paramType': param_type,
                                      'paramName': param_name }
                        result.append(param_info)
        return result