はじめに#

画像処理に、セグメンテーション(領域分割)という古くから研究されている技術が有ります。要は、画像に対して特定の領域を切り出す技術なのですが、様々な分野から自動化を切望されていた技術の一つです。

さて、そんな中、2023年4月にMeta社からSAM(Segmentation anything) という技術が発表されました。今までのものと比べ、“何でも分割"でき、マルチモーダルな入力に対応した画期的な技術で話題となりました。トレーニングデータも巨大、加えて商業フリーという形も有り、あっという間に話題になりました。 最近では動画も対応できるSAM2という形で配信されています。静止画性能について、SAM2は前バージョンに比べ若干早くなっているようです。

今回はこのSegmentation anythingを使って、ローカルでセグメンテーションを構築し、実行してみようかと思います。なお、すでにMeta社がDemoをオンラインで提供しているので、試すだけならそちらを利用しましょう。なお、動画のデモはこちらとなります。

Metaのオンラインデモは手軽ですが、プライバシーが気になる画像を扱いたい、バッチ処理で性能を測りたい、独自のアプリケーションに組み込みたいといったニーズには応えられません。この記事では、そうした方々に向けて、ローカルGPU環境でSAM2をセットアップし、インタラクティブに試せるWeb UIを構築する手順を解説します。

今回のソースコードはGithubにあげてあります。

SAM2 を 早速使ってみる。#

SAM2を公式がプログラムが提供しています。

このプログラムをベースにエンジンを作ります。最後、ここで使ったコードはGithubを参考ください。

今回はRyeを使って仮想環境を構築しています。RyeはPythonの仮想環境を管理するツールで、PyTorchのGPU環境を構築することができます。pyproject.tomlは次のようになります。

[project]
name = "segment2"
version = "0.1.0"
description = "Add your description here"
dependencies = [
    "huggingface-hub>=0.29.3",
    "torch>=2.6.0",
    "image>=1.5.33",
    "sam2>=1.1.0",
    "matplotlib>=3.10.1",
    "opencv-contrib-python>=4.11.0.86",
    "fastapi>=0.115.12",
    "uvicorn[standard]>=0.34.0",
    "python-multipart>=0.0.20",
    "opencv-python-headless>=4.11.0.86",
]
...

pipでやりたい方は、上記のライブラリを参考にインストールしてください。

エンジンの作成#

sam_engine.pyというファイルを定義します。なお、SAMはGPUではなくCPUでも演算可能ですが今回GPUを使うことを前提とします。


# 使用するモデルを選択 (tinyの方が精度は悪いがサイズは小さい)
SAM2_MODEL_NAME = "facebook/sam2-hiera-tiny"
# SAM2_MODEL_NAME = "facebook/sam2-hiera-large"

class Sam:
    def __init__(self):
        """SAM2 Predictor のインスタンスを作成する (FastAPI起動時)"""
        device = torch.device("cuda")
        if device.type == "cuda":
            torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
            # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
            if torch.cuda.get_device_properties(0).major >= 8:
                torch.backends.cuda.matmul.allow_tf32 = True
                torch.backends.cudnn.allow_tf32 = True
        else:
            raise NotImplementedError
        try:
            self.predictor = SAM2ImagePredictor.from_pretrained(SAM2_MODEL_NAME)
        except:
            self.predictor = None

