diff --git a/.gitignore b/.gitignore index 66cfded..cd3b44c 100644 --- a/.gitignore +++ b/.gitignore @@ -196,3 +196,4 @@ Data/WikiEvents/*.json* dueefin_*.json PTPCG*.json Data/ChFinAnn/types/*.json +Data/CCKS2020/*.json diff --git a/Data/CCKS2020/build_data.py b/Data/CCKS2020/build_data.py new file mode 100644 index 0000000..32d4e98 --- /dev/null +++ b/Data/CCKS2020/build_data.py @@ -0,0 +1,12 @@ +from dee.event_types import get_doc_type +from dee.utils import default_dump_json, default_load_json + +if __name__ == "__main__": + for dname in ["train", "dev", "test"]: + data = default_load_json(f"Data/CCKS2020/{dname}.json") + new_data = [] + for doc_id, content in data: + doc_type = get_doc_type(content["recguid_eventname_eventdict_list"]) + content["doc_type"] = doc_type + new_data.append([doc_id, content]) + default_dump_json(new_data, f"Data/CCKS2020/typed_{dname}.json") diff --git a/Data/WikiEvents/build_data.py b/Data/WikiEvents/build_data.py index f29fb76..619e590 100644 --- a/Data/WikiEvents/build_data.py +++ b/Data/WikiEvents/build_data.py @@ -1,6 +1,7 @@ import json from collections import defaultdict +from dee.event_types import get_doc_type from dee.utils import default_dump_json, default_load_json @@ -39,20 +40,6 @@ def extract_event_template(input_filepaths: list, output_filepath: str): default_dump_json(dict(event_type_to_roles), output_filepath) -def get_doc_type(recguid_eventname_eventdict_list): - doc_type = "unk" - num_ins = len(recguid_eventname_eventdict_list) - if num_ins == 1: - doc_type = "o2o" - else: - event_types = {x[1] for x in recguid_eventname_eventdict_list} - if len(event_types) == num_ins: - doc_type = "o2m" - else: - doc_type = "m2m" - return doc_type - - def get_string_from_absolute_index( sentence_tokens, sent_idx, start, end, sent_lens=None ): diff --git a/Data/trigger.py b/Data/trigger.py index 46d3d3a..c1af581 100644 --- a/Data/trigger.py +++ b/Data/trigger.py @@ -175,7 +175,8 @@ def auto_select( # data = load_json("typed_test.json") tot_data = [] for dname in ["train", "dev", "test"]: - tot_data += load_json(f"Data/WikiEvents/{dname}.post.wTgg.json") + # tot_data += load_json(f"Data/WikiEvents/{dname}.post.wTgg.json") + tot_data += load_json(f"Data/CCKS2020/{dname}.post.json") # check_trigger(data, num_trigger_group=num_trigger_group) auto_select( diff --git a/dee/event_types/__init__.py b/dee/event_types/__init__.py index 0ff66ca..3ffc941 100644 --- a/dee/event_types/__init__.py +++ b/dee/event_types/__init__.py @@ -1,5 +1,7 @@ import importlib import os +import re +from collections import defaultdict __current_dir = os.listdir(os.path.dirname(__file__)) AVAILABLE_TEMPLATES = list( @@ -16,6 +18,242 @@ def get_event_template(template_name): return template +def get_doc_type(recguid_eventname_eventdict_list): + doc_type = "unk" + num_ins = len(recguid_eventname_eventdict_list) + if num_ins == 0: + doc_type = "unk" + elif num_ins == 1: + doc_type = "o2o" + else: + event_types = {x[1] for x in recguid_eventname_eventdict_list} + if len(event_types) == 1: + doc_type = "o2m" + else: + doc_type = "m2m" + return doc_type + + +def get_schema_from_chfinann(data): + event_type_to_roles = defaultdict(set) + for _, content in data: + for _, etype, role_arg_pairs in content["recguid_eventname_eventdict_list"]: + event_type_to_roles[etype].update(role_arg_pairs.keys()) + return dict(event_type_to_roles) + + +def generate_event_template_from_trigger_string( + trigger_string: str, event_type_to_class_name: dict +): + """ + Args: + trigger_string: string generated by Data/trigger.py + 破产清算 = { + 1: ['公司名称'], # importance: 0.9950641658440277 + 2: ['公司名称', '公告时间'], # importance: 0.9990128331688055 + 3: ['公司名称', '公告时间', '受理法院'], # importance: 1.0 + 4: ['公司名称', '公司行业', '公告时间', '受理法院'], # importance: 1.0 + 5: ['公司名称', '公司行业', '公告时间', '受理法院', '裁定时间'], # importance: 1.0 + } + TRIGGERS['all'] = ['公司名称', '公告时间', '受理法院', '裁定时间', '公司行业'] + event_type_to_class_name: {"破产清算": "Liquidation"} + """ + string_template = """class {class_name}Event(BaseEvent):\n NAME = "{event_type}"\n FIELDS = {fields}\n TRIGGERS = {triggers}\n TRIGGERS["all"] = {triggers_all}\n def __init__(self, recguid=None):\n super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid)\n self.set_key_fields(self.TRIGGERS)\n\n""" + final_string = r"""class BaseEvent(object): + def __init__(self, fields, event_name="Event", key_fields=(), recguid=None): + self.recguid = recguid + self.name = event_name + self.fields = list(fields) + self.field2content = {f: None for f in fields} + self.nonempty_count = 0 + self.nonempty_ratio = self.nonempty_count / len(self.fields) + + self.key_fields = set(key_fields) + for key_field in self.key_fields: + assert key_field in self.field2content + + def __repr__(self): + event_str = "\n{}[\n".format(self.name) + event_str += " {}={}\n".format("recguid", self.recguid) + event_str += " {}={}\n".format("nonempty_count", self.nonempty_count) + event_str += " {}={:.3f}\n".format("nonempty_ratio", self.nonempty_ratio) + event_str += "] (\n" + for field in self.fields: + if field in self.key_fields: + key_str = " (key)" + else: + key_str = "" + event_str += ( + " " + + field + + "=" + + str(self.field2content[field]) + + ", {}\n".format(key_str) + ) + event_str += ")\n" + return event_str + + def update_by_dict(self, field2text, recguid=None): + self.nonempty_count = 0 + self.recguid = recguid + + for field in self.fields: + if field in field2text and field2text[field] is not None: + self.nonempty_count += 1 + self.field2content[field] = field2text[field] + else: + self.field2content[field] = None + + self.nonempty_ratio = self.nonempty_count / len(self.fields) + + def field_to_dict(self): + return dict(self.field2content) + + def set_key_fields(self, key_fields): + self.key_fields = set(key_fields) + + def is_key_complete(self): + for key_field in self.key_fields: + if self.field2content[key_field] is None: + return False + + return True + + def get_argument_tuple(self): + args_tuple = tuple(self.field2content[field] for field in self.fields) + return args_tuple + + def is_good_candidate(self, min_match_count=2): + key_flag = self.is_key_complete() + if key_flag: + if self.nonempty_count >= min_match_count: + return True + return False + +""" + groups = re.findall( + r"\s*(.*?) = ({.*?})\s*TRIGGERS\['all'\] = (\[.*?\])", trigger_string, re.DOTALL + ) + for event_type, triggers, trigger_all in groups: + final_string += string_template.format( + class_name=event_type_to_class_name[event_type], + event_type=event_type, + fields=str(trigger_all), + triggers=str(triggers), + triggers_all=str(trigger_all), + ) + event_type2event_class = ( + "{" + + " ".join( + [ + f"{event_class_name}Event.NAME: {event_class_name}Event," + for event_class_name in event_type_to_class_name.values() + ] + ) + + "}" + ) + event_type_fields_list = ( + "[" + + " ".join( + [ + f"({event_class_name}Event.NAME, {event_class_name}Event.FIELDS, {event_class_name}Event.TRIGGERS, 2)," + for event_class_name in event_type_to_class_name.values() + ] + ) + + "]" + ) + final_string += f"\ncommon_fields = []\nevent_type2event_class={event_type2event_class}\n\nevent_type_fields_list={event_type_fields_list}" + + return final_string + + if __name__ == "__main__": - template = get_event_template("zheng2019") - print(template) + # template = get_event_template("zheng2019") + # print(template) + final_string = generate_event_template_from_trigger_string( + """破产清算 = { + 1: ['公司名称'], # importance: 0.9950641658440277 + 2: ['公司名称', '公告时间'], # importance: 0.9990128331688055 + 3: ['公司名称', '公告时间', '受理法院'], # importance: 1.0 + 4: ['公司名称', '公司行业', '公告时间', '受理法院'], # importance: 1.0 + 5: ['公司名称', '公司行业', '公告时间', '受理法院', '裁定时间'], # importance: 1.0 +} +TRIGGERS['all'] = ['公司名称', '公告时间', '受理法院', '裁定时间', '公司行业'] + +重大安全事故 = { + 1: ['公司名称'], # importance: 0.9974424552429667 + 2: ['公司名称', '公告时间'], # importance: 1.0 + 3: ['伤亡人数', '公司名称', '公告时间'], # importance: 1.0 + 4: ['伤亡人数', '公司名称', '公告时间', '损失金额'], # importance: 1.0 + 5: ['伤亡人数', '公司名称', '公告时间', '其他影响', '损失金额'], # importance: 1.0 +} +TRIGGERS['all'] = ['公司名称', '公告时间', '伤亡人数', '损失金额', '其他影响'] + +股东减持 = { + 1: ['减持金额'], # importance: 0.9486062717770035 + 2: ['减持开始日期', '减持金额'], # importance: 0.9817073170731707 + 3: ['减持开始日期', '减持的股东', '减持金额'], # importance: 0.990418118466899 +} +TRIGGERS['all'] = ['减持金额', '减持开始日期', '减持的股东'] + +股权质押 = { + 1: ['质押金额'], # importance: 0.9625668449197861 + 2: ['质押开始日期', '质押金额'], # importance: 0.9910873440285205 + 3: ['接收方', '质押开始日期', '质押金额'], # importance: 0.9964349376114082 + 4: ['接收方', '质押开始日期', '质押结束日期', '质押金额'], # importance: 1.0 + 5: ['接收方', '质押开始日期', '质押方', '质押结束日期', '质押金额'], # importance: 1.0 +} +TRIGGERS['all'] = ['质押金额', '质押开始日期', '接收方', '质押方', '质押结束日期'] + +股东增持 = { + 1: ['增持金额'], # importance: 0.9607609988109393 + 2: ['增持的股东', '增持金额'], # importance: 0.9892984542211652 + 3: ['增持开始日期', '增持的股东', '增持金额'], # importance: 1.0 +} +TRIGGERS['all'] = ['增持金额', '增持开始日期', '增持的股东'] + +股权冻结 = { + 1: ['冻结金额'], # importance: 0.8524822695035461 + 2: ['冻结开始日期', '冻结金额'], # importance: 0.9687943262411347 + 3: ['冻结开始日期', '冻结金额', '被冻结股东'], # importance: 0.9716312056737588 + 4: ['冻结开始日期', '冻结结束日期', '冻结金额', '被冻结股东'], # importance: 0.9730496453900709 +} +TRIGGERS['all'] = ['冻结金额', '冻结开始日期', '被冻结股东', '冻结结束日期'] + +高层死亡 = { + 1: ['公司名称'], # importance: 1.0 + 2: ['公司名称', '高层人员'], # importance: 1.0 + 3: ['公司名称', '高层人员', '高层职务'], # importance: 1.0 + 4: ['公司名称', '死亡/失联时间', '高层人员', '高层职务'], # importance: 1.0 + 5: ['公司名称', '死亡/失联时间', '死亡年龄', '高层人员', '高层职务'], # importance: 1.0 +} +TRIGGERS['all'] = ['公司名称', '高层人员', '高层职务', '死亡/失联时间', '死亡年龄'] + +重大资产损失 = { + 1: ['公司名称'], # importance: 0.9949494949494949 + 2: ['公司名称', '公告时间'], # importance: 1.0 + 3: ['公司名称', '公告时间', '损失金额'], # importance: 1.0 + 4: ['公司名称', '公告时间', '其他损失', '损失金额'], # importance: 1.0 +} +TRIGGERS['all'] = ['公司名称', '公告时间', '损失金额', '其他损失'] + +重大对外赔付 = { + 1: ['公告时间'], # importance: 0.984251968503937 + 2: ['公司名称', '公告时间'], # importance: 1.0 + 3: ['公司名称', '公告时间', '赔付对象'], # importance: 1.0 + 4: ['公司名称', '公告时间', '赔付对象', '赔付金额'], # importance: 1.0 +} +TRIGGERS['all'] = ['公告时间', '公司名称', '赔付对象', '赔付金额']""", + { + "破产清算": "Bankruptcy", + "重大安全事故": "Accident", + "股东减持": "EquityUnderweight", + "股权质押": "EquityPledge", + "股东增持": "EquityOverweight", + "股权冻结": "EquityFreeze", + "高层死亡": "LeaderDeath", + "重大资产损失": "AssetLoss", + "重大对外赔付": "ExternalIndemnity", + }, + ) + print(final_string) diff --git a/dee/event_types/ccks2020.py b/dee/event_types/ccks2020.py new file mode 100644 index 0000000..4172f8f --- /dev/null +++ b/dee/event_types/ccks2020.py @@ -0,0 +1,256 @@ +class BaseEvent(object): + def __init__(self, fields, event_name="Event", key_fields=(), recguid=None): + self.recguid = recguid + self.name = event_name + self.fields = list(fields) + self.field2content = {f: None for f in fields} + self.nonempty_count = 0 + self.nonempty_ratio = self.nonempty_count / len(self.fields) + + self.key_fields = set(key_fields) + for key_field in self.key_fields: + assert key_field in self.field2content + + def __repr__(self): + event_str = "\n{}[\n".format(self.name) + event_str += " {}={}\n".format("recguid", self.recguid) + event_str += " {}={}\n".format("nonempty_count", self.nonempty_count) + event_str += " {}={:.3f}\n".format("nonempty_ratio", self.nonempty_ratio) + event_str += "] (\n" + for field in self.fields: + if field in self.key_fields: + key_str = " (key)" + else: + key_str = "" + event_str += ( + " " + + field + + "=" + + str(self.field2content[field]) + + ", {}\n".format(key_str) + ) + event_str += ")\n" + return event_str + + def update_by_dict(self, field2text, recguid=None): + self.nonempty_count = 0 + self.recguid = recguid + + for field in self.fields: + if field in field2text and field2text[field] is not None: + self.nonempty_count += 1 + self.field2content[field] = field2text[field] + else: + self.field2content[field] = None + + self.nonempty_ratio = self.nonempty_count / len(self.fields) + + def field_to_dict(self): + return dict(self.field2content) + + def set_key_fields(self, key_fields): + self.key_fields = set(key_fields) + + def is_key_complete(self): + for key_field in self.key_fields: + if self.field2content[key_field] is None: + return False + + return True + + def get_argument_tuple(self): + args_tuple = tuple(self.field2content[field] for field in self.fields) + return args_tuple + + def is_good_candidate(self, min_match_count=2): + key_flag = self.is_key_complete() + if key_flag: + if self.nonempty_count >= min_match_count: + return True + return False + + +class BankruptcyEvent(BaseEvent): + NAME = "破产清算" + FIELDS = ["公司名称", "公告时间", "受理法院", "裁定时间", "公司行业"] + TRIGGERS = { + 1: ["公司名称"], # importance: 0.9950641658440277 + 2: ["公司名称", "公告时间"], # importance: 0.9990128331688055 + 3: ["公司名称", "公告时间", "受理法院"], # importance: 1.0 + 4: ["公司名称", "公司行业", "公告时间", "受理法院"], # importance: 1.0 + 5: ["公司名称", "公司行业", "公告时间", "受理法院", "裁定时间"], # importance: 1.0 + } + TRIGGERS["all"] = ["公司名称", "公告时间", "受理法院", "裁定时间", "公司行业"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class AccidentEvent(BaseEvent): + NAME = "重大安全事故" + FIELDS = ["公司名称", "公告时间", "伤亡人数", "损失金额", "其他影响"] + TRIGGERS = { + 1: ["公司名称"], # importance: 0.9974424552429667 + 2: ["公司名称", "公告时间"], # importance: 1.0 + 3: ["伤亡人数", "公司名称", "公告时间"], # importance: 1.0 + 4: ["伤亡人数", "公司名称", "公告时间", "损失金额"], # importance: 1.0 + 5: ["伤亡人数", "公司名称", "公告时间", "其他影响", "损失金额"], # importance: 1.0 + } + TRIGGERS["all"] = ["公司名称", "公告时间", "伤亡人数", "损失金额", "其他影响"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class EquityUnderweightEvent(BaseEvent): + NAME = "股东减持" + FIELDS = ["减持金额", "减持开始日期", "减持的股东"] + TRIGGERS = { + 1: ["减持金额"], # importance: 0.9486062717770035 + 2: ["减持开始日期", "减持金额"], # importance: 0.9817073170731707 + 3: ["减持开始日期", "减持的股东", "减持金额"], # importance: 0.990418118466899 + } + TRIGGERS["all"] = ["减持金额", "减持开始日期", "减持的股东"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class EquityPledgeEvent(BaseEvent): + NAME = "股权质押" + FIELDS = ["质押金额", "质押开始日期", "接收方", "质押方", "质押结束日期"] + TRIGGERS = { + 1: ["质押金额"], # importance: 0.9625668449197861 + 2: ["质押开始日期", "质押金额"], # importance: 0.9910873440285205 + 3: ["接收方", "质押开始日期", "质押金额"], # importance: 0.9964349376114082 + 4: ["接收方", "质押开始日期", "质押结束日期", "质押金额"], # importance: 1.0 + 5: ["接收方", "质押开始日期", "质押方", "质押结束日期", "质押金额"], # importance: 1.0 + } + TRIGGERS["all"] = ["质押金额", "质押开始日期", "接收方", "质押方", "质押结束日期"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class EquityOverweightEvent(BaseEvent): + NAME = "股东增持" + FIELDS = ["增持金额", "增持开始日期", "增持的股东"] + TRIGGERS = { + 1: ["增持金额"], # importance: 0.9607609988109393 + 2: ["增持的股东", "增持金额"], # importance: 0.9892984542211652 + 3: ["增持开始日期", "增持的股东", "增持金额"], # importance: 1.0 + } + TRIGGERS["all"] = ["增持金额", "增持开始日期", "增持的股东"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class EquityFreezeEvent(BaseEvent): + NAME = "股权冻结" + FIELDS = ["冻结金额", "冻结开始日期", "被冻结股东", "冻结结束日期"] + TRIGGERS = { + 1: ["冻结金额"], # importance: 0.8524822695035461 + 2: ["冻结开始日期", "冻结金额"], # importance: 0.9687943262411347 + 3: ["冻结开始日期", "冻结金额", "被冻结股东"], # importance: 0.9716312056737588 + 4: ["冻结开始日期", "冻结结束日期", "冻结金额", "被冻结股东"], # importance: 0.9730496453900709 + } + TRIGGERS["all"] = ["冻结金额", "冻结开始日期", "被冻结股东", "冻结结束日期"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class LeaderDeathEvent(BaseEvent): + NAME = "高层死亡" + FIELDS = ["公司名称", "高层人员", "高层职务", "死亡/失联时间", "死亡年龄"] + TRIGGERS = { + 1: ["公司名称"], # importance: 1.0 + 2: ["公司名称", "高层人员"], # importance: 1.0 + 3: ["公司名称", "高层人员", "高层职务"], # importance: 1.0 + 4: ["公司名称", "死亡/失联时间", "高层人员", "高层职务"], # importance: 1.0 + 5: ["公司名称", "死亡/失联时间", "死亡年龄", "高层人员", "高层职务"], # importance: 1.0 + } + TRIGGERS["all"] = ["公司名称", "高层人员", "高层职务", "死亡/失联时间", "死亡年龄"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class AssetLossEvent(BaseEvent): + NAME = "重大资产损失" + FIELDS = ["公司名称", "公告时间", "损失金额", "其他损失"] + TRIGGERS = { + 1: ["公司名称"], # importance: 0.9949494949494949 + 2: ["公司名称", "公告时间"], # importance: 1.0 + 3: ["公司名称", "公告时间", "损失金额"], # importance: 1.0 + 4: ["公司名称", "公告时间", "其他损失", "损失金额"], # importance: 1.0 + } + TRIGGERS["all"] = ["公司名称", "公告时间", "损失金额", "其他损失"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +class ExternalIndemnityEvent(BaseEvent): + NAME = "重大对外赔付" + FIELDS = ["公告时间", "公司名称", "赔付对象", "赔付金额"] + TRIGGERS = { + 1: ["公告时间"], # importance: 0.984251968503937 + 2: ["公司名称", "公告时间"], # importance: 1.0 + 3: ["公司名称", "公告时间", "赔付对象"], # importance: 1.0 + 4: ["公司名称", "公告时间", "赔付对象", "赔付金额"], # importance: 1.0 + } + TRIGGERS["all"] = ["公告时间", "公司名称", "赔付对象", "赔付金额"] + + def __init__(self, recguid=None): + super().__init__(self.FIELDS, event_name=self.NAME, recguid=recguid) + self.set_key_fields(self.TRIGGERS) + + +common_fields = [] +event_type2event_class = { + BankruptcyEvent.NAME: BankruptcyEvent, + AccidentEvent.NAME: AccidentEvent, + EquityUnderweightEvent.NAME: EquityUnderweightEvent, + EquityPledgeEvent.NAME: EquityPledgeEvent, + EquityOverweightEvent.NAME: EquityOverweightEvent, + EquityFreezeEvent.NAME: EquityFreezeEvent, + LeaderDeathEvent.NAME: LeaderDeathEvent, + AssetLossEvent.NAME: AssetLossEvent, + ExternalIndemnityEvent.NAME: ExternalIndemnityEvent, +} +event_type_fields_list = [ + (BankruptcyEvent.NAME, BankruptcyEvent.FIELDS, BankruptcyEvent.TRIGGERS, 2), + (AccidentEvent.NAME, AccidentEvent.FIELDS, AccidentEvent.TRIGGERS, 2), + ( + EquityUnderweightEvent.NAME, + EquityUnderweightEvent.FIELDS, + EquityUnderweightEvent.TRIGGERS, + 2, + ), + (EquityPledgeEvent.NAME, EquityPledgeEvent.FIELDS, EquityPledgeEvent.TRIGGERS, 2), + ( + EquityOverweightEvent.NAME, + EquityOverweightEvent.FIELDS, + EquityOverweightEvent.TRIGGERS, + 2, + ), + (EquityFreezeEvent.NAME, EquityFreezeEvent.FIELDS, EquityFreezeEvent.TRIGGERS, 2), + (LeaderDeathEvent.NAME, LeaderDeathEvent.FIELDS, LeaderDeathEvent.TRIGGERS, 2), + (AssetLossEvent.NAME, AssetLossEvent.FIELDS, AssetLossEvent.TRIGGERS, 2), + ( + ExternalIndemnityEvent.NAME, + ExternalIndemnityEvent.FIELDS, + ExternalIndemnityEvent.TRIGGERS, + 2, + ), +] diff --git a/scripts/run_ptpcg_ccks2020.sh b/scripts/run_ptpcg_ccks2020.sh new file mode 100644 index 0000000..430815c --- /dev/null +++ b/scripts/run_ptpcg_ccks2020.sh @@ -0,0 +1,91 @@ +#!/bin/bash + +{ + MODEL_NAME='TriggerAwarePrunedCompleteGraph' + TASK_NAME='PTPCG_CCKS2020' + echo "('${TASK_NAME}', '${MODEL_NAME}'), # $(date)" >> RECORDS.md + echo "Task Name: $TASK_NAME" + echo "Model Name: $MODEL_NAME" + + # GPU_SCOPE="0,1,2,3" + # REQ_GPU_NUM=1 + GPUS="4" + # GPUS=$(python wait.py --task_name="$TASK_NAME" --cuda=$GPU_SCOPE --wait="schedule" --req_gpu_num=$REQ_GPU_NUM) + echo "GPUS: $GPUS" + EPOCH_NUM=50 + + if [[ -z "$GPUS" ]]; then + echo "GPUS is empty, stop..." + # python send_message.py "Task $TASK_NAME not started due to empty gpu assigning, please check the log." + echo "Task $TASK_NAME not started due to empty gpu assigning, please check the log." + else + echo "GPU ready." + # python send_message.py "Task $TASK_NAME started." + echo "Task $TASK_NAME started." + CUDA_VISIBLE_DEVICES=${GPUS} python -u run_dee_task.py \ + --data_dir='Data/CCKS2020' \ + --use_bert=True \ + --bert_model='/data/tzhu/PLM/bert-base-chinese' \ + --seed=99 \ + --task_name=${TASK_NAME} \ + --model_type=${MODEL_NAME} \ + --cpt_file_name=${MODEL_NAME} \ + --save_cpt_flag=False \ + --save_best_cpt=True \ + --remove_last_cpt=True \ + --resume_latest_cpt=False \ + --optimizer='adam' \ + --learning_rate=0.00005 \ + --dropout=0.1 \ + --gradient_accumulation_steps=8 \ + --train_batch_size=64 \ + --eval_batch_size=16 \ + --max_clique_decode=True \ + --num_triggers=1 \ + --eval_num_triggers=1 \ + --with_left_trigger=True \ + --directed_trigger_graph=True \ + --use_scheduled_sampling=True \ + --schedule_epoch_start=5 \ + --schedule_epoch_length=5 \ + --num_train_epochs=${EPOCH_NUM} \ + --run_mode='full' \ + --filtered_data_types='o2o,o2m,m2m' \ + --skip_train=False \ + --re_eval_flag=False \ + --add_greedy_dec=False \ + --num_lstm_layers=2 \ + --hidden_size=768 \ + --biaffine_hidden_size=512 \ + --biaffine_hard_threshold=0.5 \ + --at_least_one_comb=True \ + --include_complementary_ents=True \ + --event_type_template='ccks2020' \ + --use_span_lstm=True \ + --span_lstm_num_layer=2 \ + --role_by_encoding=True \ + --use_token_role=True \ + --ment_feature_type='concat' \ + --ment_type_hidden_size=32 + fi + + # check if the process has finished normally + LOG_FILE="Logs/$TASK_NAME.log" + if [[ -f "$LOG_FILE" ]]; then + echo "$LOG_FILE exists." + MATCHED=$(grep 'Combination' "$LOG_FILE") + if [[ -z "$MATCHED" ]]; then + # python send_message.py " Something's wrong in task $TASK_NAME, please check the log." + echo " Something's wrong in task $TASK_NAME, please check the log." + else + # python send_message.py --send_result --task_name ${TASK_NAME} --model_name ${MODEL_NAME} --max_epoch ${EPOCH_NUM} "Task $TASK_NAME finished." + echo "Task $TASK_NAME finished." + fi + else + echo "$LOG_FILE not found." + # python send_message.py "Task $TASK_NAME finished, but log is not found." + echo "Task $TASK_NAME finished, but log is not found." + fi + + exit +}