|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import typer |
|
|
from aitextgen import aitextgen |
|
|
from aitextgen.TokenDataset import TokenDataset |
|
|
from aitextgen.tokenizers import train_tokenizer |
|
|
from aitextgen.utils import GPT2ConfigCPU |
|
|
|
|
|
from outbox_parser import parse_outbox |
|
|
from config import c, gpt2_name_mapper |
|
|
|
|
|
app = typer.Typer() |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def generate(prompt: str = c.generation_prompt, include_prompt: bool = c.generation_include_prompt, |
|
|
max_length: int = c.generation_max_length, temperature: float = c.generation_temperature): |
|
|
orig_path = os.getcwd() |
|
|
os.chdir(c.datadir) |
|
|
|
|
|
ai = aitextgen(model_folder=c.model_folder, vocab_file=f'{c.tokenizer_output_prefix}-vocab.json', |
|
|
merges_file=f'{c.tokenizer_output_prefix}-merges.txt') |
|
|
|
|
|
generated_strings = [] |
|
|
while not generated_strings: |
|
|
generated_strings = ai.generate(prompt=prompt or '', include_prompt=include_prompt, |
|
|
max_length=max_length, temperature=temperature, return_as_list=True) |
|
|
|
|
|
generated_string = generated_strings[0].strip('<|endoftext|>').strip() |
|
|
|
|
|
if c.generation_zwsp_mentions: |
|
|
generated_string.replace('@', '@') # <- zwsp after the second @ |
|
|
|
|
|
typer.echo(generated_string) |
|
|
os.chdir(orig_path) |
|
|
|
|
|
|
|
|
@app.command() |
|
|
def init(outbox_json_file: Path): |
|
|
with open(outbox_json_file) as f: |
|
|
outbox = json.load(f) |
|
|
|
|
|
orig_path = os.getcwd() |
|
|
os.chdir(c.datadir) |
|
|
|
|
|
out_text = parse_outbox(outbox, c.parsing_exclude_mentions, c.parsing_arbitrary_exclude_fn) |
|
|
|
|
|
with open(c.parsed_posts_file, 'w') as f: |
|
|
f.write(out_text) |
|
|
|
|
|
train_tokenizer(c.parsed_posts_file, serialize=False, save_path=str(c.datadir), prefix=c.tokenizer_output_prefix) |
|
|
|
|
|
merges_file = f'{c.tokenizer_output_prefix}-merges.txt' |
|
|
vocab_file = f'{c.tokenizer_output_prefix}-vocab.json' |
|
|
|
|
|
config = GPT2ConfigCPU() # todo: maybe add params, probably not needed idfk |
|
|
|
|
|
ai = aitextgen( |
|
|
# tf_gpt2=gpt2_name_mapper[c.model_size], # for some reason this makes it Really slow (4 days kind of slow) |
|
|
merges_file=merges_file, vocab_file=vocab_file, config=config) |
|
|
|
|
|
data = TokenDataset(c.parsed_posts_file, |
|
|
vocab_file=vocab_file, |
|
|
merges_file=merges_file, |
|
|
block_size=c.training_block_size) |
|
|
|
|
|
typer.secho('About to start training. This will take a long time!', fg=typer.colors.BRIGHT_RED) |
|
|
|
|
|
if c.prompt_before_training: |
|
|
typer.confirm('Start training?', default=True, abort=True) |
|
|
|
|
|
ai.train(data, batch_size=c.training_batch_size, num_steps=c.training_num_steps, |
|
|
generate_every=c.training_sample_frequency, save_every=c.training_save_frequency, |
|
|
num_workers=c.training_num_workers) |
|
|
|
|
|
typer.secho('Training finished.', fg=typer.colors.GREEN, bold=True) |
|
|
os.chdir(orig_path) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
app() |
|
|
|
|
|
|