上記クラスがSAM2のエンジンです。マスクを予測するpredict_mask関数を定義します。

    def predict_mask(
        self, image_path: str, input_points_list: list, input_labels_list: list
    ):
        # print(input_points_list)
        # print(input_labels_list)
        """指定された画像と点を使ってマスクを予測する (元のサンプルに近づける)"""
        if self.predictor is None:
            # load_model が FastAPI 起動時に呼ばれているはずだが、念のためチェック
            raise RuntimeError(
                "SAM2 predictor is not initialized. Application might not have started correctly."
            )

        try:
            img = Image.open(image_path).convert("RGB")
            img_array = np.array(img)
            image_np_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)

            input_point_np = np.array(input_points_list, dtype=np.float32)
            input_label_np = np.array(input_labels_list, dtype=np.int32)

            # --- 元のサンプルのデバイス設定、autocast/TF32設定を推論前に適用 ---

            with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
                self.predictor.set_image(img_array)
                masks, scores, logits = self.predictor.predict(
                    point_coords=input_point_np,
                    point_labels=input_label_np,
                    multimask_output=True,
                )
                sorted_ind = np.argsort(scores)[::-1]
                masks = masks[sorted_ind]
                scores = scores[sorted_ind]
                logits = logits[sorted_ind]

                if len(scores) == 0:
                    print("Warning: No masks predicted.")
                    # マスクが見つからなかった場合、元画像を返す
                    return image_np_bgr, 0.0

                # 最もスコアの高いマスクを選択
                best_mask_idx = np.argmax(scores)
                best_mask_np = masks[best_mask_idx]
                best_score = scores[best_mask_idx]

                # マスクを元画像(BGR)に適用
                result_image_np = self.apply_mask_to_image(
                    image_np_bgr, best_mask_np, borders=True
                )

                return result_image_np, float(best_score)

        except FileNotFoundError:
            print(f"Error: Image file not found at {image_path}")
            return None, None
        except Exception as e:
            print(f"Error during prediction in sam_engine: {e}")
            import traceback

            traceback.print_exc()
            return None, None

類推する部分など、公式に定義されているpredict関数を利用しています。上記中で、マスク画像を生成しているapply_mask_to_image 関数を定義しまが、その内容は次のとおりです。これは元画像にマスクを適用するための関数で、cv2 を使っています。

    def apply_mask_to_image(
        self, image_np_bgr, mask_np, random_color=False, borders=True
    ):
        """マスクを画像に重ね合わせる (OpenCVを使用 - この部分は変更なし)"""
        if mask_np.dtype != bool:
            mask_boolean = mask_np.astype(bool)
        else:
            mask_boolean = mask_np

        if random_color:
            color = np.array(
                [
                    np.random.randint(0, 256),
                    np.random.randint(0, 256),
                    np.random.randint(0, 256),
                ],
                dtype=np.uint8,
            )
        else:
            color = np.array([255, 144, 30], dtype=np.uint8)  # BGR: 青っぽい色

        h, w = mask_boolean.shape[-2:]
        mask_colored = np.zeros((h, w, 3), dtype=np.uint8)
        mask_colored[mask_boolean] = color

        alpha = 0.5
        blended = cv2.addWeighted(image_np_bgr, 1.0, mask_colored, alpha, 0)

        if borders:
            try:
                mask_uint8 = mask_boolean.astype(np.uint8) * 255
                contours, _ = cv2.findContours(
                    mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
                )
                cv2.drawContours(blended, contours, -1, (255, 255, 255), thickness=2)
            except Exception as e:
                print(f"Warning: Failed to draw contours: {e}")

        return blended

エンジンはこれだけです。フロント側からPos/Negの点座標と、ラベル番号(1: positive, 0: negative)を受取り、それらをpredict関数に突っ込んでいるだけです。 Negativeは含まない領域を出力します。

FastAPIでWeb APIを構築#

SAMエンジンを動かすバックエンドとしてFastAPIを利用します。ブラウザのフロントエンドからリクエストを受け取り、先に定義したエンジンにおくり、セグメンテーション結果を返すことができます。

