srjamali
(Jamali)
December 1, 2019, 2:33am
1
Is there any way to save the trained model and then load it back for inference?
vova
(Vova Manannikov)
December 2, 2019, 11:41am
2
As far as I know, there’s no “official” way of doing this in S4TF, yet (correct me if I’m wrong, @bradlarson ).
Here’s couple examples of loading model weights:
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import TensorFlow
public class PythonCheckpointReader {
private let path: String
private var layerCounts: [String: Int] = [:]
This file has been truncated. show original
import Foundation
import TensorFlow
import FastStyleTransfer
extension TransformerNet: ImportableLayer {}
enum FileError: Error {
case fileNotFound
}
/// Updates `model` with parameters from V2 checkpoint in `path`.
func importWeights(_ model: inout TransformerNet, from path: String) throws {
guard FileManager.default.fileExists(atPath: path + ".data-00000-of-00001") else {
throw FileError.fileNotFound
}
// Names don't match exactly, and axes in filters need to be reversed.
let map = [
"conv1.conv2d.filter": ("conv1.conv2d.weight", [3, 2, 1, 0]),
"conv2.conv2d.filter": ("conv2.conv2d.weight", [3, 2, 1, 0]),
"conv3.conv2d.filter": ("conv3.conv2d.weight", [3, 2, 1, 0]),
This file has been truncated. show original
Saving would be opposite direction, using _Raw.saveV2()
.
So it boils down to using _Raw.saveV2()
on each param tensor to save it and _Raw.restoreV2()
to load back.
2 Likes