Skip to content

Commit fd4461d

Browse files
Merge pull request #6196 from philpax/add-embeddings-api
feat(api): add /sdapi/v1/embeddings
2 parents f39a79d + c65909a commit fd4461d

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

modules/api/api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
100100
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
101101
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
102102
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
103+
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
103104
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
104105
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
105106
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -327,6 +328,26 @@ def get_artists_categories(self):
327328
def get_artists(self):
328329
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
329330

331+
def get_embeddings(self):
332+
db = sd_hijack.model_hijack.embedding_db
333+
334+
def convert_embedding(embedding):
335+
return {
336+
"step": embedding.step,
337+
"sd_checkpoint": embedding.sd_checkpoint,
338+
"sd_checkpoint_name": embedding.sd_checkpoint_name,
339+
"shape": embedding.shape,
340+
"vectors": embedding.vectors,
341+
}
342+
343+
def convert_embeddings(embeddings):
344+
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
345+
346+
return {
347+
"loaded": convert_embeddings(db.word_embeddings),
348+
"skipped": convert_embeddings(db.skipped_embeddings),
349+
}
350+
330351
def refresh_checkpoints(self):
331352
shared.refresh_checkpoints()
332353

modules/api/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
249249
score: float = Field(title="Score")
250250
category: str = Field(title="Category")
251251

252+
class EmbeddingItem(BaseModel):
253+
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
254+
sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
255+
sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
256+
shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
257+
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
258+
259+
class EmbeddingsResponse(BaseModel):
260+
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
261+
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")

modules/textual_inversion/textual_inversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
5959
def __init__(self, embeddings_dir):
6060
self.ids_lookup = {}
6161
self.word_embeddings = {}
62-
self.skipped_embeddings = []
62+
self.skipped_embeddings = {}
6363
self.dir_mtime = None
6464
self.embeddings_dir = embeddings_dir
6565
self.expected_shape = -1
@@ -91,7 +91,7 @@ def load_textual_inversion_embeddings(self, force_reload = False):
9191
self.dir_mtime = mt
9292
self.ids_lookup.clear()
9393
self.word_embeddings.clear()
94-
self.skipped_embeddings = []
94+
self.skipped_embeddings.clear()
9595
self.expected_shape = self.get_expected_shape()
9696

9797
def process_file(path, filename):
@@ -136,7 +136,7 @@ def process_file(path, filename):
136136
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
137137
self.register_embedding(embedding, shared.sd_model)
138138
else:
139-
self.skipped_embeddings.append(name)
139+
self.skipped_embeddings[name] = embedding
140140

141141
for fn in os.listdir(self.embeddings_dir):
142142
try:
@@ -153,7 +153,7 @@ def process_file(path, filename):
153153

154154
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
155155
if len(self.skipped_embeddings) > 0:
156-
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
156+
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
157157

158158
def find_embedding_at_position(self, tokens, offset):
159159
token = tokens[offset]

0 commit comments

Comments
 (0)