主要なポイントは以下の通りです。

  1. モデルのロード: FastAPIの lifespan イベントハンドラを使用して、アプリケーション起動時にSAMモデル (Sam クラスのインスタンス) を一度だけロードし、GPUメモリ上に保持します。これにより、リクエストごとにモデルをロードするオーバーヘッドを避けることができます。

    # main.py (抜粋)
    from contextlib import asynccontextmanager
    from sam_engine import Sam
    import torch
    
    sam = None
    
    @asynccontextmanager
    async def lifespan(app: FastAPI):
        print("Loading SAM model...")
        global sam
        sam = Sam() # ここでインスタンス化 & モデルロード
        # 必要に応じてCUDAキャッシュクリアなど
        yield
        print("Application shutdown.")
    
    app = FastAPI(lifespan=lifespan)
    
  2. ルートエンドポイント (/):

    • GET リクエストを受け付けます。
    • フロントエンドの index.html を返します。これにより、ユーザーはブラウザで操作画面にアクセスできます。
    # main.py (抜粋)
    from fastapi.responses import HTMLResponse
    from fastapi.templating import Jinja2Templates
    
    templates = Jinja2Templates(directory="frontend")
    
    @app.get("/", response_class=HTMLResponse)
    async def read_root(request: Request):
        """フロントエンドのHTMLを返す"""
        return templates.TemplateResponse("index.html", {"request": request})
    
  3. 画像アップロードエンドポイント (/upload/):

    • POST リクエストを受け付けます。
    • フォームデータとして送信された画像ファイル (UploadFile) を受け取ります。
    • 画像をサーバー上の UPLOAD_DIR にユニークなファイル名で保存します。
    • 保存したファイル名と、画像の元の幅・高さをJSON形式で返します。これはフロントエンドで座標計算を行うために必要です。
    # main.py (抜粋)
    from fastapi import FastAPI, File, UploadFile, HTTPException
    from fastapi.responses import JSONResponse
    from PIL import Image
    import io, os, uuid
    
    UPLOAD_DIR = "backend/uploaded_images" # 保存先
    
    @app.post("/upload/")
    async def upload_image(file: UploadFile = File(...)):
        # ... (ファイル形式チェック、保存処理) ...
        contents = await file.read()
        # ... (ファイル書き込み) ...
        img = Image.open(io.BytesIO(contents))
        width, height = img.size
        return JSONResponse(
            content={"filename": filename, "width": width, "height": height}
        )
        # ... (エラーハンドリング) ...
    
  4. セグメンテーション実行エンドポイント (/segment/):

    • POST リクエストを受け付けます。
    • リクエストボディとしてJSONデータ (SegmentationRequest モデル) を受け取ります。これには、処理対象の filename、クリックされた点の座標リスト points ([[x1, y1], ...])、および各点に対応するラベルリスト labels ([1, 0, ...]) が含まれます。
    • 受け取った情報を使って、ロード済みの sam インスタンスの predict_mask メソッドを呼び出し、セグメンテーションを実行します。
    • 結果として得られたマスク適用済み画像 (NumPy配列) をJPEG形式にエンコードし、さらにBase64エンコードしてJSONレスポンス (result_image) に含めます。予測スコア (score) も一緒に返します。
    # main.py (抜粋)
    from pydantic import BaseModel
    import numpy as np
    import cv2, base64
    
    class SegmentationRequest(BaseModel):
        filename: str
        points: list[list[float]]
        labels: list[int]
    
    @app.post("/segment/")
    async def segment_image(request_data: SegmentationRequest):
        filepath = os.path.join(UPLOAD_DIR, request_data.filename)
        # ... (ファイル存在チェック、ポイント/ラベル検証) ...
    
        input_points_np = np.array(request_data.points, dtype=np.float32)
        input_labels_np = np.array(request_data.labels, dtype=np.int32)
    
        # SAMエンジンで予測を実行
        global sam
        result_image_np, score = sam.predict_mask(
            filepath, input_points_np.tolist(), input_labels_np.tolist()
        )
    
        # ... (エラーハンドリング: result_image_np is None) ...
    
        # 結果画像をBase64エンコード
        _, buffer = cv2.imencode(".jpg", result_image_np)
        img_base64 = base64.b64encode(buffer).decode("utf-8")
    
        return JSONResponse(
            content={
                "result_image": f"data:image/jpeg;base64,{img_base64}",
                "score": float(score) if score is not None else None,
            }
        )
        # ... (エラーハンドリング) ...
    

