You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

84 lines
2.9 KiB

import json
import os
from pathlib import Path
import typer
from aitextgen import aitextgen
from aitextgen.TokenDataset import TokenDataset
from aitextgen.tokenizers import train_tokenizer
from aitextgen.utils import GPT2ConfigCPU
from outbox_parser import parse_outbox
from config import c, gpt2_name_mapper
app = typer.Typer()
@app.command()
def generate(prompt: str = c.generation_prompt, include_prompt: bool = c.generation_include_prompt,
max_length: int = c.generation_max_length, temperature: float = c.generation_temperature):
orig_path = os.getcwd()
os.chdir(c.datadir)
ai = aitextgen(model_folder=c.model_folder, vocab_file=f'{c.tokenizer_output_prefix}-vocab.json',
merges_file=f'{c.tokenizer_output_prefix}-merges.txt')
generated_strings = []
while not generated_strings:
generated_strings = ai.generate(prompt=prompt or '', include_prompt=include_prompt,
max_length=max_length, temperature=temperature, return_as_list=True)
generated_string = generated_strings[0].strip('<|endoftext|>').strip()
if c.generation_zwsp_mentions:
generated_string.replace('@', '@') # <- zwsp after the second @
typer.echo(generated_string)
os.chdir(orig_path)
@app.command()
def init(outbox_json_file: Path):
with open(outbox_json_file) as f:
outbox = json.load(f)
orig_path = os.getcwd()
os.chdir(c.datadir)
out_text = parse_outbox(outbox, c.parsing_exclude_mentions, c.parsing_arbitrary_exclude_fn)
with open(c.parsed_posts_file, 'w') as f:
f.write(out_text)
train_tokenizer(c.parsed_posts_file, serialize=False, save_path=str(c.datadir), prefix=c.tokenizer_output_prefix)
merges_file = f'{c.tokenizer_output_prefix}-merges.txt'
vocab_file = f'{c.tokenizer_output_prefix}-vocab.json'
config = GPT2ConfigCPU() # todo: maybe add params, probably not needed idfk
ai = aitextgen(
# tf_gpt2=gpt2_name_mapper[c.model_size], # for some reason this makes it Really slow (4 days kind of slow)
merges_file=merges_file, vocab_file=vocab_file, config=config)
data = TokenDataset(c.parsed_posts_file,
vocab_file=vocab_file,
merges_file=merges_file,
block_size=c.training_block_size)
typer.secho('About to start training. This will take a long time!', fg=typer.colors.BRIGHT_RED)
if c.prompt_before_training:
typer.confirm('Start training?', default=True, abort=True)
ai.train(data, batch_size=c.training_batch_size, num_steps=c.training_num_steps,
generate_every=c.training_sample_frequency, save_every=c.training_save_frequency,
num_workers=c.training_num_workers)
typer.secho('Training finished.', fg=typer.colors.GREEN, bold=True)
os.chdir(orig_path)
if __name__ == '__main__':
app()