Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit 5de329d

Browse files
authored
Merge pull request #351 from gnes-ai/feat-standardscaler
feat(standarder): add standard scaler
2 parents 44a54be + 5d95c74 commit 5de329d

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

gnes/encoder/numeric/pca.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,19 @@
2323
class PCAEncoder(BaseNumericEncoder):
2424
batch_size = 2048
2525

26-
def __init__(self, output_dim: int, *args, **kwargs):
26+
def __init__(self, output_dim: int, whiten: bool=False, *args, **kwargs):
2727
super().__init__(*args, **kwargs)
2828
self.output_dim = output_dim
29+
self.whiten = whiten
2930
self.pca_components = None
3031
self.mean = None
3132

33+
3234
def post_init(self):
3335
from sklearn.decomposition import IncrementalPCA
3436
self.pca = IncrementalPCA(n_components=self.output_dim)
3537

38+
3639
@batching
3740
def train(self, vecs: np.ndarray, *args, **kwargs) -> None:
3841
num_samples, num_dim = vecs.shape
@@ -49,11 +52,16 @@ def train(self, vecs: np.ndarray, *args, **kwargs) -> None:
4952

5053
self.pca_components = np.transpose(self.pca.components_)
5154
self.mean = self.pca.mean_.astype('float32')
55+
self.explained_variance = self.pca.explained_variance_.astype('float32')
56+
5257

5358
@train_required
5459
@batching
5560
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
56-
return np.matmul(vecs - self.mean, self.pca_components)
61+
X_transformed = np.matmul(vecs - self.mean, self.pca_components)
62+
if self.whiten:
63+
X_transformed /= np.sqrt(self.explained_variance)
64+
return X_transformed
5765

5866

5967
class PCALocalEncoder(BaseNumericEncoder):

gnes/encoder/numeric/standarder.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Tencent is pleased to support the open source community by making GNES available.
2+
#
3+
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import numpy as np
18+
19+
from ..base import BaseNumericEncoder
20+
from ...helper import batching, train_required
21+
22+
23+
class StandarderEncoder(BaseNumericEncoder):
24+
batch_size = 2048
25+
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
self.mean = None
29+
self.scale = None
30+
31+
def post_init(self):
32+
from sklearn.preprocessing import StandardScaler
33+
self.standarder = StandardScaler()
34+
35+
@batching
36+
def train(self, vecs: np.ndarray, *args, **kwargs) -> None:
37+
self.standarder.partial_fit(vecs)
38+
39+
self.mean = self.standarder.mean_.astype('float32')
40+
self.scale = self.standarder.scale_.astype('float32')
41+
42+
@train_required
43+
@batching
44+
def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
45+
return (vecs - self.mean) / self.scale

0 commit comments

Comments
 (0)