これらのエンドポイントにより、フロントエンドとバックエンドが連携し、インタラクティブなセグメンテーションが可能になります。

完全な main.py のコードは、提供されているGitHubリポジトリで確認できます。

フロントエンドの実装 (JavaScript抜粋)#

フロントエンドでは、HTMLファイルを用意します。frontendというフォルダを作り、そこにindex.htmlを配置します。

mkdir frontend
touch frontend/index.html

HTMLで画像アップロード用の <input type="file">、画像表示とクリック検出用の <img> とコンテナ <div>、ポジティブ/ネガティブ選択用のラジオボタン、セグメンテーション実行ボタン、結果表示用の <img> などを用意します。 それにともなったjsを書いていきます。

以下は、それらのHTML要素と連携するJavaScriptコードの主要部分です。画像のアップロード、クリック座標の取得と整形、そしてバックエンドAPIとの通信(セグメンテーション要求と結果表示)を担当します。

// --- 主要なDOM要素を取得 ---
const imageUpload = document.getElementById('image-upload');
const imageContainer = document.getElementById('image-container');
const uploadedImage = document.getElementById('uploaded-image');
const segmentBtn = document.getElementById('segment-btn');
const resultImage = document.getElementById('result-image');
const resultScore = document.getElementById('result-score');
// (他の要素取得は省略)

// --- 状態変数 ---
let currentFilename = null;     // サーバー上の画像ファイル名
let imageOriginalWidth = 0;     // 元画像の幅
let imageOriginalHeight = 0;    // 元画像の高さ
let points = [];                // クリックされた点の情報 [{originalX, originalY, label}, ...]


// --- 1. 画像アップロード処理 (抜粋) ---
imageUpload.addEventListener('change', async (event) => {
    const file = event.target.files[0];
    if (!file) return;
    // (UIリセット処理などは省略)

    const formData = new FormData();
    formData.append('file', file);

    try {
        // FastAPIの /upload/ エンドポイントに画像をPOST
        const response = await fetch('/upload/', { method: 'POST', body: formData });
        if (!response.ok) throw new Error('Upload failed');

        const data = await response.json();
        currentFilename = data.filename;      // ファイル名を取得
        imageOriginalWidth = data.width;    // 元画像のサイズを取得
        imageOriginalHeight = data.height;

        // 画像プレビュー表示 (FileReaderを使った処理は省略)
        // ... uploadedImage.src = ...
        console.log(`Uploaded: ${currentFilename}, Size: ${imageOriginalWidth}x${imageOriginalHeight}`);

    } catch (error) {
        console.error('Upload error:', error);
        // (エラー表示処理は省略)
    }
});

// --- 2. 画像クリック時の座標取得・保存 (抜粋) ---
imageContainer.addEventListener('click', (event) => {
    if (!currentFilename || !uploadedImage.complete) return; // 画像ロード完了前にクリックされるのを防ぐ

    const rect = imageContainer.getBoundingClientRect();
    const displayWidth = uploadedImage.offsetWidth;
    const displayHeight = uploadedImage.offsetHeight;

    // クリック座標 (表示されている画像に対する相対座標)
    const clickX = event.clientX - rect.left;
    const clickY = event.clientY - rect.top;

    // 元画像における座標にスケーリング
    const scaleX = imageOriginalWidth / displayWidth;
    const scaleY = imageOriginalHeight / displayHeight;
    const originalX = Math.round(clickX * scaleX);
    const originalY = Math.round(clickY * scaleY);

    // 画像範囲外なら無視
    if (originalX < 0 || originalX >= imageOriginalWidth || originalY < 0 || originalY >= imageOriginalHeight) return;

    // 選択されているラベル (1: Positive, 0: Negative) を取得
    const selectedLabel = parseInt(document.querySelector('input[name="click-mode"]:checked').value);

    // 点の情報を保存 (UIに点を描画する処理は省略)
    points.push({ originalX, originalY, label: selectedLabel });
    console.log('Point added:', { originalX, originalY, label: selectedLabel });
    segmentBtn.disabled = false; // 実行ボタンを有効化
});

