Home About Contact
Python , LLM

日本語 LLM ELYZA 追伸

torch_dtype=torch.float16 指定を外すとどうなるか試しました。

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto")
    #torch_dtype=torch.float16)

そこそこに速いシリコンマックでの実験。 再起動直後で何も極力アプリを起動していない状態で実行。 それならば、 offload オプションなしでモデルをロードすることができました。

ただし、 処理が完了するまでに 10分かかりました。 できるにはできるが実用からは程遠い。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from bottle import Bottle, run, request

DEFAULT_SYSTEM_PROMPT = "テキストをJSON形式に変更して出力してください。"

#MODEL_NAME = "elyza/ELYZA-japanese-Llama-2-7b-instruct"
MODEL_NAME = 'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct'

def to_prompt(text, tokenizer):
    inst_open, inst_close = "[INST]", "[/INST]"
    sys_open, sys_close = "<<SYS>>\n", "\n<</SYS>>\n\n"

    return "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
        bos_token=tokenizer.bos_token,
        b_inst=inst_open,
        system=f"{sys_open}{DEFAULT_SYSTEM_PROMPT}{sys_close}",
        prompt=text,
        e_inst=inst_close)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto")

text = """
商品名 ベルギーワッフル
名称 洋菓子
原材料 小麦粉(国内製造)、液卵、砂糖、マーガリン、バター加工品
"""

prompt = to_prompt(text, tokenizer)
token_ids = tokenizer.encode(
    prompt,
    add_special_tokens=False,
    return_tensors='pt')

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        max_new_tokens=256,
        temperature=0.01,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id)

output = tokenizer.decode(output_ids.tolist()[0])
print( output )

実行結果。

$ time python main.py 
  warnings.warn(
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████| 2/2 [00:13<00:00,  6.98s/it]
<s> [INST] <<SYS>>
テキストをJSON形式に変更して出力してください。
<</SYS>>


商品名 ベルギーワッフル
名称 洋菓子
原材料 小麦粉(国内製造)、液卵、砂糖、マーガリン、バター加工品
 [/INST]  {
    "商品名": "ベルギーワッフル",
    "名称": "洋菓子",
    "原材料": [
        "小麦粉(国内製造)",
        "液卵",
        "砂糖",
        "マーガリン",
        "バター加工品"
    ]
}</s>

real	10m17.223s
user	0m38.739s
sys	6m23.799s

残念ながら LLM を ローカルで動かす時代はまだ遠いのかもしれない。 7b でこれなので、Llama 2 の 13b, 70b などをローカルで動かすなどというのは。