56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
import os
|
|
import json
|
|
import datasets
|
|
|
|
_BASE_URL = "/root/workspace/LabFinal/data/raw_dataset/"
|
|
JSON_URLS = {
|
|
"train": _BASE_URL + 'virus_train.txt',
|
|
"test": _BASE_URL + 'virus_eval_labeled.txt'
|
|
}
|
|
|
|
LABELS = ['happy', 'angry', 'sad', 'fear', 'surprise', 'neural']
|
|
LABEL2ID= {LABELS[idx]: idx for idx in range(len(LABELS))}
|
|
|
|
|
|
class VirusComment(datasets.GeneratorBasedBuilder):
|
|
BUILDER_CONFIGS = [datasets.BuilderConfig(name='main', version='0.0.1')]
|
|
|
|
def _info(self):
|
|
features = datasets.Features(
|
|
{
|
|
"id": datasets.Value("int32"),
|
|
"text": datasets.Value("string"),
|
|
"label": datasets.Value("int32")
|
|
}
|
|
)
|
|
return datasets.DatasetInfo(features=features, version=self.config.version)
|
|
|
|
def _split_generators(self, dl_manager):
|
|
split_names = {
|
|
"train": datasets.Split.TRAIN,
|
|
"test": datasets.Split.TEST,
|
|
}
|
|
json_urls = {split: JSON_URLS[split] for split in split_names.keys()}
|
|
json_paths = dl_manager.download(json_urls)
|
|
|
|
split_generators = [
|
|
datasets.SplitGenerator(
|
|
name=split_names[split],
|
|
gen_kwargs={"json_path": json_paths[split]},
|
|
)
|
|
for split in split_names.keys()
|
|
]
|
|
return split_generators
|
|
|
|
def _generate_examples(self, json_path):
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
for datum in data:
|
|
result = {
|
|
'id': datum['id'],
|
|
'text': datum['content'],
|
|
'label': LABEL2ID[datum['label']],
|
|
}
|
|
yield datum['id'], result
|