// --- 3. セグメンテーション実行 (抜粋) ---
segmentBtn.addEventListener('click', async () => {
    if (!currentFilename || points.length === 0) return;
    // (ローディング表示などのUI制御は省略)

    // バックエンドAPIに送るデータを作成
    const requestData = {
        filename: currentFilename,
        points: points.map(p => [p.originalX, p.originalY]), // [[x1, y1], [x2, y2], ...]
        labels: points.map(p => p.label),                   // [1, 0, ...]
    };

    try {
        // FastAPIの /segment/ エンドポイントにJSONデータをPOST
        const response = await fetch('/segment/', {
            method: 'POST',
            headers: { 'Content-Type': 'application/json' },
            body: JSON.stringify(requestData),
        });
        if (!response.ok) throw new Error('Segmentation failed');

        const data = await response.json();
        // 結果画像(Base64)を表示
        resultImage.src = data.result_image;
        resultImage.classList.remove('hidden');
        // スコア表示
        resultScore.textContent = data.score ? `予測スコア: ${data.score.toFixed(3)}` : 'スコアなし';

    } catch (error) {
        console.error('Segmentation error:', error);
        // (エラー表示処理は省略)
    } finally {
        // (ローディング解除などのUI制御は省略)
    }
});

// --- 4. 点をクリアする処理 (抜粋) ---
function clearPoints() {
    points = []; // 配列を空にする
    // (UI上のマーカー削除処理は省略)
    console.log('Points cleared');
    segmentBtn.disabled = true;
}

// clearPointsBtn のクリックイベントリスナーで clearPoints() を呼び出す (コード省略)

繰り返しになりますが、実際に動くコードはGithubにあげてあります。

ローカルで動かしてみる#

それでは早速起動してみましょう。今回はRye を使って動かします。

rye run uvicorn main:app --host 0.0.0.0 --port 8000

これで立ち上げあとは画像を加えて,セグメンテーションしたい領域を選択します。

GUI画面

上記のような画面が結果ですが、セグメンテーションが出来ていることが確認できます。

今回はTinyモデルで検証しましたが、GPUに余裕がある方はlargeなどを使っていきましょう。 ちなみにVRAMの使用容量ですが、 sam2.1_hiera_tiny, sam2.1_hiera_small, sam2.1_hiera_base_plus, sam2.1_hiera_large の4つがあり、一番小さいモデルと大きいモデルを比較すると次のとおりです。

名前使用中VRAM
facebook/sam2-hiera-tiny1.5GB
facebook/sam2-hiera-large3.2GB

ですので4GBのGPUでもギリギリLargeで使用ができそうです。

まとめ#

今回は、MetaのSegment Anything Model (SAM2) をローカルのGPU環境でセットアップし、FastAPIとシンプルなWeb UIを使ってインタラクティブにセグメンテーションを実行する手順を解説しました。

オンラインデモとは異なり、ローカル環境では以下のようなメリットがあります。

  • プライベートな画像を外部にアップロードせずに試せる。
  • 処理時間を計測し、特定のGPUでの性能を評価できる。
  • バッチ処理や他のアプリケーションへの組み込みなど、より高度な利用に向けた検証ができる。

今回構築したシステムは基本的なものですが、これをベースに、矩形(バウンディングボックス)入力への対応、複数オブジェクトの同時指定、あるいは動画への適用など、さらに発展させることも可能です。

SAM/SAM2は非常に強力なツールであり、ローカルで動かすことでその可能性をさらに引き出すことができます。この記事が、皆さんの研究開発やプロジェクトの一助となれば幸いです。コード全体はこちらで公開しています。